Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
a3fd485
add file settings.py to store settings
thiswillbeyourgithub Dec 4, 2024
30b8923
add a utils.py file that contains a make_dataset function
thiswillbeyourgithub Dec 4, 2024
f65111d
move DatasetEntry to utils.py
thiswillbeyourgithub Dec 4, 2024
962365a
feat: give more controls to how the user can select the layers to modify
thiswillbeyourgithub Dec 4, 2024
06a77fe
feat: add joblib memory for caching model activations
thiswillbeyourgithub Dec 4, 2024
f129548
fix: numpy v2 compatibility
thiswillbeyourgithub Dec 4, 2024
ee358c0
feat: add arg norm_type when training that allows setting l1 or l2 (o…
thiswillbeyourgithub Dec 4, 2024
a63a93b
feat: add function to accept chat template as inputs + autocorrect
thiswillbeyourgithub Dec 4, 2024
07c5b1f
fix: forgot an import
thiswillbeyourgithub Dec 4, 2024
5b96bab
minor: add tqdm for getting activations and applying the new directions
thiswillbeyourgithub Dec 4, 2024
f238580
perf: potentially faster one liner to tokenize
thiswillbeyourgithub Dec 4, 2024
af8dd6d
feat: use flag LOW_MEMORY to reduce the amount of memory needed when …
thiswillbeyourgithub Dec 4, 2024
766ac7a
fix: imports
thiswillbeyourgithub Dec 4, 2024
f6efc48
update default model in example from mistral 7B v0.1 to v0.3
thiswillbeyourgithub Dec 4, 2024
6123b86
docs: add more details on how to load the model, including quantizati…
thiswillbeyourgithub Dec 4, 2024
50bdc42
docs: use chat messages in example
thiswillbeyourgithub Dec 4, 2024
d5c3baa
docs: in example, show how to login for models that require it
thiswillbeyourgithub Dec 4, 2024
f9c90df
docs: add missing declaration of tokenizer
thiswillbeyourgithub Dec 4, 2024
63baad5
fix: in example, move the tokens directly to the correct device
thiswillbeyourgithub Dec 4, 2024
d523235
minor: changed default in the example
thiswillbeyourgithub Dec 4, 2024
9136c5b
docs: add link to github issue about OOM for gguf files
thiswillbeyourgithub Dec 4, 2024
0e07984
docs: add more examples of models
thiswillbeyourgithub Dec 4, 2024
6f25c17
fix: autocorrect chat templates
thiswillbeyourgithub Dec 4, 2024
ef30c30
docs: mention the new way to specify the layers
thiswillbeyourgithub Dec 4, 2024
ec3e28c
fix: template function for non chats
thiswillbeyourgithub Dec 10, 2024
21652a8
docs: mention in readme that qwen2.5 7B works out of the box
thiswillbeyourgithub Dec 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 101 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,43 +14,129 @@ _For a full example, see the notebooks folder or [the blog post](https://vgel.me
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# from transformers import BitsAndBytesConfig


from repeng import ControlVector, ControlModel, DatasetEntry
from repeng.utils import make_dataset, autocorrect_chat_templates

# # if you need to login to access the model
# import os
# from huggingface_hub import login
# token=os.environ["HUGGINGFACE_API_TOKEN"]
# assert token
# login(token=token)

# load and wrap model
# model_name = "mistralai/Mistral-7B-Instruct-v0.1"
model_name = "mistralai/Mistral-7B-Instruct-v0.3"
# model_name = "mistralai/Mistral-Nemo-Base-2407"
# model_name = "mistralai/Mistral-Nemo-Instruct-2407"
# model_name = "meta-llama/Llama-3.2-1B-Instruct"
# model_name = "meta-llama/Llama-3.2-3B-Instruct"
# model_name = "Qwen/Qwen2.5-7B-Instruct"

model = AutoModelForCausalLM.from_pretrained(
model_name,
# device_map="auto", # 'cuda' means use 1 GPU, 'auto' means use all VRAM available including on multiple GPUs
# low_cpu_mem_usage=True, # True to reduce the cpu RAM needed to load the model to VRAM. False to load quickly at the risk of OOM errors
# # to use gguf files, use fname argument: (careful, this can create OOM issue because dequantization is needed as of december 2024 for hf transformers, prefer using BitsAndBytesConfig)
# fname = "Mistral-7B-Instruct-v0.3.Q2_K.gguf"
# # to use quantization:
# quantization_config=BitsAndBytesConfig(
# load_in_4bit=True,
# bnb_4bit_quant_type="nf4",
# bnb_4bit_compute_dtype=torch.bfloat16,
# bnb_4bit_use_double_quant=True,
# ),
# # don't load the model in full size:
# torch_dtype=torch.float16,
)
)

# load and wrap Mistral-7B
model_name = "mistralai/Mistral-7B-Instruct-v0.1"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
model = ControlModel(model, list(range(-5, -18, -1)))
# wrap the model to give us control
model = ControlModel(
model,
# layer_ids="all", # all layers
# layer_ids="middle", # from 25% to 75% depth
# layer_ids="only_middle", # only the single layer at the middle
layer_ids="0.1-0.3", # from 10% depth to 30% depth
# layer_ids="0.5-0.9", # 50% to 90%
# layer_ids="0.1-0.5", # 10% to 50%
# layer_ids=list(range(-5, -18, -1)) # specific layer numbers
)

def make_dataset(template: str, pos_personas: list[str], neg_personas: list[str], suffixes: list[str]):
# see notebooks/experiments.ipynb for a definition of `make_dataset`
...
tokenizer = AutoTokenizer.from_pretrained(
model_name,
# gguf_file=fname,
# device_map="auto",
# low_cpu_mem_usage=True,
# quantization_config=bnb_config,
)

# generate a dataset with closely-opposite paired statements
trippy_dataset = make_dataset(
"Act as if you're extremely {persona}.",
["high on psychedelic drugs"],
["sober from psychedelic drugs"],
truncated_output_suffixes,
# you can use either chat templates...
# template=[
# {"role": "system", "content": "You talk like you are {persona}."},
# {"role": "user", "content": "{suffix}"},
# ],
# ...or strings directly:
template="Act as if you're {persona}. Someone comes at you and says '{suffix}'.",
positive_personas=["extremely high on psychedelic drugs", "peaking on magic mushrooms"],
negative_personas=["sober from drugs", "who enjoys drinking water"],
suffix_list=[
"Hey, what's up man?",
"Hey, what's up girl?",
"Welcome Mr Musk, come this way.",
"How have you been feeling lately with the medications?",
],
)

# train the vector—takes less than a minute!
trippy_vector = ControlVector.train(model, tokenizer, trippy_dataset)

# Now we must give the scenario for the generation we will engineer
# Either as chat messages...
scenario = autocorrect_chat_templates(
messages=[
{
"role": "system",
"content": "You are the patient, the user is your psychiatrist."
},
{
"role": "user",
"content": "Now let's talk about your mood. How do you feel?",
},
{
"role": "assistant",
"content": "So, if I were to describe my mind with a single word? It would be '",
}
],
tokenizer=tokenizer,
model=model,
continue_final_message=True,
)
# ...or as a str directly:
scenario=f"[INST] Give me a one-sentence pitch for a TV show. [/INST]",

# set the control strength and let inference rip!
for strength in (-2.2, 1, 2.2):
print(f"strength={strength}")
model.set_control(trippy_vector, strength)
out = model.generate(
**tokenizer(
f"[INST] Give me a one-sentence pitch for a TV show. [/INST]",
scenario,
return_tensors="pt"
),
).to(model.device),
do_sample=False,
max_new_tokens=128,
# temperature=1.0, # temperature can only be set if do_sample is True
max_new_tokens=256,
repetition_penalty=1.1,
use_cache=True, # defaults to True anyway
)
print(tokenizer.decode(out.squeeze()).strip())
# print(tokenizer.decode(out.squeeze(), skip_special_tokens=False).strip()) # if you want to display the special tokens
print()
```

Expand All @@ -69,6 +155,7 @@ For a more detailed explanation of how the library works and what it can do, see

* For a list of changes by version, see the [CHANGELOG](https://github.com/vgel/repeng/blob/main/CHANGELOG).
* For quantized use, you may be interested in [llama.cpp#5970](https://github.com/ggerganov/llama.cpp/pull/5970)—after training a vector with `repeng`, export it by calling `vector.export_gguf(filename)` and then use it in `llama.cpp` with any quant!
* To load gguf files directly, you can run into OOM errors, see [this github issue for more](See here: https://github.com/huggingface/transformers/issues/34417).
* Vector training *currently does not work* with MoE models (such as Mixtral). (This is theoretically fixable with some work, let me know if you're interested.)

## Notice
Expand Down
4 changes: 3 additions & 1 deletion repeng/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,7 @@
from transformers import PreTrainedModel, PreTrainedTokenizerBase

from . import control, extract
from .extract import ControlVector, DatasetEntry
from .extract import ControlVector
from .control import ControlModel
from . import utils
from .utils import DatasetEntry
37 changes: 34 additions & 3 deletions repeng/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,51 @@ class ControlModel(torch.nn.Module):
A wrapped language model that can have controls set on its layers with `self.set_control`.
"""

def __init__(self, model: PreTrainedModel, layer_ids: typing.Iterable[int]):
def __init__(
self,
model: PreTrainedModel,
layer_ids: typing.Union[typing.Iterable[int], typing.Literal['all', 'middle', 'only_middle'], str] = "all",
):
"""
**This mutates the wrapped `model`! Be careful using `model` after passing it to this class.**

Build a new ControlModel around a model instance, initializing control on
the layers specified in `layer_ids`.
"""

`layer_ids` can now also be a string in the format "start-end" where start and end
are floats between 0 and 1, indicating the percentage range of layers to select.
"""
super().__init__()
self.model = model
num_layers = model.config.num_hidden_layers

if not layer_ids or layer_ids == "all":
layer_ids = range(-1, -num_layers, -1)
elif layer_ids == "middle": # keep only the middle half
layer_ids = [li for li in range(-1, -num_layers, -1)]
layer_ids = layer_ids[len(layer_ids)//4:-len(layer_ids)//4]
elif layer_ids == "only_middle": # keep only the middle layer
layer_ids = [li for li in range(-1, -num_layers, -1)]
layer_ids = [layer_ids[len(layer_ids)//2]]
elif isinstance(layer_ids, str) and '-' in layer_ids:
start, end = map(float, layer_ids.split('-'))
if not (0 <= start < end <= 1):
raise ValueError("Invalid percentage range. Must be 0 <= start < end <= 1")
start_idx = max(0, min(num_layers - 1, int(start * num_layers)))
end_idx = max(0, min(num_layers - 1, int(end * num_layers)))
if start_idx == end_idx:
layer_ids = [start_idx]
else:
layer_ids = list(range(start_idx, end_idx + 1))
if not layer_ids:
raise ValueError("The specified range doesn't include any layers")
else:
assert isinstance(layer_ids, list) and all(isinstance(item, int) for item in layer_ids), "unexpected value for layer_ids"

layers = model_layer_list(model)
self.layer_ids = [i if i >= 0 else len(layers) + i for i in layer_ids]
for layer_id in layer_ids:

for layer_id in self.layer_ids:
layer = layers[layer_id]
if not isinstance(layer, ControlModule):
layers[layer_id] = ControlModule(layer)
Expand Down
87 changes: 65 additions & 22 deletions repeng/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import typing
import warnings
from joblib import Memory

import gguf
import numpy as np
Expand All @@ -12,13 +13,34 @@

from .control import ControlModel, model_layer_list
from .saes import Sae


@dataclasses.dataclass
class DatasetEntry:
positive: str
negative: str

from .settings import VERBOSE, LOW_MEMORY
from .utils import autocorrect_chat_templates, DatasetEntry, get_model_name

if not hasattr(np, "float_"):
np.float_ = np.float64

# Setup cache
cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "controlvector")
memory = Memory(cache_dir, verbose=0)

@memory.cache(ignore=["model", "encoded_batch"])
def cached_forward(model, encoded_batch, model_name: str, encoded_batch_str: str):
if VERBOSE:
print("cache bypassed")
return model(**encoded_batch, output_hidden_states=True)

def _model_forward(model, encoded_batch, use_cache=True):
"""Model forward pass with optional caching"""
if use_cache:
# the joblib cache can't handle pickling models etc so we just take the string of the dict
return cached_forward(
model=model,
encoded_batch=encoded_batch,
model_name=get_model_name(model),
encoded_batch_str=str(dict(encoded_batch)),
)
else:
return model(**encoded_batch, output_hidden_states=True)

@dataclasses.dataclass
class ControlVector:
Expand All @@ -45,6 +67,8 @@ def train(
Defaults to 32. Try reducing this if you're running out of memory.
method (str, optional): The training method to use. Can be either
"pca_diff" or "pca_center". Defaults to "pca_diff".
norm_type (str, optional): The type of normalization to use when projecting
onto the direction vector. Can be either "l1" or "l2". Defaults to "l2".

Returns:
ControlVector: The trained vector.
Expand Down Expand Up @@ -247,10 +271,12 @@ def read_representations(
hidden_layers: typing.Iterable[int] | None = None,
batch_size: int = 32,
method: typing.Literal["pca_diff", "pca_center", "umap"] = "pca_diff",
use_cache: bool = True,
transform_hiddens: (
typing.Callable[[dict[int, np.ndarray]], dict[int, np.ndarray]] | None
) = None,
) -> dict[int, np.ndarray]:
norm_type: typing.Literal["l1", "l2"] = "l2",
"""
Extract the representations based on the contrast dataset.
"""
Expand All @@ -262,18 +288,22 @@ def read_representations(
hidden_layers = [i if i >= 0 else n_layers + i for i in hidden_layers]

# the order is [positive, negative, positive, negative, ...]
train_strs = [s for ex in inputs for s in (ex.positive, ex.negative)]
train_strs = autocorrect_chat_templates(
messages=[s for ex in inputs for s in (ex.positive, ex.negative)],
tokenizer=tokenizer,
model=model,
)

layer_hiddens = batched_get_hiddens(
model, tokenizer, train_strs, hidden_layers, batch_size
model, tokenizer, train_strs, hidden_layers, batch_size, use_cache=use_cache
)

if transform_hiddens is not None:
layer_hiddens = transform_hiddens(layer_hiddens)

# get directions for each layer using PCA
directions: dict[int, np.ndarray] = {}
for layer in tqdm.tqdm(hidden_layers):
for layer in tqdm.tqdm(hidden_layers, desc="Altering direction"):
h = layer_hiddens[layer]
assert h.shape[0] == len(inputs) * 2

Expand Down Expand Up @@ -303,7 +333,7 @@ def read_representations(
directions[layer] = np.sum(train * embedding, axis=0) / np.sum(embedding)

# calculate sign
projected_hiddens = project_onto_direction(h, directions[layer])
projected_hiddens = project_onto_direction(h, directions[layer], norm_type=norm_type)

# order is [positive, negative, positive, negative, ...]
positive_smaller_mean = np.mean(
Expand Down Expand Up @@ -331,6 +361,7 @@ def batched_get_hiddens(
inputs: list[str],
hidden_layers: list[int],
batch_size: int,
use_cache: bool = True,
) -> dict[int, np.ndarray]:
"""
Using the given model and tokenizer, pass the inputs through the model and get the hidden
Expand All @@ -343,11 +374,11 @@ def batched_get_hiddens(
]
hidden_states = {layer: [] for layer in hidden_layers}
with torch.no_grad():
for batch in tqdm.tqdm(batched_inputs):
for batch in tqdm.tqdm(batched_inputs, desc="Getting activations"):
# get the last token, handling right padding if present
encoded_batch = tokenizer(batch, padding=True, return_tensors="pt")
encoded_batch = encoded_batch.to(model.device)
out = model(**encoded_batch, output_hidden_states=True)
encoded_batch = tokenizer(batch, padding=True, return_tensors="pt").to(model.device)
out = _model_forward(model, encoded_batch, use_cache=use_cache)

attention_mask = encoded_batch["attention_mask"]
for i in range(len(batch)):
last_non_padding_index = (
Expand All @@ -357,18 +388,30 @@ def batched_get_hiddens(
hidden_idx = layer + 1 if layer >= 0 else layer
hidden_state = (
out.hidden_states[hidden_idx][i][last_non_padding_index]
.cpu()
.float()
.numpy()
# .cpu()
# .float()
# .numpy()
)
hidden_states[layer].append(hidden_state)
if LOW_MEMORY:
if len(hidden_states[layer]):
hidden_states[layer] = np.vstack((hidden_states[layer], hidden_state.cpu().float().numpy()))
else:
hidden_states[layer].append(hidden_state.cpu().float().numpy())
else:
hidden_states[layer].append(hidden_state)
del out

return {k: np.vstack(v) for k, v in hidden_states.items()}
if LOW_MEMORY:
return hidden_states
else:
return {k: torch.vstack(v).cpu().float().numpy() for k, v in hidden_states.items()}


def project_onto_direction(H, direction):
def project_onto_direction(H, direction, norm_type: str = "l2"):
"""Project matrix H (n, d_1) onto direction vector (d_2,)"""
mag = np.linalg.norm(direction)
if norm_type == "l2":
mag = np.linalg.norm(direction) # l2 is the default
else:
mag = np.linalg.norm(direction, norm_type)
assert not np.isinf(mag)
return (H @ direction) / mag
5 changes: 5 additions & 0 deletions repeng/settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
VERBOSE = False

LOW_MEMORY = False # When extracting the representations for a given activation,
# if LOW_MEMORY is True then we stack arrays as they arrive. If it's False, we stack them at the end.
# This shoudldn't matter in most cases, except maybe if you have many many examples.
Loading