Skip to content

Commit c01169a

Browse files
committed
Add int8 and fpx test to TensorParallel
1 parent 6ffe236 commit c01169a

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

test/dtypes/test_affine_quantized_tensor_parallel.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
float8_weight_only,
1515
int4_weight_only,
1616
int8_weight_only,
17+
int8_dynamic_activation_int8_weight,
18+
fpx_weight_only,
1719
)
1820
from torchao.quantization.observer import PerRow, PerTensor
1921
from torchao.quantization.quant_api import quantize_
@@ -166,9 +168,33 @@ def test_tp_gemlite(self, dtype):
166168
return self._test_tp(dtype)
167169

168170

171+
class TestInt8dqAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel):
172+
QUANT_METHOD_FN = staticmethod(int8_dynamic_activation_int8_weight)
173+
COMMON_DTYPES = [torch.bfloat16]
174+
175+
@common_utils.parametrize("dtype", COMMON_DTYPES)
176+
@with_comms
177+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
178+
def test_tp(self, dtype):
179+
return self._test_tp(dtype)
180+
181+
182+
class TestFpxwoAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel):
183+
QUANT_METHOD_FN = staticmethod(fpx_weight_only)
184+
COMMON_DTYPES = [torch.bfloat16]
185+
186+
@common_utils.parametrize("dtype", COMMON_DTYPES)
187+
@with_comms
188+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
189+
def test_tp(self, dtype):
190+
return self._test_tp(dtype)
191+
192+
169193
common_utils.instantiate_parametrized_tests(TestInt8woAffineQuantizedTensorParallel)
170194
common_utils.instantiate_parametrized_tests(TestInt4woAffineQuantizedTensorParallel)
171195
common_utils.instantiate_parametrized_tests(TestGemliteLayoutTensorParallel)
196+
common_utils.instantiate_parametrized_tests(TestInt8dqAffineQuantizedTensorParallel)
197+
common_utils.instantiate_parametrized_tests(TestFpxwoAffineQuantizedTensorParallel)
172198

173199
# Run only on H100
174200
if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0):

0 commit comments

Comments
 (0)