diff --git a/packages/transformers/src/generation/beam_search.js b/packages/transformers/src/generation/beam_search.js new file mode 100644 index 000000000..661de7fdb --- /dev/null +++ b/packages/transformers/src/generation/beam_search.js @@ -0,0 +1,241 @@ +/** + * @module generation/beam_search + */ + +/** + * Stores completed beam search hypotheses for a single batch element. + */ +export class BeamHypotheses { + /** + * @param {number} num_beams Number of beams. + * @param {number} length_penalty Exponential penalty to the length. + * @param {boolean|"never"} early_stopping Whether to stop when enough hypotheses are finished. + */ + constructor(num_beams, length_penalty = 1.0, early_stopping = false) { + this.num_beams = num_beams; + this.length_penalty = length_penalty; + this.early_stopping = early_stopping; + + /** @type {{score: number, tokens: bigint[]}[]} */ + this.beams = []; + this.worst_score = 1e9; + } + + get length() { + return this.beams.length; + } + + /** + * Add a new hypothesis to the list. + * @param {number} sum_logprobs Sum of log probabilities of the hypothesis. + * @param {bigint[]} tokens The token ids of the hypothesis. + */ + add(sum_logprobs, tokens) { + const score = sum_logprobs / (tokens.length ** this.length_penalty); + if (this.beams.length < this.num_beams || score > this.worst_score) { + this.beams.push({ score, tokens }); + if (this.beams.length > this.num_beams) { + // Remove worst hypothesis + let worst_idx = 0; + for (let i = 1; i < this.beams.length; ++i) { + if (this.beams[i].score < this.beams[worst_idx].score) { + worst_idx = i; + } + } + this.beams.splice(worst_idx, 1); + } + this.worst_score = this.beams.length === this.num_beams + ? Math.min(...this.beams.map(b => b.score)) + : -1e9; + } + } + + /** + * Check whether adding more beams can possibly improve the hypotheses. + * @param {number} best_sum_logprobs Best sum of log probs among active beams. + * @param {number} cur_len Current length of generated tokens. + * @returns {boolean} + */ + is_done(best_sum_logprobs, cur_len) { + if (this.beams.length < this.num_beams) return false; + + if (this.early_stopping === true) { + return true; + } else if (this.early_stopping === 'never') { + return false; + } else { + // Heuristic: check if the best possible score for the next step + // could beat the worst completed hypothesis + const highest_attainable_score = best_sum_logprobs / (cur_len ** this.length_penalty); + return this.worst_score >= highest_attainable_score; + } + } +} + +/** + * Implements beam search scoring and beam management. + */ +export class BeamSearchScorer { + /** + * @param {number} batch_size + * @param {number} num_beams + * @param {Object} options + * @param {number} [options.length_penalty] + * @param {boolean|"never"} [options.early_stopping] + * @param {number} [options.num_return_sequences] + * @param {number|number[]|null} [options.eos_token_id] + * @param {number|null} [options.pad_token_id] + */ + constructor(batch_size, num_beams, { + length_penalty = 1.0, + early_stopping = false, + num_return_sequences = 1, + eos_token_id = null, + pad_token_id = null, + } = {}) { + this.batch_size = batch_size; + this.num_beams = num_beams; + this.length_penalty = length_penalty; + this.early_stopping = early_stopping; + this.num_return_sequences = num_return_sequences; + this.eos_token_ids = eos_token_id === null ? [] + : (Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id]); + this.pad_token_id = pad_token_id ?? 0; + + if (num_return_sequences > num_beams) { + throw new Error( + `num_return_sequences (${num_return_sequences}) must be <= num_beams (${num_beams}).`, + ); + } + + /** @type {BeamHypotheses[]} */ + this._beam_hyps = Array.from( + { length: batch_size }, + () => new BeamHypotheses(num_beams, length_penalty, early_stopping), + ); + /** @type {boolean[]} */ + this._done = new Array(batch_size).fill(false); + } + + get is_done() { + return this._done.every(Boolean); + } + + /** + * Process the current beam candidates and select the next set of beams. + * + * @param {bigint[][]} all_input_ids All sequences so far, shape [batch_size * num_beams, seq_len]. + * @param {number[]} beam_scores Cumulative scores, shape [batch_size * num_beams]. + * @param {bigint[]} next_tokens Candidate tokens, shape [batch_size * 2 * num_beams]. + * @param {number[]} next_indices Which beam each candidate came from (relative to batch group), shape [batch_size * 2 * num_beams]. + * @param {number[]} next_scores New cumulative scores for candidates, shape [batch_size * 2 * num_beams]. + * @returns {{ next_beam_scores: number[], next_beam_tokens: bigint[], next_beam_indices: number[] }} + */ + process(all_input_ids, beam_scores, next_tokens, next_indices, next_scores) { + const cur_len = all_input_ids[0].length; + const total_beams = this.batch_size * this.num_beams; + + const next_beam_scores = new Array(total_beams).fill(0); + const next_beam_tokens = new Array(total_beams).fill(0n); + const next_beam_indices = new Array(total_beams).fill(0); + + for (let batch_idx = 0; batch_idx < this.batch_size; ++batch_idx) { + if (this._done[batch_idx]) { + // Pad finished batches + for (let beam_idx = 0; beam_idx < this.num_beams; ++beam_idx) { + const flat_idx = batch_idx * this.num_beams + beam_idx; + next_beam_scores[flat_idx] = 0; + next_beam_tokens[flat_idx] = BigInt(this.pad_token_id); + next_beam_indices[flat_idx] = batch_idx * this.num_beams; + } + continue; + } + + let beam_idx = 0; + const num_candidates = 2 * this.num_beams; + for (let j = 0; j < num_candidates; ++j) { + const cand_idx = batch_idx * num_candidates + j; + const beam_token = next_tokens[cand_idx]; + const beam_score = next_scores[cand_idx]; + const beam_source = next_indices[cand_idx]; // relative to batch + const abs_beam_source = batch_idx * this.num_beams + beam_source; + + const is_eos = this.eos_token_ids.some(id => BigInt(id) === beam_token); + + if (is_eos) { + // Add completed hypothesis + const hypothesis = [...all_input_ids[abs_beam_source], beam_token]; + this._beam_hyps[batch_idx].add(beam_score, hypothesis); + } else { + // Add to next active beams + const out_idx = batch_idx * this.num_beams + beam_idx; + next_beam_scores[out_idx] = beam_score; + next_beam_tokens[out_idx] = beam_token; + next_beam_indices[out_idx] = abs_beam_source; + beam_idx++; + } + + if (beam_idx === this.num_beams) break; + } + + // If we couldn't fill all beams (too many EOS), pad with last valid + if (beam_idx < this.num_beams) { + const last_valid = batch_idx * this.num_beams + Math.max(0, beam_idx - 1); + for (; beam_idx < this.num_beams; ++beam_idx) { + const out_idx = batch_idx * this.num_beams + beam_idx; + next_beam_scores[out_idx] = next_beam_scores[last_valid] ?? 0; + next_beam_tokens[out_idx] = next_beam_tokens[last_valid] ?? 0n; + next_beam_indices[out_idx] = next_beam_indices[last_valid] ?? (batch_idx * this.num_beams); + } + } + + // Check if done for this batch using next-step scores + const start = batch_idx * this.num_beams; + const end = start + this.num_beams; + const best_sum_logprobs = Math.max(...next_beam_scores.slice(start, end)); + this._done[batch_idx] = this._beam_hyps[batch_idx].is_done(best_sum_logprobs, cur_len + 1); + } + + return { next_beam_scores, next_beam_tokens, next_beam_indices }; + } + + /** + * Finalize: select best hypotheses. + * @param {bigint[][]} all_input_ids Final sequences, shape [batch_size * num_beams, seq_len]. + * @param {number[]} beam_scores Final cumulative scores. + * @returns {bigint[][]} Best sequences, shape [batch_size * num_return_sequences, seq_len]. + */ + finalize(all_input_ids, beam_scores) { + return this.finalize_with_scores(all_input_ids, beam_scores).map((x) => x.tokens); + } + + /** + * Finalize: select best hypotheses and return scores. + * @param {bigint[][]} all_input_ids Final sequences, shape [batch_size * num_beams, seq_len]. + * @param {number[]} beam_scores Final cumulative scores. + * @returns {{tokens: bigint[], score: number}[]} Best sequences with scores. + */ + finalize_with_scores(all_input_ids, beam_scores) { + // For each batch, ensure we have enough hypotheses + for (let batch_idx = 0; batch_idx < this.batch_size; ++batch_idx) { + const hyps = this._beam_hyps[batch_idx]; + if (hyps.length < this.num_beams) { + for (let beam_idx = 0; beam_idx < this.num_beams; ++beam_idx) { + const flat_idx = batch_idx * this.num_beams + beam_idx; + hyps.add(beam_scores[flat_idx], all_input_ids[flat_idx]); + } + } + } + + // Select top num_return_sequences per batch + const results = []; + for (let batch_idx = 0; batch_idx < this.batch_size; ++batch_idx) { + const sorted = [...this._beam_hyps[batch_idx].beams] + .sort((a, b) => b.score - a.score); + for (let i = 0; i < this.num_return_sequences; ++i) { + results.push(sorted[i]); + } + } + return results; + } +} diff --git a/packages/transformers/src/models/modeling_utils.js b/packages/transformers/src/models/modeling_utils.js index d4c5d6d32..6207ad2c6 100644 --- a/packages/transformers/src/models/modeling_utils.js +++ b/packages/transformers/src/models/modeling_utils.js @@ -1,7 +1,7 @@ import { Callable } from '../utils/generic.js'; import { constructSessions, sessionRun } from './session.js'; import { AutoConfig, getCacheShapes } from '../configs.js'; -import { Tensor, full_like, cat, zeros_like, ones_like, ones } from '../utils/tensor.js'; +import { Tensor, full_like, cat, zeros_like, ones_like, ones, index_select, index_select_async } from '../utils/tensor.js'; import { DataTypeMap } from '../utils/dtypes.js'; // These will be populated by registry.js @@ -33,7 +33,9 @@ import { import { GenerationConfig } from '../generation/configuration_utils.js'; import { EosTokenCriteria, MaxLengthCriteria, StoppingCriteriaList } from '../generation/stopping_criteria.js'; import { LogitsSampler } from '../generation/logits_sampler.js'; +import { BeamSearchScorer } from '../generation/beam_search.js'; import { pick } from '../utils/core.js'; +import { log_softmax } from '../utils/maths.js'; import { ModelOutput } from './modeling_outputs.js'; import { logger } from '../utils/logger.js'; @@ -1080,27 +1082,114 @@ export class PreTrainedModel extends Callable { // } const numInputs = model_inputs[model_input_name].dims.at(0); + const num_beams = generation_config.num_beams; + const num_beam_groups = generation_config.num_beam_groups; + const is_beam_search = num_beams > 1; + const is_group_beam_search = is_beam_search && num_beam_groups > 1; + + let beam_scorer = null; + let beam_scorers = null; + /** @type {number[]} */ + let beam_scores; + + if (is_beam_search) { + // Validate beam search configuration + if (generation_config.num_return_sequences > num_beams) { + throw new Error( + `num_return_sequences (${generation_config.num_return_sequences}) must be <= num_beams (${num_beams}).`, + ); + } + if (num_beam_groups < 1) { + throw new Error(`num_beam_groups must be >= 1, but got ${num_beam_groups}.`); + } + if (num_beam_groups > num_beams) { + throw new Error( + `num_beam_groups (${num_beam_groups}) must be <= num_beams (${num_beams}).`, + ); + } + if (num_beams % num_beam_groups !== 0) { + throw new Error( + `num_beams (${num_beams}) must be divisible by num_beam_groups (${num_beam_groups}).`, + ); + } + if (is_group_beam_search && generation_config.do_sample) { + throw new Error('Diverse beam sampling (num_beam_groups > 1 with do_sample = true) is not yet supported.'); + } + if (generation_config.guidance_scale !== null && generation_config.guidance_scale > 1) { + throw new Error('Classifier-free guidance (guidance_scale > 1) is not supported with beam search.'); + } + if (streamer) { + throw new Error('Streaming is not supported with beam search.'); + } - // TODO: - // done is a list of booleans to keep track of which inputs are done - // const done = new Array(numInputs).fill(false); - // For efficiency purposes, we remove completed rows from model_inputs - // when the beam is complete, and we keep track of the row index - // const rowIndexToBatchIndex = new Map(); + if (is_group_beam_search) { + const group_size = num_beams / num_beam_groups; + beam_scorers = Array.from({ length: num_beam_groups }, () => new BeamSearchScorer(numInputs, group_size, { + length_penalty: generation_config.length_penalty, + early_stopping: generation_config.early_stopping, + num_return_sequences: group_size, + eos_token_id: generation_config.eos_token_id, + pad_token_id: generation_config.pad_token_id, + })); + } else { + beam_scorer = new BeamSearchScorer(numInputs, num_beams, { + length_penalty: generation_config.length_penalty, + early_stopping: generation_config.early_stopping, + num_return_sequences: generation_config.num_return_sequences, + eos_token_id: generation_config.eos_token_id, + pad_token_id: generation_config.pad_token_id, + }); + } + } const sampler = LogitsSampler.getSampler(generation_config); - // TODO make > numInputs - const scores = new Array(numInputs).fill(0); + const scores = new Array(is_beam_search ? numInputs * num_beams : numInputs).fill(0); /** @type {bigint[][]} */ const all_input_ids = input_ids.tolist(); + + if (is_beam_search) { + // Expand all_input_ids: [A, B] -> [A, A, A, B, B, B] for num_beams=3 + const expanded_ids = []; + for (const ids of all_input_ids) { + for (let b = 0; b < num_beams; ++b) { + expanded_ids.push([...ids]); + } + } + all_input_ids.length = 0; + all_input_ids.push(...expanded_ids); + + // Expand model_inputs tensors by repeating each row num_beams times + const expanded_indices = []; + for (let i = 0; i < numInputs; ++i) { + for (let b = 0; b < num_beams; ++b) { + expanded_indices.push(i); + } + } + for (const key of [model_input_name, 'attention_mask', 'decoder_attention_mask', 'encoder_outputs']) { + const tensor = model_inputs[key]; + if (tensor instanceof Tensor) { + if (tensor.location === 'cpu' || tensor.location === 'cpu-pinned') { + model_inputs[key] = index_select(tensor, expanded_indices); + } else { + model_inputs[key] = await index_select_async(tensor, expanded_indices); + } + } + } + + // Initialize beam_scores: first beam of each input = 0, rest = -1e9 + beam_scores = new Array(numInputs * num_beams).fill(-1e9); + for (let i = 0; i < numInputs; ++i) { + const group_size = num_beams / num_beam_groups; + for (let g = 0; g < num_beam_groups; ++g) { + beam_scores[i * num_beams + g * group_size] = 0; + } + } + } + if (streamer) { streamer.put(all_input_ids); } - // const all_generated_input_ids = Array.from({ length: numInputs }, () => []); - - // NOTE: For now, we don't support spawning new beams - // TODO: when we do, we simply copy past key values and accumulate into single large tensor //////////////////////////////////////////////////// // Generic search which handles 4 generation modes: @@ -1145,39 +1234,287 @@ export class PreTrainedModel extends Callable { /** @type {[bigint][]} */ const generated_input_ids = []; - // const new_kv_cache = [];// NOTE: Only used for beam search when concatenating new kv - // Loop over each batch - for (let batch_idx = 0; batch_idx < next_tokens_scores.dims.at(0); ++batch_idx) { - const logs = next_tokens_scores[batch_idx]; - - const sampledTokens = await sampler(logs); - for (const [newTokenId, logProb] of sampledTokens) { - const bigint = BigInt(newTokenId); - // TODO: If branching, use previous beam as a starting point - // update generated ids, model inputs, and length for next step - scores[batch_idx] += logProb; - all_input_ids[batch_idx].push(bigint); - generated_input_ids.push([bigint]); - - // TODO: Support beam search - break; + + if (is_beam_search) { + // Beam search: score all candidates across beams, select top 2*num_beams per batch + const vocab_size = next_tokens_scores.dims.at(-1); + const total_beams = numInputs * num_beams; + + /** @type {number[]} */ + let next_beam_scores; + /** @type {bigint[]} */ + let next_beam_tokens; + /** @type {number[]} */ + let next_beam_indices; + + if (is_group_beam_search) { + const group_size = num_beams / num_beam_groups; + next_beam_scores = new Array(total_beams).fill(0); + next_beam_tokens = new Array(total_beams).fill(0n); + next_beam_indices = new Array(total_beams).fill(0); + + const diversity_penalty = generation_config.diversity_penalty ?? 0; + /** @type {Map[]} */ + const prev_group_tokens = Array.from({ length: numInputs }, () => new Map()); + + for (let group_idx = 0; group_idx < num_beam_groups; ++group_idx) { + const group_offset = group_idx * group_size; + + const group_all_input_ids = new Array(numInputs * group_size); + const group_beam_scores = new Array(numInputs * group_size); + for (let batch_idx = 0; batch_idx < numInputs; ++batch_idx) { + for (let beam_idx = 0; beam_idx < group_size; ++beam_idx) { + const global = batch_idx * num_beams + group_offset + beam_idx; + const local = batch_idx * group_size + beam_idx; + group_all_input_ids[local] = all_input_ids[global]; + group_beam_scores[local] = beam_scores[global]; + } + } + + const group_next_tokens = []; + const group_next_indices = []; + const group_next_scores = []; + + for (let batch_idx = 0; batch_idx < numInputs; ++batch_idx) { + // Collect candidates across all beams for this batch item + /** @type {{score: number, token: bigint, beam_idx: number}[]} */ + const candidates = []; + + for (let beam_idx = 0; beam_idx < group_size; ++beam_idx) { + const flat_beam = batch_idx * num_beams + group_offset + beam_idx; + const logits_data = /** @type {Float32Array} */ (next_tokens_scores[flat_beam].data); + const log_probs = log_softmax(logits_data); + const beam_score = beam_scores[flat_beam]; + + // Find top 2*group_size candidates per beam for efficiency + /** @type {{score: number, token: bigint, beam_idx: number}[]} */ + const beam_candidates = []; + for (let v = 0; v < vocab_size; ++v) { + const token = BigInt(v); + let score = beam_score + log_probs[v]; + const penalty_count = prev_group_tokens[batch_idx].get(token) ?? 0; + if (penalty_count > 0) { + score -= diversity_penalty * penalty_count; + } + beam_candidates.push({ + score, + token, + beam_idx, + }); + } + beam_candidates.sort((a, b) => b.score - a.score); + candidates.push(...beam_candidates.slice(0, 2 * group_size)); + } + + // Select top 2*group_size candidates globally for this batch item + candidates.sort((a, b) => b.score - a.score); + const top_candidates = candidates.slice(0, 2 * group_size); + + for (const c of top_candidates) { + group_next_tokens.push(c.token); + group_next_indices.push(c.beam_idx); + group_next_scores.push(c.score); + } + } + + // Let beam_scorer process: route EOS to hypotheses, select continuing beams + const { + next_beam_scores: group_next_beam_scores, + next_beam_tokens: group_next_beam_tokens, + next_beam_indices: group_next_beam_indices, + } = beam_scorers[group_idx].process( + group_all_input_ids, + group_beam_scores, + group_next_tokens, + group_next_indices, + group_next_scores, + ); + + // Map group outputs into global beam arrays + for (let batch_idx = 0; batch_idx < numInputs; ++batch_idx) { + for (let beam_idx = 0; beam_idx < group_size; ++beam_idx) { + const global = batch_idx * num_beams + group_offset + beam_idx; + const local = batch_idx * group_size + beam_idx; + next_beam_scores[global] = group_next_beam_scores[local]; + next_beam_tokens[global] = group_next_beam_tokens[local]; + const local_source = group_next_beam_indices[local]; + const source_in_batch = local_source - batch_idx * group_size; + next_beam_indices[global] = batch_idx * num_beams + group_offset + source_in_batch; + } + } + + // Track tokens selected by this group for diversity penalty + for (let batch_idx = 0; batch_idx < numInputs; ++batch_idx) { + const token_counts = prev_group_tokens[batch_idx]; + for (let beam_idx = 0; beam_idx < group_size; ++beam_idx) { + const global = batch_idx * num_beams + group_offset + beam_idx; + const token = next_beam_tokens[global]; + token_counts.set(token, (token_counts.get(token) ?? 0) + 1); + } + } + } + } else { + const all_next_tokens = []; + const all_next_indices = []; + const all_next_scores = []; + const do_beam_sample = generation_config.do_sample; + const top_k = generation_config.top_k; + + for (let batch_idx = 0; batch_idx < numInputs; ++batch_idx) { + // Collect candidates across all beams for this batch item + /** @type {{score: number, token: bigint, beam_idx: number}[]} */ + const candidates = []; + + for (let beam_idx = 0; beam_idx < num_beams; ++beam_idx) { + const flat_beam = batch_idx * num_beams + beam_idx; + const logits_data = /** @type {Float32Array} */ (next_tokens_scores[flat_beam].data); + if (do_beam_sample) { + const beam_score = beam_scores[flat_beam]; + const beam_vocab = logits_data.length; + let k = top_k > 0 ? Math.min(top_k, beam_vocab) : beam_vocab; + const indices = Array.from({ length: beam_vocab }, (_, i) => i); + indices.sort((a, b) => logits_data[b] - logits_data[a]); + if (k < indices.length) { + indices.length = k; + } + const maxLogit = logits_data[indices[0]]; + const probs = new Array(indices.length); + let sum = 0; + for (let i = 0; i < indices.length; ++i) { + const val = Math.exp(logits_data[indices[i]] - maxLogit); + probs[i] = val; + sum += val; + } + for (let i = 0; i < probs.length; ++i) { + probs[i] /= sum; + } + for (let s = 0; s < num_beams; ++s) { + const sampledIndex = sampler.randomSelect(probs); + const tokenId = indices[sampledIndex]; + candidates.push({ + score: beam_score + Math.log(probs[sampledIndex]), + token: BigInt(tokenId), + beam_idx, + }); + } + } else { + const log_probs = log_softmax(logits_data); + + // Find top 2*num_beams candidates per beam for efficiency + /** @type {{score: number, token: bigint, beam_idx: number}[]} */ + const beam_candidates = []; + for (let v = 0; v < vocab_size; ++v) { + beam_candidates.push({ + score: beam_scores[flat_beam] + log_probs[v], + token: BigInt(v), + beam_idx, + }); + } + beam_candidates.sort((a, b) => b.score - a.score); + candidates.push(...beam_candidates.slice(0, 2 * num_beams)); + } + } + + // Select top 2*num_beams candidates globally for this batch item + candidates.sort((a, b) => b.score - a.score); + const top_candidates = candidates.slice(0, 2 * num_beams); + + for (const c of top_candidates) { + all_next_tokens.push(c.token); + all_next_indices.push(c.beam_idx); + all_next_scores.push(c.score); + } + } + + // Let beam_scorer process: route EOS to hypotheses, select continuing beams + ({ + next_beam_scores, + next_beam_tokens, + next_beam_indices, + } = beam_scorer.process( + all_input_ids, + beam_scores, + all_next_tokens, + all_next_indices, + all_next_scores, + )); + } + + // Reorder all_input_ids based on beam indices and append new tokens + const new_all_input_ids = []; + for (let i = 0; i < total_beams; ++i) { + new_all_input_ids.push([...all_input_ids[next_beam_indices[i]], next_beam_tokens[i]]); + generated_input_ids.push([next_beam_tokens[i]]); + } + all_input_ids.length = 0; + all_input_ids.push(...new_all_input_ids); + + // Update beam scores + for (let i = 0; i < total_beams; ++i) { + beam_scores[i] = next_beam_scores[i]; + } + + // Reorder KV cache to match beam reordering + model_inputs['past_key_values'] = await this._reorder_cache( + this.getPastKeyValues(outputs, model_inputs.past_key_values), + next_beam_indices, + ); + } else { + // Greedy / sample: existing behavior + for (let batch_idx = 0; batch_idx < next_tokens_scores.dims.at(0); ++batch_idx) { + const logs = next_tokens_scores[batch_idx]; + + const sampledTokens = await sampler(logs); + for (const [newTokenId, logProb] of sampledTokens) { + const bigint = BigInt(newTokenId); + scores[batch_idx] += logProb; + all_input_ids[batch_idx].push(bigint); + generated_input_ids.push([bigint]); + break; + } } } + if (streamer) { streamer.put(generated_input_ids); } + // Check stopping conditions + if (is_beam_search) { + const done = is_group_beam_search + ? beam_scorers.every((scorer) => scorer.is_done) + : beam_scorer.is_done; + if (done) { + break; + } + } const stop = prepared_stopping_criteria(all_input_ids); if (stop.every((x) => x)) { break; } - model_inputs = this._update_model_kwargs_for_generation({ - generated_input_ids, - outputs, - model_inputs, - is_encoder_decoder, - }); + if (is_beam_search) { + // For beam search, we already updated past_key_values during reordering. + // Now update the remaining model inputs. + model_inputs['input_ids'] = new Tensor('int64', generated_input_ids.flat(), [generated_input_ids.length, 1]); + + if (!is_encoder_decoder) { + model_inputs.attention_mask = cat( + [model_inputs.attention_mask, ones([model_inputs.attention_mask.dims[0], 1])], + 1, + ); + } + + // Force recreate position_ids in next iteration + model_inputs['position_ids'] = null; + } else { + model_inputs = this._update_model_kwargs_for_generation({ + generated_input_ids, + outputs, + model_inputs, + is_encoder_decoder, + }); + } } if (streamer) { @@ -1187,18 +1524,74 @@ export class PreTrainedModel extends Callable { // Retrieve and dispose all final past key values (including encoder attentions) const past_key_values = this.getPastKeyValues(outputs, model_inputs.past_key_values, true); - // TODO: ensure all_input_ids is padded correctly... - const sequences = new Tensor('int64', all_input_ids.flat(), [all_input_ids.length, all_input_ids[0].length]); + let sequences; + if (is_beam_search) { + /** @type {bigint[][]} */ + let best_sequences; + if (is_group_beam_search) { + const group_size = num_beams / num_beam_groups; + /** @type {{tokens: bigint[], score: number}[][]} */ + const grouped = Array.from({ length: numInputs }, () => []); + + for (let group_idx = 0; group_idx < num_beam_groups; ++group_idx) { + const group_offset = group_idx * group_size; + const group_all_input_ids = new Array(numInputs * group_size); + const group_beam_scores = new Array(numInputs * group_size); + + for (let batch_idx = 0; batch_idx < numInputs; ++batch_idx) { + for (let beam_idx = 0; beam_idx < group_size; ++beam_idx) { + const global = batch_idx * num_beams + group_offset + beam_idx; + const local = batch_idx * group_size + beam_idx; + group_all_input_ids[local] = all_input_ids[global]; + group_beam_scores[local] = beam_scores[global]; + } + } + + const group_results = beam_scorers[group_idx].finalize_with_scores( + group_all_input_ids, + group_beam_scores, + ); + + for (let batch_idx = 0; batch_idx < numInputs; ++batch_idx) { + const offset = batch_idx * group_size; + for (let i = 0; i < group_size; ++i) { + grouped[batch_idx].push(group_results[offset + i]); + } + } + } + + best_sequences = []; + for (let batch_idx = 0; batch_idx < numInputs; ++batch_idx) { + const sorted = grouped[batch_idx].sort((a, b) => b.score - a.score); + for (let i = 0; i < generation_config.num_return_sequences; ++i) { + best_sequences.push(sorted[i].tokens); + } + } + } else { + // Finalize beam search: select best hypotheses + best_sequences = beam_scorer.finalize(all_input_ids, beam_scores); + } + + // Pad sequences to equal length for tensor creation + const max_len = Math.max(...best_sequences.map(s => s.length)); + const pad_id = BigInt(generation_config.pad_token_id ?? 0); + const padded = best_sequences.map(s => { + const p = [...s]; + while (p.length < max_len) p.push(pad_id); + return p; + }); + sequences = new Tensor('int64', padded.flat(), [padded.length, max_len]); + } else { + sequences = new Tensor('int64', all_input_ids.flat(), [all_input_ids.length, all_input_ids[0].length]); + } if (generation_config.return_dict_in_generate) { + const past_key_values_for_return = is_beam_search ? null : past_key_values; return { sequences, - past_key_values, + past_key_values: past_key_values_for_return, ...attentions, ...return_dict_items, - // TODO: - // scores, - // logits, }; } else { // Dispose all remaining tensors @@ -1254,6 +1647,36 @@ export class PreTrainedModel extends Callable { return pkvs; } + /** + * Reorder the past key values cache to match beam reordering. + * @param {Object} past_key_values The past key values object. + * @param {number[]} beam_indices Indices indicating which beam each new position came from. + * @returns {Promise} Reordered past key values. + */ + async _reorder_cache(past_key_values, beam_indices) { + if (!past_key_values) return past_key_values; + const reordered = Object.create(null); + for (const key in past_key_values) { + const tensor = past_key_values[key]; + const location = tensor.location; + if (key.includes('encoder')) { + // Encoder PKVs are shared across beams; pass through + reordered[key] = tensor; + } else { + if (tensor.location === 'cpu' || tensor.location === 'cpu-pinned') { + reordered[key] = index_select(tensor, beam_indices); + } else { + reordered[key] = await index_select_async(tensor, beam_indices); + } + // Dispose old tensor if it owns GPU resources + if (location === 'gpu-buffer' || location === 'texture' || location === 'ml-tensor') { + tensor.dispose(); + } + } + } + return reordered; + } + /** * Returns an object containing attentions from the given model output object. * diff --git a/packages/transformers/src/transformers.js b/packages/transformers/src/transformers.js index d8eceaaf6..8db362e46 100644 --- a/packages/transformers/src/transformers.js +++ b/packages/transformers/src/transformers.js @@ -45,6 +45,7 @@ export { PretrainedConfig, AutoConfig } from './configs.js'; export * from './generation/streamers.js'; export * from './generation/stopping_criteria.js'; export * from './generation/logits_process.js'; +export * from './generation/beam_search.js'; export { read_audio, RawAudio } from './utils/audio.js'; export { load_image, RawImage } from './utils/image.js'; diff --git a/packages/transformers/src/utils/tensor.js b/packages/transformers/src/utils/tensor.js index 12e4ff411..b2e7f644e 100644 --- a/packages/transformers/src/utils/tensor.js +++ b/packages/transformers/src/utils/tensor.js @@ -1346,6 +1346,73 @@ export function cat(tensors, dim = 0) { return new Tensor(resultType, result, resultDims); } +/** + * Select entries from a tensor along dimension 0 using an index array. + * Equivalent to `torch.index_select(tensor, 0, indices)`. + * @param {Tensor} tensor The source tensor. + * @param {number[]} indices Array of row indices to select. + * @returns {Tensor} A new tensor with selected rows. + */ +function _index_select_from_data(tensor, indices, data) { + const [batchSize, ...restDims] = tensor.dims; + const rowSize = restDims.length > 0 ? restDims.reduce((a, b) => a * b, 1) : 1; + const newBatchSize = indices.length; + + let out; + if (ArrayBuffer.isView(data) && typeof data.subarray === 'function' && typeof data.set === 'function') { + // @ts-ignore + out = new data.constructor(newBatchSize * rowSize); + for (let i = 0; i < newBatchSize; ++i) { + const srcOffset = indices[i] * rowSize; + const dstOffset = i * rowSize; + out.set(data.subarray(srcOffset, srcOffset + rowSize), dstOffset); + } + } else { + out = new Array(newBatchSize * rowSize); + for (let i = 0; i < newBatchSize; ++i) { + const srcOffset = indices[i] * rowSize; + const dstOffset = i * rowSize; + for (let j = 0; j < rowSize; ++j) { + out[dstOffset + j] = data[srcOffset + j]; + } + } + } + + return new Tensor(tensor.type, out, [newBatchSize, ...restDims]); +} + +export function index_select(tensor, indices) { + if (tensor.location !== 'cpu' && tensor.location !== 'cpu-pinned') { + throw new Error( + `index_select only supports CPU tensors. Got location: ${tensor.location}. ` + + 'Use index_select_async to handle GPU tensors.', + ); + } + + return _index_select_from_data(tensor, indices, tensor.data); +} + +/** + * Async version of index_select that supports GPU-backed tensors by downloading data to CPU. + * @param {Tensor} tensor The source tensor. + * @param {number[]} indices Array of row indices to select. + * @returns {Promise} A new tensor with selected rows (CPU-backed). + */ +export async function index_select_async(tensor, indices) { + if (tensor.location === 'cpu' || tensor.location === 'cpu-pinned') { + return index_select(tensor, indices); + } + + // Download GPU/ML data to CPU. This will materialize CPU data and may release GPU resources. + const ort_tensor = tensor.ort_tensor; + if (!ort_tensor?.getData) { + throw new Error(`Tensor does not support getData() for location: ${tensor.location}.`); + } + + const data = await ort_tensor.getData(true); + return _index_select_from_data(tensor, indices, data); +} + /** * Stack an array of tensors along a specified dimension. * @param {Tensor[]} tensors The array of tensors to stack. diff --git a/packages/transformers/tests/utils/generation.test.js b/packages/transformers/tests/utils/generation.test.js index 6fb6f6535..063e51317 100644 --- a/packages/transformers/tests/utils/generation.test.js +++ b/packages/transformers/tests/utils/generation.test.js @@ -10,6 +10,9 @@ import { // Other TextStreamer, + RawImage, + BeamSearchScorer, + BeamHypotheses, random, } from "../../src/transformers.js"; @@ -574,3 +577,224 @@ describe("PKV caching", () => { }, MAX_MODEL_DISPOSE_TIME); }); }); + +describe("Beam search", () => { + describe(`encoder-decoder`, () => { + const model_id = "hf-internal-testing/tiny-random-T5ForConditionalGeneration"; + const DUMMY_TEXT = "hello"; + + let model; + let tokenizer; + beforeAll(async () => { + model = await AutoModelForSeq2SeqLM.from_pretrained(model_id, DEFAULT_MODEL_OPTIONS); + tokenizer = await AutoTokenizer.from_pretrained(model_id); + }, MAX_MODEL_LOAD_TIME); + + it( + "basic beam search (num_beams=4)", + async () => { + const inputs = tokenizer(DUMMY_TEXT); + const outputs = await model.generate({ + ...inputs, + num_beams: 4, + max_new_tokens: 5, + }); + // Output should have shape [1, seq_len] (1 return sequence) + expect(outputs.dims[0]).toEqual(1); + expect(outputs.dims[1]).toBeGreaterThanOrEqual(2); // at least BOS + 1 token + expect(outputs.dims[1]).toBeLessThanOrEqual(6); // at most BOS + 5 tokens + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "diverse beam search (num_beam_groups=2)", + async () => { + const inputs = tokenizer(DUMMY_TEXT); + const outputs = await model.generate({ + ...inputs, + num_beams: 4, + num_beam_groups: 2, + diversity_penalty: 0.5, + max_new_tokens: 5, + }); + // Output should have shape [1, seq_len] (1 return sequence) + expect(outputs.dims[0]).toEqual(1); + expect(outputs.dims[1]).toBeGreaterThanOrEqual(2); + expect(outputs.dims[1]).toBeLessThanOrEqual(6); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "beam sampling (do_sample=true)", + async () => { + const inputs = tokenizer(DUMMY_TEXT); + const outputs = await model.generate({ + ...inputs, + num_beams: 4, + do_sample: true, + top_k: 10, + max_new_tokens: 5, + }); + // Output should have shape [1, seq_len] (1 return sequence) + expect(outputs.dims[0]).toEqual(1); + expect(outputs.dims[1]).toBeGreaterThanOrEqual(2); + expect(outputs.dims[1]).toBeLessThanOrEqual(6); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "num_return_sequences > 1", + async () => { + const inputs = tokenizer(DUMMY_TEXT); + const outputs = await model.generate({ + ...inputs, + num_beams: 4, + num_return_sequences: 3, + max_new_tokens: 5, + }); + // Output should have shape [3, seq_len] + expect(outputs.dims[0]).toEqual(3); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "beam search produces different output from greedy", + async () => { + const inputs = tokenizer(DUMMY_TEXT); + const greedy_outputs = await model.generate({ + ...inputs, + max_new_tokens: 5, + }); + const beam_outputs = await model.generate({ + ...inputs, + num_beams: 4, + max_new_tokens: 5, + }); + // Both should succeed and produce valid tensors + expect(greedy_outputs.dims[0]).toEqual(1); + expect(beam_outputs.dims[0]).toEqual(1); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "beam search rejects classifier-free guidance", + async () => { + const inputs = tokenizer(DUMMY_TEXT); + await expect( + model.generate({ + ...inputs, + num_beams: 4, + guidance_scale: 2, + max_new_tokens: 5, + }), + ).rejects.toThrow("Classifier-free guidance (guidance_scale > 1) is not supported with beam search."); + }, + MAX_TEST_EXECUTION_TIME, + ); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); + + describe(`decoder-only`, () => { + const model_id = "hf-internal-testing/tiny-random-LlamaForCausalLM"; + const DUMMY_TEXT = "hello"; + + let model; + let tokenizer; + beforeAll(async () => { + model = await AutoModelForCausalLM.from_pretrained(model_id, DEFAULT_MODEL_OPTIONS); + tokenizer = await AutoTokenizer.from_pretrained(model_id); + }, MAX_MODEL_LOAD_TIME); + + it( + "basic beam search (num_beams=4)", + async () => { + const inputs = tokenizer(DUMMY_TEXT); + const outputs = await model.generate({ + ...inputs, + num_beams: 4, + max_new_tokens: 5, + }); + // Output should have shape [1, seq_len] + expect(outputs.dims[0]).toEqual(1); + // seq_len = prompt_len + generated (up to max_new_tokens) + expect(outputs.dims[1]).toBeGreaterThanOrEqual(3); // at least BOS + prompt + 1 token + expect(outputs.dims[1]).toBeLessThanOrEqual(7); // at most BOS + prompt + 5 tokens + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "diverse beam search (num_beam_groups=2)", + async () => { + const inputs = tokenizer(DUMMY_TEXT); + const outputs = await model.generate({ + ...inputs, + num_beams: 4, + num_beam_groups: 2, + diversity_penalty: 0.5, + max_new_tokens: 5, + }); + // Output should have shape [1, seq_len] + expect(outputs.dims[0]).toEqual(1); + expect(outputs.dims[1]).toBeGreaterThanOrEqual(3); + expect(outputs.dims[1]).toBeLessThanOrEqual(7); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "beam sampling (do_sample=true)", + async () => { + const inputs = tokenizer(DUMMY_TEXT); + const outputs = await model.generate({ + ...inputs, + num_beams: 4, + do_sample: true, + top_k: 10, + max_new_tokens: 5, + }); + // Output should have shape [1, seq_len] + expect(outputs.dims[0]).toEqual(1); + expect(outputs.dims[1]).toBeGreaterThanOrEqual(3); + expect(outputs.dims[1]).toBeLessThanOrEqual(7); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "num_return_sequences > 1", + async () => { + const inputs = tokenizer(DUMMY_TEXT); + const outputs = await model.generate({ + ...inputs, + num_beams: 4, + num_return_sequences: 3, + max_new_tokens: 5, + }); + // Output should have shape [3, seq_len] + expect(outputs.dims[0]).toEqual(3); + }, + MAX_TEST_EXECUTION_TIME, + ); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); + + describe("error cases", () => { + it("num_return_sequences > num_beams throws", () => { + expect(() => { + new BeamSearchScorer(1, 4, { num_return_sequences: 5 }); + }).toThrow("num_return_sequences (5) must be <= num_beams (4)"); + }); + }); +});