Skip to content

Conversation

jdefreitas02
Copy link
Contributor

@jdefreitas02 jdefreitas02 commented Apr 22, 2025

This PR improves LoRA compilation time and memory usage by splitting the large graph created by setting LoRAs into smaller sub-graphs. It also stops recompilations caused by indexing multiple LoRAs.

It further optimises the work done in #15655

Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
…` to be called with infinities

Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added ci/build v1 tpu Related to Google TPUs labels Apr 22, 2025
Copy link

mergify bot commented Apr 22, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @jdefreitas02.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Apr 22, 2025
@jdefreitas02 jdefreitas02 changed the title Better tpu multilora compilation [Hardware][TPU][V1] Better tpu multilora compilation Apr 22, 2025
@@ -0,0 +1,98 @@
# SPDX-License-Identifier: Apache-2.0
Copy link
Collaborator

@yaochengji yaochengji Apr 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bythew3i , could you help review the multi-lora kernels?

For introduction, @bythew3i is a Pallas and TPU expert and also the main author of ragged paged attention kernel in vLLM

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Chengji for the introduction. Also thanks @jdefreitas02 for the detailed comments in pallas kernel.

I wonder what is the motivation of wrting the pallas kernel to implement multi-lora? It seems to me using normal pytorch implementation can achieve better performance. Not every case needs a kernel. Unless we want to manual fuse lora kernel to attention kernel, other than that, I do not see the pallas kernel will outperform given it can not naturally fused with other ops by XLA in TPU.

(cc: @yarongmu-google )

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @bythew3i there were a few reasons for writing the kernel. In the original PR I started out with a pytorch implementation like what's there in the CPU implementation, but it was extremely slow, which led me down this route.

I didn't look at the IR or HLO for it but my guess for the reason is that the index_select operation causes a lot of data copies, which we're able to avoid in a kernel.

This kernel also has the LoRA laning feature, which allows us to pack multiple adapters into 1 TPU register, which reduces the number of matrix multiplications we need to do by a large factor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So @jdefreitas02 ran a benchmark comparing a pytorch implementation against the kernels for Llama3.1 8B, with 1 LoRA, 1024 input tokens and 1024 output tokens. The results were:

Pytorch: 406 tok/s
Pallas: 1407 tok/s

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please share baseline implemetation?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, the kernels are replaced with this function:

def ref_bgmv(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.Tensor):
    selected_loras = loras[idxs]
    if len(selected_loras.shape) == 4:
        selected_loras = selected_loras.squeeze(axis=1)

    T, L, D= selected_loras.shape
    return (selected_loras @ inputs.reshape(
        (T, D, 1))).reshape((T, L))

@jdefreitas02 do you still have the code where these are integrated?

Copy link
Contributor

@bythew3i bythew3i May 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Can you also please share all the inputs' shape and dtype that you used for benchmarking?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both the input and output sizes were 1024 and the dtype was fp16. Also we used Nvidia's GenAI perf image.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the info!

Now, I see the original TPU gather is too slow and nice optimization in your mask solution! I put the other comments in #15655 (comment) PTAL. Thanks!

Copy link
Collaborator

@yaochengji yaochengji left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for you contribution and continuous improvements from the old PR!

I tested the test_lora.py lora locally but got the error after pre-compilation finished, do you know why?

RuntimeError: Bad StatusOr access: UNKNOWN: TPU initialization failed: open(/dev/vfio/0): Device or resource busy: Device or resource busy; Couldn't open iommu group /dev/vfio/0

bias = bias.view(-1, bias.shape[-1])
bias = bias[indices]
bias = torch.where(indices[:, None] == -1, 0, bias)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove the empty line.

"cpu",
long_lora_context,
)
self._token_lora_indices[:base_indices.shape[0]] = base_indices.to(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use self._token_lora_indices = base_indices.to(self.device) here? The underlying implementation is different than GPUs. It creates intermediate buffers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I have now implemented this and it removed a couple subgraphs in the process.

self.lora_config.max_lora_rank = _get_padded_lora_rank(
self.lora_config.max_lora_rank, self.lora_config.max_loras)

if self.lora_config is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems redundant.

self.lora_config, self.device)
replace_set_lora(model)
punica_wrapper = self.lora_manager._adapter_manager.punica_wrapper
if not self.enforce_eager:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can mark_compiled even when enforce_eager.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, I've changed this in the original PR #14238, planning on merging it in once it's accepted.

def get_input_embeddings(self, *args, **kwargs):
return self.model.get_input_embeddings(*args, **kwargs)

def add_lora(self, lora_request: LoRARequest) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this function still needed with the new set_lora function?

@Akshat-Tripathi
Copy link
Contributor

I tested the test_lora.py lora locally but got the error after pre-compilation finished, do you know why?

RuntimeError: Bad StatusOr access: UNKNOWN: TPU initialization failed: open(/dev/vfio/0): Device or resource busy: Device or resource busy; Couldn't open iommu group /dev/vfio/0

That looks like there's another process using your TPU, killing it should help

@Akshat-Tripathi
Copy link
Contributor

@jdefreitas02 Do you have a list of the graphs we're compiling now?

@jdefreitas02
Copy link
Contributor Author

@jdefreitas02 Do you have a list of the graphs we're compiling now?

Stage #xla graphs
backbone 14
tpu_set_lora 6
select_hidden_states 4
sample_from_hidden 10

@yaochengji
Copy link
Collaborator

That looks like there's another process using your TPU, killing it should help

There's no other process. I guess something is wrong in the program.

Jorge de Freitas added 2 commits April 29, 2025 13:47
@Akshat-Tripathi
Copy link
Contributor

There's no other process. I guess something is wrong in the program.

Hi @yaochengji we're not able to reproduce your error on our end, maybe it's something with your environment?

Signed-off-by: Jorge de Freitas <[email protected]>
@yaochengji
Copy link
Collaborator

Hi @yaochengji we're not able to reproduce your error on our end, maybe it's something with your environment?

But my other program looks good. NVM, we can focus on the first multi-lora PR first.

Signed-off-by: Jorge de Freitas <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants