1515)
1616from torchao .prototype .smoothquant .core import SmoothQuantStep
1717from torchao .quantization import quantize_
18+ from torchao .quantization .linear_activation_scale import (
19+ WeightTensorWithLinearActivationScaleMetadata ,
20+ )
1821from torchao .quantization .quant_api import (
1922 Int8DynamicActivationInt8WeightConfig ,
2023)
24+ from torchao .quantization .utils import (
25+ compute_error as SQNR ,
26+ )
2127
2228
2329class ToyLinearModel (torch .nn .Module ):
@@ -34,16 +40,19 @@ def example_inputs(
3440 dtype = torch .bfloat16 ,
3541 device = "cuda" ,
3642 ):
37- return [
38- torch .randn (
39- 1 ,
40- sequence_length ,
41- self .linear1 .in_features ,
42- dtype = dtype ,
43- device = device ,
44- )
45- for j in range (batch_size )
46- ]
43+ # For SmoothQuant tests, we intentionally insert some outliers to input features
44+ x = torch .randn (
45+ batch_size ,
46+ sequence_length ,
47+ self .linear1 .in_features ,
48+ dtype = dtype ,
49+ device = device ,
50+ )
51+ n_outliers = max (1 , int (x .size (- 1 ) * 0.1 ))
52+ # Randomly select outlier features
53+ outlier_indices = torch .randperm (x .size (- 1 ))[:n_outliers ]
54+ x [:, :, outlier_indices ] *= 10.0
55+ return (x ,)
4756
4857 def forward (self , x ):
4958 x = self .linear1 (x )
@@ -52,7 +61,9 @@ def forward(self, x):
5261 return x
5362
5463
55- @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
64+ device_list = ["cpu" , "cuda" ] if torch .cuda .is_available () else ["cpu" ]
65+
66+
5667@unittest .skipIf (torch .version .hip is not None , "Skipping tests in ROCm" )
5768class TestSmoothQuant (unittest .TestCase ):
5869 """SmoothQuant tests using only supported quantization configs."""
@@ -72,37 +83,25 @@ def setUpClass(cls):
7283 # TODO(#1639): Fix for supporting more API in torchao/quantization/quant_api.py
7384 ],
7485 )
75- @common_utils .parametrize ("device" , [ "cpu" , "cuda" ] )
86+ @common_utils .parametrize ("device" , device_list )
7687 @common_utils .parametrize ("input_dtype" , [torch .bfloat16 ])
7788 def test_smoothquant_accuracy (self , alpha , base_config , device , input_dtype ):
7889 """Test if SmoothQuant achieves lower loss than basic quantization."""
79- in_features = 64
80- out_features = 128
81-
82- # Note: This is sanity check. For real run, consider Transformer model to reproduce.
83- X = torch .randn (16 , in_features , dtype = input_dtype , device = device )
84- W = torch .randn (out_features , in_features , dtype = input_dtype , device = device )
85-
8690 # Create linear layer
87- linear = (
88- torch .nn .Linear (in_features , out_features , bias = False )
89- .to (device )
90- .to (input_dtype )
91- )
92- with torch .no_grad ():
93- linear .weight .copy_ (W )
91+ m = ToyLinearModel ().eval ().to (device ).to (input_dtype )
92+ x = m .example_inputs (batch_size = 16 , dtype = input_dtype , device = device )
9493
9594 # Reference output
96- out_ref = linear ( X )
95+ out_ref = m ( * x )
9796
9897 # Step 1. Basic quantization
99- basic_model = deepcopy (linear )
98+ basic_model = deepcopy (m )
10099 quantize_ (basic_model , base_config )
101- out_basic = basic_model (X )
100+ out_basic = basic_model (* x )
102101 loss_base = torch .nn .functional .mse_loss (out_basic , out_ref ).item ()
103102
104- # SmoothQuant quantization
105- model = deepcopy (linear )
103+ # Step 2. SmoothQuant
104+ model = deepcopy (m )
106105 config = SmoothQuantConfig (
107106 base_config = base_config ,
108107 step = SmoothQuantStep .PREPARE ,
@@ -111,18 +110,25 @@ def test_smoothquant_accuracy(self, alpha, base_config, device, input_dtype):
111110 quantize_ (model , config )
112111
113112 # Perform calibration with test data
114- model (X )
113+ model (* x )
115114
116- # Step 2. SmoothQuant
117115 config .step = SmoothQuantStep .CONVERT
118116 quantize_ (model , config )
117+ assert isinstance (
118+ model .linear1 .weight , WeightTensorWithLinearActivationScaleMetadata
119+ )
120+ assert isinstance (
121+ model .linear2 .weight , WeightTensorWithLinearActivationScaleMetadata
122+ )
119123
120- out_smoothquant = model (X )
124+ out_smoothquant = model (* x )
121125 loss_smoothquant = torch .nn .functional .mse_loss (out_smoothquant , out_ref ).item ()
122126
123127 assert loss_smoothquant < loss_base , (
124128 f"SmoothQuant loss ({ loss_smoothquant :.6f} ) should not be higher than basic loss ({ loss_base :.6f} )"
125129 )
130+ # Make sure the result is reasonable
131+ self .assertGreater (SQNR (out_ref , out_smoothquant ), 20.0 )
126132
127133 @common_utils .parametrize (
128134 "base_config" ,
0 commit comments