@@ -227,121 +227,6 @@ bool IsFp8Type(Type t) {
227227                   mlir::Float8E4M3B11FNUZType>(t);
228228}
229229
230- Value Cast (EmitterLocOpBuilder b, Value value, Type dst_element_ty) {
231-   Type src_ty = value.getType ();
232-   Type src_element_ty = src_ty;
233-   Type fp32_ty = b.getF32Type ();
234-   Type dst_ty = dst_element_ty;
235-   if  (auto  src_shaped_ty = mlir::dyn_cast<ShapedType>(src_ty)) {
236-     src_element_ty = src_shaped_ty.getElementType ();
237-     dst_ty = src_shaped_ty.clone (src_shaped_ty.getShape (), dst_element_ty);
238-     fp32_ty = src_shaped_ty.clone (src_shaped_ty.getShape (), b.getF32Type ());
239-   }
240-   if  (src_ty == dst_ty) {
241-     return  value;
242-   }
243- 
244-   //  All operations on bf16 are done through f32.
245-   if  (src_element_ty.isBF16 ()) {
246-     return  Cast (b, b.create <ma::ExtFOp>(fp32_ty, value), dst_element_ty);
247-   }
248-   if  (dst_element_ty.isBF16 ()) {
249-     //  S8 -> BF16 is directly supported and doesn't need to go through f32.
250-     if  (!src_element_ty.isInteger (8 )) {
251-       return  b.create <ma::TruncFOp>(dst_ty, Cast (b, value, b.getF32Type ()));
252-     }
253-   }
254- 
255-   //  float => float
256-   auto  src_fp_element_ty = mlir::dyn_cast<mlir::FloatType>(src_element_ty);
257-   auto  dst_fp_element_ty = mlir::dyn_cast<mlir::FloatType>(dst_element_ty);
258-   if  (src_fp_element_ty && dst_fp_element_ty) {
259-     //  F8 <-> FP16, BF16, FP32, FP64 need to be handled via Triton's tt.fp_to_fp
260-     //  because LLVM doesn't support casts from/to FP8.
261-     //  TODO(b/266862493): Add end-to-end test once FP8 support lands in XLA as
262-     //  we can't test the code below without patching the feature.
263-     if  (IsFp8Type (src_element_ty)) {
264-       return  b.create <mt::FpToFpOp>(dst_ty, value);
265-     }
266-     if  (IsFp8Type (dst_element_ty)) {
267-       return  b.create <mt::FpToFpOp>(
268-           dst_ty, value,
269-           mt::RoundingModeAttr::get (b.getContext (), mt::RoundingMode::RTNE));
270-     }
271- 
272-     if  (src_fp_element_ty.getFPMantissaWidth () >
273-         dst_fp_element_ty.getFPMantissaWidth ()) {
274-       return  b.create <ma::TruncFOp>(dst_ty, value);
275-     } else  {
276-       return  b.create <ma::ExtFOp>(dst_ty, value);
277-     }
278-   }
279-   //  int => int
280-   if  (mlir::isa<mlir::IntegerType>(src_element_ty) &&
281-       mlir::isa<mlir::IntegerType>(dst_element_ty)) {
282-     if  (src_element_ty.getIntOrFloatBitWidth () <
283-         dst_element_ty.getIntOrFloatBitWidth ()) {
284-       if  (src_element_ty.isInteger (1 )) {
285-         return  b.create <ma::ExtUIOp>(dst_ty, value);
286-       }
287-       return  b.create <ma::ExtSIOp>(dst_ty, value);
288-     }
289-     return  b.create <ma::TruncIOp>(dst_ty, value);
290-   }
291-   //  int => float
292-   if  (mlir::isa<mlir::IntegerType>(src_element_ty) && dst_fp_element_ty) {
293-     //  TODO(b/266862493): Support unsigned integer types.
294-     if  (src_element_ty.isInteger (1 )) {
295-       return  b.create <ma::UIToFPOp>(dst_ty, value);
296-     }
297-     return  b.create <ma::SIToFPOp>(dst_ty, value);
298-   }
299-   //  float => int
300-   if  (src_fp_element_ty && mlir::isa<mlir::IntegerType>(dst_element_ty)) {
301-     if  (dst_element_ty.isInteger (1 )) {
302-       return  b.create <ma::CmpFOp>(ma::CmpFPredicate::UNE, value,
303-                                   ZerosLike (b, value));
304-     }
305-     //  TODO(b/266862493): Support unsigned integer types.
306-     //  The current logic handles signed integer types only. Additional handling
307-     //  is needed for unsigned integer types.
308-     auto  cst_int = [&](EmitterLocOpBuilder b, int64_t  x) {
309-       if  (auto  src_shaped_ty = mlir::dyn_cast<ShapedType>(src_ty)) {
310-         return  CreateConst (b, dst_element_ty, x, src_shaped_ty.getShape ());
311-       } else  {
312-         return  CreateConst (b, dst_element_ty, x);
313-       }
314-     };
315-     auto  cst_float = [&](EmitterLocOpBuilder b, int64_t  x) {
316-       if  (auto  src_shaped_ty = mlir::dyn_cast<ShapedType>(src_ty)) {
317-         return  CreateConst (b, src_fp_element_ty, x, src_shaped_ty.getShape ());
318-       } else  {
319-         return  CreateConst (b, src_fp_element_ty, x);
320-       }
321-     };
322-     auto  fptosi = b.create <ma::FPToSIOp>(dst_ty, value);
323-     int64_t  min = llvm::minIntN (dst_element_ty.getIntOrFloatBitWidth ());
324-     int64_t  max = llvm::maxIntN (dst_element_ty.getIntOrFloatBitWidth ());
325- 
326-     //  value <= static_cast<float>(INT_MIN) ? INT_MIN : ...
327-     auto  clamped = b.create <ma::SelectOp>(
328-         b.create <ma::CmpFOp>(ma::CmpFPredicate::OLE, value, cst_float (b, min)),
329-         cst_int (b, min), fptosi);
330-     //  value >= static_cast<float>(INT_MAX) ? INT_MAX : ...
331-     clamped = b.create <ma::SelectOp>(
332-         b.create <ma::CmpFOp>(ma::CmpFPredicate::OGE, value, cst_float (b, max)),
333-         cst_int (b, max), clamped);
334-     //  isnan(value) ? 0 : ...
335-     return  b.create <ma::SelectOp>(
336-         b.create <ma::CmpFOp>(ma::CmpFPredicate::UNO, value, value),
337-         cst_int (b, 0 ), clamped);
338-   }
339- 
340-   LOG (FATAL) << " Type conversion not supported: " 
341-              << llvm_ir::DumpToString (src_element_ty) << "  -> " 
342-              << llvm_ir::DumpToString (dst_element_ty);
343- }
344- 
345230Value Subtract (EmitterLocOpBuilder b, ValueRange values) {
346231  if  (mlir::isa<mlir::IntegerType>(mlir::getElementTypeOrSelf (values[0 ]))) {
347232    return  b.create <ma::SubIOp>(values[0 ], values[1 ]);
@@ -448,7 +333,7 @@ absl::StatusOr<Value> EmitElementwise(EmitterLocOpBuilder b,
448333    case  HloOpcode::kConvert : {
449334      TF_ASSIGN_OR_RETURN (Type dst_ty,
450335                          TritonType (b, hlo.shape ().element_type ()));
451-       return  Cast (b, inputs[0 ], dst_ty);
336+       return  triton:: Cast0 ], dst_ty);
452337    }
453338    case  HloOpcode::kAdd :
454339      if  (is_integer) {
@@ -661,7 +546,7 @@ absl::StatusOr<Value> EmitScope(
661546    if  (hlo->opcode () == HloOpcode::kConvert  &&
662547        hlo->operand (0 )->shape ().element_type () == S4) {
663548      Value unpacked;
664-       unpacked = Cast (b, values[hlo->operand (0 )], b.getI8Type ());
549+       unpacked = triton:: Castoperand (0 )], b.getI8Type ());
665550      std::vector<Value> operands ({unpacked});
666551      TF_ASSIGN_OR_RETURN (result, EmitElementwise (b, libdevice_path,
667552                                                  device_info, *hlo, operands));
@@ -817,7 +702,7 @@ ma::ConstantOp Cst64(EmitterLocOpBuilder b, int64_t v) {
817702}
818703
819704Value RoundToBF16 (EmitterLocOpBuilder b, Value input) {
820-   return  Cast (b, input, b.getBF16Type ());
705+   return  triton:: CastgetBF16Type ());
821706};
822707
823708/* static*/ MatMulDims::Create (
@@ -1480,7 +1365,7 @@ class MatMulEmitterHelper {
14801365          " 64 bit dynamic-slice indices are not supported yet." 
14811366    }
14821367    majormost_dim_start_index_val =
1483-         Cast (b, majormost_dim_start_index_val, b.getI32Type ());
1368+         triton:: CastgetI32Type ());
14841369    majormost_dim_start_index_val =
14851370        b.create <ma::MaxSIOp>(majormost_dim_start_index_val, Cst32 (b, 0 ));
14861371    majormost_dim_start_index_val =
@@ -2041,7 +1926,7 @@ class IterableInput {
20411926    Value param_value = EmitParameterLoad (b, args.front (), boundary_checks_);
20421927    if  (type_ != storage_type_) {
20431928      //  For example cast i8 to i1.
2044-       param_value = Cast (b, param_value, type_);
1929+       param_value = triton:: Cast
20451930    }
20461931    return  param_value;
20471932  }
@@ -2167,10 +2052,10 @@ Value EmitRegularMatmul(EmitterLocOpBuilder& b, Value lhs, Value rhs, Value acc,
21672052  if  (dot_instr->precision_config ().algorithm () ==
21682053      PrecisionConfig::ALG_DOT_BF16_BF16_F32) {
21692054    if  (dot_instr->operand (0 )->shape ().element_type () == F32) {
2170-       lhs = Cast (b, lhs, b.getBF16Type ());
2055+       lhs = triton:: CastgetBF16Type ());
21712056    }
21722057    if  (dot_instr->operand (1 )->shape ().element_type () == F32) {
2173-       rhs = Cast (b, rhs, b.getBF16Type ());
2058+       rhs = triton:: CastgetBF16Type ());
21742059    }
21752060  }
21762061
@@ -2364,7 +2249,7 @@ absl::StatusOr<std::optional<stream_executor::gpu::TmaMetadata>> EmitMatMul(
23642249  absl::flat_hash_map<const  HloInstruction*, Value> values_out;
23652250  TF_ASSIGN_OR_RETURN (Type acc_final_ty,
23662251                      TritonType (b, dot_instr->shape ().element_type ()));
2367-   values_out[dot_instr] = Cast (b, acc_final, acc_final_ty);
2252+   values_out[dot_instr] = triton:: Cast
23682253
23692254  //  Emit the output scope.
23702255  if  (std::vector<const  HloInstruction*> to_emit =
0 commit comments