Skip to content

Commit 063846a

Browse files
dominicsymeseric-k256
authored andcommitted
Microscaling format support
Adds support for Open Compute Project (OCP) floating point Microscaling formats (MX). Provide cast and matrix multiply operators that work with the microscaling formats. CONST supports constants of the MXFP data types. CAST supports casting the MXFP data types to and from bf16 Co-Authored-By: Eric Kunze <[email protected]> Signed-off-by: Dominic Symes <[email protected]> Signed-off-by: Eric Kunze <[email protected]> Change-Id: Ifb05503937f3d5c74cebe106156c60bff9af21dc
1 parent e32e56c commit 063846a

20 files changed

+743
-50
lines changed

chapters/appendix_a.adoc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,26 @@ for (0 <= n < N, 0 <= c < C, 0 <= x < W) {
223223
}
224224
----
225225

226+
==== MATMUL_T_BLOCK_SCALED
227+
228+
The following generates input test data for test set S.
229+
For compliant implementation, the test must pass whenever the attributes satisfy:
230+
`N*H*W >= MIN_DOT_PRODUCTS`
231+
232+
[source,c++]
233+
----
234+
KS = C;
235+
for (0 <= n < N, 0 <= y < H, 0 <= c < C) {
236+
A[n, y, c] = tosa_pro_fp_data(S, KS, 0, c, (n*H+y)*C+c);
237+
}
238+
A_scale, A_values = CAST_TO_BLOCK_SCALED(A);
239+
for (0 <= n < N, 0 <= c < C, 0 <= x < W) {
240+
B[n, x, c] = tosa_pro_fp_data(S, KS, 1, c, (n*W+x)*C+c);
241+
}
242+
B_scale, B_values = CAST_TO_BLOCK_SCALED(B);
243+
244+
----
245+
226246
==== TRANSPOSE_CONV2D
227247

228248
The following generates input test data for test set S.

chapters/introduction.adoc

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -271,9 +271,33 @@ Number formats not required for any operators in a profile do not need to be imp
271271
| (1<<47)-1
272272
|Signed 48-bit two's-complement value.
273273

274+
|fp4e2m1_t
275+
| -6.0
276+
| +6.0
277+
| 4-bit floating-point defined by <<OCP-MX,OCP-MX>> with two bits of exponent and one bit of mantissa. +
278+
Normal values must be supported. +
279+
Subnormal values must be supported. +
280+
Signed zero must be supported.
281+
282+
|fp6e3m2_t
283+
| -28.0
284+
| +28.0
285+
| 6-bit floating-point defined by <<OCP-MX,OCP-MX>> with three bits of exponent and two bits of mantissa. +
286+
Normal values must be supported. +
287+
Subnormal values must be supported. +
288+
Signed zero must be supported.
289+
290+
|fp6e2m3_t
291+
| -7.5
292+
| +7.5
293+
| 6-bit floating-point defined by <<OCP-MX,OCP-MX>> with two bits of exponent and three bits of mantissa. +
294+
Normal values must be supported. +
295+
Subnormal values must be supported. +
296+
Signed zero must be supported.
297+
274298
|fp8e4m3_t
275299
| -448
276-
| 448
300+
| +448
277301
| 8-bit floating-point defined by <<OCP-OFP8,OCP-OFP8>> with four bits of exponent and three bits of mantissa. +
278302
Normal values must be supported. +
279303
Subnormal values must be supported. +
@@ -292,6 +316,12 @@ Positive and negative infinity must be supported. +
292316
NaN encodings must be supported. +
293317
Signed zero must be supported.
294318

319+
|fp8ue8m0_t
320+
| exp2(-127)
321+
| exp2(+127)
322+
| 8-bit floating-point value defined by <<OCP-MX,OCP-MX>> with no sign bit, eight bits of exponent, and no mantissa bits. +
323+
The NaN encoding must be supported. +
324+
295325
|fp16_t
296326
| -infinity
297327
| +infinity
@@ -331,6 +361,11 @@ Subnormal values must either be supported or flushed to sign-preserved zero. +
331361
Positive and negative infinity must be supported. +
332362
At least one NaN encoding must be supported. +
333363
Signed zero must be supported.
364+
365+
|mxint8_t
366+
| -2
367+
| +1 + 63/64
368+
| 8-bit integer format with an implicit 1/64 scale defined by <<OCP-MX,OCP-MX>>. +
334369
|===
335370

336371
Note: In this specification, minimum<type> and maximum<type> will denote the minimum and maximum values of the data as stored in memory (ignoring the zero point).
@@ -450,15 +485,21 @@ This section assumes an operation acting on tensors named 'input', 'weight' and
450485
Each output tensor element can be expressed as a dot product of elements between the 'input' and 'weight' tensors with optional bias addition.
451486
The dot product has length KS, the kernel size.
452487
If the operation does not specify a bias then 'bias' is taken to be zero in this section.
488+
If the dot product is of a block-scaled tensor, then 'input_scale' and 'weight_scale' are inputs to the dot product.
489+
453490
Note: KS is defined for each relevant operator in the appendix section <<Floating-Point Operator Test Data>>.
454491

455-
In other words, each output element `out` can be expressed as a dot product between input elements `in[k]`, weight elements `w[k]`, bias `b`:
492+
Each output element `out` can be expressed as a dot product between input elements `in[k]`, weight elements `w[k]`, bias `b`:
456493

457494
`out = in[0] * w[0] + in[1] * w[1] + ... + in[KS-1] * w[KS-1] + b`
458495

459496
The positions of `in[k]`, `w[k]`, `b` in the input, weight and bias tensors depends on the operation being performed.
460497
This may be, for example, a convolution.
461498

499+
In a block-scaled dot product, each input and weight element `in[k]` and `w[k]` are scaled based on the corresponding scale values:
500+
`in[k] = in_data[k] * in_scale[k/block_size]`
501+
`w[k] = w_data[k] * w_scale[k/block_size]`
502+
462503
This section defines the accuracy required for these operations.
463504
In this section:
464505

@@ -480,9 +521,9 @@ ABS_BOUND is the maximum allowed absolute error when NaN or overflow is not pres
480521
|===
481522
|Condition|ABS_BOUND|Notes
482523

483-
|`(is_same<in_t,fp8e5m2_t>() \|\| is_same<in_t,fp8e4m3_t>()) && is_same<acc_t,fp32_t>`
524+
|`(is_same<in_t,fp8e5m2_t>() \|\| is_same<in_t,fp8e4m3_t>() \|\| is_same<in_t,fp6e3m2_t>() \|\| is_same<in_t,fp6e2m3_t>() \|\| is_same<in_t,fp4e2m1_t>() \|\| is_same<in_t,mxint8_t>()) && is_same<acc_t,fp32_t>`
484525
|`2 * max(ksb, min(ksb,64) * (1 << 10))`
485-
| The FP8 dot product with FP32 accumulator is allowed a relaxed absolute error bound. +
526+
| The FP8 dot product with FP32 accumulator as well as block scaled dot products are allowed a relaxed absolute error bound. +
486527
The 2 factor allows for different rounding modes. +
487528
The second term in the maximum allows accumulating intermediates at lower precision. +
488529
If the operator does not use an accumulator type acc_t, the final comparison should be is_same<out_t,fp32_t>.
@@ -499,9 +540,9 @@ The squared error for each result is summed, and the result must be less than th
499540
|===
500541
|Condition|VARIANCE_ERROR_BOUND|Notes
501542

502-
| `(is_same<in_t,fp8e5m2_t>() \|\| is_same<in_t,fp8e4m3_t>()) && is_same<acc_t,fp32_t>`
543+
|`(is_same<in_t,fp8e5m2_t>() \|\| is_same<in_t,fp8e4m3_t>() \|\| is_same<in_t,fp6e3m2_t>() \|\| is_same<in_t,fp6e2m3_t>() \|\| is_same<in_t,fp4e2m1_t>() \|\| is_same<in_t,mxint8_t>()) && is_same<acc_t,fp32_t>`
503544
| `4 * 0.4 * max(ksb, min(ksb,64) * (1 << 20))`
504-
| The FP8 dot product with FP32 accumulator is allowed a relaxed variance error bound. +
545+
| The FP8 dot product with FP32 accumulator as well as block scaled dot products are allowed a relaxed variance error bound. +
505546
The factors are similar to the absolute bound with precision factors squared for the variance bound. +
506547
The 0.4 factor is derived from the uniform [-1,1] distribution variance of 1/3 by rounding up. +
507548
The 4 factor is the square of the 2 factor in the absolute bound to allow for different rounding modes. +
@@ -678,3 +719,4 @@ The following publications are referred to in this specification, or provide mor
678719

679720
. [[IEEE-754]]IEEE Std 754-2008, _IEEE Standard for Floating-point Arithmetic_, August 2008.
680721
. [[OCP-OFP8]]Open Compute Project OCP 8-bit Floating Point Specification (OFP8) Revision 1.0
722+
. [[OCP-MX]] Open Compute Project OCP Microscaling Formats (MX) Specification Version 1.0

chapters/tensor_ops.adoc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,25 @@ include::{generated}/operators/MATMUL.adoc[]
187187
include::{pseudocode}/operators/MATMUL.tosac[lines=10..-1]
188188
----
189189

190+
==== MATMUL_T_BLOCK_SCALED
191+
192+
Performs two dimensional matrix multiplications using block scaled tensors.
193+
The block dimension is always the the last dimension of the tensor, so the result is effectively a matrix multiply of A by the transposed B matrix.
194+
If the D dimension of input B is of size 1, the B matrix will be broadcast.
195+
196+
*Precision Requirements*
197+
198+
* Each output can be expressed as a dot product of two input vectors multiplied by the scale factors for the A and B tensors.
199+
* The dot product must meet the <<Dot product accuracy requirements>>.
200+
* When generating the data sets for the Dot product accuracy requirements, the data should be generated as fp32 and converted to a scale/value tensor pair using the scale calculation defined in CAST_TO_BLOCK_SCALED.
201+
202+
include::{generated}/operators/MATMUL_T_BLOCK_SCALED.adoc[]
203+
204+
[source,c++]
205+
----
206+
include::{pseudocode}/operators/MATMUL_T_BLOCK_SCALED.tosac[lines=10..-1]
207+
----
208+
190209
==== MAX_POOL2D
191210

192211
This performs a max pooling over the given input tensor.

chapters/type_conversion.adoc

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,49 @@ Rules when casting between different types:
5656
include::{pseudocode}/operators/CAST.tosac[lines=10..-1]
5757
----
5858

59+
==== CAST_FROM_BLOCK_SCALED
60+
61+
Apply the scales from a scale tensor to the values in a value tensor, casting the result to the output type.
62+
The block dimension must be the last dimension of the tensor.
63+
64+
include::{generated}/operators/CAST_FROM_BLOCK_SCALED.adoc[]
65+
66+
*Precision Requirements*
67+
68+
* Subnormal values must be supported on the output type.
69+
* Let `x` be a value from the `input_data` tensor.
70+
* Let `s` be the corresponding scale value from the `input_scale` tensor.
71+
* Let `out_ref = x * s` calculated using fp64_t arithmetic.
72+
* Let `out_imp` be the result of the implementation.
73+
* Then `tosa_reference_check_from_block<in_t, out_t>(out_imp, out_ref, s)` must be true.
74+
75+
[source,c++]
76+
----
77+
include::{pseudocode}/operators/CAST_FROM_BLOCK_SCALED.tosac[lines=10..-1]
78+
----
79+
80+
==== CAST_TO_BLOCK_SCALED
81+
82+
Calculate a scale value per block of input values and use that to calculate scaled data values from an input tensor.
83+
The output tensors are cast to the specified scale and value types.
84+
The block dimension will be the last dimension of the tensor.
85+
86+
include::{generated}/operators/CAST_TO_BLOCK_SCALED.adoc[]
87+
88+
*Precision Requirements*
89+
90+
* Subnormal values must be supported on the output type.
91+
* Let `x` be a value from the `input_data` tensor
92+
* Let `out_ref_scale` be the results of calculating the block scale for the block containing `x` using fp64_t arithmetic.
93+
* Let `out_ref_value = x / out_ref_scale` calculated using fp64_t arithmetic.
94+
* Let `out_imp_scale, out_imp_value` be the results of the implementation for input `x`.
95+
* Then `tosa_reference_check_scale<scale_t, out_t>(out_imp_scale, out_imp_value, out_ref_scale, out_ref_value)` must be true.
96+
97+
[source,c++]
98+
----
99+
include::{pseudocode}/operators/CAST_TO_BLOCK_SCALED.tosac[lines=10..-1]
100+
----
101+
59102
==== RESCALE
60103

61104
RESCALE is defined using an integer multiply, add, and shift.

pseudocode/library/generic_helpers.tosac

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
//
22
// This confidential and proprietary software may be used only as
33
// authorised by a licensing agreement from ARM Limited
4-
// (C) COPYRIGHT 2020-2024 ARM Limited
4+
// (C) COPYRIGHT 2020-2025 ARM Limited
55
// ALL RIGHTS RESERVED
66
// The entire notice above must be reproduced on all authorised
77
// copies and copies may only be made to the extent permitted
88
// by a licensing agreement from ARM Limited.
99

1010
bool_t is_floating_point<type>() {
11-
if (is_same<type,fp16_t>() || is_same<type,fp32_t>() || is_same<type,bf16_t>() || is_same<type,fp8e4m3_t>() || is_same<type,fp8e5m2_t>()) {
11+
if (is_same<type,fp16_t>() || is_same<type,fp32_t>() || is_same<type,bf16_t>() || is_same<type,fp8e4m3_t>() || is_same<type,fp8e5m2_t>() ||
12+
is_same<type,fp4e2m1_t>() || is_same<type,fp6e3m2_t>() || is_same<type,fp6e2m3_t>() || is_same<type,fp8ue8m0_t>()) {
1213
return true;
1314
}
1415
return false;
@@ -69,7 +70,10 @@ in_out_t maximum_u<in_out_t>();
6970
in_out_t minimum_u<in_out_t>();
7071

7172
// return true if the given value is a NaN. Only valid for floating-point types
72-
bool_t is_a_NaN(fp64_t value);
73+
bool_t is_a_NaN(in_t value);
74+
75+
// return true if the given value is an Infinity. Only valid for floating-point types with defined Infinity values.
76+
bool_t is_an_Inf(in_t value);
7377

7478
// return true if value is a normal fp64 value (Not zero, subnormal, infinite or NaN)
7579
bool_t is_normal_fp64(fp64_t value);

pseudocode/library/numeric_accuracy_helpers.tosac

Lines changed: 85 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,25 +37,44 @@ fp64_t normal_min<in_t>() {
3737
return exp2(-6);
3838
} else if (is_same<in_t,fp8e5m2_t>()) {
3939
return exp2(-14);
40-
}
40+
} else if (is_same<in_t,fp6e2m3_t>()) {
41+
return 1.0;
42+
} else if (is_same<in_t,fp6e3m2_t>()) {
43+
return 0.25;
44+
} else if (is_same<in_t,fp4e2m1_t>()) {
45+
return 1.0;
46+
} else if (is_same<in_t,mxint8_t>()) {
47+
return 1/64.0;
48+
} else if (is_same<in_t,fp8ue8m0_t>()) {
49+
return exp2(-127);
4150
}
4251

4352
fp64_t normal_max<in_t>() {
4453
if (is_same<in_t,fp32_t>()) {
4554
return exp2(128) - exp2(127-23);
4655
} else if (is_same<in_t,bf16_t>()) {
47-
return exp2(128) - exp2(127- 7);
56+
return exp2(128) - exp2(127-7);
4857
} else if (is_same<in_t,fp16_t>()) {
49-
return exp2( 16) - exp2( 15-10);
58+
return exp2(16) - exp2(15-10);
5059
} else if (is_same<in_t,fp8e4m3_t>()) {
51-
return exp2( 9) - exp2( 8-2);
60+
return exp2(9) - exp2(8-2);
5261
} else if (is_same<in_t,fp8e5m2_t>()) {
53-
return exp2( 16) - exp2( 15-2);
62+
return exp2(16) - exp2(15-2);
63+
} else if (is_same<in_t,fp6e2m3_t>()) {
64+
return 7.5;
65+
} else if (is_same<in_t,fp6e3m2_t>()) {
66+
return 28.0;
67+
} else if (is_same<in_t,fp4e2m1_t>()) {
68+
return 6.0;
69+
} else if (is_same<in_t,mxint8_t>()) {
70+
return 1.0 + 63.0/64.0;
71+
} else if (is_same<in_t,fp8ue8m0_t>()) {
72+
return exp2(127);
5473
}
5574
}
5675

5776
// Number of fractional (mantissa bits)
58-
int normal_frac<in_t> () {
77+
int normal_frac<in_t>() {
5978
if (is_same<in_t,fp32_t>()) {
6079
return 23;
6180
} else if (is_same<in_t,bf16_t>()) {
@@ -66,9 +85,69 @@ int normal_frac<in_t> () {
6685
return 3;
6786
} else if (is_same<in_t,fp8e5m2_t>()) {
6887
return 2;
88+
} else if (is_same<in_t,fp6e2m3_t>()) {
89+
return 3;
90+
} else if (is_same<in_t,fp6e3m2_t>()) {
91+
return 2;
92+
} else if (is_same<in_t,fp4e2m1_t>()) {
93+
return 1;
94+
} else if (is_same<in_t,mxint8_t>()) {
95+
return 0;
6996
}
7097
}
7198

99+
// Exponent width
100+
int exponent_bits<in_t>() {
101+
if (is_same<in_t, fp32_t) {
102+
return 8;
103+
} else if (is_same<in_t, fp16_t) {
104+
return 5;
105+
} else if (is_same<in_t, bf16_t) {
106+
return 8;
107+
} else if (is_same<in_t,fp8e4m3_t>()) {
108+
return 4;
109+
} else if (is_same<in_t, fp8e5m2_t>()) {
110+
return 5;
111+
} else if (is_same<in_t, fp6e2m3_t>()) {
112+
return 2;
113+
} else if (is_same<in_t, fp6e3m2_t>()) {
114+
return 3;
115+
} else if (is_same<in_t, fp4e2m1_t>()) {
116+
return 2;
117+
} else if (is_same<in_t, mxint8_t>()) {
118+
return 0;
119+
}
120+
}
121+
122+
int exponent_bias<in_t>() {
123+
if (is_same<in_t, fp32_t) {
124+
return 127;
125+
} else if (is_same<in_t, fp16_t) {
126+
return 15;
127+
} else if (is_same<in_t, bf16_t) {
128+
return 127;
129+
} else if (is_same<in_t,fp8e4m3_t>()) {
130+
return 7;
131+
} else if (is_same<in_t, fp8e5m2_t>()) {
132+
return 15;
133+
} else if (is_same<in_t, fp6e2m3_t>()) {
134+
return 1;
135+
} else if (is_same<in_t, fp6e3m2_t>()) {
136+
return 3;
137+
} else if (is_same<in_t, fp4e2m1_t>()) {
138+
return 1;
139+
} else if (is_same<in_t, mxint8_t>()) {
140+
return 6;
141+
} else if (is_same<in_t, fp8ue8m0_t>()) {
142+
return 127;
143+
}
144+
}
145+
146+
// Returns a mask for the low N bits of a value
147+
uint32_t get_low_bitmask(int32_t bits) {
148+
return (1 << bits) - 1;
149+
}
150+
72151
fp64_t calcAbsErrorBound<in_t>(fp64_t bound_magnitude, fp64_t bounds_value,
73152
fp64_t lower_bound, fp64_t normal_divisor) {
74153
fp64_t error_bound = 0.0;

0 commit comments

Comments
 (0)