-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbitnet_layer.py
More file actions
28 lines (23 loc) · 992 Bytes
/
bitnet_layer.py
File metadata and controls
28 lines (23 loc) · 992 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from transformers import AutoModelForCausalLM
import torch
import torch.nn as nn
from Bitnet158Model_copy import BitLinear
def convert_to_bitnet(model):
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
bit_linear = BitLinear(module.in_features, module.out_features)
bit_linear.weight.data = module.weight.data.clone()
setattr(module.parent, name.split('.')[-1], bit_linear)
return model
# Load and convert the model
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
original_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
bitnet_model = convert_to_bitnet(original_model)
# Verify conversion
def verify_conversion(model):
for name, module in model.named_modules():
if isinstance(module, BitLinear):
print(f"Converted: {name}")
elif isinstance(module, nn.Linear):
print(f"Not converted: {name}")
verify_conversion(bitnet_model)