Skip to content

Commit 95e792f

Browse files
torch get device with index of CUDA_VISIBLE_DEVICES, not value of it (#1096)
* torch get device with index of CUDA_VISIBLE_DEVICES, not value of it * revert rocm text
1 parent 620fcf1 commit 95e792f

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

gptqmodel/nn_modules/qlinear/marlin.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -322,8 +322,7 @@ def validate_device(cls, device: DEVICE):
322322
if CUDA_VISIBLE_DEVICES is None:
323323
has_cuda_v8 = all(torch.cuda.get_device_capability(i)[0] >= 8 for i in range(torch.cuda.device_count()))
324324
else:
325-
has_cuda_v8 = all(torch.cuda.get_device_capability(int(i))[0] >= 8 for i in CUDA_VISIBLE_DEVICES.split(","))
326-
325+
has_cuda_v8 = all(torch.cuda.get_device_capability(i)[0] >= 8 for i in range(len(CUDA_VISIBLE_DEVICES.split(","))))
327326
if not has_cuda_v8:
328327
raise NotImplementedError("Marlin kernel only supports compute capability >= 8.0.")
329328

0 commit comments

Comments
 (0)