@@ -136,7 +136,6 @@ def populate_loras(
136136 id_to_index : list [int | None ],
137137 layer : BaseLayerWithLoRA ,
138138 layer_weights : torch .Tensor ,
139- generate_embeddings_tensor : int = 0 ,
140139 repeats : int = 1 ,
141140) -> tuple [dict [int , LoRALayerWeights ], dict [int , list [LoRALayerWeights ]]]:
142141 """This method populates the lora layers with lora weights.
@@ -148,8 +147,6 @@ def populate_loras(
148147 layer: the LoRAlayer to populate.
149148 layer_weights: the PyTorch tensor containing the layer's
150149 weights.
151- generate_embeddings_tensor: whether to generate an
152- embeddings tensor for each LoRA.
153150 repeats: must only be set for column parallel packed
154151 layers. Indicates the number of loras to compose
155152 together to create a single lora layer.
@@ -171,7 +168,6 @@ def populate_loras(
171168 sublora = DummyLoRAManager (layer_weights .device ).init_random_lora (
172169 module_name = f"fake_{ i } " ,
173170 weight = layer_weights ,
174- generate_embeddings_tensor = generate_embeddings_tensor ,
175171 )
176172 sublora .lora_b = sublora .lora_b [
177173 (sublora_len * i ) : (sublora_len * (i + 1 )), :
@@ -185,7 +181,6 @@ def populate_loras(
185181 slot_idx ,
186182 lora_a = lora .lora_a ,
187183 lora_b = lora .lora_b ,
188- embeddings_tensor = lora .embeddings_tensor ,
189184 )
190185
191186 lora_dict [lora_id ] = lora
@@ -306,7 +301,6 @@ def create_random_embedding_layer():
306301 id_to_index ,
307302 max_loras ,
308303 vocab_size ,
309- lora_config .lora_extra_vocab_size ,
310304 )
311305
312306 lora_result = lora_embedding (torch .cat (inputs ))
@@ -344,7 +338,6 @@ def create_random_embedding_layer():
344338 id_to_index ,
345339 max_loras ,
346340 vocab_size ,
347- lora_config .lora_extra_vocab_size ,
348341 )
349342
350343 lora_result = lora_embedding (torch .cat (inputs ))
@@ -354,149 +347,6 @@ def create_random_embedding_layer():
354347 torch .testing .assert_close (lora_result , expected_result , rtol = rtol , atol = atol )
355348
356349
357- @torch .inference_mode ()
358- # @pytest.mark.skip(
359- # reason="Fails when loras are in any slot other than the first.")
360- @pytest .mark .parametrize ("num_loras" , [1 , 2 , 4 ])
361- @pytest .mark .parametrize ("device" , DEVICES )
362- @pytest .mark .parametrize ("vocab_size" , [512 , 32000 , 64000 , 128000 ])
363- @pytest .mark .parametrize ("stage" , STAGES )
364- def test_embeddings_with_new_embeddings (
365- dist_init , num_loras , device , vocab_size , stage
366- ) -> None :
367- if current_platform .is_cuda_alike ():
368- torch .cuda .set_device (device )
369-
370- torch .set_default_device (device )
371- max_loras = 8
372- punica_wrapper = get_punica_wrapper (8192 , 256 , device , max_loras = max_loras )
373- assert check_punica_wrapper (punica_wrapper )
374- lora_config = LoRAConfig (
375- max_loras = max_loras , max_lora_rank = 8 , lora_dtype = torch .float16
376- )
377-
378- def create_random_embedding_layer ():
379- embedding = VocabParallelEmbedding (vocab_size , 256 )
380- embedding_data = torch .rand_like (embedding .weight .data )
381- embedding .weight .data = embedding_data
382- embedding .weight .data [vocab_size :, :] = 0
383- expanded_embedding = VocabParallelEmbedding (
384- vocab_size + lora_config .lora_extra_vocab_size * max_loras ,
385- 256 ,
386- org_num_embeddings = vocab_size ,
387- )
388- expanded_embedding .weight .data [:vocab_size , :] = embedding_data
389- # We need to deepcopy the embedding as it will be modified
390- # in place
391- lora_embedding = VocabParallelEmbeddingWithLoRA (deepcopy (expanded_embedding ))
392- lora_embedding .create_lora_weights (max_loras , lora_config )
393-
394- return expanded_embedding , lora_embedding
395-
396- for i in range (NUM_RANDOM_SEEDS ):
397- set_random_seed (i )
398-
399- id_to_index = get_random_id_to_index (num_loras , max_loras )
400- expanded_embedding , lora_embedding = create_random_embedding_layer ()
401- lora_dict , _ = populate_loras (
402- id_to_index ,
403- layer = lora_embedding ,
404- layer_weights = torch .zeros (
405- (256 , vocab_size + lora_config .lora_extra_vocab_size )
406- ),
407- generate_embeddings_tensor = 256 ,
408- )
409-
410- lora_embedding .set_mapping (punica_wrapper )
411- # All embeddings tensors have the same shape.
412- embeddings_tensors = [
413- lora_dict [id ].embeddings_tensor for id in sorted (lora_dict .keys ())
414- ]
415- embeddings_tensor_len = embeddings_tensors [0 ].shape [0 ]
416-
417- # Add empty embeddings_tensors for unoccupied lora slots.
418- for _ in range (max_loras - len (embeddings_tensors )):
419- embeddings_tensors .append (torch .zeros (embeddings_tensors [0 ].shape ))
420-
421- inputs , index_mapping , prompt_mapping = create_random_inputs (
422- active_lora_ids = list (lora_dict .keys ()),
423- num_inputs = num_loras * 3 ,
424- input_size = (200 ,),
425- input_range = (1 , vocab_size ),
426- device = device ,
427- )
428- lora_mapping = LoRAMapping (index_mapping , prompt_mapping , is_prefill = stage )
429- punica_wrapper .update_metadata (
430- lora_mapping ,
431- id_to_index ,
432- max_loras ,
433- vocab_size ,
434- lora_config .lora_extra_vocab_size ,
435- )
436- original_inputs = deepcopy (inputs )
437-
438- # Force some of the inputs to be in the extended embeddings range
439- # to guarantee that their behavior is tested.
440- for input_ , original_input_ , lora_id in zip (
441- inputs , original_inputs , prompt_mapping
442- ):
443- embedding_id = lora_id - 1
444- input_ [- 1 ] = vocab_size + (embedding_id * embeddings_tensor_len )
445- original_input_ [- 1 ] = vocab_size
446- input_ [- 2 ] = vocab_size + ((embedding_id + 1 ) * embeddings_tensor_len - 1 )
447- original_input_ [- 2 ] = vocab_size + embeddings_tensor_len - 1
448-
449- expanded_embedding .weight [
450- vocab_size : vocab_size + (embeddings_tensor_len * max_loras )
451- ] = torch .cat (embeddings_tensors )
452-
453- lora_result = lora_embedding (torch .cat (original_inputs ))
454-
455- expected_results : list [torch .Tensor ] = []
456- for input_ , original_input_ , lora_id in zip (
457- inputs , original_inputs , prompt_mapping
458- ):
459- lora = lora_dict [lora_id ]
460- result = expanded_embedding (input_ )
461- after_a = F .embedding (
462- original_input_ ,
463- lora .lora_a .T ,
464- )
465- result += after_a @ lora .lora_b .T
466- expected_results .append (result )
467- expected_result = torch .cat (expected_results )
468-
469- rtol , atol = TOLERANCES [lora_result .dtype ]
470- torch .testing .assert_close (lora_result , expected_result , rtol = rtol , atol = atol )
471-
472- # Check that resetting the lora weights succeeds
473-
474- for slot_idx in range (max_loras ):
475- lora_embedding .reset_lora (slot_idx )
476-
477- inputs , index_mapping , prompt_mapping = create_random_inputs (
478- active_lora_ids = [0 ],
479- num_inputs = num_loras * 3 ,
480- input_size = (200 ,),
481- input_range = (1 , vocab_size ),
482- device = device ,
483- )
484- original_inputs = deepcopy (inputs )
485- lora_mapping = LoRAMapping (index_mapping , prompt_mapping , is_prefill = stage )
486- punica_wrapper .update_metadata (
487- lora_mapping ,
488- id_to_index ,
489- max_loras ,
490- vocab_size ,
491- lora_config .lora_extra_vocab_size ,
492- )
493- lora_result = lora_embedding (torch .cat (original_inputs ))
494- expected_result = expanded_embedding (torch .cat (inputs ))
495-
496- rtol , atol = TOLERANCES [lora_result .dtype ]
497- torch .testing .assert_close (lora_result , expected_result , rtol = rtol , atol = atol )
498-
499-
500350@torch .inference_mode ()
501351@pytest .mark .parametrize ("num_loras" , [1 , 2 , 4 ])
502352@pytest .mark .parametrize ("device" , DEVICES )
@@ -518,16 +368,13 @@ def test_lm_head_logits_processor(
518368
519369 def _pretest ():
520370 linear = ParallelLMHead (
521- vocab_size + lora_config .lora_extra_vocab_size ,
522- 1024 ,
523- vocab_size ,
371+ num_embeddings = vocab_size ,
372+ embedding_dim = 1024 ,
524373 params_dtype = torch .float16 ,
525374 )
526375 linear .weight .data = torch .rand_like (linear .weight .data )
527376 linear .weight .data [:, vocab_size :] = 0
528- logits_processor = LogitsProcessor (
529- vocab_size + lora_config .lora_extra_vocab_size , vocab_size
530- )
377+ logits_processor = LogitsProcessor (vocab_size )
531378 lora_logits_processor = LogitsProcessorWithLoRA (
532379 logits_processor , 1024 , linear .weight .dtype , linear .weight .device , None
533380 )
@@ -541,15 +388,12 @@ def _pretest():
541388 id_to_index = get_random_id_to_index (num_loras , max_loras )
542389 linear , logits_processor , lora_logits_processor = _pretest ()
543390 lora_logits_processor .set_mapping (punica_wrapper )
544- # NOTE: all the generated loras share the same embeddings tensor.
391+
545392 lora_dict , _ = populate_loras (
546393 id_to_index ,
547394 layer = lora_logits_processor ,
548395 layer_weights = linear .weight ,
549- generate_embeddings_tensor = 1024 ,
550396 )
551- embeddings_tensor = list (lora_dict .values ())[0 ].embeddings_tensor
552- embeddings_tensor_len = embeddings_tensor .shape [0 ]
553397
554398 inputs , index_mapping , prompt_mapping = create_random_inputs (
555399 active_lora_ids = list (lora_dict .keys ()),
@@ -565,7 +409,6 @@ def _pretest():
565409 id_to_index ,
566410 max_loras ,
567411 vocab_size ,
568- lora_config .lora_extra_vocab_size ,
569412 )
570413 input_ = torch .rand (20 , 1024 )
571414
@@ -575,23 +418,16 @@ def _pretest():
575418
576419 original_lm_head = deepcopy (linear )
577420
578- linear .weight [
579- logits_processor .org_vocab_size : logits_processor .org_vocab_size
580- + embeddings_tensor_len
581- ] = embeddings_tensor
582-
583- logits_processor .org_vocab_size = vocab_size + lora_config .lora_extra_vocab_size
584421 expected_results : list [torch .Tensor ] = []
585422 for input_ , lora_id in zip (inputs , prompt_mapping ):
586423 lora = lora_dict [lora_id ]
587424 result = logits_processor ._get_logits (
588425 hidden_states = input_ , lm_head = linear , embedding_bias = None
589426 )
590- result [:, vocab_size + embeddings_tensor_len :] = float ( "-inf" )
427+
591428 result += input_ @ lora .lora_a .T @ lora .lora_b .T * lora .scaling
592429 expected_results .append (result )
593430 expected_result = torch .cat (expected_results )
594- logits_processor .org_vocab_size = vocab_size
595431
596432 # Check that resetting the lora weights succeeds
597433
@@ -612,7 +448,6 @@ def _pretest():
612448 id_to_index ,
613449 max_loras ,
614450 vocab_size ,
615- lora_config .lora_extra_vocab_size ,
616451 )
617452
618453 lora_result = lora_logits_processor ._get_logits (
@@ -694,7 +529,6 @@ def create_random_linear_replicated_layer():
694529 id_to_index ,
695530 max_loras ,
696531 512 ,
697- lora_config .lora_extra_vocab_size ,
698532 )
699533
700534 lora_result = lora_linear (torch .cat (inputs ))[0 ]
@@ -726,7 +560,10 @@ def create_random_linear_replicated_layer():
726560 lora_mapping = LoRAMapping (index_mapping , prompt_mapping , is_prefill = stage )
727561
728562 punica_wrapper .update_metadata (
729- lora_mapping , id_to_index , max_loras , 512 , lora_config .lora_extra_vocab_size
563+ lora_mapping ,
564+ id_to_index ,
565+ max_loras ,
566+ 512 ,
730567 )
731568
732569 lora_result = lora_linear (torch .cat (inputs ))[0 ]
@@ -817,7 +654,6 @@ def create_random_linear_parallel_layer():
817654 id_to_index ,
818655 max_loras ,
819656 512 ,
820- lora_config .lora_extra_vocab_size ,
821657 )
822658
823659 lora_result = lora_linear (torch .cat (inputs ))[0 ]
@@ -849,7 +685,10 @@ def create_random_linear_parallel_layer():
849685 lora_mapping = LoRAMapping (index_mapping , prompt_mapping , is_prefill = stage )
850686
851687 punica_wrapper .update_metadata (
852- lora_mapping , id_to_index , max_loras , 512 , lora_config .lora_extra_vocab_size
688+ lora_mapping ,
689+ id_to_index ,
690+ max_loras ,
691+ 512 ,
853692 )
854693
855694 lora_result = lora_linear (torch .cat (inputs ))[0 ]
@@ -963,7 +802,6 @@ class FakeConfig:
963802 id_to_index ,
964803 max_loras ,
965804 512 ,
966- lora_config .lora_extra_vocab_size ,
967805 )
968806
969807 lora_result = lora_linear (torch .cat (inputs ))[0 ]
@@ -1000,7 +838,6 @@ class FakeConfig:
1000838 id_to_index ,
1001839 max_loras ,
1002840 512 ,
1003- lora_config .lora_extra_vocab_size ,
1004841 )
1005842
1006843 lora_result = lora_linear (torch .cat (inputs ))[0 ]
0 commit comments