- 
          
- 
                Notifications
    You must be signed in to change notification settings 
- Fork 10.7k
[Kernel] Adding basic Triton JitCache for triton_attn #16606
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Burkhard Ringlein <[email protected]>
Signed-off-by: Burkhard Ringlein <[email protected]>
Signed-off-by: Burkhard Ringlein <[email protected]>
Signed-off-by: Burkhard Ringlein <[email protected]>
Signed-off-by: Burkhard Ringlein <[email protected]>
Signed-off-by: Burkhard Ringlein <[email protected]>
Signed-off-by: Burkhard Ringlein <[email protected]>
Signed-off-by: Burkhard Ringlein <[email protected]>
Signed-off-by: Burkhard Ringlein <[email protected]>
Signed-off-by: Burkhard Ringlein <[email protected]>
Signed-off-by: Burkhard Ringlein <[email protected]>
Signed-off-by: Burkhard Ringlein <[email protected]>
| 👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run  Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add  🚀 | 
Signed-off-by: Burkhard Ringlein <[email protected]>
Signed-off-by: Burkhard Ringlein <[email protected]>
Signed-off-by: Burkhard Ringlein <[email protected]>
Signed-off-by: Burkhard Ringlein <[email protected]>
Signed-off-by: Burkhard Ringlein <[email protected]>
| (I also removed the jitcache from  | 
| Testing this PR on MI-300X VLLM_USE_V1=1 VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1 VLLM_ROCM_CUSTOM_PAGED_ATTN=0 lm_eval --model vllm --model_args pretrained=/models/llama-3.1-8b/instruct/ --tasks gsm8k --num_fewshot 5 --batch_size auto --limit 500 Performance: VLLM_USE_V1=1 VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1 VLLM_ROCM_CUSTOM_PAGED_ATTN=0 python benchmarks/benchmark_serving.py --model /models/llama-3.1-8b/instruct/ --dataset-name sharegpt --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json --ignore-eos The triton launch overhead does not appear to be a bottleneck on MI-300X. We are currently investigating this and may follow-up in a separate PR. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The infra looks good. I just want to make sure I'm understanding the key generation logic correctly.
        
          
                vllm/triton_utils/jit_cache.py
              
                Outdated
          
        
      | :param assume_const: A list of parameters that are NOT marked as | ||
| tl.constexpr but should be treated as constants in | ||
| this kernel launch. | ||
| :param assume_const: list[str] | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you use type hints here instead of the "type" comments?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, addressed.
        
          
                vllm/triton_utils/jit_cache.py
              
                Outdated
          
        
      | cache_launch_grid=False, | ||
| assume_const=None, | ||
| ): | ||
| # we depend on the triton version, right now, 3.0 -- 3.2 are supported | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you make sure this doesn't crash if someone uses Triton 3.3?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, so you prefer that the cache just "disables itself" and does nothing if the triton version is not supported (yet)? I could do that. Also I mean that for Triton 3.3 it would just do nothing and call regular JIT...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(In general I think we can support triton 3.3 as well, but that's not a focus right now)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I think that's the best solution for now. Crashing the whole process just because this caching system isn't supported yet seems like too much.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, I implemented it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How hard is 3.3 support? Since we've updated to torch 2.7 it would be good to get Triton 3.3 support here. Hopefully it just works?
| self.base_fn = fn | ||
| while not inspect.isfunction(self.base_fn): | ||
| self.base_fn = self.base_fn.fn | ||
| self.cache_lock = cache_lock | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is the cache_lock locked?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right now, it is not used (only _dynamic is used in this PR), see also below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since only _dynamic is used, can we remove this static vs dynamic distinction and simplify the code?
| self.assume_const = assume_const | ||
| self.kernel_cache: dict[str, PreparedKernel] = {} | ||
|  | ||
| def calc_cache_index(kwargs): | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you talk a bit about the key generation logic? I'm somewhat confused. It looks like, in this case, we would be constructing a key with the values of the "USE_ALIBI_SLOPES", "SLIDING_WINDOW", and "filter_by_query_len". Since those are the strings in the "check_keys" argument to the decorator. How were these arguments selected?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you are right :). The developer needs to select these arguments based on her/his knowledge of the application. Put very basic, the jitcache trades safety in all scenarios and high launch overhead against a low launch overhead but reduced/relaxed safety checks applicable only to applications-specific use. It is then the job of the developers to ensure that the relaxed safety checks still hold for the particular application.
In the case of the paged_attention_2d kernel, we assume that only the arguments "USE_ALIBI_SLOPES", "SLIDING_WINDOW", "filter_by_query_len" change during the lifetime of one vLLM instance (one model within vLLM, to be precise) (or at most these arguments, I actually think they barley change and we could still reduce the list). Other arguments like num_heads don't change during a models live time (at least, to my knowledge), hence we don't need to check them every time and realize "oh, didn't change...I just used some micro-seconds to ensure this again".
Said differently, if we would use this kernel in an application that would use the same python process / same kernel instance to serve multiple different LLMs (something like attention-as-a-service...just making things up ;) ), then we would need to extend the check_keys list to ensure it also holds in this scenario (or not use the jitcache).
I can try to mention this more explicitly in the doc strings.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added some more explanations in the docstr of the decorator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we be more explicit with the JIT cache here and properly scope it to the model?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do think that we need to scope these caches to the model. Users are allowed to run multiple models without terminating the process.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these caches scoped to the model in the current state of this PR? (I think this is critical to have -- ideally before landing this PR)
| self.non_const_vals_lst = [] | ||
| self.update_args_index = {} | ||
| for i, arg_n in enumerate(self.non_const_arg_names): | ||
| if arg_n in update_only_arg_names: | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think some comments would be good here. Something like
            # If the argument can change each time the kernel is called, store a dummy value 
            # that will be set each time __call__ is called
            if arg_n in update_only_arg_names:
                self.update_args_index[arg_n] = i
                self.non_const_vals_lst.append("dummy_value")
            # else the argument is assumed to be constant and we can just store its initial value 
            else:
                self.non_const_vals_lst.append(assume_const_vals_dict[arg_n])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, I added something similar.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the main value of the non_const_vals_lst that it lets us do fewer copies in the call function? It doesn't effect the caching at all right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that is correct, yes
Signed-off-by: Burkhard Ringlein <[email protected]>
Signed-off-by: Burkhard Ringlein <[email protected]>
Signed-off-by: Burkhard Ringlein <[email protected]>
Signed-off-by: Burkhard Ringlein <[email protected]>
Signed-off-by: Burkhard Ringlein <[email protected]>
Signed-off-by: Burkhard Ringlein <[email protected]>
This reverts commit 450770c.
Signed-off-by: Burkhard Ringlein <[email protected]>
Signed-off-by: Burkhard Ringlein <[email protected]>
Signed-off-by: Burkhard Ringlein <[email protected]>
| some updates on this PR: 
 From this PR, two todos remain: 
 Correctness:Performance:tracingIf analyzing vllm profiles, the jitcache functions as expected and reduces the triton launch overhead from 148us to 26us: H100With pytorch 2.7 and triton 3.3, the performance of the triton attention with this PR on an H100 drops to: with  (so roughly only 2% increase in total performance due to the jitcache). MI300With pytorch 2.7 and triton 3.3, the performance of the triton attention with this PR on an H100 drops to: with   | 
Signed-off-by: Burkhard Ringlein <[email protected]>
Signed-off-by: Burkhard Ringlein <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
| self.assume_const = assume_const | ||
| self.kernel_cache: dict[str, PreparedKernel] = {} | ||
|  | ||
| def calc_cache_index(kwargs): | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these caches scoped to the model in the current state of this PR? (I think this is critical to have -- ideally before landing this PR)
Co-authored-by: Burkhard Ringlein <[email protected]> Signed-off-by: Thomas Parnell <[email protected]>
| I worked a bit on this PR today to: 
 I don't have write perms on @bringlein's fork so I pushed the latest code to my fork here: I've done quite a lot of benchmarking on H100 and unfortunately I'm not really seeing any significant performance improvement anymore. I believe the "big" performance improvement we were seeing in the past (with Triton 3.2) was an artifact of not handling the specialization correctly. These are the results of running the ShareGPT serving benchmark on H100, repeated 10 times, with and without jitcache:   The only benefit seems to be that the performance is more "stable" (standard deviation is lower). | 
Signed-off-by: Burkhard Ringlein <[email protected]>
Signed-off-by: Burkhard Ringlein <[email protected]>
Signed-off-by: Burkhard Ringlein <[email protected]>
| This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you! | 
| This pull request has merge conflicts that must be resolved before it can be | 


In this PR we improve the performance of the V1
triton_attnbackend. We do not change the triton kernels, but we reduce the well known launch overhead of triton kernels by caching the JIT artifacts of the decode kernel.Performance
All of the below results are for
meta-llama/Llama-3.1-8B-Instructon an NVIDIA H100 GPU. All experiments are done with--no-enable-prefix-cachingandcurrent upstream:
with this PR
to compare, using V1
FlashAttentionbackend:So, this PR improves the performance of the V1
triton_attnbackend by 27% and outperforms the FlashAttention-3 baseline by 5% for the serving benchmark.We are in the process of evaluating the performance on AMD GPUs.
Correctness
Using the jitcache still produces correct results.
Using FlashAttention on H100 we see:
with this PR, we see
Details / How did we achieve this performance improvement
The launch overhead of triton kernels is a well known problem (see e.g. 1, 2, 3). Parts of the launch overhead comes from the fact that the triton JIT checks very carefully if an existing binary is safe to use.
In many scenarios, these checks can be relaxed and check only a subset of the parameters.
This PR adds such a cache with relaxed checks is implemented by
jitcache. It is implemented as a decorator that could be used in front of thetriton.jitdecorator:As short description, the
jitcachefollows the steps of the triton JIT compiler to produce a binary, but does this for each version (indicated bycheck_keys) only once. For all consecutive invocations, only non-constant arguments are updated (and copied to GPU) skipping most parts of the compiler.A detailed usage description can be found in here.
This reduces the launch overhead of the
paged_attention_2dkernel of the triton backend from ~186us down to ~24us.Discussion
We have added this new jit cache to
vllm/triton_utils/because we expect that the vllm community prefers to have this rather as internal tool as an external dependency. However, we also published it as part of our triton-dejavu framework. So vllm could also import it from there, if this is preferred.Ideally, something like the
jitcachecould be added as feature to triton itself, but we expect this to maybe be a lengthier process. However, in all cases we would need to update the jitcache for every new triton release.