diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 3859d8039b..280558e636 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -798,6 +798,14 @@ def test_int8_weight_only_quant_subclass_api(self, device, dtype): _int8wo_api, device, 40, test_dtype=dtype ) + @parameterized.expand(COMMON_DEVICE_DTYPE) + @torch._inductor.config.patch({"freezing": True}) + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "freeze requires torch 2.4 and after.") + def test_int8_weight_only_quant_with_freeze(self, device, dtype): + self._test_lin_weight_subclass_api_impl( + _int8wo_api, device, 40, test_dtype=dtype + ) + @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.") def test_int4_weight_only_quant_subclass_api(self, device, dtype):