Skip to content

Commit fa539fb

Browse files
committed
Add F4E2M1FN type: import mxfloat.h
1 parent b86b937 commit fa539fb

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

third_party/tsl/tsl/platform/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1066,6 +1066,7 @@ cc_library(
10661066
deps = [
10671067
"@ml_dtypes//:float8",
10681068
"@ml_dtypes//:intn",
1069+
"@ml_dtypes//:mxfloat",
10691070
],
10701071
)
10711072

third_party/tsl/tsl/platform/ml_dtypes.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@ limitations under the License.
1616
#ifndef TENSORFLOW_TSL_PLATFORM_ML_DTYPES_H_
1717
#define TENSORFLOW_TSL_PLATFORM_ML_DTYPES_H_
1818

19-
#include "ml_dtypes/include/float8.h" // from @ml_dtypes
20-
#include "ml_dtypes/include/intn.h" // from @ml_dtypes
19+
#include "ml_dtypes/include/float8.h" // from @ml_dtypes
20+
#include "ml_dtypes/include/intn.h" // from @ml_dtypes
21+
#include "ml_dtypes/include/mxfloat.h" // from @ml_dtypes
2122

2223
namespace tsl {
24+
using float4_e2m1fn = ::ml_dtypes::float4_e2m1fn;
2325
using float8_e3m4 = ::ml_dtypes::float8_e3m4;
2426
using float8_e4m3 = ::ml_dtypes::float8_e4m3;
2527
using float8_e4m3fn = ::ml_dtypes::float8_e4m3fn;

0 commit comments

Comments
 (0)