Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
268 commits
Select commit Hold shift + click to select a range
1dbfcd9
Fixed lora_input preparation for actual execution
Akshat-Tripathi Jan 23, 2025
1bb2578
Fixed wrong model bug
Akshat-Tripathi Jan 24, 2025
ddc4cbc
Moved if statements outside of for loops in PunicaWrapperTPU
Akshat-Tripathi Jan 24, 2025
48a6944
Added early exits to PunicaWrapperTPU lora functions
Akshat-Tripathi Jan 28, 2025
7802e84
Added torch ops for tpu (Static prefill sizes)
Akshat-Tripathi Jan 30, 2025
ab5396b
XLA bgmv operations are now imported from the default torch_ops
Akshat-Tripathi Jan 30, 2025
fdf29d3
Removed TODOs
Akshat-Tripathi Jan 31, 2025
c2b4139
Removed old code
Akshat-Tripathi Jan 31, 2025
f31b7d1
Linting
Akshat-Tripathi Jan 31, 2025
87ff73e
Fixed import error
Akshat-Tripathi Feb 3, 2025
96c3dde
lint
Akshat-Tripathi Feb 4, 2025
4e72ede
Abstracted out infinity values
Akshat-Tripathi Mar 3, 2025
e4d35ce
Moved and modified bgmv ops from the cpu backend to the tpu backend, …
Akshat-Tripathi Feb 7, 2025
3cf0680
Removed total_size for linting
Akshat-Tripathi Feb 7, 2025
a8ab0c9
Reverted changes to torch_ops
Akshat-Tripathi Feb 7, 2025
d73f1ce
Lint
Akshat-Tripathi Feb 7, 2025
e01d9a4
Replaced in-place buffer updates with direct returning
Akshat-Tripathi Mar 3, 2025
0c1bfb9
PunicaWrapperTPU now returns unchanged buffer if no loras are needed
Akshat-Tripathi Feb 11, 2025
46ce7fa
Simplified TPU prefill
Akshat-Tripathi Feb 12, 2025
5d0cc37
Removed sgmv kernels from TPU implementation
Akshat-Tripathi Feb 12, 2025
7590b0e
Fix bug
Akshat-Tripathi Feb 12, 2025
e7f75b5
Added torch.compiles to PunicaWrapperTPU functions
Akshat-Tripathi Feb 12, 2025
fe193f7
Replaced "x[x==-1] = y" with "x = torch.where(x == - 1, y)"
Akshat-Tripathi Feb 14, 2025
52e3911
Revert "Added torch.compiles to PunicaWrapperTPU functions"
Akshat-Tripathi Feb 14, 2025
33a70b0
Fix linting
Akshat-Tripathi Feb 14, 2025
67446b2
Added lora hotswapping test
Akshat-Tripathi Feb 18, 2025
0db19b1
Fixed hotswapping test prompt
Akshat-Tripathi Feb 18, 2025
a4c3b0a
Fixed bug in tpu lora test
Akshat-Tripathi Feb 18, 2025
9d6c388
Merged set_no_lora() functionality with _udpate_prefill_metada
Akshat-Tripathi Feb 14, 2025
2a9978e
Added Multi-LoRA functionality to TPU V1
Akshat-Tripathi Feb 14, 2025
b8c65bc
Added test that verifies switching
Akshat-Tripathi Feb 17, 2025
942ef07
Added bgmv kernel test code
Akshat-Tripathi Feb 4, 2025
56529b9
Added some dynamic lora selection
Akshat-Tripathi Feb 6, 2025
735073f
Moved and modified bgmv ops from the cpu backend to the tpu backend, …
Akshat-Tripathi Feb 7, 2025
1067b50
Added bgmv kernel test
Akshat-Tripathi Feb 10, 2025
d897f87
Made bgmv kernel fully functional (WIP on supporting smaller ranks) (…
Akshat-Tripathi Feb 10, 2025
d6eca29
Updated bgmv_kernel to work with ranks that aren't exact multiples of…
Akshat-Tripathi Feb 17, 2025
d97aae5
Removed interpreted mode on kernel
Akshat-Tripathi Feb 18, 2025
3ac0f63
Added pallas kernel benchmarking script
Akshat-Tripathi Feb 18, 2025
a620e58
Fixed mosaic kernel compilation issue
Akshat-Tripathi Feb 24, 2025
00d6dfd
Added reference kernel benchmarking
Akshat-Tripathi Feb 24, 2025
fb0601d
Registered the custom op
Akshat-Tripathi Feb 24, 2025
89b062e
Integrated bgmv kernel
Akshat-Tripathi Feb 24, 2025
ef2ef8c
Fixed model compilation bugs
Akshat-Tripathi Feb 24, 2025
a79e19d
Minor changes
Akshat-Tripathi Feb 25, 2025
cc8cdf6
Removed scratch files
Akshat-Tripathi Mar 4, 2025
ad8c565
Minor pallas kernel fixes
Akshat-Tripathi Mar 5, 2025
8d83065
integrate ragged paged attn v2
yaochengji Mar 3, 2025
dea7d02
fix precompile
yaochengji Mar 5, 2025
0cf0eaa
Merge branch 'chengji/ragged_attn_v2_new' into multi_lora_tpu_v1
Akshat-Tripathi Mar 6, 2025
6249307
Fixed padding issue with v1
Akshat-Tripathi Mar 6, 2025
af0a6a9
Added temporary patch over pallas kernel routing bug
Akshat-Tripathi Mar 6, 2025
264d36a
Updated kernel test
Akshat-Tripathi Mar 6, 2025
b725c6a
Lint
Akshat-Tripathi Mar 6, 2025
038465c
Removed duplicate method
Akshat-Tripathi Mar 6, 2025
2004369
Lint
Akshat-Tripathi Mar 6, 2025
71a1cdd
More linting
Akshat-Tripathi Mar 6, 2025
3dba9e0
Linting
Akshat-Tripathi Mar 6, 2025
f7f95e4
Lint
Akshat-Tripathi Mar 6, 2025
adfdcdb
Fixed bug related to consecutive pallas kernels
Akshat-Tripathi Mar 6, 2025
a6d5c01
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi Mar 7, 2025
5a27785
Removed v0 TPU LoRA implementation
Akshat-Tripathi Mar 7, 2025
5d15fbc
Fixed VocabParallelEmbeddingWithLoRA compilation error
Akshat-Tripathi Mar 8, 2025
ca3d810
Fixed LogitsProcessorWithLoRA layer compilation issue
Akshat-Tripathi Mar 10, 2025
12f71ce
Slightly sped up the kernel
Akshat-Tripathi Mar 10, 2025
d040ee8
Lint
Akshat-Tripathi Mar 10, 2025
e696144
Fixed bug with higher batch sizes
Akshat-Tripathi Mar 10, 2025
d110613
Lint
Akshat-Tripathi Mar 10, 2025
f8d5da2
Removed TODO in bgmv pallas test
Akshat-Tripathi Mar 11, 2025
d114377
Fixed PunicaWrapperBase typing
Akshat-Tripathi Mar 11, 2025
430bae9
Fixed bug where vLLM crashes on decode
Akshat-Tripathi Mar 11, 2025
fb36fd6
Fixed NaN bug with LogitsProcessor
Akshat-Tripathi Mar 11, 2025
c454062
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi Mar 12, 2025
23b14d1
Updated LoRALogitsProcessor to work with the TPU
Akshat-Tripathi Mar 12, 2025
27d6f70
Lint
Akshat-Tripathi Mar 12, 2025
b547271
Fixed batched logits processing
Akshat-Tripathi Mar 12, 2025
f5138b8
Updated kernel test
Akshat-Tripathi Mar 12, 2025
ad14872
Added kernel benchmark (dev only, remove later)
Akshat-Tripathi Mar 12, 2025
7418b5a
Tuned bgmv kernel block sizes
Akshat-Tripathi Mar 13, 2025
2aacb34
Improved lora output masking
Akshat-Tripathi Mar 13, 2025
6ee0b57
Skipped matmuls where no loras are needed
Akshat-Tripathi Mar 13, 2025
d9e415f
Renamed variables for better readabiity
Akshat-Tripathi Mar 17, 2025
460e808
Moved inner loop into grid spec
Akshat-Tripathi Mar 17, 2025
12ac3b8
Revert "Moved inner loop into grid spec"
Akshat-Tripathi Mar 17, 2025
1bb152f
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi Mar 18, 2025
af15bd1
Added comment
Akshat-Tripathi Mar 18, 2025
41555d1
Lint
Akshat-Tripathi Mar 18, 2025
4ac7aa9
Added a fused shrink/expand kernel
Akshat-Tripathi Mar 18, 2025
9f5a497
Revert "Added a fused shrink/expand kernel"
Akshat-Tripathi Mar 18, 2025
54344b7
Added some autotuning for kernels
Akshat-Tripathi Mar 18, 2025
c5a42e2
Renamed padding variables
Akshat-Tripathi Mar 18, 2025
e66067c
Used a static ones vector, gives a 5%ish perf boost
Akshat-Tripathi Mar 18, 2025
7c79683
Restricted block sizes to prevent memory from blowing up
Akshat-Tripathi Mar 19, 2025
d7338f8
Removed larger lora/dim block sizes since they reduce perf outside of…
Akshat-Tripathi Mar 19, 2025
2bb8868
Allowed smaller LoRA blocks if necessary
Akshat-Tripathi Mar 19, 2025
27ad793
Replaced torch.cat operations with F.pad
Akshat-Tripathi Mar 20, 2025
a82f3fe
Added fused lora transpose [experimental]
Akshat-Tripathi Mar 20, 2025
640420b
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi Mar 20, 2025
de6746a
Separated bgmv_shrink and bgmv_expand kernels to avoid unneccessary d…
Akshat-Tripathi Mar 20, 2025
19b9089
Removed redundant branch
Akshat-Tripathi Mar 20, 2025
a02d0e9
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi Mar 24, 2025
e07d6fb
Moved punica related `mark_dynamic` to the TPUModelRunner to allow th…
Akshat-Tripathi Mar 24, 2025
5b4ba1b
Moved `maybe_dummy_run_with_lora` to the `_dummy_run` method
Akshat-Tripathi Mar 24, 2025
efbdc62
Merge branch 'multi_lora_tpu_v1' into tpu_bgmv_optimisation
Akshat-Tripathi Mar 24, 2025
b64dc31
Lint
Akshat-Tripathi Mar 24, 2025
49a8102
Minor fixes + lint
Akshat-Tripathi Mar 24, 2025
c1be5f9
Lint
Akshat-Tripathi Mar 24, 2025
bf44d65
Fixed mark_dynamic placement for eager/compiled modes
Akshat-Tripathi Mar 25, 2025
15ff074
Fixed mark_dynamic placement for eager/compiled modes
Akshat-Tripathi Mar 25, 2025
d9f89b6
Temporary fix to LogitsProcessorWithLoRA pipeline bubble issue
Akshat-Tripathi Mar 25, 2025
81775d3
Sampler is now compiled with LoRA
Akshat-Tripathi Mar 25, 2025
ab036e0
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi Mar 26, 2025
2bc00b8
Merge branch 'main' into tpu_bgmv_optimisation
Akshat-Tripathi Mar 26, 2025
829028d
Removed early exits since they cause eager execution
Akshat-Tripathi Mar 26, 2025
b6af323
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi Mar 26, 2025
df69c52
Merge branch 'main' into tpu_bgmv_optimisation
Akshat-Tripathi Mar 26, 2025
5638e7d
Removed some recompilations when updating LoRA metadata
Akshat-Tripathi Mar 27, 2025
bae61a2
Aligned lora codepath with recompilation fixes
Akshat-Tripathi Mar 27, 2025
dc8b940
Disabled add_lora_logits temporarily
Akshat-Tripathi Mar 27, 2025
eb804a0
Added the LoRA Laning optimisation + tests + explanation
Akshat-Tripathi Mar 27, 2025
fbb902a
Updated kernel benchmarking script with lora laning
Akshat-Tripathi Mar 27, 2025
8ba2749
Added error for when someone tries to use LoRA adapters on the V0 TPU…
Akshat-Tripathi Mar 27, 2025
51d87a5
Added test to buildkite
Akshat-Tripathi Mar 27, 2025
bf52dbd
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi Mar 27, 2025
fce044a
Merge branch 'multi_lora_tpu_v1' into tpu_bgmv_optimisation
Akshat-Tripathi Mar 27, 2025
8b1dae8
Lint
Akshat-Tripathi Mar 27, 2025
aad109b
Optimised single lora kernels
Akshat-Tripathi Mar 27, 2025
b09d595
Fixed compilation bug
Akshat-Tripathi Mar 27, 2025
151fde4
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi Mar 31, 2025
a1df8c8
Merge branch 'multi_lora_tpu_v1' into tpu_bgmv_optimisation
Akshat-Tripathi Mar 31, 2025
72d95c6
Fixed LoRA Laning bug
Akshat-Tripathi Mar 31, 2025
be0915c
Fixed extra recompilations
Akshat-Tripathi Mar 31, 2025
478a8bb
Lint
Akshat-Tripathi Mar 31, 2025
4178e58
Lint
Akshat-Tripathi Mar 31, 2025
8a3009d
Added type annotation to lora_output
Akshat-Tripathi Mar 31, 2025
493e73f
Merge branch 'multi_lora_tpu_v1' into tpu_bgmv_optimisation
Akshat-Tripathi Mar 31, 2025
1d6085a
Removed unused function/parameter
Akshat-Tripathi Mar 31, 2025
f208234
Removed redundant padding in kernel for larger lora/dim sizes
Akshat-Tripathi Apr 1, 2025
ec0e181
Moved xm.mark_step() calls to move appropriate places
Akshat-Tripathi Apr 1, 2025
38de473
Reduced number of graphs compiled
Akshat-Tripathi Apr 1, 2025
8dabfab
Fixed memory usage problem
Akshat-Tripathi Apr 1, 2025
f5949a7
Lint
Akshat-Tripathi Apr 1, 2025
a7ae288
Lint
Akshat-Tripathi Apr 1, 2025
2e67aa8
Removed first inference recompilation
Akshat-Tripathi Apr 1, 2025
27b3c52
Fixed more recompilations
Akshat-Tripathi Apr 1, 2025
d1452af
Added flag to disabled add_lora_logits()
Akshat-Tripathi Apr 1, 2025
1cc89a5
Lint
Akshat-Tripathi Apr 1, 2025
93d3e8f
Fixed performance issue where the sampler would face long stalls
Akshat-Tripathi Apr 2, 2025
9fb50b9
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi Apr 2, 2025
592c62f
Merge branch 'multi_lora_tpu_v1' into tpu_bgmv_optimisation
Akshat-Tripathi Apr 2, 2025
e1aaed6
Fixed laning integration bug
Akshat-Tripathi Apr 2, 2025
62500e1
Lint
Akshat-Tripathi Apr 2, 2025
eb72ab6
Removed LoRA vocab padding for TPU
Akshat-Tripathi Apr 4, 2025
49157b1
Merge branch 'multi_lora_tpu_v1' into tpu_bgmv_optimisation
Akshat-Tripathi Apr 4, 2025
54db22d
Fixed 0 padding issue with LoRA
Akshat-Tripathi Apr 4, 2025
5232785
Changed TPU lora_vocab_padding_size to 1
Akshat-Tripathi Apr 4, 2025
1b4c2f2
Fixed bug in bgmv_expand kernel - outputs weren't being written with …
Akshat-Tripathi Apr 4, 2025
c8f68d7
Changed TPU lora_vocab_padding_size to 1
Akshat-Tripathi Apr 4, 2025
ed3b245
Enabled lora bias
Akshat-Tripathi Apr 4, 2025
4855791
Merge branch 'multi_lora_tpu_v1' into tpu_bgmv_optimisation
Akshat-Tripathi Apr 7, 2025
9d35414
Replaced `enable_laning` flag with dim comparison
Akshat-Tripathi Apr 7, 2025
54c00c3
Enabled fully sharded loras
Akshat-Tripathi Apr 7, 2025
f3e48a6
Merge branch 'multi_lora_tpu_v1' into tpu_bgmv_optimisation
Akshat-Tripathi Apr 7, 2025
a4b2e27
Removed test benchmarking file
Akshat-Tripathi Apr 7, 2025
9f0fdbe
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi Apr 7, 2025
fbddd3c
Refactored add_shrink to return a tensor not a tuple
Akshat-Tripathi Apr 8, 2025
2012bbd
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi Apr 9, 2025
d1c11c8
Merge branch 'multi_lora_tpu_v1' into tpu_bgmv_optimisation
Akshat-Tripathi Apr 9, 2025
1803135
Removed tuple return in add_shrink()
Akshat-Tripathi Apr 9, 2025
0eeb72c
Removed extra compilation
Akshat-Tripathi Apr 9, 2025
c1be9fe
Replaced copies with buffer donation to reduce memory usage
Akshat-Tripathi Apr 9, 2025
de5da33
Added explicit compilation in add_lora
Akshat-Tripathi Apr 10, 2025
5adc67f
Removed LoRA ID collision
Akshat-Tripathi Apr 10, 2025
342ff8b
Fix pre-commit
Akshat-Tripathi Apr 10, 2025
fc65edb
Reduced number of iterations in test_lora
Akshat-Tripathi Apr 10, 2025
2f1da29
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi Apr 10, 2025
7daaafa
Lint
Akshat-Tripathi Apr 10, 2025
893ac04
Reduced pallas kernel test size
Akshat-Tripathi Apr 11, 2025
2a0fce7
Added/removed comments
Akshat-Tripathi Apr 11, 2025
4d42844
Fixed pallas kernel test
Akshat-Tripathi Apr 11, 2025
50a06fc
Made LoRA e2e test more robust
Akshat-Tripathi Apr 11, 2025
9b78b74
Merge branch 'multi_lora_tpu_v1' into tpu_bgmv_optimisation
Akshat-Tripathi Apr 14, 2025
177fced
Added mark_steps to set_lora to break up large graphs
Akshat-Tripathi Apr 22, 2025
12ff364
Stopped index based recompilation for multi-lora
Akshat-Tripathi Apr 22, 2025
491578d
Restored original maybe_dummy_run_with_lora
Akshat-Tripathi Apr 22, 2025
5cb4724
Split up into lora setup and lora selection functions
Akshat-Tripathi Apr 22, 2025
ca68ce6
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi Apr 22, 2025
e91774a
Merge branch 'multi_lora_tpu_v1' into tpu_bgmv_optimisation
Akshat-Tripathi Apr 22, 2025
f4be6cc
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi Apr 22, 2025
155c2ad
Merge branch 'multi_lora_tpu_v0' of https://github.com/krai/vllm into…
Akshat-Tripathi Apr 22, 2025
317a131
Removed mark_compiled from punica_tpu
Akshat-Tripathi Apr 22, 2025
cba8267
refactor mark step from layers into tpu model runner
xihajun Apr 22, 2025
b482ec8
Split TPU LoRA test into several smaller ones
Akshat-Tripathi Apr 22, 2025
2f26dd9
Fix lora spelling
Akshat-Tripathi Apr 22, 2025
d4b3707
refactor mark step from layers into tpu model runner
jdefreitas02 Apr 22, 2025
afd7690
Merge branch 'better_tpu_multilora_compilation' of https://github.com…
xihajun Apr 23, 2025
8ccbaa8
Added comment explaining how multi-lora test adapters were trained
Akshat-Tripathi Apr 24, 2025
d227381
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi Apr 24, 2025
b65f60e
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi Apr 25, 2025
8a45758
Moved TPU lora tests into tests/tpu/lora
Akshat-Tripathi Apr 25, 2025
987589a
Updated TPU tests
Akshat-Tripathi Apr 25, 2025
bc49d0f
Fixed tpu-test script
Akshat-Tripathi Apr 25, 2025
50e9738
Fixed pallas kernel dtype in test
Akshat-Tripathi Apr 25, 2025
ed1738a
mask creation moved outside of matmul loop
Apr 25, 2025
f81111e
Moved mask setup outside of lora running loop
Akshat-Tripathi Apr 28, 2025
4a07cf6
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi Apr 29, 2025
8cd5cb7
Disabled LoRA serving for now
Akshat-Tripathi Apr 29, 2025
94bc0e2
compile reset_lora function as separate graph
Apr 29, 2025
b0bfc7a
update base metadata padding moved to cpu
Apr 30, 2025
6282cd5
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi Apr 30, 2025
975935e
Merge branch 'multi_lora_tpu_v1' into tpu_bgmv_optimisation
Akshat-Tripathi Apr 30, 2025
1846ef3
Temporarily disabled the TPU lora tests
Akshat-Tripathi Apr 30, 2025
a6791a2
Merge branch 'multi_lora_tpu_v1' into tpu_bgmv_optimisation
Akshat-Tripathi Apr 30, 2025
a006f6b
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi May 1, 2025
539366c
Merge branch 'main' into tpu_bgmv_optimisation
Akshat-Tripathi May 1, 2025
966e800
fix size of sampler indices
May 1, 2025
d72a86b
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi May 6, 2025
139e9b1
Merge branch 'multi_lora_tpu_v1' into tpu_bgmv_optimisation
Akshat-Tripathi May 6, 2025
22bafee
remove add lora
May 6, 2025
aff7414
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi May 6, 2025
dae6c40
Merge branch 'multi_lora_tpu_v1' into tpu_bgmv_optimisation
Akshat-Tripathi May 6, 2025
e487ecb
Fixed incorrect torch.wheres
Akshat-Tripathi May 7, 2025
20c5981
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi May 7, 2025
df67053
Lint
Akshat-Tripathi May 7, 2025
f9f9011
Merge branch 'multi_lora_tpu_v1' into tpu_bgmv_optimisation
Akshat-Tripathi May 7, 2025
8f5c64d
Merge branch 'main' into tpu_bgmv_optimisation
Akshat-Tripathi May 7, 2025
9f249bf
Updated comment + variable name fixes
Akshat-Tripathi May 8, 2025
0fe1866
Removed unused function
Akshat-Tripathi May 8, 2025
1e25225
Fixed spelling inconsistency
Akshat-Tripathi May 8, 2025
a5aba61
Removed redundant `mark_step`
Akshat-Tripathi May 8, 2025
bae5615
Merge branch 'better_tpu_multilora_compilation' into tpu_bgmv_optimis…
Akshat-Tripathi May 15, 2025
19dc26f
Removed unlowered ops when resettiing loras
Akshat-Tripathi May 15, 2025
fff8b03
Fixed extra recompilation due to incorrect vocab size
Akshat-Tripathi May 15, 2025
d9aa141
Removed recompilations in logprobs gathering
Akshat-Tripathi May 15, 2025
f5b72ce
Replaced the pallas kernle with einsums
Akshat-Tripathi May 15, 2025
17fcaef
Merge branch 'main' into tpu_bgmv_optimisation
Akshat-Tripathi May 15, 2025
cb18346
Reenabled TPU LoRA test
Akshat-Tripathi May 15, 2025
dd448d4
Fixed precommit
Akshat-Tripathi May 15, 2025
69b5a33
Switched jax implementation to use a custom op
Akshat-Tripathi May 16, 2025
71624c6
Fixed bug when setting LoRAs. This introduces recompilations
Akshat-Tripathi May 19, 2025
e6dfd00
Removed extra `add_lora` function
Akshat-Tripathi May 19, 2025
a242f31
Lint
Akshat-Tripathi May 19, 2025
6770adf
Merge branch 'main' into tpu_bgmv_optimisation
Akshat-Tripathi May 20, 2025
820d9f6
Merge branch 'main' into tpu_bgmv_optimisation
Akshat-Tripathi May 21, 2025
376b1a1
Merge branch 'main' into tpu_bgmv_optimisation
Akshat-Tripathi May 27, 2025
c575c0e
Addressed PR comments
Akshat-Tripathi May 27, 2025
53ab8c4
Moved LoRA recompilation check to the test script
Akshat-Tripathi May 27, 2025
6edf6fb
Lint
Akshat-Tripathi May 27, 2025
8a87529
Merge branch 'main' into tpu_bgmv_optimisation
Akshat-Tripathi May 28, 2025
a3dcb2c
Merge branch 'main' into tpu_bgmv_optimisation
Akshat-Tripathi May 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions .buildkite/scripts/hardware_ci/run-tpu-v1-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,8 @@ run_and_track_test 11 "test_struct_output_generate.py" \
"python3 -m pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py"
run_and_track_test 12 "test_moe_pallas.py" \
"python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py"

# Disable the TPU LoRA tests until the feature is activated
# run_and_track_test 13 "test_lora (directory)" \
# "python3 -m pytest -s -v /workspace/vllm/tests/tpu/lora/"
run_and_track_test 13 "test_lora.py" \
"VLLM_XLA_CHECK_RECOMPILATION=0 python3 -m pytest -s -v /workspace/vllm/tests/tpu/lora/test_lora.py"

# After all tests have been attempted, exit with the overall status.
if [ "$overall_script_exit_code" -ne 0 ]; then
Expand Down
73 changes: 0 additions & 73 deletions tests/tpu/lora/test_pallas_kernels.py

This file was deleted.

4 changes: 2 additions & 2 deletions vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def from_local_checkpoint(
weights_mapper: Optional[WeightsMapper] = None,
tensorizer_config_dict: Optional[dict] = None) -> "LoRAModel":
"""Create a LoRAModel from a local checkpoint.

Args:
lora_dir: The local path that has lora data.
expected_lora_modules: Name of modules that are expected to be
Expand Down Expand Up @@ -620,7 +620,7 @@ def _match_target_modules(self, module_name: str):
def _filter_unsupported_mm_module(self, module_name: str) -> bool:
"""
Regarding multimodal models, vLLM currently only supports adding LoRA to
language model. LoRA for other modules, such as the vision tower, will
language model. LoRA for other modules, such as the vision tower, will
be filtered out.
"""
if self.supports_mm:
Expand Down
140 changes: 89 additions & 51 deletions vllm/lora/ops/xla_ops/lora_ops.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,99 @@
# SPDX-License-Identifier: Apache-2.0

import jax
import jax.numpy as jnp
import torch
import torch.nn.functional as F
import torch_xla.core.xla_builder as xb
from torch.library import impl
from torch_xla.experimental.custom_kernel import XLA_LIB, jax_import_guard

# Required to register the custom ops
import vllm.lora.ops.xla_ops.pallas # noqa # pylint: disable=unused-import

@jax.jit
def bgmv_jax(inputs, loras, idxs):
return jnp.einsum(
"td,tX,Xld->tl",
inputs,
jax.nn.one_hot(idxs, loras.shape[0], dtype=inputs.dtype),
loras,
)

def bgmv_expand(inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
add_inputs: bool = True):

XLA_LIB.define("bgmv(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor")


@impl(XLA_LIB, "bgmv", "XLA")
def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor):
if len(loras.shape) == 4:
loras = loras.squeeze(axis=1)

jax_import_guard()
return xb.call_jax(bgmv_jax, (inputs, loras, idxs))


@impl(XLA_LIB, "bgmv", "CompositeExplicitAutograd")
def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor,
idxs: torch.IntTensor):
T, _ = inputs.shape
if len(loras.shape) == 4:
loras = loras.squeeze(axis=1)
_, L, _ = loras.shape

return torch.empty((T, L), device=inputs.device)


def bgmv_expand(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
add_inputs: bool = True,
):
"""
Args:
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
lora_b_weights (torch.Tensor): LoRA weights of shape

lora_b_weights (torch.Tensor): LoRA weights of shape
[num_loras, lora_rank, hidden_size].
output_tensor (torch.Tensor): output tensor of shape

output_tensor (torch.Tensor): output tensor of shape
[num_tokens, hidden_size * num_slices].
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]

lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
indicating which LoRA matrix to use for each token.
add_inputs (bool): Whether or not to add the input tensor to the output
add_inputs (bool): Whether or not to add the input tensor to the output
tensor.
"""

outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
n_tokens = outputs.size(0)

limit = output_tensor.shape[0]
if outputs.shape[0] == 1 and output_tensor.shape[0] != 1:
limit = 1

outputs = torch.cat(
(outputs,
torch.zeros((n_tokens, output_tensor.shape[1] - outputs.shape[1]),
device=outputs.device)),
dim=1)
if output_tensor.shape[1] > outputs.shape[1]:
outputs = F.pad(outputs,
(0, output_tensor.shape[1] - outputs.shape[1], 0, 0))
Comment on lines +74 to +76
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this traced?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be traced, I'm not sure if we actually hit this case though, it's based off the CPU implementation


if add_inputs:
return output_tensor + outputs[:limit, :]
return output_tensor + outputs[:limit, :output_tensor.shape[1]]
else:
return outputs[:limit, :]
return outputs[:limit, :output_tensor.shape[1]]


def bgmv_shrink(inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
scaling: float = 1.0):
def bgmv_shrink(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
lora_indices_tensor: torch.Tensor,
scaling: float = 1.0,
):
"""
Args:
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
lora_b_weights (torch.Tensor): LoRA weights of shape
lora_b_weights (torch.Tensor): LoRA weights of shape
[num_loras, lora_rank, hidden_size].
output_tensor (torch.Tensor): (Unused) output tensor (placeholder).
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
indicating which LoRA matrix to use for each token.
scaling (float, optional): Scalar multiplier applied to the output.
"""
Expand All @@ -66,39 +102,41 @@ def bgmv_shrink(inputs: torch.Tensor,
lora_indices_tensor)


def bgmv_expand_slice(inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = True):
def bgmv_expand_slice(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = True,
):
"""
Args:
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
lora_b_weights (torch.Tensor): LoRA weights of shape

lora_b_weights (torch.Tensor): LoRA weights of shape
[num_loras, lora_rank, hidden_size].
output_tensor (torch.Tensor): output tensor of shape

output_tensor (torch.Tensor): output tensor of shape
[num_tokens, hidden_size * num_slices].
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]

lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
indicating which LoRA matrix to use for each token.
add_inputs (bool): Whether or not to add the input tensor to the output
add_inputs (bool): Whether or not to add the input tensor to the output
tensor.
"""
outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
n_tokens = outputs.size(0)

outputs = torch.cat((
torch.zeros((n_tokens, slice_offset), device=outputs.device),
outputs = F.pad(
outputs,
torch.zeros(
(n_tokens, output_tensor.shape[1] - (slice_offset + slice_size)),
device=outputs.device),
),
dim=1)
(
slice_offset,
output_tensor.shape[1] - (slice_offset + slice_size),
0,
0,
),
)

if add_inputs:
return output_tensor + outputs
Expand Down
Loading