Skip to content

Commit afd3929

Browse files
committed
Fix fused fp8 <-> fp8 conversions
1 parent b4ff71b commit afd3929

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

xla/backends/gpu/codegen/triton/emitter_helpers.cc

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,13 @@ bool IsFp8Type(Type t) {
126126
Value Cast(EmitterLocOpBuilder& b, Value value, Type dst_element_ty) {
127127
Type src_ty = value.getType();
128128
Type src_element_ty = src_ty;
129+
Type fp16_ty = b.getF16Type();
129130
Type fp32_ty = b.getF32Type();
130131
Type dst_ty = dst_element_ty;
131132
if (auto src_shaped_ty = mlir::dyn_cast<ShapedType>(src_ty)) {
132133
src_element_ty = src_shaped_ty.getElementType();
133134
dst_ty = src_shaped_ty.clone(src_shaped_ty.getShape(), dst_element_ty);
135+
fp16_ty = src_shaped_ty.clone(src_shaped_ty.getShape(), b.getF16Type());
134136
fp32_ty = src_shaped_ty.clone(src_shaped_ty.getShape(), b.getF32Type());
135137
}
136138
if (src_ty == dst_ty) {
@@ -156,14 +158,21 @@ Value Cast(EmitterLocOpBuilder& b, Value value, Type dst_element_ty) {
156158
// because LLVM doesn't support casts from/to FP8.
157159
// TODO(b/266862493): Add end-to-end test once FP8 support lands in XLA as
158160
// we can't test the code below without patching the feature.
159-
if (IsFp8Type(src_element_ty)) {
161+
if (IsFp8Type(src_element_ty) && !IsFp8Type(dst_element_ty)) {
160162
return b.create<mt::FpToFpOp>(dst_ty, value);
161163
}
162-
if (IsFp8Type(dst_element_ty)) {
164+
if (IsFp8Type(dst_element_ty) && !IsFp8Type(src_element_ty)) {
163165
return b.create<mt::FpToFpOp>(
164166
dst_ty, value,
165167
mt::RoundingModeAttr::get(b.getContext(), mt::RoundingMode::RTNE));
166168
}
169+
if (IsFp8Type(src_element_ty) && IsFp8Type(dst_element_ty)) {
170+
// FP8 <-> FP8 conversion needs to go through FP16
171+
auto fp16_value = b.create<mt::FpToFpOp>(fp16_ty, value);
172+
return b.create<mt::FpToFpOp>(
173+
dst_ty, fp16_value,
174+
mt::RoundingModeAttr::get(b.getContext(), mt::RoundingMode::RTNE));
175+
}
167176

168177
if (src_fp_element_ty.getFPMantissaWidth() >
169178
dst_fp_element_ty.getFPMantissaWidth()) {

xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.cc

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,11 +230,13 @@ bool IsFp8Type(Type t) {
230230
Value Cast(EmitterLocOpBuilder b, Value value, Type dst_element_ty) {
231231
Type src_ty = value.getType();
232232
Type src_element_ty = src_ty;
233+
Type fp16_ty = b.getF16Type();
233234
Type fp32_ty = b.getF32Type();
234235
Type dst_ty = dst_element_ty;
235236
if (auto src_shaped_ty = mlir::dyn_cast<ShapedType>(src_ty)) {
236237
src_element_ty = src_shaped_ty.getElementType();
237238
dst_ty = src_shaped_ty.clone(src_shaped_ty.getShape(), dst_element_ty);
239+
fp16_ty = src_shaped_ty.clone(src_shaped_ty.getShape(), b.getF16Type());
238240
fp32_ty = src_shaped_ty.clone(src_shaped_ty.getShape(), b.getF32Type());
239241
}
240242
if (src_ty == dst_ty) {
@@ -260,14 +262,21 @@ Value Cast(EmitterLocOpBuilder b, Value value, Type dst_element_ty) {
260262
// because LLVM doesn't support casts from/to FP8.
261263
// TODO(b/266862493): Add end-to-end test once FP8 support lands in XLA as
262264
// we can't test the code below without patching the feature.
263-
if (IsFp8Type(src_element_ty)) {
265+
if (IsFp8Type(src_element_ty) && !IsFp8Type(dst_element_ty)) {
264266
return b.create<mt::FpToFpOp>(dst_ty, value);
265267
}
266-
if (IsFp8Type(dst_element_ty)) {
268+
if (IsFp8Type(dst_element_ty) && !IsFp8Type(src_element_ty)) {
267269
return b.create<mt::FpToFpOp>(
268270
dst_ty, value,
269271
mt::RoundingModeAttr::get(b.getContext(), mt::RoundingMode::RTNE));
270272
}
273+
if (IsFp8Type(src_element_ty) && IsFp8Type(dst_element_ty)) {
274+
// FP8 <-> FP8 conversion needs to go through FP16
275+
auto fp16_value = b.create<mt::FpToFpOp>(fp16_ty, value);
276+
return b.create<mt::FpToFpOp>(
277+
dst_ty, fp16_value,
278+
mt::RoundingModeAttr::get(b.getContext(), mt::RoundingMode::RTNE));
279+
}
271280

272281
if (src_fp_element_ty.getFPMantissaWidth() >
273282
dst_fp_element_ty.getFPMantissaWidth()) {

0 commit comments

Comments
 (0)