@@ -35,7 +35,7 @@ def get_config(group_size):
3535
3636@unittest .skipIf (not torch_version_at_least ("2.8.0" ), "Need pytorch 2.8+" )
3737@unittest .skipIf (not torch .xpu .is_available (), "XPU not available" )
38- class Int4PlainInt32Tensor (TestCase ):
38+ class Int4PlainInt32TensorXPU (TestCase ):
3939 @parametrize (
4040 "sizes" ,
4141 [
@@ -98,8 +98,75 @@ def test_activation_prescaling(self):
9898 self .assertTrue (compute_error (original * _ACT_PRE_SCALE , quantized ) > 20 )
9999
100100
101- instantiate_parametrized_tests (Int4PlainInt32Tensor )
101+ @unittest .skipIf (not torch_version_at_least ("2.7.1" ), "Need pytorch 2.7.1+" )
102+ @unittest .skipIf (
103+ torch .accelerator .current_accelerator ().type != "npu"
104+ or not torch .accelerator .is_available (),
105+ "NPU not available" ,
106+ )
107+ class Int4PlainInt32TensorNPU (TestCase ):
108+
109+ @parametrize ("device" , ["npu" ])
110+ @parametrize (
111+ "sizes" ,
112+ [
113+ ((128 ,), 256 , 128 ),
114+ ((32 , 128 ), 512 , 128 ),
115+ ((2 , 32 , 128 ), 256 , 128 ),
116+ ],
117+ )
118+ @parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
119+ @parametrize ("group_size" , [32 , 64 ])
120+ def test_linear (self , device , sizes , dtype , group_size ):
121+ M , N , K = sizes
122+ input = torch .randn (* M , K , dtype = dtype , device = device )
123+ linear = torch .nn .Linear (K , N , dtype = dtype , device = device )
124+ orig_output = linear (input )
125+ quantize_ (linear , get_config (group_size ))
126+ quantized_output = linear (input )
127+ self .assertTrue (compute_error (orig_output , quantized_output ) > 10 )
128+
129+ @parametrize ("device" , ["npu" ])
130+ @parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
131+ def test_module_path (self , device , dtype ):
132+ linear = torch .nn .Linear (128 , 256 , dtype = dtype , device = device )
133+ quantize_ (linear , get_config (group_size = 64 ))
134+ self .assertEqual (
135+ str (type (linear .weight )),
136+ "<class 'torchao.quantization.Int4PlainInt32Tensor'>" ,
137+ )
138+
139+ with tempfile .NamedTemporaryFile () as f :
140+ torch .save (linear .state_dict (), f )
141+ f .seek (0 )
142+ state_dict = torch .load (f )
143+ self .assertEqual (
144+ str (type (state_dict ["weight" ])),
145+ "<class 'torchao.quantization.Int4PlainInt32Tensor'>" ,
146+ )
147+
148+ @parametrize ("device" , ["npu" ])
149+ @parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
150+ def test_activation_prescaling (self , device , dtype ):
151+ input = torch .randn (1 , 128 , dtype = dtype , device = device )
152+ linear = torch .nn .Linear (128 , 256 , bias = False , dtype = dtype , device = device )
153+ original = linear (input )
154+ quantize_ (linear , get_config (64 ))
155+ qw = linear .weight
156+ assert isinstance (
157+ qw , SupportsActivationPreScaling
158+ ), "Expected int4 tensor supports activation prescaling"
159+ assert qw .act_pre_scale is None , "Default `act_pre_scale` is None"
160+ _ACT_PRE_SCALE = 2
161+ qw .act_pre_scale = _ACT_PRE_SCALE
162+ quantized = linear (input )
163+
164+ # making sure activation pre scaling is successfully applied to the activation
165+ self .assertTrue (compute_error (original * _ACT_PRE_SCALE , quantized ) > 10 )
166+
102167
168+ instantiate_parametrized_tests (Int4PlainInt32TensorXPU )
169+ instantiate_parametrized_tests (Int4PlainInt32TensorNPU )
103170
104171if __name__ == "__main__" :
105172 run_tests ()
0 commit comments