@@ -126,11 +126,13 @@ bool IsFp8Type(Type t) {
126126Value Cast (EmitterLocOpBuilder& b, Value value, Type dst_element_ty) {
127127 Type src_ty = value.getType ();
128128 Type src_element_ty = src_ty;
129+ Type fp16_ty = b.getF16Type ();
129130 Type fp32_ty = b.getF32Type ();
130131 Type dst_ty = dst_element_ty;
131132 if (auto src_shaped_ty = mlir::dyn_cast<ShapedType>(src_ty)) {
132133 src_element_ty = src_shaped_ty.getElementType ();
133134 dst_ty = src_shaped_ty.clone (src_shaped_ty.getShape (), dst_element_ty);
135+ fp16_ty = src_shaped_ty.clone (src_shaped_ty.getShape (), b.getF16Type ());
134136 fp32_ty = src_shaped_ty.clone (src_shaped_ty.getShape (), b.getF32Type ());
135137 }
136138 if (src_ty == dst_ty) {
@@ -156,14 +158,21 @@ Value Cast(EmitterLocOpBuilder& b, Value value, Type dst_element_ty) {
156158 // because LLVM doesn't support casts from/to FP8.
157159 // TODO(b/266862493): Add end-to-end test once FP8 support lands in XLA as
158160 // we can't test the code below without patching the feature.
159- if (IsFp8Type (src_element_ty)) {
161+ if (IsFp8Type (src_element_ty) && ! IsFp8Type (dst_element_ty) ) {
160162 return b.create <mt::FpToFpOp>(dst_ty, value);
161163 }
162- if (IsFp8Type (dst_element_ty)) {
164+ if (IsFp8Type (dst_element_ty) && ! IsFp8Type (src_element_ty) ) {
163165 return b.create <mt::FpToFpOp>(
164166 dst_ty, value,
165167 mt::RoundingModeAttr::get (b.getContext (), mt::RoundingMode::RTNE));
166168 }
169+ if (IsFp8Type (src_element_ty) && IsFp8Type (dst_element_ty)) {
170+ // FP8 <-> FP8 conversion needs to go through FP16
171+ auto fp16_value = b.create <mt::FpToFpOp>(fp16_ty, value);
172+ return b.create <mt::FpToFpOp>(
173+ dst_ty, fp16_value,
174+ mt::RoundingModeAttr::get (b.getContext (), mt::RoundingMode::RTNE));
175+ }
167176
168177 if (src_fp_element_ty.getFPMantissaWidth () >
169178 dst_fp_element_ty.getFPMantissaWidth ()) {
0 commit comments