-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
[Hardware][TPU][V1] Better tpu multilora compilation #16989
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Hardware][TPU][V1] Better tpu multilora compilation #16989
Conversation
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
…` to be called with infinities Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: xihajun <[email protected]>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Jorge de Freitas <[email protected]>
…/Krai/vllm into better_tpu_multilora_compilation
Signed-off-by: Jorge de Freitas <[email protected]>
@@ -0,0 +1,98 @@ | |||
# SPDX-License-Identifier: Apache-2.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Chengji for the introduction. Also thanks @jdefreitas02 for the detailed comments in pallas kernel.
I wonder what is the motivation of wrting the pallas kernel to implement multi-lora? It seems to me using normal pytorch implementation can achieve better performance. Not every case needs a kernel. Unless we want to manual fuse lora kernel to attention kernel, other than that, I do not see the pallas kernel will outperform given it can not naturally fused with other ops by XLA in TPU.
(cc: @yarongmu-google )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @bythew3i there were a few reasons for writing the kernel. In the original PR I started out with a pytorch implementation like what's there in the CPU implementation, but it was extremely slow, which led me down this route.
I didn't look at the IR or HLO for it but my guess for the reason is that the index_select
operation causes a lot of data copies, which we're able to avoid in a kernel.
This kernel also has the LoRA laning feature, which allows us to pack multiple adapters into 1 TPU register, which reduces the number of matrix multiplications we need to do by a large factor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So @jdefreitas02 ran a benchmark comparing a pytorch implementation against the kernels for Llama3.1 8B, with 1 LoRA, 1024 input tokens and 1024 output tokens. The results were:
Pytorch: 406 tok/s
Pallas: 1407 tok/s
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please share baseline implemetation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, the kernels are replaced with this function:
def ref_bgmv(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.Tensor):
selected_loras = loras[idxs]
if len(selected_loras.shape) == 4:
selected_loras = selected_loras.squeeze(axis=1)
T, L, D= selected_loras.shape
return (selected_loras @ inputs.reshape(
(T, D, 1))).reshape((T, L))
@jdefreitas02 do you still have the code where these are integrated?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Can you also please share all the inputs' shape and dtype that you used for benchmarking?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both the input and output sizes were 1024 and the dtype was fp16. Also we used Nvidia's GenAI perf image.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the info!
Now, I see the original TPU gather is too slow and nice optimization in your mask solution! I put the other comments in #15655 (comment) PTAL. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for you contribution and continuous improvements from the old PR!
I tested the test_lora.py
lora locally but got the error after pre-compilation finished, do you know why?
RuntimeError: Bad StatusOr access: UNKNOWN: TPU initialization failed: open(/dev/vfio/0): Device or resource busy: Device or resource busy; Couldn't open iommu group /dev/vfio/0
bias = bias.view(-1, bias.shape[-1]) | ||
bias = bias[indices] | ||
bias = torch.where(indices[:, None] == -1, 0, bias) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove the empty line.
"cpu", | ||
long_lora_context, | ||
) | ||
self._token_lora_indices[:base_indices.shape[0]] = base_indices.to( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use self._token_lora_indices = base_indices.to(self.device)
here? The underlying implementation is different than GPUs. It creates intermediate buffers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I have now implemented this and it removed a couple subgraphs in the process.
self.lora_config.max_lora_rank = _get_padded_lora_rank( | ||
self.lora_config.max_lora_rank, self.lora_config.max_loras) | ||
|
||
if self.lora_config is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems redundant.
self.lora_config, self.device) | ||
replace_set_lora(model) | ||
punica_wrapper = self.lora_manager._adapter_manager.punica_wrapper | ||
if not self.enforce_eager: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can mark_compiled
even when enforce_eager
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, I've changed this in the original PR #14238, planning on merging it in once it's accepted.
vllm/v1/worker/tpu_model_runner.py
Outdated
def get_input_embeddings(self, *args, **kwargs): | ||
return self.model.get_input_embeddings(*args, **kwargs) | ||
|
||
def add_lora(self, lora_request: LoRARequest) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this function still needed with the new set_lora
function?
That looks like there's another process using your TPU, killing it should help |
@jdefreitas02 Do you have a list of the graphs we're compiling now? |
Signed-off-by: Akshat Tripathi <[email protected]>
|
There's no other process. I guess something is wrong in the program. |
Signed-off-by: Jorge de Freitas <[email protected]>
Signed-off-by: Jorge de Freitas <[email protected]>
Hi @yaochengji we're not able to reproduce your error on our end, maybe it's something with your environment? |
Signed-off-by: Jorge de Freitas <[email protected]>
But my other program looks good. NVM, we can focus on the first multi-lora PR first. |
Signed-off-by: Jorge de Freitas <[email protected]>
This PR improves LoRA compilation time and memory usage by splitting the large graph created by setting LoRAs into smaller sub-graphs. It also stops recompilations caused by indexing multiple LoRAs.
It further optimises the work done in #15655