-
Notifications
You must be signed in to change notification settings - Fork 2k
feat: add configurable residual processing to reduce peak VRAM usage #239
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f902fd5
4309c38
30685de
89d45e1
67bf79b
a7c8f09
ba17216
4a37686
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -632,6 +632,9 @@ def get_residuals(self, prompts: list[Prompt]) -> Tensor: | |
| max_new_tokens=1, | ||
| output_hidden_states=True, | ||
| return_dict_in_generate=True, | ||
| # KV cache is unnecessary here because we only need the hidden states | ||
| # for the first generated token. | ||
| use_cache=False, | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This applies to the logprobs also. |
||
| ) | ||
|
|
||
| # This cast is valid because GenerateDecoderOnlyOutput is the return type | ||
|
|
@@ -665,7 +668,11 @@ def get_residuals(self, prompts: list[Prompt]) -> Tensor: | |
| dim=2, | ||
| keepdim=True, | ||
| ) | ||
| return torch.clamp(residuals, -thresholds, thresholds) | ||
| residuals = torch.clamp(residuals, -thresholds, thresholds) | ||
|
|
||
| if self.settings.offload_outputs_to_cpu: | ||
| residuals = residuals.cpu() | ||
| empty_cache() | ||
|
|
||
| return residuals | ||
|
|
||
|
|
@@ -677,6 +684,30 @@ def get_residuals_batched(self, prompts: list[Prompt]) -> Tensor: | |
|
|
||
| return torch.cat(residuals, dim=0) | ||
|
|
||
| def get_residuals_mean(self, prompts: list[Prompt]) -> Tensor: | ||
| if not prompts: | ||
| raise ValueError("prompts must not be empty") | ||
|
|
||
| running_sum = None | ||
| total_count = 0 | ||
|
|
||
| for batch in batchify(prompts, self.settings.batch_size): | ||
| batch_residuals = self.get_residuals(batch) | ||
|
|
||
| # Accumulate in high precision on CPU to reduce peak VRAM usage. | ||
| batch_sum = batch_residuals.sum(dim=0, dtype=torch.float64).cpu() | ||
|
|
||
| if running_sum is None: | ||
| running_sum = batch_sum | ||
| else: | ||
| running_sum += batch_sum | ||
|
|
||
| total_count += batch_residuals.shape[0] | ||
|
|
||
| assert running_sum is not None | ||
|
|
||
| return (running_sum / total_count).to(torch.float32) | ||
|
|
||
| # We work with logprobs rather than probabilities for numerical stability | ||
| # when computing the KL divergence. | ||
| def get_logprobs(self, prompts: list[Prompt]) -> Tensor: | ||
|
|
@@ -687,6 +718,7 @@ def get_logprobs(self, prompts: list[Prompt]) -> Tensor: | |
| max_new_tokens=1, | ||
| output_scores=True, | ||
| return_dict_in_generate=True, | ||
| use_cache=False, | ||
| ) | ||
|
|
||
| # This cast is valid because GenerateDecoderOnlyOutput is the return type | ||
|
|
@@ -698,7 +730,15 @@ def get_logprobs(self, prompts: list[Prompt]) -> Tensor: | |
| logits = cast(tuple[FloatTensor], outputs.scores)[0] | ||
|
|
||
| # The returned tensor has shape (prompt, token). | ||
| return F.log_softmax(logits, dim=-1) | ||
| logprobs = F.log_softmax(logits, dim=-1) | ||
|
|
||
| del outputs | ||
|
|
||
| if self.settings.offload_outputs_to_cpu: | ||
| logprobs = logprobs.cpu() | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it really make sense to offload logprobs? Typical vocabulary sizes are around 250k today, so even at 32 bits per logprob, that's just 8 Megabytes. Which is essentially a rounding error even on very small GPUs.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Individually small, but across batches and repeated evaluations they can add up and contribute to allocator pressure. Offloading keeps VRAM usage more predictable during longer runs. |
||
| empty_cache() | ||
|
|
||
| return logprobs | ||
|
|
||
| def get_logprobs_batched(self, prompts: list[Prompt]) -> Tensor: | ||
| logprobs = [] | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.