-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathchat.py
More file actions
149 lines (120 loc) · 4.72 KB
/
chat.py
File metadata and controls
149 lines (120 loc) · 4.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
#!/usr/bin/env python3
"""Interactive chat with a trained Hybrid Liquid-Dense model.
Usage:
python chat.py --checkpoint checkpoints_local/step_48828.pt
python chat.py --checkpoint checkpoints_local/step_48828.pt --temperature 0.9
"""
import argparse
import os
import sys
import glob
import torch
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from hybrid_liquid_dense.config import HybridLiquidDenseConfig
from hybrid_liquid_dense.model import HybridLiquidDenseModel
def load_model(checkpoint_path: str, device: str = "cuda"):
"""Load model from checkpoint."""
ckpt = torch.load(checkpoint_path, map_location=device, weights_only=True)
# Rebuild config from saved dict
cfg_dict = ckpt["config"]
config = HybridLiquidDenseConfig(**{
k: v for k, v in cfg_dict.items()
if k in HybridLiquidDenseConfig.__dataclass_fields__
})
model = HybridLiquidDenseModel(config).to(device)
model.load_state_dict(ckpt["model"])
model.eval()
step = ckpt.get("step", "?")
n = model.num_parameters()
print(f"Loaded: {n/1e6:.1f}M params, step {step}")
print(f"Config: d={config.d_model}, ff={config.d_ff}, L={config.n_layers}, seq={config.max_seq_len}")
return model, config
def generate_text(model, enc, prompt, max_tokens=200, temperature=0.8, top_k=50, device="cuda"):
"""Generate text from a prompt."""
ids = enc.encode(prompt)
input_ids = torch.tensor([ids], dtype=torch.long, device=device)
with torch.no_grad():
generated = model.generate(
input_ids,
max_new_tokens=max_tokens,
temperature=temperature,
top_k=top_k,
)
tokens = generated[0].tolist()
# Stop at end-of-text token if generated
eot = enc.eot_token
if eot in tokens:
tokens = tokens[:tokens.index(eot)]
return enc.decode(tokens)
def interactive_chat(checkpoint: str = None, device: str = "cuda",
temperature: float = 0.8, top_k: int = 50,
max_tokens: int = 200):
"""Interactive chat loop."""
import tiktoken
enc = tiktoken.get_encoding("gpt2")
# Auto-find latest checkpoint if not specified
if checkpoint is None or not os.path.exists(checkpoint):
ckpts = sorted(glob.glob("checkpoints_local/step_*.pt"))
if not ckpts:
print("No checkpoint found. Train first: python train_local.py")
return
checkpoint = ckpts[-1]
model, config = load_model(checkpoint, device)
print()
print("=" * 50)
print(" Hybrid Liquid-Dense Language Model")
print(f" {model.num_parameters()/1e6:.1f}M params | temp={temperature} | top_k={top_k}")
print("=" * 50)
print()
print("Type a prompt and press Enter. Type 'quit' to exit.")
print("Commands: /temp 0.9 /topk 40 /tokens 300")
print()
while True:
try:
prompt = input("You> ").strip()
except (EOFError, KeyboardInterrupt):
print("\nBye!")
break
if not prompt:
continue
if prompt.lower() in ("quit", "exit", "/quit", "/exit"):
print("Bye!")
break
# Commands
if prompt.startswith("/temp "):
temperature = float(prompt.split()[1])
print(f" Temperature set to {temperature}")
continue
if prompt.startswith("/topk "):
top_k = int(prompt.split()[1])
print(f" Top-k set to {top_k}")
continue
if prompt.startswith("/tokens "):
max_tokens = int(prompt.split()[1])
print(f" Max tokens set to {max_tokens}")
continue
# Generate
t0 = __import__("time").time()
output = generate_text(model, enc, prompt, max_tokens, temperature, top_k, device)
elapsed = __import__("time").time() - t0
n_tok = len(enc.encode(output))
print(f"\nModel> {prompt}{output}")
print(f" [{n_tok} tokens, {elapsed:.1f}s, {n_tok/max(elapsed,0.01):.0f} tok/s]\n")
def main():
p = argparse.ArgumentParser(description="Chat with Hybrid Liquid-Dense LM")
p.add_argument("--checkpoint", type=str, default=None,
help="Path to checkpoint (auto-finds latest if omitted)")
p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
p.add_argument("--temperature", type=float, default=0.8)
p.add_argument("--top_k", type=int, default=50)
p.add_argument("--max_tokens", type=int, default=200)
args = p.parse_args()
interactive_chat(
checkpoint=args.checkpoint,
device=args.device,
temperature=args.temperature,
top_k=args.top_k,
max_tokens=args.max_tokens,
)
if __name__ == "__main__":
main()