Add Softmax kernel in Triton. Use softmax kernel and argmax in Llama generation.py. + Small changes#11
Add Softmax kernel in Triton. Use softmax kernel and argmax in Llama generation.py. + Small changes#11catherinelee274 wants to merge 12 commits intotriton-lang:mainfrom
Conversation
| import pandas as pd | ||
|
|
||
|
|
||
| def compare_benchmarks(benchmarks: Dict[str, Dict[str, Any]]) -> Dict[str, Any]: |
There was a problem hiding this comment.
any reason you are deleting this?
| import pytest | ||
| from kernels.fused_softmax import triton_softmax | ||
|
|
||
| @pytest.mark.parametrize("input_size", [(1024, 1024), (512, 512), (2048, 512)]) |
There was a problem hiding this comment.
Thanks for adding tests!
There was a problem hiding this comment.
FYI, this might not be ideal because we are not calling softmax from triton.ops like the other tests. I ran into issues with doing it that way.
models/llama/llama/math_ops.py
Outdated
| def softmax(self, x, dim): | ||
| if self.use_triton: | ||
| return F.softmax(x, dim=-1) | ||
| if self.use_triton and len(x) == 2: |
There was a problem hiding this comment.
It looks like you're trying to check the number of dimensions here, right? len(x) gets the number of elements, equivalent to x.numel(). I think you want x.dim() or x.ndim.
models/llama/llama/math_ops.py
Outdated
| if self.use_triton: | ||
| return F.softmax(x, dim=-1) | ||
| if self.use_triton and len(x) == 2: | ||
| return triton_softmax(x, dim=-1) |
There was a problem hiding this comment.
Why are we passing dim=-1 to these calls, when we receive dim as an argument? Let's pass it through properly instead of overriding it. (Also, does the fused Triton kernel actually handle dim!=-1 correctly?)
There was a problem hiding this comment.
Currently it does not handle dim != -1 . Looking into it (seeing how llama.cpp is doing this) if you have any pointers.
models/llama/llama/math_ops.py
Outdated
| @Profiler.profiling_decorator("argmax") | ||
| def argmax(self, x, dim): | ||
| if self.use_triton: | ||
| # TODO: change |
There was a problem hiding this comment.
Instead of adding a TODO to the code here, would you mind creating an issue to track it?
- Rename certain functions to conform with naming scheme - Current triton softmax does not handle > 2 dimensions but will need to investigate (probably by looking at llama.cpp)
dist.destroy_process_group()to remove warning during benchmarkingResults from calling
python3 main.py llama_chat_completion --benchmark --ckpt_dir <model_checkpoint_path> --tokenizer_path <model_tokenizer_path>With No Changes:
With just softmax
With softmax and argmax