Skip to content

Commit d235369

Browse files
author
Peiming Liu
authored
[mlir][NFC] update code to use mlir::dyn_cast/cast/isa (#90633)
Fix compiler warning caused by using deprecated interface (#90413)
1 parent 49bb993 commit d235369

File tree

9 files changed

+60
-55
lines changed

9 files changed

+60
-55
lines changed

mlir/include/mlir/Tools/PDLL/AST/Nodes.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -597,7 +597,7 @@ class RangeExpr final : public Node::NodeBase<RangeExpr, Expr>,
597597
}
598598

599599
/// Return the range result type of this expression.
600-
RangeType getType() const { return Base::getType().cast<RangeType>(); }
600+
RangeType getType() const { return mlir::cast<RangeType>(Base::getType()); }
601601

602602
private:
603603
RangeExpr(SMRange loc, RangeType type, unsigned numElements)
@@ -630,7 +630,7 @@ class TupleExpr final : public Node::NodeBase<TupleExpr, Expr>,
630630
}
631631

632632
/// Return the tuple result type of this expression.
633-
TupleType getType() const { return Base::getType().cast<TupleType>(); }
633+
TupleType getType() const { return mlir::cast<TupleType>(Base::getType()); }
634634

635635
private:
636636
TupleExpr(SMRange loc, TupleType type) : Base(loc, type) {}

mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -757,7 +757,7 @@ pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
757757
ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos();
758758
ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm();
759759

760-
auto expandTy = expandOp.getType().dyn_cast<RankedTensorType>();
760+
auto expandTy = dyn_cast<RankedTensorType>(expandOp.getType());
761761
if (!expandTy)
762762
return failure();
763763
ArrayRef<int64_t> dstShape = expandTy.getShape();

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2256,7 +2256,7 @@ void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
22562256
ArrayRef<OpFoldResult> outputShape) {
22572257
auto [staticOutputShape, dynamicOutputShape] =
22582258
decomposeMixedValues(SmallVector<OpFoldResult>(outputShape));
2259-
build(builder, result, resultType.cast<MemRefType>(), src,
2259+
build(builder, result, llvm::cast<MemRefType>(resultType), src,
22602260
getReassociationIndicesAttribute(builder, reassociation),
22612261
dynamicOutputShape, staticOutputShape);
22622262
}
@@ -2266,7 +2266,7 @@ void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
22662266
ArrayRef<ReassociationIndices> reassociation) {
22672267
SmallVector<OpFoldResult> inputShape =
22682268
getMixedSizes(builder, result.location, src);
2269-
MemRefType memrefResultTy = resultType.cast<MemRefType>();
2269+
MemRefType memrefResultTy = llvm::cast<MemRefType>(resultType);
22702270
FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
22712271
builder, result.location, memrefResultTy, reassociation, inputShape);
22722272
// Failure of this assertion usually indicates presence of multiple
@@ -2867,7 +2867,8 @@ static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
28672867
/// marked as dropped in `droppedDims`.
28682868
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2,
28692869
const llvm::SmallBitVector &droppedDims) {
2870-
assert(size_t(t1.getRank()) == droppedDims.size() && "incorrect number of bits");
2870+
assert(size_t(t1.getRank()) == droppedDims.size() &&
2871+
"incorrect number of bits");
28712872
assert(size_t(t1.getRank() - t2.getRank()) == droppedDims.count() &&
28722873
"incorrect number of dropped dims");
28732874
int64_t t1Offset, t2Offset;

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1663,7 +1663,7 @@ void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
16631663
ArrayRef<OpFoldResult> outputShape) {
16641664
auto [staticOutputShape, dynamicOutputShape] =
16651665
decomposeMixedValues(SmallVector<OpFoldResult>(outputShape));
1666-
build(builder, result, resultType.cast<RankedTensorType>(), src,
1666+
build(builder, result, cast<RankedTensorType>(resultType), src,
16671667
getReassociationIndicesAttribute(builder, reassociation),
16681668
dynamicOutputShape, staticOutputShape);
16691669
}
@@ -1673,7 +1673,7 @@ void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
16731673
ArrayRef<ReassociationIndices> reassociation) {
16741674
SmallVector<OpFoldResult> inputShape =
16751675
getMixedSizes(builder, result.location, src);
1676-
auto tensorResultTy = resultType.cast<RankedTensorType>();
1676+
auto tensorResultTy = cast<RankedTensorType>(resultType);
16771677
FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
16781678
builder, result.location, tensorResultTy, reassociation, inputShape);
16791679
// Failure of this assertion usually indicates presence of multiple

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,8 @@ static bool hasZeroDimension(ShapedType shapedType) {
220220
return false;
221221
}
222222

223-
template <typename T> static LogicalResult verifyConvOp(T op) {
223+
template <typename T>
224+
static LogicalResult verifyConvOp(T op) {
224225
// All TOSA conv ops have an input() and weight().
225226
auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
226227
auto weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
@@ -962,7 +963,7 @@ mlir::LogicalResult tosa::ReshapeOp::verify() {
962963
return emitOpError() << "tensor has a dimension with size zero. Each "
963964
"dimension of a tensor must have size >= 1";
964965

965-
if ((int64_t) getNewShape().size() != outputType.getRank())
966+
if ((int64_t)getNewShape().size() != outputType.getRank())
966967
return emitOpError() << "new shape does not match result rank";
967968

968969
for (auto [newShapeDim, outputShapeDim] :
@@ -1127,7 +1128,7 @@ LogicalResult TransposeOp::reifyResultShapes(
11271128
return failure();
11281129

11291130
Value input = getInput1();
1130-
auto inputType = input.getType().cast<TensorType>();
1131+
auto inputType = cast<TensorType>(input.getType());
11311132

11321133
SmallVector<OpFoldResult> returnedDims(inputType.getRank());
11331134
for (auto dim : transposePerms) {

mlir/lib/Tools/PDLL/AST/Types.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ Type Type::refineWith(Type other) const {
3535
return *this;
3636

3737
// Operation types are compatible if the operation names don't conflict.
38-
if (auto opTy = dyn_cast<OperationType>()) {
39-
auto otherOpTy = other.dyn_cast<ast::OperationType>();
38+
if (auto opTy = mlir::dyn_cast<OperationType>(*this)) {
39+
auto otherOpTy = mlir::dyn_cast<ast::OperationType>(other);
4040
if (!otherOpTy)
4141
return nullptr;
4242
if (!otherOpTy.getName())
@@ -105,25 +105,26 @@ Type RangeType::getElementType() const {
105105
// TypeRangeType
106106

107107
bool TypeRangeType::classof(Type type) {
108-
RangeType range = type.dyn_cast<RangeType>();
109-
return range && range.getElementType().isa<TypeType>();
108+
RangeType range = mlir::dyn_cast<RangeType>(type);
109+
return range && mlir::isa<TypeType>(range.getElementType());
110110
}
111111

112112
TypeRangeType TypeRangeType::get(Context &context) {
113-
return RangeType::get(context, TypeType::get(context)).cast<TypeRangeType>();
113+
return mlir::cast<TypeRangeType>(
114+
RangeType::get(context, TypeType::get(context)));
114115
}
115116

116117
//===----------------------------------------------------------------------===//
117118
// ValueRangeType
118119

119120
bool ValueRangeType::classof(Type type) {
120-
RangeType range = type.dyn_cast<RangeType>();
121-
return range && range.getElementType().isa<ValueType>();
121+
RangeType range = mlir::dyn_cast<RangeType>(type);
122+
return range && mlir::isa<ValueType>(range.getElementType());
122123
}
123124

124125
ValueRangeType ValueRangeType::get(Context &context) {
125-
return RangeType::get(context, ValueType::get(context))
126-
.cast<ValueRangeType>();
126+
return mlir::cast<ValueRangeType>(
127+
RangeType::get(context, ValueType::get(context)));
127128
}
128129

129130
//===----------------------------------------------------------------------===//

mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -337,13 +337,13 @@ Value CodeGen::genNonInitializerVar(const ast::VariableDecl *varDecl,
337337
// Generate a value based on the type of the variable.
338338
ast::Type type = varDecl->getType();
339339
Type mlirType = genType(type);
340-
if (type.isa<ast::ValueType>())
340+
if (isa<ast::ValueType>(type))
341341
return builder.create<pdl::OperandOp>(loc, mlirType, getTypeConstraint());
342-
if (type.isa<ast::TypeType>())
342+
if (isa<ast::TypeType>(type))
343343
return builder.create<pdl::TypeOp>(loc, mlirType, /*type=*/TypeAttr());
344-
if (type.isa<ast::AttributeType>())
344+
if (isa<ast::AttributeType>(type))
345345
return builder.create<pdl::AttributeOp>(loc, getTypeConstraint());
346-
if (ast::OperationType opType = type.dyn_cast<ast::OperationType>()) {
346+
if (ast::OperationType opType = dyn_cast<ast::OperationType>(type)) {
347347
Value operands = builder.create<pdl::OperandsOp>(
348348
loc, pdl::RangeType::get(builder.getType<pdl::ValueType>()),
349349
/*type=*/Value());
@@ -354,12 +354,12 @@ Value CodeGen::genNonInitializerVar(const ast::VariableDecl *varDecl,
354354
loc, opType.getName(), operands, std::nullopt, ValueRange(), results);
355355
}
356356

357-
if (ast::RangeType rangeTy = type.dyn_cast<ast::RangeType>()) {
357+
if (ast::RangeType rangeTy = dyn_cast<ast::RangeType>(type)) {
358358
ast::Type eleTy = rangeTy.getElementType();
359-
if (eleTy.isa<ast::ValueType>())
359+
if (isa<ast::ValueType>(eleTy))
360360
return builder.create<pdl::OperandsOp>(loc, mlirType,
361361
getTypeConstraint());
362-
if (eleTy.isa<ast::TypeType>())
362+
if (isa<ast::TypeType>(eleTy))
363363
return builder.create<pdl::TypesOp>(loc, mlirType, /*types=*/ArrayAttr());
364364
}
365365

@@ -440,7 +440,7 @@ Value CodeGen::genExprImpl(const ast::MemberAccessExpr *expr) {
440440
ast::Type parentType = expr->getParentExpr()->getType();
441441

442442
// Handle operation based member access.
443-
if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) {
443+
if (ast::OperationType opType = dyn_cast<ast::OperationType>(parentType)) {
444444
if (isa<ast::AllResultsMemberAccessExpr>(expr)) {
445445
Type mlirType = genType(expr->getType());
446446
if (isa<pdl::ValueType>(mlirType))
@@ -480,7 +480,7 @@ Value CodeGen::genExprImpl(const ast::MemberAccessExpr *expr) {
480480
}
481481

482482
// Handle tuple based member access.
483-
if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) {
483+
if (auto tupleType = dyn_cast<ast::TupleType>(parentType)) {
484484
auto elementNames = tupleType.getElementNames();
485485

486486
// The index is either a numeric index, or a name.
@@ -581,14 +581,14 @@ CodeGen::genConstraintOrRewriteCall(const T *decl, Location loc,
581581
if (!cstBody) {
582582
ast::Type declResultType = decl->getResultType();
583583
SmallVector<Type> resultTypes;
584-
if (ast::TupleType tupleType = declResultType.dyn_cast<ast::TupleType>()) {
584+
if (ast::TupleType tupleType = dyn_cast<ast::TupleType>(declResultType)) {
585585
for (ast::Type type : tupleType.getElementTypes())
586586
resultTypes.push_back(genType(type));
587587
} else {
588588
resultTypes.push_back(genType(declResultType));
589589
}
590-
PDLOpT pdlOp = builder.create<PDLOpT>(
591-
loc, resultTypes, decl->getName().getName(), inputs);
590+
PDLOpT pdlOp = builder.create<PDLOpT>(loc, resultTypes,
591+
decl->getName().getName(), inputs);
592592
if (isNegated && std::is_same_v<PDLOpT, pdl::ApplyNativeConstraintOp>)
593593
cast<pdl::ApplyNativeConstraintOp>(pdlOp).setIsNegated(true);
594594
return pdlOp->getResults();

mlir/lib/Tools/PDLL/Parser/Parser.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -623,7 +623,7 @@ LogicalResult Parser::convertExpressionTo(
623623
return diag;
624624
};
625625

626-
if (auto exprOpType = exprType.dyn_cast<ast::OperationType>())
626+
if (auto exprOpType = dyn_cast<ast::OperationType>(exprType))
627627
return convertOpExpressionTo(expr, exprOpType, type, emitConvertError);
628628

629629
// FIXME: Decide how to allow/support converting a single result to multiple,
@@ -638,7 +638,7 @@ LogicalResult Parser::convertExpressionTo(
638638
return success();
639639

640640
// Handle tuple types.
641-
if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>())
641+
if (auto exprTupleType = dyn_cast<ast::TupleType>(exprType))
642642
return convertTupleExpressionTo(expr, exprTupleType, type, emitConvertError,
643643
noteAttachFn);
644644

@@ -650,7 +650,7 @@ LogicalResult Parser::convertOpExpressionTo(
650650
function_ref<ast::InFlightDiagnostic()> emitErrorFn) {
651651
// Two operation types are compatible if they have the same name, or if the
652652
// expected type is more general.
653-
if (auto opType = type.dyn_cast<ast::OperationType>()) {
653+
if (auto opType = dyn_cast<ast::OperationType>(type)) {
654654
if (opType.getName())
655655
return emitErrorFn();
656656
return success();
@@ -702,7 +702,7 @@ LogicalResult Parser::convertTupleExpressionTo(
702702
function_ref<ast::InFlightDiagnostic()> emitErrorFn,
703703
function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
704704
// Handle conversions between tuples.
705-
if (auto tupleType = type.dyn_cast<ast::TupleType>()) {
705+
if (auto tupleType = dyn_cast<ast::TupleType>(type)) {
706706
if (tupleType.size() != exprType.size())
707707
return emitErrorFn();
708708

@@ -2568,7 +2568,7 @@ Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
25682568
}
25692569

25702570
// Constraint types cannot be used when defining variables.
2571-
if (type.isa<ast::ConstraintType, ast::RewriteType>()) {
2571+
if (isa<ast::ConstraintType, ast::RewriteType>(type)) {
25722572
return emitError(
25732573
loc, llvm::formatv("unable to define variable of `{0}` type", type));
25742574
}
@@ -2782,7 +2782,7 @@ Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
27822782
FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
27832783
StringRef name, SMRange loc) {
27842784
ast::Type parentType = parentExpr->getType();
2785-
if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) {
2785+
if (ast::OperationType opType = dyn_cast<ast::OperationType>(parentType)) {
27862786
if (name == ast::AllResultsMemberAccessExpr::getMemberName())
27872787
return valueRangeTy;
27882788

@@ -2808,7 +2808,7 @@ FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
28082808
// operations. It returns a single value.
28092809
return valueTy;
28102810
}
2811-
} else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) {
2811+
} else if (auto tupleType = dyn_cast<ast::TupleType>(parentType)) {
28122812
// Handle indexed results.
28132813
unsigned index = 0;
28142814
if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
@@ -2845,7 +2845,7 @@ FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
28452845
for (ast::NamedAttributeDecl *attr : attributes) {
28462846
// Check for an attribute type, or a type awaiting resolution.
28472847
ast::Type attrType = attr->getValue()->getType();
2848-
if (!attrType.isa<ast::AttributeType>()) {
2848+
if (!isa<ast::AttributeType>(attrType)) {
28492849
return emitError(
28502850
attr->getValue()->getLoc(),
28512851
llvm::formatv("expected `Attr` expression, but got `{0}`", attrType));
@@ -3024,7 +3024,7 @@ LogicalResult Parser::validateOperationOperandsOrResults(
30243024
// ValueRange. This situations arises quite often with nested operation
30253025
// expressions: `op<my_dialect.foo>(op<my_dialect.bar>)`
30263026
if (singleTy == valueTy) {
3027-
if (valueExprType.isa<ast::OperationType>()) {
3027+
if (isa<ast::OperationType>(valueExprType)) {
30283028
valueExpr = convertOpToValue(valueExpr);
30293029
continue;
30303030
}
@@ -3048,7 +3048,7 @@ Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements,
30483048
ArrayRef<StringRef> elementNames) {
30493049
for (const ast::Expr *element : elements) {
30503050
ast::Type eleTy = element->getType();
3051-
if (eleTy.isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>()) {
3051+
if (isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>(eleTy)) {
30523052
return emitError(
30533053
element->getLoc(),
30543054
llvm::formatv("unable to build a tuple with `{0}` element", eleTy));
@@ -3064,7 +3064,7 @@ FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc,
30643064
ast::Expr *rootOp) {
30653065
// Check that root is an Operation.
30663066
ast::Type rootType = rootOp->getType();
3067-
if (!rootType.isa<ast::OperationType>())
3067+
if (!isa<ast::OperationType>(rootType))
30683068
return emitError(rootOp->getLoc(), "expected `Op` expression");
30693069

30703070
return ast::EraseStmt::create(ctx, loc, rootOp);
@@ -3075,7 +3075,7 @@ Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp,
30753075
MutableArrayRef<ast::Expr *> replValues) {
30763076
// Check that root is an Operation.
30773077
ast::Type rootType = rootOp->getType();
3078-
if (!rootType.isa<ast::OperationType>()) {
3078+
if (!isa<ast::OperationType>(rootType)) {
30793079
return emitError(
30803080
rootOp->getLoc(),
30813081
llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
@@ -3088,7 +3088,7 @@ Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp,
30883088
ast::Type replType = replExpr->getType();
30893089

30903090
// Check that replExpr is an Operation, Value, or ValueRange.
3091-
if (replType.isa<ast::OperationType>()) {
3091+
if (isa<ast::OperationType>(replType)) {
30923092
if (shouldConvertOpToValues)
30933093
replExpr = convertOpToValue(replExpr);
30943094
continue;
@@ -3110,7 +3110,7 @@ Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp,
31103110
ast::CompoundStmt *rewriteBody) {
31113111
// Check that root is an Operation.
31123112
ast::Type rootType = rootOp->getType();
3113-
if (!rootType.isa<ast::OperationType>()) {
3113+
if (!isa<ast::OperationType>(rootType)) {
31143114
return emitError(
31153115
rootOp->getLoc(),
31163116
llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
@@ -3125,9 +3125,9 @@ Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp,
31253125

31263126
LogicalResult Parser::codeCompleteMemberAccess(ast::Expr *parentExpr) {
31273127
ast::Type parentType = parentExpr->getType();
3128-
if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>())
3128+
if (ast::OperationType opType = dyn_cast<ast::OperationType>(parentType))
31293129
codeCompleteContext->codeCompleteOperationMemberAccess(opType);
3130-
else if (ast::TupleType tupleType = parentType.dyn_cast<ast::TupleType>())
3130+
else if (ast::TupleType tupleType = dyn_cast<ast::TupleType>(parentType))
31313131
codeCompleteContext->codeCompleteTupleMemberAccess(tupleType);
31323132
return failure();
31333133
}

0 commit comments

Comments
 (0)