Skip to content

Commit 9875be6

Browse files
authored
[LoRA][2/2]Remove LoRA extra vocab (#28545)
Signed-off-by: Jee Jee Li <[email protected]>
1 parent df44df0 commit 9875be6

28 files changed

+133
-528
lines changed

tests/lora/conftest.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,16 @@ def olmoe_lora_files():
250250
return snapshot_download(repo_id="jeeejeee/olmoe-instruct-text2sql-spider")
251251

252252

253+
@pytest.fixture(scope="session")
254+
def qwen3_lora_files():
255+
return snapshot_download(repo_id="charent/self_cognition_Alice")
256+
257+
258+
@pytest.fixture(scope="session")
259+
def llama32_lora_files():
260+
return snapshot_download(repo_id="jeeejeee/llama32-3b-text2sql-spider")
261+
262+
253263
@pytest.fixture
254264
def reset_default_device():
255265
"""

tests/lora/test_layers.py

Lines changed: 13 additions & 176 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,6 @@ def populate_loras(
136136
id_to_index: list[int | None],
137137
layer: BaseLayerWithLoRA,
138138
layer_weights: torch.Tensor,
139-
generate_embeddings_tensor: int = 0,
140139
repeats: int = 1,
141140
) -> tuple[dict[int, LoRALayerWeights], dict[int, list[LoRALayerWeights]]]:
142141
"""This method populates the lora layers with lora weights.
@@ -148,8 +147,6 @@ def populate_loras(
148147
layer: the LoRAlayer to populate.
149148
layer_weights: the PyTorch tensor containing the layer's
150149
weights.
151-
generate_embeddings_tensor: whether to generate an
152-
embeddings tensor for each LoRA.
153150
repeats: must only be set for column parallel packed
154151
layers. Indicates the number of loras to compose
155152
together to create a single lora layer.
@@ -171,7 +168,6 @@ def populate_loras(
171168
sublora = DummyLoRAManager(layer_weights.device).init_random_lora(
172169
module_name=f"fake_{i}",
173170
weight=layer_weights,
174-
generate_embeddings_tensor=generate_embeddings_tensor,
175171
)
176172
sublora.lora_b = sublora.lora_b[
177173
(sublora_len * i) : (sublora_len * (i + 1)), :
@@ -185,7 +181,6 @@ def populate_loras(
185181
slot_idx,
186182
lora_a=lora.lora_a,
187183
lora_b=lora.lora_b,
188-
embeddings_tensor=lora.embeddings_tensor,
189184
)
190185

191186
lora_dict[lora_id] = lora
@@ -306,7 +301,6 @@ def create_random_embedding_layer():
306301
id_to_index,
307302
max_loras,
308303
vocab_size,
309-
lora_config.lora_extra_vocab_size,
310304
)
311305

312306
lora_result = lora_embedding(torch.cat(inputs))
@@ -344,7 +338,6 @@ def create_random_embedding_layer():
344338
id_to_index,
345339
max_loras,
346340
vocab_size,
347-
lora_config.lora_extra_vocab_size,
348341
)
349342

350343
lora_result = lora_embedding(torch.cat(inputs))
@@ -354,149 +347,6 @@ def create_random_embedding_layer():
354347
torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
355348

356349

357-
@torch.inference_mode()
358-
# @pytest.mark.skip(
359-
# reason="Fails when loras are in any slot other than the first.")
360-
@pytest.mark.parametrize("num_loras", [1, 2, 4])
361-
@pytest.mark.parametrize("device", DEVICES)
362-
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
363-
@pytest.mark.parametrize("stage", STAGES)
364-
def test_embeddings_with_new_embeddings(
365-
dist_init, num_loras, device, vocab_size, stage
366-
) -> None:
367-
if current_platform.is_cuda_alike():
368-
torch.cuda.set_device(device)
369-
370-
torch.set_default_device(device)
371-
max_loras = 8
372-
punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
373-
assert check_punica_wrapper(punica_wrapper)
374-
lora_config = LoRAConfig(
375-
max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16
376-
)
377-
378-
def create_random_embedding_layer():
379-
embedding = VocabParallelEmbedding(vocab_size, 256)
380-
embedding_data = torch.rand_like(embedding.weight.data)
381-
embedding.weight.data = embedding_data
382-
embedding.weight.data[vocab_size:, :] = 0
383-
expanded_embedding = VocabParallelEmbedding(
384-
vocab_size + lora_config.lora_extra_vocab_size * max_loras,
385-
256,
386-
org_num_embeddings=vocab_size,
387-
)
388-
expanded_embedding.weight.data[:vocab_size, :] = embedding_data
389-
# We need to deepcopy the embedding as it will be modified
390-
# in place
391-
lora_embedding = VocabParallelEmbeddingWithLoRA(deepcopy(expanded_embedding))
392-
lora_embedding.create_lora_weights(max_loras, lora_config)
393-
394-
return expanded_embedding, lora_embedding
395-
396-
for i in range(NUM_RANDOM_SEEDS):
397-
set_random_seed(i)
398-
399-
id_to_index = get_random_id_to_index(num_loras, max_loras)
400-
expanded_embedding, lora_embedding = create_random_embedding_layer()
401-
lora_dict, _ = populate_loras(
402-
id_to_index,
403-
layer=lora_embedding,
404-
layer_weights=torch.zeros(
405-
(256, vocab_size + lora_config.lora_extra_vocab_size)
406-
),
407-
generate_embeddings_tensor=256,
408-
)
409-
410-
lora_embedding.set_mapping(punica_wrapper)
411-
# All embeddings tensors have the same shape.
412-
embeddings_tensors = [
413-
lora_dict[id].embeddings_tensor for id in sorted(lora_dict.keys())
414-
]
415-
embeddings_tensor_len = embeddings_tensors[0].shape[0]
416-
417-
# Add empty embeddings_tensors for unoccupied lora slots.
418-
for _ in range(max_loras - len(embeddings_tensors)):
419-
embeddings_tensors.append(torch.zeros(embeddings_tensors[0].shape))
420-
421-
inputs, index_mapping, prompt_mapping = create_random_inputs(
422-
active_lora_ids=list(lora_dict.keys()),
423-
num_inputs=num_loras * 3,
424-
input_size=(200,),
425-
input_range=(1, vocab_size),
426-
device=device,
427-
)
428-
lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
429-
punica_wrapper.update_metadata(
430-
lora_mapping,
431-
id_to_index,
432-
max_loras,
433-
vocab_size,
434-
lora_config.lora_extra_vocab_size,
435-
)
436-
original_inputs = deepcopy(inputs)
437-
438-
# Force some of the inputs to be in the extended embeddings range
439-
# to guarantee that their behavior is tested.
440-
for input_, original_input_, lora_id in zip(
441-
inputs, original_inputs, prompt_mapping
442-
):
443-
embedding_id = lora_id - 1
444-
input_[-1] = vocab_size + (embedding_id * embeddings_tensor_len)
445-
original_input_[-1] = vocab_size
446-
input_[-2] = vocab_size + ((embedding_id + 1) * embeddings_tensor_len - 1)
447-
original_input_[-2] = vocab_size + embeddings_tensor_len - 1
448-
449-
expanded_embedding.weight[
450-
vocab_size : vocab_size + (embeddings_tensor_len * max_loras)
451-
] = torch.cat(embeddings_tensors)
452-
453-
lora_result = lora_embedding(torch.cat(original_inputs))
454-
455-
expected_results: list[torch.Tensor] = []
456-
for input_, original_input_, lora_id in zip(
457-
inputs, original_inputs, prompt_mapping
458-
):
459-
lora = lora_dict[lora_id]
460-
result = expanded_embedding(input_)
461-
after_a = F.embedding(
462-
original_input_,
463-
lora.lora_a.T,
464-
)
465-
result += after_a @ lora.lora_b.T
466-
expected_results.append(result)
467-
expected_result = torch.cat(expected_results)
468-
469-
rtol, atol = TOLERANCES[lora_result.dtype]
470-
torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
471-
472-
# Check that resetting the lora weights succeeds
473-
474-
for slot_idx in range(max_loras):
475-
lora_embedding.reset_lora(slot_idx)
476-
477-
inputs, index_mapping, prompt_mapping = create_random_inputs(
478-
active_lora_ids=[0],
479-
num_inputs=num_loras * 3,
480-
input_size=(200,),
481-
input_range=(1, vocab_size),
482-
device=device,
483-
)
484-
original_inputs = deepcopy(inputs)
485-
lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
486-
punica_wrapper.update_metadata(
487-
lora_mapping,
488-
id_to_index,
489-
max_loras,
490-
vocab_size,
491-
lora_config.lora_extra_vocab_size,
492-
)
493-
lora_result = lora_embedding(torch.cat(original_inputs))
494-
expected_result = expanded_embedding(torch.cat(inputs))
495-
496-
rtol, atol = TOLERANCES[lora_result.dtype]
497-
torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
498-
499-
500350
@torch.inference_mode()
501351
@pytest.mark.parametrize("num_loras", [1, 2, 4])
502352
@pytest.mark.parametrize("device", DEVICES)
@@ -518,16 +368,13 @@ def test_lm_head_logits_processor(
518368

519369
def _pretest():
520370
linear = ParallelLMHead(
521-
vocab_size + lora_config.lora_extra_vocab_size,
522-
1024,
523-
vocab_size,
371+
num_embeddings=vocab_size,
372+
embedding_dim=1024,
524373
params_dtype=torch.float16,
525374
)
526375
linear.weight.data = torch.rand_like(linear.weight.data)
527376
linear.weight.data[:, vocab_size:] = 0
528-
logits_processor = LogitsProcessor(
529-
vocab_size + lora_config.lora_extra_vocab_size, vocab_size
530-
)
377+
logits_processor = LogitsProcessor(vocab_size)
531378
lora_logits_processor = LogitsProcessorWithLoRA(
532379
logits_processor, 1024, linear.weight.dtype, linear.weight.device, None
533380
)
@@ -541,15 +388,12 @@ def _pretest():
541388
id_to_index = get_random_id_to_index(num_loras, max_loras)
542389
linear, logits_processor, lora_logits_processor = _pretest()
543390
lora_logits_processor.set_mapping(punica_wrapper)
544-
# NOTE: all the generated loras share the same embeddings tensor.
391+
545392
lora_dict, _ = populate_loras(
546393
id_to_index,
547394
layer=lora_logits_processor,
548395
layer_weights=linear.weight,
549-
generate_embeddings_tensor=1024,
550396
)
551-
embeddings_tensor = list(lora_dict.values())[0].embeddings_tensor
552-
embeddings_tensor_len = embeddings_tensor.shape[0]
553397

554398
inputs, index_mapping, prompt_mapping = create_random_inputs(
555399
active_lora_ids=list(lora_dict.keys()),
@@ -565,7 +409,6 @@ def _pretest():
565409
id_to_index,
566410
max_loras,
567411
vocab_size,
568-
lora_config.lora_extra_vocab_size,
569412
)
570413
input_ = torch.rand(20, 1024)
571414

@@ -575,23 +418,16 @@ def _pretest():
575418

576419
original_lm_head = deepcopy(linear)
577420

578-
linear.weight[
579-
logits_processor.org_vocab_size : logits_processor.org_vocab_size
580-
+ embeddings_tensor_len
581-
] = embeddings_tensor
582-
583-
logits_processor.org_vocab_size = vocab_size + lora_config.lora_extra_vocab_size
584421
expected_results: list[torch.Tensor] = []
585422
for input_, lora_id in zip(inputs, prompt_mapping):
586423
lora = lora_dict[lora_id]
587424
result = logits_processor._get_logits(
588425
hidden_states=input_, lm_head=linear, embedding_bias=None
589426
)
590-
result[:, vocab_size + embeddings_tensor_len :] = float("-inf")
427+
591428
result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
592429
expected_results.append(result)
593430
expected_result = torch.cat(expected_results)
594-
logits_processor.org_vocab_size = vocab_size
595431

596432
# Check that resetting the lora weights succeeds
597433

@@ -612,7 +448,6 @@ def _pretest():
612448
id_to_index,
613449
max_loras,
614450
vocab_size,
615-
lora_config.lora_extra_vocab_size,
616451
)
617452

618453
lora_result = lora_logits_processor._get_logits(
@@ -694,7 +529,6 @@ def create_random_linear_replicated_layer():
694529
id_to_index,
695530
max_loras,
696531
512,
697-
lora_config.lora_extra_vocab_size,
698532
)
699533

700534
lora_result = lora_linear(torch.cat(inputs))[0]
@@ -726,7 +560,10 @@ def create_random_linear_replicated_layer():
726560
lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
727561

728562
punica_wrapper.update_metadata(
729-
lora_mapping, id_to_index, max_loras, 512, lora_config.lora_extra_vocab_size
563+
lora_mapping,
564+
id_to_index,
565+
max_loras,
566+
512,
730567
)
731568

732569
lora_result = lora_linear(torch.cat(inputs))[0]
@@ -817,7 +654,6 @@ def create_random_linear_parallel_layer():
817654
id_to_index,
818655
max_loras,
819656
512,
820-
lora_config.lora_extra_vocab_size,
821657
)
822658

823659
lora_result = lora_linear(torch.cat(inputs))[0]
@@ -849,7 +685,10 @@ def create_random_linear_parallel_layer():
849685
lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
850686

851687
punica_wrapper.update_metadata(
852-
lora_mapping, id_to_index, max_loras, 512, lora_config.lora_extra_vocab_size
688+
lora_mapping,
689+
id_to_index,
690+
max_loras,
691+
512,
853692
)
854693

855694
lora_result = lora_linear(torch.cat(inputs))[0]
@@ -963,7 +802,6 @@ class FakeConfig:
963802
id_to_index,
964803
max_loras,
965804
512,
966-
lora_config.lora_extra_vocab_size,
967805
)
968806

969807
lora_result = lora_linear(torch.cat(inputs))[0]
@@ -1000,7 +838,6 @@ class FakeConfig:
1000838
id_to_index,
1001839
max_loras,
1002840
512,
1003-
lora_config.lora_extra_vocab_size,
1004841
)
1005842

1006843
lora_result = lora_linear(torch.cat(inputs))[0]

0 commit comments

Comments
 (0)