-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate.py
More file actions
232 lines (176 loc) · 7.7 KB
/
generate.py
File metadata and controls
232 lines (176 loc) · 7.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
"""
generate.py - Text Generation with MiniGPT
=============================================
Once trained, the model generates text AUTOREGRESSIVELY:
1. Start with a prompt (e.g., "ROMEO:")
2. Feed it through the model to get next-word probabilities
3. Sample a word from those probabilities
4. Append the sampled word to the sequence
5. Repeat from step 2
The quality and diversity of generated text is controlled by
SAMPLING STRATEGIES:
GREEDY DECODING (temperature=0):
Always pick the most probable word. Deterministic but boring --
often produces repetitive, safe text.
Example: "the the the the the..."
TEMPERATURE SCALING:
Divide logits by temperature T before softmax.
T < 1.0 -> SHARPER distribution (more confident, less diverse)
The model strongly prefers its top choices
T = 1.0 -> UNCHANGED (use the model's raw predictions)
T > 1.0 -> FLATTER distribution (less confident, more diverse)
Even unlikely words have a chance
Math insight: softmax(x/T) approaches:
- One-hot (argmax) as T -> 0
- Uniform distribution as T -> infinity
TOP-K SAMPLING:
Only consider the top K most likely words. Set all others to 0 probability.
This prevents the model from ever picking extremely unlikely words
that would produce nonsense.
K = 1: Same as greedy (only top word)
K = 40: Choose from top 40 candidates (good default)
K = 2000: No filtering (full vocabulary)
TOP-P (NUCLEUS) SAMPLING:
Instead of a fixed K, dynamically choose the smallest set of words
whose cumulative probability exceeds P.
If the model is very confident (one word has 95% probability),
nucleus sampling might only consider 1-2 words.
If the model is uncertain, it considers more words.
P = 0.9 is a common choice.
"""
import torch
import torch.nn.functional as F
from model.transformer import MiniGPT
from data.tokenizer import WordTokenizer
@torch.no_grad()
def generate(
model: MiniGPT,
tokenizer: WordTokenizer,
prompt: str,
max_tokens: int = 100,
temperature: float = 0.8,
top_k: int = 40,
device: torch.device = None,
) -> str:
"""
Generate text from a prompt using the trained model.
Args:
model: Trained MiniGPT model
tokenizer: WordTokenizer with vocabulary
prompt: Starting text (e.g., "ROMEO:")
max_tokens: Maximum number of tokens to generate
temperature: Controls randomness (0 = greedy, 1 = normal, >1 = more random)
top_k: Only sample from top-k most likely tokens (0 = no filtering)
device: Computation device (CPU/GPU)
Returns:
Generated text string (prompt + generated continuation)
"""
if device is None:
device = next(model.parameters()).device
model.eval()
# Encode the prompt into token IDs
token_ids = tokenizer.encode(prompt)
if len(token_ids) == 0:
# If prompt produces no tokens, start with beginning-of-sequence
token_ids = [2] # <bos>
# Convert to tensor: shape (1, seq_len) -- batch size of 1
tokens = torch.tensor([token_ids], dtype=torch.long, device=device)
# Generate tokens one at a time
generated_ids = list(token_ids)
for _ in range(max_tokens):
# Truncate to max_seq_len if sequence gets too long
# The model can only handle max_seq_len tokens of context
context = tokens[:, -model.config.max_seq_len:]
# Forward pass: get logits for ALL positions
logits = model(context) # (1, seq_len, vocab_size)
# We only care about the LAST position's prediction
# (what comes after the last token we've seen)
next_logits = logits[:, -1, :] # (1, vocab_size)
# Apply temperature scaling
next_logits = apply_temperature(next_logits, temperature)
# Apply top-k filtering
if top_k > 0:
next_logits = apply_top_k(next_logits, top_k)
# Convert logits to probabilities
probs = F.softmax(next_logits, dim=-1)
# Sample from the distribution
if temperature == 0:
# Greedy: always pick the most likely word
next_token = torch.argmax(probs, dim=-1, keepdim=True)
else:
# Stochastic: randomly sample proportional to probabilities
next_token = torch.multinomial(probs, num_samples=1)
# Append to sequence
tokens = torch.cat([tokens, next_token], dim=1)
generated_ids.append(next_token.item())
# Stop if we generate an end-of-sequence token
if next_token.item() == 3: # <eos>
break
# Decode back to text
return tokenizer.decode(generated_ids)
def apply_temperature(logits: torch.Tensor, temperature: float) -> torch.Tensor:
"""
Scale logits by temperature.
Temperature controls the "confidence" of the probability distribution:
logits = [2.0, 1.0, 0.5]
T=0.5 (sharper): logits/T = [4.0, 2.0, 1.0] -> probs ≈ [0.84, 0.11, 0.04]
T=1.0 (normal): logits/T = [2.0, 1.0, 0.5] -> probs ≈ [0.56, 0.21, 0.12]
T=2.0 (flatter): logits/T = [1.0, 0.5, 0.25] -> probs ≈ [0.42, 0.25, 0.20]
Lower temperature -> more confident -> less diverse text
Higher temperature -> less confident -> more diverse text
"""
if temperature == 0:
return logits # Will use argmax anyway
return logits / temperature
def apply_top_k(logits: torch.Tensor, k: int) -> torch.Tensor:
"""
Keep only the top-k logits, set the rest to -infinity.
This prevents sampling extremely unlikely tokens that would
produce nonsensical text.
Steps:
1. Find the k-th largest logit value (the threshold)
2. Set all logits below this threshold to -inf
3. After softmax, these -inf values become probability 0
Example with k=3:
logits = [2.0, 0.5, 1.8, -1.0, 1.5]
Top 3: [2.0, 1.8, 1.5] (threshold = 1.5)
After: [2.0, -inf, 1.8, -inf, 1.5]
Probs: [0.41, 0, 0.33, 0, 0.25] (only 3 candidates)
"""
if k >= logits.size(-1):
return logits # No filtering needed
# Get the k-th largest value
# topk returns (values, indices), we just need the values
top_k_values, _ = torch.topk(logits, k, dim=-1)
# The threshold is the smallest value among the top-k
threshold = top_k_values[:, -1].unsqueeze(-1)
# Set everything below the threshold to -infinity
logits = logits.masked_fill(logits < threshold, float("-inf"))
return logits
def apply_top_p(logits: torch.Tensor, p: float) -> torch.Tensor:
"""
Nucleus (top-p) sampling: keep the smallest set of tokens whose
cumulative probability exceeds p.
Unlike top-k which always considers exactly k tokens, top-p
ADAPTS to the model's confidence:
- Confident prediction (one token has 95% prob) -> few candidates
- Uncertain prediction (many tokens with similar prob) -> many candidates
Args:
logits: Raw scores of shape (batch, vocab_size)
p: Cumulative probability threshold (e.g., 0.9)
Returns:
Filtered logits with low-probability tokens set to -inf
"""
# Sort tokens by probability (descending)
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
sorted_probs = F.softmax(sorted_logits, dim=-1)
# Compute cumulative probabilities
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
# Find tokens where cumulative probability exceeds p
# We shift right by 1 so that the token that crosses p is INCLUDED
sorted_mask = cumulative_probs - sorted_probs > p
# Set filtered tokens to -inf
sorted_logits[sorted_mask] = float("-inf")
# Un-sort back to original order
logits = sorted_logits.scatter(1, sorted_indices, sorted_logits)
return logits