From 3ed5a89c5b4b2b76aeb4bfbf1a917560c11d2b5b Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Thu, 11 Apr 2024 13:23:19 -0400 Subject: [PATCH 1/9] Consolidate best guess and correct tokens --- fms_extras/utils/generation.py | 45 +++++++++------------------------- 1 file changed, 12 insertions(+), 33 deletions(-) diff --git a/fms_extras/utils/generation.py b/fms_extras/utils/generation.py index a31c125..0d1b428 100644 --- a/fms_extras/utils/generation.py +++ b/fms_extras/utils/generation.py @@ -218,8 +218,8 @@ def __get_best_candidates( ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Find the candidates with the best speculator predictions and get the indices of the - best candidates, the number of tokens correct, and the base model output values - for that candidate (tokens and embeddings) + best candidates, the number of tokens correct, and the base model outputs for those + candidates and values (tokens and embeddings). Args: input_ids: torch.Tensor @@ -230,10 +230,10 @@ def __get_best_candidates( the output embeddings for the best guesses Returns: - Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] - a tensor of the next tokens per best guess (b x 1+h), a tensor of the embeds - per best guess (b x 1+h x d), a tensor of the number of correct tokens per - best guess (b), and a tensor containing the best guess amongst all candidates + Tuple[List[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor] + a list of tensors of the correct tokens per best guess ([b] x n<=1+h), a tensor of the + last correct embed per best guess (b x 1 x d), a tensor of the number of correct tokens + per best guess (b), and a tensor containing the best guess index amongst all candidates per sequence (b) """ batch_size, num_candidates_per_sequence, decode_seq_length = input_ids.shape @@ -254,36 +254,18 @@ def __get_best_candidates( ).squeeze( 1 ) # b 1+h d - return next_vals, embeds, n_correct, best_guess - -def __get_correct_tokens( - next_vals: torch.Tensor, n_correct: torch.Tensor, embeds: torch.Tensor -) -> Tuple[List[torch.Tensor], torch.Tensor]: - """ - extract the correct tokens and the last correct embedding from each candidate, to - be used to start the next set of speculative candidates - - Args: - next_vals: torch.Tensor - a tensor of the next tokens per best guess - n_correct: torch.Tensor - a tensor of the number of correct tokens per best guess - embeds: torch.Tensor - a tensor of the embeds per best guess - - Returns: - Tuple[List[torch.Tensor], torch.Tensor] - a list of tensor of the correct tokens, and a tensor of the correct embeddings - """ + # Remove any wrong speculator tokens from best candidate next_vals_split = list(next_vals) next_vals_split = [ next_vals_split[i][: n_correct[i] + 1] for i in range(len(next_vals_split)) ] # [b] h' + + # Get last correct embedding for use in next round of predictions embeds = embeds.gather( 1, n_correct.view(-1, 1, 1).expand(-1, -1, embeds.size(2)) - ) # Grab last correct embed - return next_vals_split, embeds + ) # b 1 d + return next_vals_split, embeds, n_correct, best_guess def __prune_candidates( @@ -315,7 +297,7 @@ def __prune_candidates( correct embeddings, and a list of the best candidate id per sequence """ # get the best candidates - next_vals, embeds, n_correct, best_guess = __get_best_candidates( + next_vals_split, embeds, n_correct, best_guess = __get_best_candidates( input_ids, next_vals, embeds ) @@ -327,9 +309,6 @@ def __prune_candidates( n_correct, input_ids.size(2), ) - - # Remove any wrong speculator tokens from best candidate - next_vals_split, embeds = __get_correct_tokens(next_vals, n_correct, embeds) return next_vals_split, embeds, parent_sequence_ids From 1ac6623adf270212bb25dbbcca44fd122d234409 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Thu, 11 Apr 2024 14:52:29 -0400 Subject: [PATCH 2/9] Add candidate-consistent next_val sampling --- fms_extras/utils/generation.py | 100 ++++++++++++++++++++++++++------- 1 file changed, 79 insertions(+), 21 deletions(-) diff --git a/fms_extras/utils/generation.py b/fms_extras/utils/generation.py index 0d1b428..da6296f 100644 --- a/fms_extras/utils/generation.py +++ b/fms_extras/utils/generation.py @@ -218,8 +218,8 @@ def __get_best_candidates( ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Find the candidates with the best speculator predictions and get the indices of the - best candidates, the number of tokens correct, and the base model outputs for those - candidates and values (tokens and embeddings). + best candidates, the number of tokens correct, and the base model output values + for that candidate (tokens and embeddings) Args: input_ids: torch.Tensor @@ -230,10 +230,10 @@ def __get_best_candidates( the output embeddings for the best guesses Returns: - Tuple[List[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor] - a list of tensors of the correct tokens per best guess ([b] x n<=1+h), a tensor of the - last correct embed per best guess (b x 1 x d), a tensor of the number of correct tokens - per best guess (b), and a tensor containing the best guess index amongst all candidates + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + a tensor of the next tokens per best guess (b x 1+h), a tensor of the embeds + per best guess (b x 1+h x d), a tensor of the number of correct tokens per + best guess (b), and a tensor containing the best guess amongst all candidates per sequence (b) """ batch_size, num_candidates_per_sequence, decode_seq_length = input_ids.shape @@ -254,18 +254,36 @@ def __get_best_candidates( ).squeeze( 1 ) # b 1+h d + return next_vals, embeds, n_correct, best_guess - # Remove any wrong speculator tokens from best candidate + +def __get_correct_tokens( + next_vals: torch.Tensor, n_correct: torch.Tensor, embeds: torch.Tensor +) -> Tuple[List[torch.Tensor], torch.Tensor]: + """ + extract the correct tokens and the last correct embedding from each candidate, to + be used to start the next set of speculative candidates + + Args: + next_vals: torch.Tensor + a tensor of the next tokens per best guess + n_correct: torch.Tensor + a tensor of the number of correct tokens per best guess + embeds: torch.Tensor + a tensor of the embeds per best guess + + Returns: + Tuple[List[torch.Tensor], torch.Tensor] + a list of tensor of the correct tokens, and a tensor of the correct embeddings + """ next_vals_split = list(next_vals) next_vals_split = [ next_vals_split[i][: n_correct[i] + 1] for i in range(len(next_vals_split)) ] # [b] h' - - # Get last correct embedding for use in next round of predictions embeds = embeds.gather( 1, n_correct.view(-1, 1, 1).expand(-1, -1, embeds.size(2)) - ) # b 1 d - return next_vals_split, embeds, n_correct, best_guess + ) # Grab last correct embed + return next_vals_split, embeds def __prune_candidates( @@ -297,7 +315,7 @@ def __prune_candidates( correct embeddings, and a list of the best candidate id per sequence """ # get the best candidates - next_vals_split, embeds, n_correct, best_guess = __get_best_candidates( + next_vals, embeds, n_correct, best_guess = __get_best_candidates( input_ids, next_vals, embeds ) @@ -309,6 +327,9 @@ def __prune_candidates( n_correct, input_ids.size(2), ) + + # Remove any wrong speculator tokens from best candidate + next_vals_split, embeds = __get_correct_tokens(next_vals, n_correct, embeds) return next_vals_split, embeds, parent_sequence_ids @@ -336,24 +357,45 @@ def __extract_decode_output( Returns: Tuple[torch.Tensor, torch.Tensor] - the un-flattened next tokens per candidate per sequence, and the - un-flattened output embedding vector + the un-flattened logit scores per token per candidate per sequence, + and the un-flattened output embedding vectors """ logits, _, embeds = model_output # 1 n' v, 1 n' d OR bk 1+h v, bk 1+h d - next_vals = torch.argmax(logits, dim=-1) # 1 n' OR bk 1+h # If we used batch flattening / tree attention, unflatten the outputs if unflat_indices is not None: - next_vals = apply_index_map(next_vals[0], unflat_indices) # b k 1+h + logits = apply_index_map(logits[0], unflat_indices) # b k 1+h v embeds = apply_index_map(embeds[0], unflat_indices) # b k 1+h d else: - next_vals = next_vals.view( - batch_size, n_candidates, decode_seq_length - ) # b k 1+h + logits = logits.view( + batch_size, n_candidates, decode_seq_length, logits.size(2) + ) # b k 1+h v embeds = embeds.view( batch_size, n_candidates, decode_seq_length, embeds.size(2) ) # b k 1+h d - return next_vals, embeds + return logits, embeds + + +def __generate_targets( + logits: torch.Tensor, temperature: int = 1, top_k: int = 5, do_sample: bool = False +) -> torch.Tensor: + if not do_sample: + return logits.argmax(-1) + + # Get sample distributions + logits = logits / temperature + v, _ = logits.topk(logits, top_k) + logits[logits < v[:, :, :, [-1]]] = -float("inf") + probs = logits.softmax(-1) # b k 1+h v + + # Sample candidate-consistent ground truths + key = torch.rand(1, 1, logits.size(2), 1, device=probs.device) + a = probs.cumsum(3).sub(key).sign() # All intervals around/above key + b = ( + probs.flip(3).cumsum(3).flip(3).sub(1 - key).sign() + ) # All intervals around/below key + choice = a.add(b) # The interval around key + return choice.argmax(3) def speculative_generate( @@ -369,6 +411,9 @@ def speculative_generate( decode_model: Optional[Union[Callable, torch.nn.Module]] = None, # todo: This is a WIP to enable cudagraphs, currently its only for batch_size=1 cudagraphs: bool = False, + do_sample: bool = False, + temperature: float = 1.0, + top_k: int = 5, ): """ A reference implementation of speculative decoding generation. @@ -412,6 +457,15 @@ def speculative_generate( if True, cudagraphs is used and all metadata will be padded, otherwise metadata will not be padded unless required. Note: This is a WIP and only works for batch_size=1 + do_sample: bool + non-deterministic, multinomial output sampling. False for greedy. + Provides output diversity, but lowers speculative decoding speedup. + temperature: float + temperature of softmax when sampling. Lowering this should provide + better speculative decoding speedup when do_sample=True. + top_k: int + only search among top k tokens. Lowering this should provide + better speculative decoding speedup when do_sample=True. Returns: result: List of id tensors, possibly different lengths if batching. n_steps: Number of foward passes used to generate provided tokens. @@ -497,10 +551,14 @@ def speculative_generate( use_cache=True, ) # 1 n' v OR bk 1+h v - next_vals, embeds = __extract_decode_output( + logits, embeds = __extract_decode_output( output, unflat_indices, bsize, n_candidates, inp_len ) + next_vals = __generate_targets( + logits, temperature=temperature, top_k=top_k, do_sample=do_sample + ) + next_vals_list, embeds, parent_sequence_ids = __prune_candidates( input_ids, next_vals, embeds, kv_cache_manager, child_sequence_ids_list ) From be7e0cc8717e59f5e7a9e2e47af4d5291e55cb35 Mon Sep 17 00:00:00 2001 From: Joshua Rosenkranz Date: Thu, 11 Apr 2024 16:15:57 -0400 Subject: [PATCH 3/9] fixed issue with top_k in sampling; fixed type hint for temperature; added args to script for sampling --- fms_extras/utils/generation.py | 4 ++-- scripts/paged_speculative_inference.py | 24 +++++++++++++++++++++++- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/fms_extras/utils/generation.py b/fms_extras/utils/generation.py index da6296f..98c8d06 100644 --- a/fms_extras/utils/generation.py +++ b/fms_extras/utils/generation.py @@ -377,14 +377,14 @@ def __extract_decode_output( def __generate_targets( - logits: torch.Tensor, temperature: int = 1, top_k: int = 5, do_sample: bool = False + logits: torch.Tensor, temperature: float = 1.0, top_k: int = 5, do_sample: bool = False ) -> torch.Tensor: if not do_sample: return logits.argmax(-1) # Get sample distributions logits = logits / temperature - v, _ = logits.topk(logits, top_k) + v, _ = logits.topk(top_k) logits[logits < v[:, :, :, [-1]]] = -float("inf") probs = logits.softmax(-1) # b k 1+h v diff --git a/scripts/paged_speculative_inference.py b/scripts/paged_speculative_inference.py index befd820..d2f24f3 100644 --- a/scripts/paged_speculative_inference.py +++ b/scripts/paged_speculative_inference.py @@ -94,6 +94,23 @@ action="store_true", help="use a batch of prompts as input (note this is still wip for reduce-overhead=True)", ) +parser.add_argument( + "--top_k", + type=int, + default=10, + help="sample only among top k most confident tokens (ignored if do_sample=False)", +) +parser.add_argument( + "--temperature", + type=float, + default=1.0, + help="degree of smoothing for sampling distribution (ignored if do_sample=False)", +) +parser.add_argument( + "--do_sample", + action="store_true", + help="enable non-greedy generation" +) args = parser.parse_args() @@ -232,6 +249,9 @@ def infer(ids, warmup): # todo: we can only reduce-overhead for now when batch size is 1 flattening=not (args.compile and compile_mode == "reduce-overhead"), cudagraphs=cudagraphs, + do_sample=args.do_sample, + temperature=args.temperature, + top_k=args.top_k ) else: result, n_steps, ttft, generated_token_time_out = paged_generate( @@ -240,9 +260,11 @@ def infer(ids, warmup): kv_cache_manager, max_new_tokens=100, max_seq_len=model.config.max_expected_seq_len, - do_sample=False, decode_model=decode_model, cudagraphs=cudagraphs, + do_sample=args.do_sample, + temperature=args.temperature, + top_k=args.top_k ) if not warmup: total_tokens = 0 From 96fdc7d419ba0a64928b0693047e702b1a9fdb52 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Thu, 11 Apr 2024 18:51:36 -0400 Subject: [PATCH 4/9] Add docstring for generate_targets --- fms_extras/utils/generation.py | 33 +++++++++++++++++++++++++- scripts/paged_speculative_inference.py | 8 +++---- 2 files changed, 35 insertions(+), 6 deletions(-) diff --git a/fms_extras/utils/generation.py b/fms_extras/utils/generation.py index 98c8d06..0b0a466 100644 --- a/fms_extras/utils/generation.py +++ b/fms_extras/utils/generation.py @@ -377,8 +377,39 @@ def __extract_decode_output( def __generate_targets( - logits: torch.Tensor, temperature: float = 1.0, top_k: int = 5, do_sample: bool = False + logits: torch.Tensor, + temperature: float = 1.0, + top_k: int = 5, + do_sample: bool = False, ) -> torch.Tensor: + """ + Extracts ground-truth tokens from a set of logits. If performing greedy decoding, + simply returns the most confident tokens. Otherwise, implements consistent multinomial + sampling - two identical distributions will always produce the same (randomized) sample. + Thus by induction, two candidates with identical prefixes will receive the same ground + truth sample up to the point their inputs diverge. This allows us to ensure that at least + one candidate will be accepted, so long as the candidate set covers the top_k options. + + For example, if the base model predicts tokens A and B with equal 50% probability, and the + speculator produces one candidate with A and another with B, with independent sampling there's + a 25% chance of rejecting both, even though one must be correct. Consistent sampling allows us + to avoid this. + + Args: + logits: torch.Tensor + Probability logits for a set of candidate sequences. Expects size + bsize x n_candidates x seq_len x vocab_size + temperature: float + Degree of smoothing on softmax sampling distribution + top_k: int + Sample only among the top_k most confident tokens + do_sample: bool + Enable non-greedy decoding with consistent sampling + + Returns: + torch.Tensor + Tensor of chosen token values for each sequence + """ if not do_sample: return logits.argmax(-1) diff --git a/scripts/paged_speculative_inference.py b/scripts/paged_speculative_inference.py index d2f24f3..e4fce56 100644 --- a/scripts/paged_speculative_inference.py +++ b/scripts/paged_speculative_inference.py @@ -107,9 +107,7 @@ help="degree of smoothing for sampling distribution (ignored if do_sample=False)", ) parser.add_argument( - "--do_sample", - action="store_true", - help="enable non-greedy generation" + "--do_sample", action="store_true", help="enable non-greedy generation" ) args = parser.parse_args() @@ -251,7 +249,7 @@ def infer(ids, warmup): cudagraphs=cudagraphs, do_sample=args.do_sample, temperature=args.temperature, - top_k=args.top_k + top_k=args.top_k, ) else: result, n_steps, ttft, generated_token_time_out = paged_generate( @@ -264,7 +262,7 @@ def infer(ids, warmup): cudagraphs=cudagraphs, do_sample=args.do_sample, temperature=args.temperature, - top_k=args.top_k + top_k=args.top_k, ) if not warmup: total_tokens = 0 From ea646a2cc1446b4feaa8d257b0a4b4eabe900f7d Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Thu, 11 Apr 2024 18:57:12 -0400 Subject: [PATCH 5/9] Simpler/faster consistent sampling --- fms_extras/utils/generation.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/fms_extras/utils/generation.py b/fms_extras/utils/generation.py index 0b0a466..9c4ce52 100644 --- a/fms_extras/utils/generation.py +++ b/fms_extras/utils/generation.py @@ -419,14 +419,14 @@ def __generate_targets( logits[logits < v[:, :, :, [-1]]] = -float("inf") probs = logits.softmax(-1) # b k 1+h v - # Sample candidate-consistent ground truths + # Sample candidate-consistent ground truths: partition number line in [0,1] + # according to given multinomial distribution. Pick a random location + # on that line, return interval containing that location. key = torch.rand(1, 1, logits.size(2), 1, device=probs.device) - a = probs.cumsum(3).sub(key).sign() # All intervals around/above key - b = ( - probs.flip(3).cumsum(3).flip(3).sub(1 - key).sign() - ) # All intervals around/below key - choice = a.add(b) # The interval around key - return choice.argmax(3) + a = ( + probs.cumsum(3).sub(key).sign() + ) # Sign flips on probability interval containing key + return a.sub(1).div(-2).sum(3) # Get index of sign-flip def speculative_generate( From f6d96f60eb6221ffdcab86ffe0d2cd99d37a333f Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 12 Apr 2024 12:00:42 -0400 Subject: [PATCH 6/9] Vectorized do_sample in generate_targets --- fms_extras/utils/generation.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/fms_extras/utils/generation.py b/fms_extras/utils/generation.py index 9c4ce52..c9b90ff 100644 --- a/fms_extras/utils/generation.py +++ b/fms_extras/utils/generation.py @@ -378,9 +378,9 @@ def __extract_decode_output( def __generate_targets( logits: torch.Tensor, + do_sample: torch.Tensor, temperature: float = 1.0, top_k: int = 5, - do_sample: bool = False, ) -> torch.Tensor: """ Extracts ground-truth tokens from a set of logits. If performing greedy decoding, @@ -399,19 +399,18 @@ def __generate_targets( logits: torch.Tensor Probability logits for a set of candidate sequences. Expects size bsize x n_candidates x seq_len x vocab_size + do_sample: torch.Tensor + A tensor of booleans enabling/disabling non-greedy decoding with consistent + sampling, for each of bsize input sequences temperature: float Degree of smoothing on softmax sampling distribution top_k: int Sample only among the top_k most confident tokens - do_sample: bool - Enable non-greedy decoding with consistent sampling Returns: torch.Tensor Tensor of chosen token values for each sequence """ - if not do_sample: - return logits.argmax(-1) # Get sample distributions logits = logits / temperature @@ -426,7 +425,12 @@ def __generate_targets( a = ( probs.cumsum(3).sub(key).sign() ) # Sign flips on probability interval containing key - return a.sub(1).div(-2).sum(3) # Get index of sign-flip + samples = a.sub(1).div(-2).sum(3) # Get index of sign-flip + + # Composite greedy and non greedy outputs + greedy = logits.argmax(-1) + mask = do_sample[:, None, None].int() + return samples * mask + (1 - mask) * greedy def speculative_generate( @@ -586,8 +590,12 @@ def speculative_generate( output, unflat_indices, bsize, n_candidates, inp_len ) + if do_sample: + do_sample_vector = torch.ones(bsize, device=logits.device) + else: + do_sample_vector = torch.zeros(bsize, device=logits.device) next_vals = __generate_targets( - logits, temperature=temperature, top_k=top_k, do_sample=do_sample + logits, do_sample_vector, temperature=temperature, top_k=top_k ) next_vals_list, embeds, parent_sequence_ids = __prune_candidates( From 65525167004e4fa05e0979525f924922d19e4940 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 13 May 2024 16:36:08 -0400 Subject: [PATCH 7/9] Cleaner sample/nosample masking --- fms_extras/utils/generation.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/fms_extras/utils/generation.py b/fms_extras/utils/generation.py index c9b90ff..a6a7a1b 100644 --- a/fms_extras/utils/generation.py +++ b/fms_extras/utils/generation.py @@ -429,8 +429,7 @@ def __generate_targets( # Composite greedy and non greedy outputs greedy = logits.argmax(-1) - mask = do_sample[:, None, None].int() - return samples * mask + (1 - mask) * greedy + return torch.where(do_sample[:, None, None], samples, greedy) def speculative_generate( From 2f47c7034f45758c47f00997f75fb16da2811e72 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 17 Jun 2024 14:18:17 -0400 Subject: [PATCH 8/9] Fix do_sample typing issues --- fms_extras/utils/generation.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/fms_extras/utils/generation.py b/fms_extras/utils/generation.py index a6a7a1b..51cfcfb 100644 --- a/fms_extras/utils/generation.py +++ b/fms_extras/utils/generation.py @@ -429,6 +429,7 @@ def __generate_targets( # Composite greedy and non greedy outputs greedy = logits.argmax(-1) + samples = samples.to(dtype=greedy.dtype) return torch.where(do_sample[:, None, None], samples, greedy) @@ -590,9 +591,9 @@ def speculative_generate( ) if do_sample: - do_sample_vector = torch.ones(bsize, device=logits.device) + do_sample_vector = torch.ones(bsize, device=logits.device, dtype=torch.bool) else: - do_sample_vector = torch.zeros(bsize, device=logits.device) + do_sample_vector = torch.zeros(bsize, device=logits.device, dtype=torch.bool) next_vals = __generate_targets( logits, do_sample_vector, temperature=temperature, top_k=top_k ) From bc75432828e2fd582850b5c415df012c95ef04d9 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 17 Jun 2024 14:19:24 -0400 Subject: [PATCH 9/9] Linting --- fms_extras/utils/generation.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fms_extras/utils/generation.py b/fms_extras/utils/generation.py index 51cfcfb..6e6e218 100644 --- a/fms_extras/utils/generation.py +++ b/fms_extras/utils/generation.py @@ -593,7 +593,9 @@ def speculative_generate( if do_sample: do_sample_vector = torch.ones(bsize, device=logits.device, dtype=torch.bool) else: - do_sample_vector = torch.zeros(bsize, device=logits.device, dtype=torch.bool) + do_sample_vector = torch.zeros( + bsize, device=logits.device, dtype=torch.bool + ) next_vals = __generate_targets( logits, do_sample_vector, temperature=temperature, top_k=top_k )