Skip to content

Commit f75ff40

Browse files
authored
support auto_host2device on RTN and GPTQ(#1894)
Signed-off-by: He, Xin3 <[email protected]>
1 parent b9e73f5 commit f75ff40

File tree

6 files changed

+65
-13
lines changed

6 files changed

+65
-13
lines changed

neural_compressor/torch/algorithms/weight_only/gptq.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,13 @@
2727
import torch.nn as nn
2828
from tqdm import tqdm
2929

30-
from neural_compressor.torch.utils import get_accelerator, is_transformers_imported, logger, set_module
30+
from neural_compressor.torch.utils import (
31+
get_accelerator,
32+
get_model_device,
33+
is_transformers_imported,
34+
logger,
35+
set_module,
36+
)
3137
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator
3238

3339
from .modules import WeightOnlyLinear
@@ -995,6 +1001,7 @@ def prepare(
9951001
if use_layer_wise: # pragma: no cover
9961002
assert model_path is not None, "model_path should not be None when use layer wise mode"
9971003

1004+
self.model_device = get_model_device(model) # return model on the same device
9981005
self.gptq_quantizer = RAWGPTQuantizer(
9991006
model,
10001007
weight_config=self.quant_config,
@@ -1013,6 +1020,7 @@ def convert(self, model, *args, **kwargs):
10131020
self.gptq_quantizer.model = model
10141021
self.gptq_quantizer.remove_prepare_for_calibration()
10151022
q_model, gptq_config = self.gptq_quantizer.execute_quantization()
1023+
q_model = q_model.to(self.model_device)
10161024
q_model.gptq_config = gptq_config
10171025
logger.info("GPTQ quantizing done.")
10181026
return q_model

neural_compressor/torch/algorithms/weight_only/modules.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -270,8 +270,8 @@ def recover(self):
270270

271271
def pack_tensor_with_torch(self, raw_tensor):
272272
target_len = math.ceil(raw_tensor.shape[1] / self.n_pack)
273-
packed_tensor = torch.zeros(raw_tensor.shape[0], target_len, dtype=self.compression_dtype).to(self.device)
274-
mask = torch.tensor(2**self.bits - 1, dtype=self.compression_dtype).to(self.device)
273+
packed_tensor = torch.zeros(raw_tensor.shape[0], target_len, dtype=self.compression_dtype).to(raw_tensor.device)
274+
mask = torch.tensor(2**self.bits - 1, dtype=self.compression_dtype).to(raw_tensor.device)
275275
for j in range(packed_tensor.shape[1]):
276276
start = self.n_pack * j
277277
end = self.n_pack * (j + 1)
@@ -286,8 +286,8 @@ def pack_tensor_with_torch(self, raw_tensor):
286286
def unpack_tensor_with_torch(self, packed_tensor):
287287
target_dtype = torch.int8 if not hasattr(self, "qzeros") or "int" not in self.dtype else torch.uint8
288288
target_len = packed_tensor.shape[1] * self.n_pack
289-
unpacked_tensor = torch.zeros(packed_tensor.shape[0], target_len, dtype=target_dtype).to(self.device)
290-
mask = torch.tensor(2**self.bits - 1, dtype=self.compression_dtype).to(self.device)
289+
unpacked_tensor = torch.zeros(packed_tensor.shape[0], target_len, dtype=target_dtype).to(packed_tensor.device)
290+
mask = torch.tensor(2**self.bits - 1, dtype=self.compression_dtype).to(packed_tensor.device)
291291
for j in range(packed_tensor.shape[1]):
292292
for e in range(self.n_pack):
293293
index = j * self.n_pack + e
@@ -338,13 +338,13 @@ def unpack_tensor_with_numpy(self, packed_tensor):
338338
return unpacked_tensor
339339

340340
def pack_tensor(self, raw_tensor):
341-
if "cuda" in self.device:
341+
if "cuda" in raw_tensor.device.type:
342342
return self.pack_tensor_with_torch(raw_tensor)
343343
else:
344344
return self.pack_tensor_with_numpy(raw_tensor)
345345

346346
def unpack_tensor(self, packed_tensor):
347-
if "cuda" in self.device:
347+
if "cuda" in packed_tensor.device.type:
348348
return self.unpack_tensor_with_torch(packed_tensor)
349349
else:
350350
return self.unpack_tensor_with_numpy(packed_tensor)

neural_compressor/torch/algorithms/weight_only/rtn.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from neural_compressor.torch.utils import (
2929
get_accelerator,
3030
get_attr,
31+
get_model_device,
3132
is_transformers_imported,
3233
logger,
3334
set_attr,
@@ -99,10 +100,7 @@ def convert(
99100
"""
100101
weight_config = self.quant_config
101102
device = get_accelerator(kwargs.pop("device", "auto")).current_device_name()
102-
103-
# Put model on device explicitly
104-
# TODO: refine it later, Put module on device one by one instead of the whole model
105-
model.to(device)
103+
model_device = get_model_device(model) # return model on the same device
106104

107105
# for transformers model. If lm_head is tied from embedding, we deepcopy it.
108106
if quant_lm_head and getattr(getattr(model, "config", None), "tie_word_embeddings", False):
@@ -132,6 +130,8 @@ def convert(
132130
dtype = weight_config[name].get("dtype", "int")
133131
if dtype == "fp32":
134132
continue
133+
# Move modules to the accelerator device layer-by-layer
134+
m.to(device)
135135
### FP8 cast part
136136
if dtype in ["fp8_e5m2", "fp8_e5m2fnuz", "fp8_e4m3fn", "fp8_e4m3fnuz"]:
137137
logger.debug("Cast module {} to FP8 using qdq mode, no scaling".format(name))
@@ -223,4 +223,8 @@ def convert(
223223
return new_module
224224
else:
225225
set_module(model, name, new_module)
226+
# Move modules back to the model device layer-by-layer
227+
m.to(model_device)
228+
new_module.to(model_device)
229+
model.to(model_device)
226230
return model

neural_compressor/torch/utils/utility.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,3 +265,16 @@ def dump_model_op_stats(mode, tune_cfg):
265265
output_data.append(field_results)
266266

267267
Statistics(output_data, header="Mixed Precision Statistics", field_names=field_names).print_stat()
268+
269+
270+
def get_model_device(model: torch.nn.Module):
271+
"""Get the device.
272+
273+
Args:
274+
model (torch.nn.Module): the input model.
275+
276+
Returns:
277+
device (str): a string.
278+
"""
279+
for n, p in model.named_parameters():
280+
return p.data.device.type # p.data.device == device(type='cpu')

test/3x/torch/quantization/weight_only/test_gptq.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,20 @@ def setup_class(self):
3636
def teardown_class(self):
3737
shutil.rmtree("saved_results", ignore_errors=True)
3838

39+
@pytest.mark.skipif(device == "cpu", reason="no available accelerator")
40+
def test_auto_host2device(self):
41+
# if model is on CPU, we move it to device layer-by-layer for acceleration,
42+
# and then move it back to CPU after quantization.
43+
model = copy.deepcopy(self.tiny_gptj).to("cpu")
44+
example_inputs = copy.deepcopy(self.example_inputs).to("cpu")
45+
quant_config = get_default_gptq_config()
46+
model = prepare(model, quant_config)
47+
run_fn(model)
48+
model = convert(model)
49+
gptq_label = model(example_inputs)[0]
50+
gptq_atol = (gptq_label - self.label.to("cpu")).amax()
51+
assert gptq_atol < 0.06, "GPTQ should have low atol."
52+
3953
def test_accuracy_improvement(self):
4054
# test_default_rtn_config
4155
model = copy.deepcopy(self.tiny_gptj)
@@ -215,9 +229,9 @@ def test_conv1d(self):
215229
from transformers import GPT2Model, GPT2Tokenizer
216230

217231
tokenizer = GPT2Tokenizer.from_pretrained("sshleifer/tiny-gpt2")
218-
model = GPT2Model.from_pretrained("sshleifer/tiny-gpt2")
232+
model = GPT2Model.from_pretrained("sshleifer/tiny-gpt2").to(device)
219233
text = "Replace me by any text you'd like."
220-
encoded_input = tokenizer(text, return_tensors="pt")
234+
encoded_input = tokenizer(text, return_tensors="pt").to(device)
221235

222236
def run_fn_conv1d(model):
223237
model(**encoded_input)

test/3x/torch/quantization/weight_only/test_rtn.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,3 +352,16 @@ def mock_is_transformers_imported():
352352
model = convert(model)
353353
out = model(self.example_inputs)[0]
354354
assert torch.allclose(out, self.label, atol=1e-1), "Accuracy gap atol > 0.1 is unexpected."
355+
356+
@pytest.mark.skipif(device == "cpu", reason="no available accelerator")
357+
def test_auto_host2device(self):
358+
# if model is on CPU, we move it to device layer-by-layer for acceleration,
359+
# and then move it back to CPU after quantization.
360+
model = copy.deepcopy(self.tiny_gptj).to("cpu")
361+
example_inputs = copy.deepcopy(self.example_inputs).to("cpu")
362+
quant_config = get_default_rtn_config()
363+
model = prepare(model, quant_config)
364+
model = convert(model)
365+
rtn_label = model(example_inputs)[0]
366+
rtn_atol = (rtn_label - self.label.to("cpu")).amax()
367+
assert rtn_atol < 0.08, "RTN should have low atol."

0 commit comments

Comments
 (0)