@@ -102,9 +102,9 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
102102 offset = lora_b .shape [0 ] // 2
103103
104104 left_weight = lora_b [tp_rank * shard_size :(tp_rank + 1 ) *
105- shard_size ,:]
105+ shard_size , :]
106106 right_weight = lora_b [offset + tp_rank * shard_size :offset +
107- (tp_rank + 1 ) * shard_size ,:]
107+ (tp_rank + 1 ) * shard_size , :]
108108 lora_b = torch .cat ([left_weight , right_weight ], dim = 0 )
109109 # Applicable to cases where the base_layer is
110110 # ColumnParallelLinear.
@@ -113,7 +113,7 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
113113 shard_size = self .output_size
114114 start_idx = tensor_model_parallel_rank * shard_size
115115 end_idx = (tensor_model_parallel_rank + 1 ) * shard_size
116- lora_b = lora_b [start_idx :end_idx ,:]
116+ lora_b = lora_b [start_idx :end_idx , :]
117117 return lora_b
118118
119119 def slice_bias (self , bias : torch .Tensor ) -> torch .Tensor :
@@ -252,7 +252,7 @@ def slice_lora_b(
252252 zip (self .output_ids , self .output_slices )):
253253 if (lora_b_i := lora_b [i ]) is not None :
254254 sliced_lora_b [i ] = lora_b_i [shard_size * shard_id :shard_size *
255- (shard_id + 1 ),:]
255+ (shard_id + 1 ), :]
256256 return sliced_lora_b
257257
258258 def slice_bias (
@@ -346,15 +346,15 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
346346 self .kv_shard_id = tp_rank // self .base_layer .num_kv_head_replicas
347347 lora_b_q = lora_b [self .q_proj_shard_size *
348348 self .q_shard_id :self .q_proj_shard_size *
349- (self .q_shard_id + 1 ),:]
349+ (self .q_shard_id + 1 ), :]
350350 k_offset = self .q_proj_total_size
351351 lora_b_k = lora_b [k_offset +
352352 self .kv_proj_shard_size * self .kv_shard_id :k_offset +
353- self .kv_proj_shard_size * (self .kv_shard_id + 1 ),:]
353+ self .kv_proj_shard_size * (self .kv_shard_id + 1 ), :]
354354 v_offset = k_offset + self .kv_proj_total_size
355- lora_b_v = lora_b [ v_offset +
355+ lora_b_v = lora_b [v_offset +
356356 self .kv_proj_shard_size * self .kv_shard_id :v_offset +
357- self .kv_proj_shard_size * (self .kv_shard_id + 1 ),:]
357+ self .kv_proj_shard_size * (self .kv_shard_id + 1 ), :]
358358 lora_b = torch .cat ([lora_b_q , lora_b_k , lora_b_v ], dim = 0 )
359359 return lora_b
360360
@@ -464,7 +464,7 @@ def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
464464 tp_rank = get_tensor_model_parallel_rank ()
465465 shard_size = self .lora_a_stacked [0 ].shape [2 ]
466466 start_idx = tp_rank * shard_size
467- lora_a = lora_a [start_idx :start_idx + shard_size ,:]
467+ lora_a = lora_a [start_idx :start_idx + shard_size , :]
468468 return lora_a
469469
470470 def apply (self ,
@@ -507,10 +507,10 @@ def slice_lora_a(
507507 output_shard_size = self .lora_a_stacked [0 ].shape [2 ]
508508 output_start_idx = self .tp_rank * output_shard_size
509509 lora_a = [
510- lora_a [0 ][ output_start_idx :output_start_idx +
511- output_shard_size ,:] if lora_a [0 ] is not None else None ,
510+ lora_a [0 ][output_start_idx :output_start_idx +
511+ output_shard_size , :] if lora_a [0 ] is not None else None ,
512512 lora_a [1 ][output_start_idx :output_start_idx +
513- output_shard_size ,:] if lora_a [1 ] is not None else None ,
513+ output_shard_size , :] if lora_a [1 ] is not None else None ,
514514 ]
515515 return lora_a
516516
@@ -550,7 +550,7 @@ def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
550550 tp_rank = get_tensor_model_parallel_rank ()
551551 shard_size = self .lora_a_stacked [0 ].shape [2 ]
552552 start_idx = tp_rank * shard_size
553- lora_a = lora_a [start_idx :start_idx + shard_size ,:]
553+ lora_a = lora_a [start_idx :start_idx + shard_size , :]
554554 return lora_a
555555
556556 def apply (self ,
@@ -589,11 +589,11 @@ def slice_lora_a(
589589 start_idx = [self .tp_rank * shard_size [i ] for i in range (3 )]
590590 lora_a = [
591591 lora_a [0 ][start_idx [0 ]:start_idx [0 ] +
592- shard_size [0 ],:] if lora_a [0 ] is not None else None ,
592+ shard_size [0 ], :] if lora_a [0 ] is not None else None ,
593593 lora_a [1 ][start_idx [1 ]:start_idx [1 ] +
594- shard_size [1 ],:] if lora_a [1 ] is not None else None ,
594+ shard_size [1 ], :] if lora_a [1 ] is not None else None ,
595595 lora_a [2 ][start_idx [2 ]:start_idx [2 ] +
596- shard_size [2 ],:] if lora_a [2 ] is not None else None ,
596+ shard_size [2 ], :] if lora_a [2 ] is not None else None ,
597597 ]
598598 return lora_a
599599
0 commit comments