What is Guard in the context of torch.compile ? A guard can be considered as an boolean check that ensures a torch compiled code can be run on a specific set of inputs without recompilation.
Dynamic Shape Guard Dropping :
When using torch.compile, the compiler normally inserts guards (runtime checks) to ensure that the compiled graph is valid for the current input's shape. In case of LLM Inference, which handle varying sequence lengths (dynamic shapes), these guards can trigger frequent re-compilations leading to performace overhead. If we can relax the constraint of these shape checksm then we can get back this performance.
vLLM achieves this through
- unbacked symints : the part of the graph which is dependant on an output from previous graph assumes it as an symbolic output instead of a fixed number
- guard dropping : equivalent of dynamo's `dynamic=True`, the example code shownases this behaviour
- custom operator wrapping : vLLM wraps some complext logics like attention into a custom operator thereby ensuring dynamo doesn't inspect the wrapped/black-boxed kernels for shape checks
Experiment 1: Using dynamic param of compile
(.venv_exp) (main) jeromean@jeromean-vm-u24:torch_guards $ python guard_drop.py
================================================================================
*************** Base Stats ****************
================================================================================
----------------------------------------
defaultdict(<class 'collections.Counter'>, {'frames': Counter({'total': 2, 'ok': 2}), 'stats': Counter({'calls_captured': 2, 'unique_graphs': 2}), 'inductor': Counter({'fxgraph_cache_miss': 2}), 'aot_autograd': Counter({'total': 2, 'autograd_cache_miss': 2, 'autograd_cache_saved': 2, 'ok': 2}), 'graph_break': Counter()})
----------------------------------------
================================================================================
*************** Dynamic=True ****************
================================================================================
----------------------------------------
defaultdict(<class 'collections.Counter'>, {'frames': Counter({'total': 1, 'ok': 1}), 'stats': Counter({'calls_captured': 1, 'unique_graphs': 1}), 'aot_autograd': Counter({'total': 1, 'autograd_cache_hit': 1, 'ok': 1}), 'inductor': Counter({'fxgraph_cache_hit': 1})})
----------------------------------------
Experiment 2: Marking the tensor dims as dynamic
(.venv_exp) (main) jeromean@jeromean-vm-u24:torch_guards $ python guard_drop_mark_dynamic.py
================================================================================
*****Standard with strict guard checks*********
================================================================================
(1, 64)
Step 1: Input (1, 64) -> 6304.69 ms | Compiles so far : 1196
(2, 64)
V0217 15:59:26.145000 33014 torch/_dynamo/guards.py:4181] [0/1] [__recompiles] Recompiling function forward in /home/jeromean/experiment/torch_guards/guard_drop_mark_dynamic.py:28
V0217 15:59:26.145000 33014 torch/_dynamo/guards.py:4181] [0/1] [__recompiles] triggered by the following guard failure(s):
V0217 15:59:26.145000 33014 torch/_dynamo/guards.py:4181] [0/1] [__recompiles] - 0/0: tensor 'x' size mismatch at index 0. expected 1, actual 2
Step 2: Input (2, 64) -> 157.36 ms | Compiles so far : 1180
(4, 64)
V0217 15:59:26.303000 33014 torch/_dynamo/guards.py:4181] [0/2] [__recompiles] Recompiling function forward in /home/jeromean/experiment/torch_guards/guard_drop_mark_dynamic.py:28
V0217 15:59:26.303000 33014 torch/_dynamo/guards.py:4181] [0/2] [__recompiles] triggered by the following guard failure(s):
V0217 15:59:26.303000 33014 torch/_dynamo/guards.py:4181] [0/2] [__recompiles] - 0/1: tensor 'x' size mismatch at index 0. expected 2, actual 4
V0217 15:59:26.303000 33014 torch/_dynamo/guards.py:4181] [0/2] [__recompiles] - 0/0: tensor 'x' size mismatch at index 0. expected 1, actual 4
Step 3: Input (4, 64) -> 153.62 ms | Compiles so far : 1332
(3, 64)
V0217 15:59:26.457000 33014 torch/_dynamo/guards.py:4181] [0/3] [__recompiles] Recompiling function forward in /home/jeromean/experiment/torch_guards/guard_drop_mark_dynamic.py:28
V0217 15:59:26.457000 33014 torch/_dynamo/guards.py:4181] [0/3] [__recompiles] triggered by the following guard failure(s):
V0217 15:59:26.457000 33014 torch/_dynamo/guards.py:4181] [0/3] [__recompiles] - 0/2: tensor 'x' size mismatch at index 0. expected 4, actual 3
V0217 15:59:26.457000 33014 torch/_dynamo/guards.py:4181] [0/3] [__recompiles] - 0/1: tensor 'x' size mismatch at index 0. expected 2, actual 3
V0217 15:59:26.457000 33014 torch/_dynamo/guards.py:4181] [0/3] [__recompiles] - 0/0: tensor 'x' size mismatch at index 0. expected 1, actual 3
Step 4: Input (3, 64) -> 151.81 ms | Compiles so far : 1484
(1, 64)
Step 5: Input (1, 64) -> 0.34 ms | Compiles so far : 1484
================================================================================
*****Optimized with guard drop*********
================================================================================
(1, 64)
Step 1: Input (1, 64) -> 184.94 ms | Compiles so far : 1145
(2, 64)
V0217 15:59:26.799000 33014 torch/_dynamo/guards.py:4181] [0/1] [__recompiles] Recompiling function forward in /home/jeromean/experiment/torch_guards/guard_drop_mark_dynamic.py:28
V0217 15:59:26.799000 33014 torch/_dynamo/guards.py:4181] [0/1] [__recompiles] triggered by the following guard failure(s):
V0217 15:59:26.799000 33014 torch/_dynamo/guards.py:4181] [0/1] [__recompiles] - 0/0: tensor 'x' size mismatch at index 0. expected 1, actual 2
Step 2: Input (2, 64) -> 123.28 ms | Compiles so far : 1156
(4, 64)
Step 3: Input (4, 64) -> 0.34 ms | Compiles so far : 1156
(3, 64)
Step 4: Input (3, 64) -> 0.31 ms | Compiles so far : 1156
(1, 64)
Step 5: Input (1, 64) -> 0.30 ms | Compiles so far : 1156
##############################
SUMMARY
##############################
Standard Avg : 1353.56 ms
Optimized Avg : 61.83 ms
Experiment 3: Simulate guard dropping for SgLang with a compile wrapper
(.venv_exp) (main) jeromean@jeromean-vm-u24:torch_guards $ python guard_drop_wrapper.py
**** Benchmark : STANDARD ****
[ STANDARD ] Len: 16 | Latency: 6142.52 ms | COMPILED
[ STANDARD ] Len: 64 | Latency: 171.45 ms | COMPILED
[ STANDARD ] Len: 128 | Latency: 167.20 ms | COMPILED
[ STANDARD ] Len: 256 | Latency: 168.35 ms | COMPILED
**** Benchmark : DROP-GUARD ****
[DROP-GUARD] Len: 16 | Latency: 204.57 ms | COMPILED
[DROP-GUARD] Len: 64 | Latency: 0.63 ms | REUSED
[DROP-GUARD] Len: 128 | Latency: 0.76 ms | REUSED
[DROP-GUARD] Len: 256 | Latency: 0.99 ms | REUSED
Experiment 3.1: Simulate guard dropping with a compile wrappe integrated with sglang
- clone the sglang repo with the changes at
https://github.com/jeromean/sglang:drop_guard - run benchmark tests from sgl python
- run the test once before mesuriung as sgl+cudagraph creates graphs for multiple bs's at the first run
- Base Run without dropping any guards
srun python -m sglang.bench_one_batch --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 --device cuda --enable-torch-compile --batch-size 32 --input-len 128 --output-len 32
#####################################################
max_total_num_tokens=1447477
Warmup ...
######################## forward_extend ############################# none ######
Prefill. latency: 0.21103 s, throughput: 19409.33 token/s
Decode 0. Batch size: 32, latency: 0.46622 s, throughput: 68.64 token/s
Decode 1. Batch size: 32, latency: 0.00446 s, throughput: 7173.88 token/s
Decode 2. Batch size: 32, latency: 0.00431 s, throughput: 7419.12 token/s
Decode 3. Batch size: 32, latency: 0.00431 s, throughput: 7421.65 token/s
Decode 4. Batch size: 32, latency: 0.00428 s, throughput: 7472.47 token/s
Decode. median latency: 0.00428 s, median throughput: 7474.02 token/s
Total. latency: 0.806 s, throughput: 6353.26 token/s
Benchmark ...
######################## forward_extend ############################# none ######
Prefill. latency: 0.07074 s, throughput: 57900.12 token/s
Decode 0. Batch size: 32, latency: 0.00439 s, throughput: 7282.87 token/s
Decode 1. Batch size: 32, latency: 0.00433 s, throughput: 7391.29 token/s
Decode 2. Batch size: 32, latency: 0.00429 s, throughput: 7461.41 token/s
Decode 3. Batch size: 32, latency: 0.00425 s, throughput: 7522.78 token/s
Decode 4. Batch size: 32, latency: 0.00428 s, throughput: 7483.22 token/s
Decode. median latency: 0.00310 s, median throughput: 10317.56 token/s
Total. latency: 0.181 s, throughput: 28290.34 token/s
- Disable all guards and run sglang benchmark
srun python -m sglang.bench_one_batch --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 --device cuda --enable-torch-compile --batch-size 32 --input-len 128 --output-len 32 --drop-guard all
#####################################################
******************************** guard drop :: all **********
########################compilewrapper############################# all ######
max_total_num_tokens=1447477
Warmup ...
######################## forward_extend ############################# all ######
**** make seq len dynamic in prefill
Prefill. latency: 13.47556 s, throughput: 303.96 token/s
Decode 0. Batch size: 32, latency: 0.09775 s, throughput: 327.38 token/s
Decode 1. Batch size: 32, latency: 0.00450 s, throughput: 7113.39 token/s
Decode 2. Batch size: 32, latency: 0.00432 s, throughput: 7400.39 token/s
Decode 3. Batch size: 32, latency: 0.00435 s, throughput: 7357.28 token/s
Decode 4. Batch size: 32, latency: 0.00431 s, throughput: 7432.05 token/s
Decode. median latency: 0.00428 s, median throughput: 7468.06 token/s
Total. latency: 13.702 s, throughput: 373.66 token/s
Benchmark ...
######################## forward_extend ############################# all ######
**** make seq len dynamic in prefill
Prefill. latency: 0.04583 s, throughput: 89375.06 token/s
Decode 0. Batch size: 32, latency: 0.00328 s, throughput: 9744.11 token/s
Decode 1. Batch size: 32, latency: 0.00319 s, throughput: 10041.16 token/s
Decode 2. Batch size: 32, latency: 0.00321 s, throughput: 9971.05 token/s
Decode 3. Batch size: 32, latency: 0.00323 s, throughput: 9906.02 token/s
Decode 4. Batch size: 32, latency: 0.00319 s, throughput: 10042.42 token/s
Decode. median latency: 0.00309 s, median throughput: 10343.14 token/s
Total. latency: 0.142 s, throughput: 35947.35 token/s
Throughoput Comparison
| DropType | Throughput(tokens/sec) |
|---|---|
| Base | 28846.02 |
| Shapes | 33786.61 |
| All | 34001.15 |
TBD:
- there are some recompiles which needs to be analyzed
- have drop guard more configurable from sglang
- checkout drop guard with different parallelism strategies
- analyze the number of graphs generated