diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp index 74f0f61d04a1a..9214bc5b2c13e 100644 --- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp @@ -270,19 +270,11 @@ class CmpIOpConversion : public OpConversionPattern { bool needsUnsigned = needsUnsignedCmp(op.getPredicate()); emitc::CmpPredicate pred = toEmitCPred(op.getPredicate()); - Type arithmeticType = type; - if (type.isUnsignedInteger() != needsUnsigned) { - arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(), - /*isSigned=*/!needsUnsigned); - } - Value lhs = adaptor.getLhs(); - Value rhs = adaptor.getRhs(); - if (arithmeticType != type) { - lhs = rewriter.template create(op.getLoc(), arithmeticType, - lhs); - rhs = rewriter.template create(op.getLoc(), arithmeticType, - rhs); - } + + Type arithmeticType = adaptIntegralTypeSignedness(type, needsUnsigned); + Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType); + Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType); + rewriter.replaceOpWithNewOp(op, op.getType(), pred, lhs, rhs); return success(); } @@ -328,37 +320,26 @@ class CastConversion : public OpConversionPattern { return success(); } - bool isTruncation = operandType.getIntOrFloatBitWidth() > - opReturnType.getIntOrFloatBitWidth(); + bool isTruncation = + (isa(operandType) && isa(opReturnType) && + operandType.getIntOrFloatBitWidth() > + opReturnType.getIntOrFloatBitWidth()); bool doUnsigned = castToUnsigned || isTruncation; - Type castType = opReturnType; - // If the op is a ui variant and the type wanted as - // return type isn't unsigned, we need to issue an unsigned type to do - // the conversion. - if (castType.isUnsignedInteger() != doUnsigned) { - castType = rewriter.getIntegerType(opReturnType.getIntOrFloatBitWidth(), - /*isSigned=*/!doUnsigned); - } + // Adapt the signedness of the result (bitwidth-preserving cast) + // This is needed e.g., if the return type is signless. + Type castDestType = adaptIntegralTypeSignedness(opReturnType, doUnsigned); - Value actualOp = adaptor.getIn(); - // Adapt the signedness of the operand if necessary - if (operandType.isUnsignedInteger() != doUnsigned) { - Type correctSignednessType = - rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(), - /*isSigned=*/!doUnsigned); - actualOp = rewriter.template create( - op.getLoc(), correctSignednessType, actualOp); - } + // Adapt the signedness of the operand (bitwidth-preserving cast) + Type castSrcType = adaptIntegralTypeSignedness(operandType, doUnsigned); + Value actualOp = adaptValueType(adaptor.getIn(), rewriter, castSrcType); - auto result = rewriter.template create(op.getLoc(), castType, - actualOp); + // Actual cast (may change bitwidth) + auto cast = rewriter.template create(op.getLoc(), + castDestType, actualOp); // Cast to the expected output type - if (castType != opReturnType) { - result = rewriter.template create(op.getLoc(), - opReturnType, result); - } + auto result = adaptValueType(cast, rewriter, opReturnType); rewriter.replaceOp(op, result); return success(); @@ -410,8 +391,6 @@ class IntegerOpConversion final : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "i1 type is not implemented"); } - Value lhs = adaptor.getLhs(); - Value rhs = adaptor.getRhs(); Type arithmeticType = type; if ((type.isSignlessInteger() || type.isSignedInteger()) && !bitEnumContainsAll(op.getOverflowFlags(), @@ -421,20 +400,15 @@ class IntegerOpConversion final : public OpConversionPattern { arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(), /*isSigned=*/false); } - if (arithmeticType != type) { - lhs = rewriter.template create(op.getLoc(), arithmeticType, - lhs); - rhs = rewriter.template create(op.getLoc(), arithmeticType, - rhs); - } - Value result = rewriter.template create(op.getLoc(), - arithmeticType, lhs, rhs); + Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType); + Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType); + + Value arithmeticResult = rewriter.template create( + op.getLoc(), arithmeticType, lhs, rhs); + + Value result = adaptValueType(arithmeticResult, rewriter, type); - if (arithmeticType != type) { - result = - rewriter.template create(op.getLoc(), type, result); - } rewriter.replaceOp(op, result); return success(); } diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir index 71f1a6abd913b..607e5bf9b1a3b 100644 --- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir +++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir @@ -466,6 +466,13 @@ func.func @arith_trunci(%arg0: i32) -> i8 { // CHECK: emitc.cast %[[Trunc]] : ui8 to i8 %truncd = arith.trunci %arg0 : i32 to i8 + // CHECK: %[[Const:.*]] = "emitc.constant" + // CHECK-SAME: value = 1 + // CHECK-SAME: () -> i32 + // CHECK: %[[AndOne:.*]] = emitc.bitwise_and %[[Arg0]], %[[Const]] : (i32, i32) -> i32 + // CHECK: %[[Conv:.*]] = emitc.cast %[[AndOne]] : i32 to i1 + %bool = arith.trunci %arg0 : i32 to i1 + return %truncd : i8 }