Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions config.default.toml
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@ refusal_markers = [
# System prompt to use when prompting the model.
system_prompt = "You are a helpful assistant."

# Move intermediate analysis tensors (such as residuals and logprobs)
# to CPU memory as soon as possible to reduce peak VRAM usage.
# This lowers peak VRAM usage during residual analysis and evaluation,
# but may slightly reduce performance due to host/device transfers.
offload_outputs_to_cpu = true
Comment thread
magiccodingman marked this conversation as resolved.

# Dataset of prompts that tend to not result in refusals (used for calculating refusal directions).
[good_prompts]
dataset = "mlabonne/harmless_alpaca"
Expand Down
8 changes: 8 additions & 0 deletions src/heretic/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,14 @@ class Settings(BaseSettings):
description="System prompt to use when prompting the model.",
)

offload_outputs_to_cpu: bool = Field(
default=True,
description=(
"Whether to move intermediate analysis tensors (such as residuals and logprobs) "
"to CPU memory as soon as possible to reduce peak VRAM usage."
),
)

good_prompts: DatasetSpecification = Field(
default=DatasetSpecification(
dataset="mlabonne/harmless_alpaca",
Expand Down
40 changes: 26 additions & 14 deletions src/heretic/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,13 +445,33 @@ def run():

print()
print("Calculating per-layer refusal directions...")
print("* Obtaining residuals for good prompts...")
good_residuals = model.get_residuals_batched(good_prompts)
print("* Obtaining residuals for bad prompts...")
bad_residuals = model.get_residuals_batched(bad_prompts)

good_means = good_residuals.mean(dim=0)
bad_means = bad_residuals.mean(dim=0)
needs_full_residuals = settings.print_residual_geometry or settings.plot_residuals

good_residuals = None
bad_residuals = None

if needs_full_residuals:
print("* Obtaining residuals for good prompts...")
good_residuals = model.get_residuals_batched(good_prompts)
print("* Obtaining residuals for bad prompts...")
bad_residuals = model.get_residuals_batched(bad_prompts)

good_means = good_residuals.mean(dim=0)
bad_means = bad_residuals.mean(dim=0)

analyzer = Analyzer(settings, model, good_residuals, bad_residuals)

if settings.print_residual_geometry:
analyzer.print_residual_geometry()

if settings.plot_residuals:
analyzer.plot_residuals()
else:
print("* Obtaining residual mean for good prompts...")
good_means = model.get_residuals_mean(good_prompts)
print("* Obtaining residual mean for bad prompts...")
bad_means = model.get_residuals_mean(bad_prompts)

refusal_directions = F.normalize(bad_means - good_means, p=2, dim=1)

Expand All @@ -466,14 +486,6 @@ def run():
)
refusal_directions = F.normalize(refusal_directions, p=2, dim=1)

analyzer = Analyzer(settings, model, good_residuals, bad_residuals)

if settings.print_residual_geometry:
analyzer.print_residual_geometry()

if settings.plot_residuals:
analyzer.plot_residuals()

# We don't need the residuals after computing refusal directions.
del good_residuals, bad_residuals, analyzer
empty_cache()
Expand Down
44 changes: 42 additions & 2 deletions src/heretic/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,9 @@ def get_residuals(self, prompts: list[Prompt]) -> Tensor:
max_new_tokens=1,
output_hidden_states=True,
return_dict_in_generate=True,
# KV cache is unnecessary here because we only need the hidden states
# for the first generated token.
use_cache=False,
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This applies to the logprobs also.

)

# This cast is valid because GenerateDecoderOnlyOutput is the return type
Expand Down Expand Up @@ -665,7 +668,11 @@ def get_residuals(self, prompts: list[Prompt]) -> Tensor:
dim=2,
keepdim=True,
)
return torch.clamp(residuals, -thresholds, thresholds)
residuals = torch.clamp(residuals, -thresholds, thresholds)

if self.settings.offload_outputs_to_cpu:
residuals = residuals.cpu()
empty_cache()

return residuals

Expand All @@ -677,6 +684,30 @@ def get_residuals_batched(self, prompts: list[Prompt]) -> Tensor:

return torch.cat(residuals, dim=0)

def get_residuals_mean(self, prompts: list[Prompt]) -> Tensor:
if not prompts:
raise ValueError("prompts must not be empty")

running_sum = None
total_count = 0

for batch in batchify(prompts, self.settings.batch_size):
batch_residuals = self.get_residuals(batch)

# Accumulate in high precision on CPU to reduce peak VRAM usage.
batch_sum = batch_residuals.sum(dim=0, dtype=torch.float64).cpu()

if running_sum is None:
running_sum = batch_sum
else:
running_sum += batch_sum

total_count += batch_residuals.shape[0]

assert running_sum is not None

return (running_sum / total_count).to(torch.float32)

# We work with logprobs rather than probabilities for numerical stability
# when computing the KL divergence.
def get_logprobs(self, prompts: list[Prompt]) -> Tensor:
Expand All @@ -687,6 +718,7 @@ def get_logprobs(self, prompts: list[Prompt]) -> Tensor:
max_new_tokens=1,
output_scores=True,
return_dict_in_generate=True,
use_cache=False,
)

# This cast is valid because GenerateDecoderOnlyOutput is the return type
Expand All @@ -698,7 +730,15 @@ def get_logprobs(self, prompts: list[Prompt]) -> Tensor:
logits = cast(tuple[FloatTensor], outputs.scores)[0]

# The returned tensor has shape (prompt, token).
return F.log_softmax(logits, dim=-1)
logprobs = F.log_softmax(logits, dim=-1)

del outputs

if self.settings.offload_outputs_to_cpu:
logprobs = logprobs.cpu()
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it really make sense to offload logprobs?

Typical vocabulary sizes are around 250k today, so even at 32 bits per logprob, that's just 8 Megabytes. Which is essentially a rounding error even on very small GPUs.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Individually small, but across batches and repeated evaluations they can add up and contribute to allocator pressure. Offloading keeps VRAM usage more predictable during longer runs.

empty_cache()

return logprobs

def get_logprobs_batched(self, prompts: list[Prompt]) -> Tensor:
logprobs = []
Expand Down
Loading