Skip to content

Commit 3afa566

Browse files
authored
Revert "mx: add ceil and RNE rounding modes to the cast from fp32 to e8m0 (#1…"
This reverts commit 6a0f490.
1 parent 6a0f490 commit 3afa566

File tree

1 file changed

+6
-54
lines changed

1 file changed

+6
-54
lines changed

torchao/prototype/mx_formats/mx_tensor.py

Lines changed: 6 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
* Zeros: N/A
1717
"""
1818

19-
from enum import Enum, auto
2019
from typing import Dict, Union
2120

2221
import torch
@@ -54,38 +53,11 @@
5453
unpack_uint4,
5554
)
5655

57-
# TODO(later): read from somewhere else?
58-
SBITS, EBITS_F32, MBITS_F32 = 1, 8, 23
59-
EBITS_F4_E2M1, MBITS_F4_E2M1 = 2, 1
60-
EBITS_F6_E2M3, MBITS_F6_E2M3 = 2, 3
61-
EBITS_F6_E3M2, MBITS_F6_E3M2 = 3, 2
62-
EBITS_F8_E4M3, MBITS_F8_E4M3 = 4, 3
63-
EBITS_F8_E5M2, MBITS_F8_E5M2 = 5, 2
64-
65-
66-
class ScaleCalculationMode(Enum):
67-
"""
68-
Enum representing the different methods for calculating MX block scaling.
69-
There are three methods available:
70-
FLOOR: This method is recommended by the OCP MX Spec 1.0 and uses X = 2^floor(log2(max_abs(v))-max_exp).
71-
It result in overflow issues for large values and bad for gradient quantization.
72-
CEIL: This method avoids overflow issues, but small values may shift to 0 due to a large scaling factor.
73-
It uses X = 2^ceil(log2(max_abs(v))-max_exp).
74-
EVEN: This method is a trade-off between Option 1 and Option 2. It uses X = 2^(floor(log2(rounding(max_abs(v)))-max_exp)).
75-
It provides better accuracy for MX4 training compared to FLOOR and CEIL.
76-
By default, we use the EVEN method for better accuracy.
77-
"""
78-
79-
FLOOR = auto()
80-
CEIL = auto()
81-
EVEN = auto()
82-
8356

8457
def to_mx(
8558
data_hp: torch.Tensor,
8659
elem_dtype: Union[torch.dtype, str],
8760
block_size: int,
88-
scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR,
8961
):
9062
"""
9163
Takes a high precision tensor and converts to MX scale and raw data, in
@@ -116,45 +88,25 @@ def to_mx(
11688
# where the values are zero.
11789
eps = F32_MIN_NORMAL * (max_abs == 0).type(max_abs.dtype)
11890

91+
# Find largest power of 2 less than or equal to max_abs.
92+
largest_p2_lt_max_abs = torch.floor(torch.log2(max_abs + eps))
93+
11994
# Set X to be the largest power-of-two less than or equal to
12095
# max_abs(v), divided by the largest power of two representable
121-
# in the element data type, and get the mbits at the same time
96+
# in the element data type
12297
if elem_dtype == torch.float8_e4m3fn:
12398
target_max_pow2 = F8E4M3_MAX_POW2
124-
mbits = MBITS_F8_E4M3
12599
elif elem_dtype == torch.float8_e5m2:
126100
target_max_pow2 = F8E5M2_MAX_POW2
127-
mbits = MBITS_F8_E5M2
128101
elif elem_dtype == DTYPE_FP6_E2M3:
129102
target_max_pow2 = F6_E2M3_MAX_POW2
130-
mbits = MBITS_F6_E2M3
131103
elif elem_dtype == DTYPE_FP6_E3M2:
132104
target_max_pow2 = F6_E3M2_MAX_POW2
133-
mbits = MBITS_F6_E3M2
134105
elif elem_dtype == DTYPE_FP4:
135106
target_max_pow2 = F4_E2M1_MAX_POW2
136-
mbits = MBITS_F4_E2M1
137107
else:
138-
raise AssertionError("unsupported element dtype")
139-
140-
# rounding before calculating the largest power of 2
141-
# X = 2^(floor(log2(rounding(max_abs(v)))-max_exp))
142-
if scaling_mode == ScaleCalculationMode.EVEN:
143-
nan_mask = torch.isnan(max_abs)
144-
max_abs = max_abs.to(torch.float32).view(torch.int32)
145-
val_to_add = 1 << (MBITS_F32 - mbits - 1)
146-
mask = ((1 << (EBITS_F32 + SBITS)) - 1) << MBITS_F32
147-
max_abs = (max_abs + val_to_add) & mask
148-
max_abs = max_abs.view(torch.float32)
149-
max_abs[nan_mask] = torch.tensor(float("nan"), device=max_abs.device)
150-
151-
# Calculate the scale for different modes
152-
if scaling_mode in (ScaleCalculationMode.FLOOR, ScaleCalculationMode.EVEN):
153-
scale_e8m0_unbiased = torch.floor(torch.log2(max_abs + eps)) - target_max_pow2
154-
elif scaling_mode == ScaleCalculationMode.CEIL:
155-
scale_e8m0_unbiased = torch.ceil(torch.log2(max_abs + eps)) - target_max_pow2
156-
else:
157-
raise AssertionError("unsupported scaling calculation mode")
108+
raise AssertionError("unsupported")
109+
scale_e8m0_unbiased = largest_p2_lt_max_abs - target_max_pow2
158110

159111
# Clamp to exponents that can be represented in e8m0
160112
scale_e8m0_unbiased = torch.clamp(

0 commit comments

Comments
 (0)