diff --git a/llama.py b/llama.py index 1a7a61f..246819e 100644 --- a/llama.py +++ b/llama.py @@ -486,8 +486,8 @@ def packed_dataset(tokenizer, dataset: str): ds = load_dataset(dataset, split="train") all_tokens = [] for i in tqdm(range(0, len(ds), 4096)): - tokens_batch = tokenizer.encode(ds[i:i+4096]["text"], add_eos=True) - tokens_batch = [np.array(tokens, dtype=np.uint16) for tokens in tokens_batch] + tokens_batch = tokenizer.encode(ds[i:i+4096]["text"]) + tokens_batch = [np.array([1] + tokens + [2], dtype=np.uint16) for tokens in tokens_batch] all_tokens.extend(tokens_batch) flattened = np.concatenate(all_tokens)