Skip to content

Commit 3bfb76d

Browse files
changwangssyiliu30
andauthored
Support PyTorch eager mode BF16 MixedPrecision (#1321)
Signed-off-by: changwangss <[email protected]> Signed-off-by: yiliu30 <[email protected]> Co-authored-by: yiliu30 <[email protected]>
1 parent dc9328c commit 3bfb76d

File tree

2 files changed

+5
-0
lines changed

2 files changed

+5
-0
lines changed

neural_compressor/adaptor/torch_utils/bf16_convert.py

+4
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ def __init__(self, module):
3030
super(BF16ModuleWrapper, self).__init__()
3131
self.add_module("module", module)
3232
self.train(module.training)
33+
# WA for TransformerEncoder to access its Linear's weights and bias
34+
if isinstance(module, nn.Linear):
35+
self.weight = self.module.weight if hasattr(self.module, "weight") else None
36+
self.bias = self.module.bias if hasattr(self.module, "bias") else None
3337

3438
def forward(self, X):
3539
"""Convert dtype."""

neural_compressor/strategy/strategy.py

+1
Original file line numberDiff line numberDiff line change
@@ -1519,6 +1519,7 @@ def _set_framework_info(self, q_dataloader, q_func=None):
15191519
elif self.config.backend == "default":
15201520
framework = "pytorch_fx"
15211521
if self.mixed_precision_mode:
1522+
framework = "pytorch"
15221523
framework_specific_info.update({"approach": "post_training_dynamic_quant"})
15231524
framework_specific_info.update({"recipes": self.config.recipes})
15241525
framework_specific_info.update({"q_dataloader": q_dataloader})

0 commit comments

Comments
 (0)