Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 28 additions & 12 deletions example-chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import time
import json
from pathlib import Path

from llama import ModelArgs, Transformer, Tokenizer, LLaMA


Expand All @@ -28,10 +29,10 @@ def load(
model_args: ModelArgs = ModelArgs(
max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params
)

print("Loading tokenizer")
tokenizer = Tokenizer(model_path=tokenizer_path)
model_args.vocab_size = tokenizer.n_words

print("Loading model")
model = Transformer(model_args)

# Original copyright by tloen
Expand Down Expand Up @@ -83,25 +84,40 @@ def main(
ckpt_dir: str = './model',
tokenizer_path: str = './tokenizer/tokenizer.model',
temperature: float = 0.8,
top_p: float = 0.95,
max_seq_len: int = 512,
max_batch_size: int = 32,
top_p: float = 0.95, # use 0.95 or so for top_p sampler, and 0.0 for top_k sampler
top_k: int = 40,
repetition_penalty: float = (1.0 / 0.85), # 1.0 to disable repetition_penalty
sampler: str = 'top_p', # top_p or top_k
max_seq_len: int = 2048,
max_batch_size: int = 1,
):
# torch.manual_seed(1)
# torch.set_default_dtype(torch.bfloat16)

generator = load(ckpt_dir, tokenizer_path, max_seq_len, max_batch_size)

ctx = """A dialog, where User interacts with AI. AI is helpful, kind, obedient, honest, and knows its own limits.
User: Hello, AI.
AI: Hello! How can I assist you today?
"""

while True:
prompt = input(f'prompt> ')
if len(prompt.strip()) > 0:
prompts = [prompt]
prompt = input(f'User: ')
if ctx != "":
ctx = ctx + "User: " + prompt + "\n"
else:
ctx = prompt + "\n"

ctx = (ctx[-1920:]) if len(ctx) >= 2048 else ctx

if len(ctx.strip()) > 0:
prompts = [ctx]
results = generator.generate(
prompts, max_gen_len=256, temperature=temperature, top_p=top_p
# somehow it did not find top_k?
prompts, max_gen_len=max_seq_len, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, sampler=sampler
# prompts, max_gen_len=max_seq_len, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty, sampler=sampler
)

for result in results:
print(result)
ctx = results[0]


if __name__ == "__main__":
Expand Down
94 changes: 75 additions & 19 deletions llama/generation.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the GNU General Public License version 3.

# Copyright by Shawn Presser
# https://github.com/shawwn/
# taken here
# https://github.com/shawwn/llama/commit/40d99d329a5e38d85904d3a6519c54e6dd6ee9e1

from typing import List

import torch
import traceback
from tqdm import tqdm

from llama.tokenizer import Tokenizer
from llama.model import Transformer
from tqdm import trange


class LLaMA:
Expand All @@ -21,11 +28,16 @@ def generate(
max_gen_len: int,
temperature: float = 0.8,
top_p: float = 0.95,
top_k: int = 40,
repetition_penalty: float = (1.0 / 0.85),
sampler: str = 'top_k',
) -> List[str]:
bsz = len(prompts)
params = self.model.params
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)

count_newlines = prompts[0].count("\n")

prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts]

min_prompt_size = min([len(t) for t in prompt_tokens])
Expand All @@ -36,19 +48,35 @@ def generate(
tokens = torch.full((bsz, total_len), self.tokenizer.pad_id).cpu().long()
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t).long()
tokens[k, -1] = self.tokenizer.eos_id
input_text_mask = tokens != self.tokenizer.pad_id

start_pos = min_prompt_size
prev_pos = 0
decoded = [None] * bsz

steps = total_len - start_pos
pbar = tqdm(total=steps)

for cur_pos in range(start_pos, total_len):
for cur_pos in trange(start_pos, total_len, desc="forward"):
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)

# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0:
logits_new = logits.clone()
batch_size = len(tokens)
for i in range(batch_size):
for token in set(tokens[i].tolist()):
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if logits[i, token] < 0:
logits_new[i, token] = logits[i, token] * repetition_penalty
else:
logits_new[i, token] = logits[i, token] / repetition_penalty
logits = logits_new

if temperature > 0:
probs = torch.softmax(logits / temperature, dim=-1)
next_token = sample_top_p(probs, top_p)
if sampler == 'top_k':
next_token = sample_top_k(probs, top_p=top_p, top_k=top_k)
else:
next_token = sample_top_p(probs, top_p)
else:
next_token = torch.argmax(logits, dim=-1)
next_token = next_token.reshape(-1)
Expand All @@ -59,23 +87,35 @@ def generate(
tokens[:, cur_pos] = next_token
prev_pos = cur_pos

pbar.update(1)

pbar.close()

decoded = []
for i, t in enumerate(tokens.tolist()):
# cut to max gen len
t = t[: len(prompt_tokens[i]) + max_gen_len]
# cut to eos tok if any
try:
t = t[: t.index(self.tokenizer.eos_id)]
except ValueError:
pass
decoded.append(self.tokenizer.decode(t))
print("-" * 30)
for i, t in enumerate(tokens.tolist()):
# i = cur_pos
# t = next_token
# cut to max gen len
# t = t[: len(pr-ompt_tokens[i]) + max_gen_len]
t = t[: min(cur_pos, len(prompt_tokens[i]) + max_gen_len)]
# cut to eos tok if any
try:
t = t[: t.index(self.tokenizer.eos_id)]
except ValueError:
pass # traceback.print_exc()
try:
d = self.tokenizer.decode(t)
print(d)
decoded[i] = d

result_count_newlines = d.count("\n")
if result_count_newlines > count_newlines:
return decoded

except IndexError:
traceback.print_exc()
print(t)
print("-" * 30)
return decoded


# default sampler
def sample_top_p(probs, p):
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
Expand All @@ -85,3 +125,19 @@ def sample_top_p(probs, p):
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token


# sampler by Shawn
def sample_top_k(probs, top_p=0.0, top_k=40):
if top_k > 0:
probs_sort, probs_idx = torch.topk(probs, top_k)
else:
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
if top_p > 0.0:
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > top_p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token
1 change: 1 addition & 0 deletions merge-weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def write_model(input_base_path, model_size):
state_dict = {}

for layer_i in range(n_layers):
print("loading layer "+str(layer_i)+" of "+str(n_layers))
if model_size == "7B":
state_dict |= {
f"layers.{layer_i}.attention.wq.weight": loaded[
Expand Down