@@ -464,7 +464,11 @@ def to_device(tensor: torch.Tensor):
464464 for field_name in LoRAKernelMeta .__dataclass_fields__ :
465465 field = getattr (self .lora_kernel_meta , field_name )
466466 assert isinstance (field , torch .Tensor )
467- setattr (self .lora_kernel_meta , field_name , to_device (field ))
467+ setattr (
468+ self .lora_kernel_meta ,
469+ field_name ,
470+ to_device (field ) if field_name != "no_lora_flag_cpu" else field ,
471+ )
468472
469473 def metadata (self ) -> tuple [int , int , int ]:
470474 """
@@ -512,6 +516,7 @@ def as_lora_shrink_kwargs(self) -> dict[str, Any]:
512516 "lora_token_start_loc" : self .lora_kernel_meta .lora_token_start_loc ,
513517 "lora_ids" : self .lora_kernel_meta .active_lora_ids ,
514518 "scaling" : 1.0 ,
519+ "no_lora_flag_cpu" : self .lora_kernel_meta .no_lora_flag_cpu ,
515520 }
516521
517522 def as_lora_expand_kwargs (self , add_inputs : bool ) -> dict [str , Any ]:
@@ -552,6 +557,7 @@ def as_lora_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]:
552557 "lora_ids" : self .lora_kernel_meta .active_lora_ids ,
553558 "offset_start" : 0 ,
554559 "add_inputs" : add_inputs ,
560+ "no_lora_flag_cpu" : self .lora_kernel_meta .no_lora_flag_cpu ,
555561 }
556562
557563 def bench_fn_kwargs (
0 commit comments