@@ -270,8 +270,8 @@ def recover(self):
270
270
271
271
def pack_tensor_with_torch (self , raw_tensor ):
272
272
target_len = math .ceil (raw_tensor .shape [1 ] / self .n_pack )
273
- packed_tensor = torch .zeros (raw_tensor .shape [0 ], target_len , dtype = self .compression_dtype ).to (self .device )
274
- mask = torch .tensor (2 ** self .bits - 1 , dtype = self .compression_dtype ).to (self .device )
273
+ packed_tensor = torch .zeros (raw_tensor .shape [0 ], target_len , dtype = self .compression_dtype ).to (raw_tensor .device )
274
+ mask = torch .tensor (2 ** self .bits - 1 , dtype = self .compression_dtype ).to (raw_tensor .device )
275
275
for j in range (packed_tensor .shape [1 ]):
276
276
start = self .n_pack * j
277
277
end = self .n_pack * (j + 1 )
@@ -286,8 +286,8 @@ def pack_tensor_with_torch(self, raw_tensor):
286
286
def unpack_tensor_with_torch (self , packed_tensor ):
287
287
target_dtype = torch .int8 if not hasattr (self , "qzeros" ) or "int" not in self .dtype else torch .uint8
288
288
target_len = packed_tensor .shape [1 ] * self .n_pack
289
- unpacked_tensor = torch .zeros (packed_tensor .shape [0 ], target_len , dtype = target_dtype ).to (self .device )
290
- mask = torch .tensor (2 ** self .bits - 1 , dtype = self .compression_dtype ).to (self .device )
289
+ unpacked_tensor = torch .zeros (packed_tensor .shape [0 ], target_len , dtype = target_dtype ).to (packed_tensor .device )
290
+ mask = torch .tensor (2 ** self .bits - 1 , dtype = self .compression_dtype ).to (packed_tensor .device )
291
291
for j in range (packed_tensor .shape [1 ]):
292
292
for e in range (self .n_pack ):
293
293
index = j * self .n_pack + e
@@ -338,13 +338,13 @@ def unpack_tensor_with_numpy(self, packed_tensor):
338
338
return unpacked_tensor
339
339
340
340
def pack_tensor (self , raw_tensor ):
341
- if "cuda" in self .device :
341
+ if "cuda" in raw_tensor .device . type :
342
342
return self .pack_tensor_with_torch (raw_tensor )
343
343
else :
344
344
return self .pack_tensor_with_numpy (raw_tensor )
345
345
346
346
def unpack_tensor (self , packed_tensor ):
347
- if "cuda" in self .device :
347
+ if "cuda" in packed_tensor .device . type :
348
348
return self .unpack_tensor_with_torch (packed_tensor )
349
349
else :
350
350
return self .unpack_tensor_with_numpy (packed_tensor )
0 commit comments