Skip to content

Commit 3da495d

Browse files
dolpmjeejeelee
authored andcommitted
[fix] lora benchmarks pass no_lora_flag_cpu (vllm-project#23774)
Signed-off-by: Dylan Maloy <[email protected]> Co-authored-by: Jee Jee Li <[email protected]> Signed-off-by: charlifu <[email protected]>
1 parent b0b2bc0 commit 3da495d

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

benchmarks/kernels/benchmark_lora.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,11 @@ def to_device(tensor: torch.Tensor):
464464
for field_name in LoRAKernelMeta.__dataclass_fields__:
465465
field = getattr(self.lora_kernel_meta, field_name)
466466
assert isinstance(field, torch.Tensor)
467-
setattr(self.lora_kernel_meta, field_name, to_device(field))
467+
setattr(
468+
self.lora_kernel_meta,
469+
field_name,
470+
to_device(field) if field_name != "no_lora_flag_cpu" else field,
471+
)
468472

469473
def metadata(self) -> tuple[int, int, int]:
470474
"""
@@ -512,6 +516,7 @@ def as_lora_shrink_kwargs(self) -> dict[str, Any]:
512516
"lora_token_start_loc": self.lora_kernel_meta.lora_token_start_loc,
513517
"lora_ids": self.lora_kernel_meta.active_lora_ids,
514518
"scaling": 1.0,
519+
"no_lora_flag_cpu": self.lora_kernel_meta.no_lora_flag_cpu,
515520
}
516521

517522
def as_lora_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]:
@@ -552,6 +557,7 @@ def as_lora_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]:
552557
"lora_ids": self.lora_kernel_meta.active_lora_ids,
553558
"offset_start": 0,
554559
"add_inputs": add_inputs,
560+
"no_lora_flag_cpu": self.lora_kernel_meta.no_lora_flag_cpu,
555561
}
556562

557563
def bench_fn_kwargs(

0 commit comments

Comments
 (0)