|
16 | 16 | * Zeros: N/A
|
17 | 17 | """
|
18 | 18 |
|
19 |
| -from enum import Enum, auto |
20 | 19 | from typing import Dict, Union
|
21 | 20 |
|
22 | 21 | import torch
|
|
54 | 53 | unpack_uint4,
|
55 | 54 | )
|
56 | 55 |
|
57 |
| -# TODO(later): read from somewhere else? |
58 |
| -SBITS, EBITS_F32, MBITS_F32 = 1, 8, 23 |
59 |
| -EBITS_F4_E2M1, MBITS_F4_E2M1 = 2, 1 |
60 |
| -EBITS_F6_E2M3, MBITS_F6_E2M3 = 2, 3 |
61 |
| -EBITS_F6_E3M2, MBITS_F6_E3M2 = 3, 2 |
62 |
| -EBITS_F8_E4M3, MBITS_F8_E4M3 = 4, 3 |
63 |
| -EBITS_F8_E5M2, MBITS_F8_E5M2 = 5, 2 |
64 |
| - |
65 |
| - |
66 |
| -class ScaleCalculationMode(Enum): |
67 |
| - """ |
68 |
| - Enum representing the different methods for calculating MX block scaling. |
69 |
| - There are three methods available: |
70 |
| - FLOOR: This method is recommended by the OCP MX Spec 1.0 and uses X = 2^floor(log2(max_abs(v))-max_exp). |
71 |
| - It result in overflow issues for large values and bad for gradient quantization. |
72 |
| - CEIL: This method avoids overflow issues, but small values may shift to 0 due to a large scaling factor. |
73 |
| - It uses X = 2^ceil(log2(max_abs(v))-max_exp). |
74 |
| - EVEN: This method is a trade-off between Option 1 and Option 2. It uses X = 2^(floor(log2(rounding(max_abs(v)))-max_exp)). |
75 |
| - It provides better accuracy for MX4 training compared to FLOOR and CEIL. |
76 |
| - By default, we use the EVEN method for better accuracy. |
77 |
| - """ |
78 |
| - |
79 |
| - FLOOR = auto() |
80 |
| - CEIL = auto() |
81 |
| - EVEN = auto() |
82 |
| - |
83 | 56 |
|
84 | 57 | def to_mx(
|
85 | 58 | data_hp: torch.Tensor,
|
86 | 59 | elem_dtype: Union[torch.dtype, str],
|
87 | 60 | block_size: int,
|
88 |
| - scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR, |
89 | 61 | ):
|
90 | 62 | """
|
91 | 63 | Takes a high precision tensor and converts to MX scale and raw data, in
|
@@ -116,45 +88,25 @@ def to_mx(
|
116 | 88 | # where the values are zero.
|
117 | 89 | eps = F32_MIN_NORMAL * (max_abs == 0).type(max_abs.dtype)
|
118 | 90 |
|
| 91 | + # Find largest power of 2 less than or equal to max_abs. |
| 92 | + largest_p2_lt_max_abs = torch.floor(torch.log2(max_abs + eps)) |
| 93 | + |
119 | 94 | # Set X to be the largest power-of-two less than or equal to
|
120 | 95 | # max_abs(v), divided by the largest power of two representable
|
121 |
| - # in the element data type, and get the mbits at the same time |
| 96 | + # in the element data type |
122 | 97 | if elem_dtype == torch.float8_e4m3fn:
|
123 | 98 | target_max_pow2 = F8E4M3_MAX_POW2
|
124 |
| - mbits = MBITS_F8_E4M3 |
125 | 99 | elif elem_dtype == torch.float8_e5m2:
|
126 | 100 | target_max_pow2 = F8E5M2_MAX_POW2
|
127 |
| - mbits = MBITS_F8_E5M2 |
128 | 101 | elif elem_dtype == DTYPE_FP6_E2M3:
|
129 | 102 | target_max_pow2 = F6_E2M3_MAX_POW2
|
130 |
| - mbits = MBITS_F6_E2M3 |
131 | 103 | elif elem_dtype == DTYPE_FP6_E3M2:
|
132 | 104 | target_max_pow2 = F6_E3M2_MAX_POW2
|
133 |
| - mbits = MBITS_F6_E3M2 |
134 | 105 | elif elem_dtype == DTYPE_FP4:
|
135 | 106 | target_max_pow2 = F4_E2M1_MAX_POW2
|
136 |
| - mbits = MBITS_F4_E2M1 |
137 | 107 | else:
|
138 |
| - raise AssertionError("unsupported element dtype") |
139 |
| - |
140 |
| - # rounding before calculating the largest power of 2 |
141 |
| - # X = 2^(floor(log2(rounding(max_abs(v)))-max_exp)) |
142 |
| - if scaling_mode == ScaleCalculationMode.EVEN: |
143 |
| - nan_mask = torch.isnan(max_abs) |
144 |
| - max_abs = max_abs.to(torch.float32).view(torch.int32) |
145 |
| - val_to_add = 1 << (MBITS_F32 - mbits - 1) |
146 |
| - mask = ((1 << (EBITS_F32 + SBITS)) - 1) << MBITS_F32 |
147 |
| - max_abs = (max_abs + val_to_add) & mask |
148 |
| - max_abs = max_abs.view(torch.float32) |
149 |
| - max_abs[nan_mask] = torch.tensor(float("nan"), device=max_abs.device) |
150 |
| - |
151 |
| - # Calculate the scale for different modes |
152 |
| - if scaling_mode in (ScaleCalculationMode.FLOOR, ScaleCalculationMode.EVEN): |
153 |
| - scale_e8m0_unbiased = torch.floor(torch.log2(max_abs + eps)) - target_max_pow2 |
154 |
| - elif scaling_mode == ScaleCalculationMode.CEIL: |
155 |
| - scale_e8m0_unbiased = torch.ceil(torch.log2(max_abs + eps)) - target_max_pow2 |
156 |
| - else: |
157 |
| - raise AssertionError("unsupported scaling calculation mode") |
| 108 | + raise AssertionError("unsupported") |
| 109 | + scale_e8m0_unbiased = largest_p2_lt_max_abs - target_max_pow2 |
158 | 110 |
|
159 | 111 | # Clamp to exponents that can be represented in e8m0
|
160 | 112 | scale_e8m0_unbiased = torch.clamp(
|
|
0 commit comments