Skip to content

Commit b7d6069

Browse files
committed
Update
[ghstack-poisoned]
2 parents 1c8adb4 + 4463b79 commit b7d6069

File tree

5 files changed

+10
-6
lines changed

5 files changed

+10
-6
lines changed

test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
)
2222
from torchao.quantization.utils import compute_error
2323
from torchao.sparsity.sparse_api import apply_fake_sparsity
24+
from torchao.testing.utils import skip_if_rocm
2425
from torchao.utils import (
2526
TORCH_VERSION_AT_LEAST_2_8,
2627
)
@@ -38,6 +39,7 @@ class TestInt4MarlinSparseTensor(TestCase):
3839
def setUp(self):
3940
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []
4041

42+
@skip_if_rocm("ROCm enablement in progress")
4143
@parametrize("config", [BF16_ACT_CONFIG])
4244
@parametrize(
4345
"sizes",
@@ -65,6 +67,7 @@ def test_linear(self, config, sizes):
6567
quantized_and_compiled = compiled_linear(input)
6668
self.assertTrue(compute_error(original, quantized_and_compiled) > 20)
6769

70+
@skip_if_rocm("ROCm enablement in progress")
6871
@unittest.skip("Fix later")
6972
@parametrize("config", [BF16_ACT_CONFIG])
7073
def test_to_device(self, config):
@@ -81,6 +84,7 @@ def test_to_device(self, config):
8184
quantize_(linear, config)
8285
linear.to(device)
8386

87+
@skip_if_rocm("ROCm enablement in progress")
8488
@parametrize("config", [BF16_ACT_CONFIG])
8589
def test_module_path(self, config):
8690
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)

test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@
3333
version=2,
3434
)
3535

36+
# only 128 group_size is supported
3637
FP8_ACT_CONFIG = Float8DynamicActivationInt4WeightConfig(
37-
group_size=128,
3838
packing_format="preshuffled",
3939
)
4040

test/quantization/test_qat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1927,7 +1927,7 @@ def test_quantize_api_fp8_int4(self):
19271927
quantize_(model, QATConfig(Float8DynamicActivationInt4WeightConfig(), step="convert"))
19281928
"""
19291929
self._test_quantize_api_against_ptq(
1930-
Float8DynamicActivationInt4WeightConfig(group_size=128),
1930+
Float8DynamicActivationInt4WeightConfig(),
19311931
target_prepare_sqnr=15,
19321932
target_convert_sqnr=float("inf"),
19331933
)

torchao/csrc/cuda/mx_kernels/mxfp8_quantize.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,7 @@ __device__ __forceinline__ OType torchao_quantize_value(float input_value,
451451
* Template parameters ensure compile-time array size checking for safety
452452
*/
453453
template <typename OType, int NUM_VALUES, ScaleCalculationMode ScalingMode>
454-
__device__ __forceinline__ float
454+
__device__ __forceinline__ void
455455
quantize_block(float amax, e8m0_t &out_scale,
456456
const float (&input_values)[NUM_VALUES],
457457
OType (&output_values)[NUM_VALUES]) {

torchao/quantization/quant_api.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1156,13 +1156,13 @@ def _int4_weight_only_transform(
11561156
class Float8DynamicActivationInt4WeightConfig(AOBaseConfig):
11571157
"""Configuration for apply float8 dynamic per row quantization and int4
11581158
per group weight quantization to linear
1159+
(only group_size 128 is supported right now since underlying kernel used only supports 128
1160+
and above and no benefits of making it bigger)
11591161
11601162
Args:
1161-
`group_size`: group size for groupwise quantization for weight
11621163
`packing_format`: how the weight is packed, only preshuffled is supported
11631164
"""
11641165

1165-
group_size: int = 128
11661166
packing_format: PackingFormat = "preshuffled"
11671167

11681168

@@ -1174,13 +1174,13 @@ def _float8_dynamic_activation_int4_weight_transform(
11741174
"applying int8 weight only quant requires module to have weight attribute"
11751175
+ " but {module} does not have one"
11761176
)
1177-
group_size = config.group_size
11781177
packing_format = config.packing_format
11791178

11801179
assert packing_format == "preshuffled", (
11811180
f"only preshuffled packing_format supported right now, got: {packing_format}"
11821181
)
11831182
weight = module.weight
1183+
group_size = 128
11841184
block_size = tuple([1 for _ in range(weight.ndim - 1)] + [group_size])
11851185
new_weight = Int4PreshuffledTensor.from_hp(
11861186
module.weight,

0 commit comments

Comments
 (0)