@@ -52,8 +52,14 @@ def generate_sequential_greedy(model, first_token, prefill_len, kv_backup, num_t
5252
5353
5454def generate_jacobi_original (
55- model , first_token , prefill_len , kv_backup , num_tokens ,
56- n_tokens = 8 , max_iter = 3 , init_strategy = "repeat"
55+ model ,
56+ first_token ,
57+ prefill_len ,
58+ kv_backup ,
59+ num_tokens ,
60+ n_tokens = 8 ,
61+ max_iter = 3 ,
62+ init_strategy = "repeat" ,
5763):
5864 """Generate tokens using Jacobi decoding (original, with CPU copies)."""
5965 model .restore_kv_cache (kv_backup )
@@ -74,7 +80,9 @@ def generate_jacobi_original(
7480 break
7581
7682 accepted , new_pos , stats = model .decode_step_jacobi (
77- tokens [- 1 ], position , context_len ,
83+ tokens [- 1 ],
84+ position ,
85+ context_len ,
7886 n_tokens = current_n ,
7987 max_iter = max_iter ,
8088 init_strategy = init_strategy ,
@@ -95,8 +103,7 @@ def generate_jacobi_original(
95103
96104
97105def generate_jacobi_lookahead (
98- model , first_token , prefill_len , num_tokens ,
99- n_tokens = 8 , max_iter = 3 , init_strategy = "repeat"
106+ model , first_token , prefill_len , num_tokens , n_tokens = 8 , max_iter = 3 , init_strategy = "repeat"
100107):
101108 """Generate tokens using Jacobi decoding with lookahead KV (GPU-side)."""
102109 # Set confirmed position after prefill
@@ -195,9 +202,7 @@ def main():
195202 print (f"\n --- Sequential Baseline ({ GEN_TOKENS } tokens) ---" )
196203
197204 start_event .record ()
198- seq_tokens = generate_sequential_greedy (
199- model , first_token , prefill_len , kv_backup , GEN_TOKENS
200- )
205+ seq_tokens = generate_sequential_greedy (model , first_token , prefill_len , kv_backup , GEN_TOKENS )
201206 stop_event .record ()
202207 stop_event .synchronize ()
203208
@@ -215,8 +220,14 @@ def main():
215220
216221 start_event .record ()
217222 jacobi_orig_tokens , avg_iter_o , conv_rate_o = generate_jacobi_original (
218- model , first_token , prefill_len , kv_backup , GEN_TOKENS ,
219- n_tokens = 8 , max_iter = 3 , init_strategy = "repeat"
223+ model ,
224+ first_token ,
225+ prefill_len ,
226+ kv_backup ,
227+ GEN_TOKENS ,
228+ n_tokens = 8 ,
229+ max_iter = 3 ,
230+ init_strategy = "repeat" ,
220231 )
221232 stop_event .record ()
222233 stop_event .synchronize ()
@@ -239,8 +250,7 @@ def main():
239250
240251 start_event .record ()
241252 jacobi_look_tokens , avg_iter_l , conv_rate_l = generate_jacobi_lookahead (
242- model , first_token , prefill_len , GEN_TOKENS ,
243- n_tokens = 8 , max_iter = 3 , init_strategy = "repeat"
253+ model , first_token , prefill_len , GEN_TOKENS , n_tokens = 8 , max_iter = 3 , init_strategy = "repeat"
244254 )
245255 stop_event .record ()
246256 stop_event .synchronize ()
@@ -263,8 +273,7 @@ def main():
263273
264274 start_event .record ()
265275 jacobi_greedy_tokens , avg_iter_g , conv_rate_g = generate_jacobi_lookahead (
266- model , first_token , prefill_len , GEN_TOKENS ,
267- n_tokens = 8 , max_iter = 3 , init_strategy = "greedy"
276+ model , first_token , prefill_len , GEN_TOKENS , n_tokens = 8 , max_iter = 3 , init_strategy = "greedy"
268277 )
269278 stop_event .record ()
270279 stop_event .synchronize ()
@@ -291,9 +300,15 @@ def main():
291300 print (f"\n { 'Method' :<35} { 'Time (ms)' :<12} { 'tok/s' :<10} { 'Speedup' :<10} { 'Match' } " )
292301 print ("-" * 77 )
293302 print (f"{ 'Sequential (baseline)' :<35} { seq_time :<12.1f} { seq_tps :<10.2f} { '1.00x' :<10} { 'N/A' } " )
294- print (f"{ 'Jacobi Original (CPU copies)' :<35} { jacobi_orig_time :<12.1f} { jacobi_orig_tps :<10.2f} { speedup_orig :.2f} x{ '' :<5} { 'YES' if match_orig else 'NO' } " )
295- print (f"{ 'Jacobi Lookahead (GPU-side)' :<35} { jacobi_look_time :<12.1f} { jacobi_look_tps :<10.2f} { speedup_look :.2f} x{ '' :<5} { 'YES' if match_look else 'NO' } " )
296- print (f"{ 'Jacobi Lookahead (greedy init)' :<35} { jacobi_greedy_time :<12.1f} { jacobi_greedy_tps :<10.2f} { (seq_time / jacobi_greedy_time ):.2f} x{ '' :<5} { 'YES' if match_greedy else 'NO' } " )
303+ print (
304+ f"{ 'Jacobi Original (CPU copies)' :<35} { jacobi_orig_time :<12.1f} { jacobi_orig_tps :<10.2f} { speedup_orig :.2f} x{ '' :<5} { 'YES' if match_orig else 'NO' } "
305+ )
306+ print (
307+ f"{ 'Jacobi Lookahead (GPU-side)' :<35} { jacobi_look_time :<12.1f} { jacobi_look_tps :<10.2f} { speedup_look :.2f} x{ '' :<5} { 'YES' if match_look else 'NO' } "
308+ )
309+ print (
310+ f"{ 'Jacobi Lookahead (greedy init)' :<35} { jacobi_greedy_time :<12.1f} { jacobi_greedy_tps :<10.2f} { (seq_time / jacobi_greedy_time ):.2f} x{ '' :<5} { 'YES' if match_greedy else 'NO' } "
311+ )
297312
298313 print (f"\n Lookahead vs Original speedup: { speedup_look_vs_orig :.2f} x" )
299314
0 commit comments