diff --git a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h index 5515ee7548b5a..aed2562e4d30d 100644 --- a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h +++ b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h @@ -597,7 +597,7 @@ class RangeExpr final : public Node::NodeBase, } /// Return the range result type of this expression. - RangeType getType() const { return Base::getType().cast(); } + RangeType getType() const { return mlir::cast(Base::getType()); } private: RangeExpr(SMRange loc, RangeType type, unsigned numElements) @@ -630,7 +630,7 @@ class TupleExpr final : public Node::NodeBase, } /// Return the tuple result type of this expression. - TupleType getType() const { return Base::getType().cast(); } + TupleType getType() const { return mlir::cast(Base::getType()); } private: TupleExpr(SMRange loc, TupleType type) : Base(loc, type) {} diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp index 9a2493a59e019..2bea083ac2d78 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -757,7 +757,7 @@ pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp, ArrayRef innerDimsPos = unPackOp.getInnerDimsPos(); ArrayRef outerDimsPerm = unPackOp.getOuterDimsPerm(); - auto expandTy = expandOp.getType().dyn_cast(); + auto expandTy = dyn_cast(expandOp.getType()); if (!expandTy) return failure(); ArrayRef dstShape = expandTy.getShape(); diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index ced7fdd0a90f0..b969d41d934d4 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -2256,7 +2256,7 @@ void ExpandShapeOp::build(OpBuilder &builder, OperationState &result, ArrayRef outputShape) { auto [staticOutputShape, dynamicOutputShape] = decomposeMixedValues(SmallVector(outputShape)); - build(builder, result, resultType.cast(), src, + build(builder, result, llvm::cast(resultType), src, getReassociationIndicesAttribute(builder, reassociation), dynamicOutputShape, staticOutputShape); } @@ -2266,7 +2266,7 @@ void ExpandShapeOp::build(OpBuilder &builder, OperationState &result, ArrayRef reassociation) { SmallVector inputShape = getMixedSizes(builder, result.location, src); - MemRefType memrefResultTy = resultType.cast(); + MemRefType memrefResultTy = llvm::cast(resultType); FailureOr> outputShape = inferOutputShape( builder, result.location, memrefResultTy, reassociation, inputShape); // Failure of this assertion usually indicates presence of multiple @@ -2867,7 +2867,8 @@ static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) { /// marked as dropped in `droppedDims`. static bool haveCompatibleStrides(MemRefType t1, MemRefType t2, const llvm::SmallBitVector &droppedDims) { - assert(size_t(t1.getRank()) == droppedDims.size() && "incorrect number of bits"); + assert(size_t(t1.getRank()) == droppedDims.size() && + "incorrect number of bits"); assert(size_t(t1.getRank() - t2.getRank()) == droppedDims.count() && "incorrect number of dropped dims"); int64_t t1Offset, t2Offset; diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 7a5546bf13757..4c65045084dc5 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1663,7 +1663,7 @@ void ExpandShapeOp::build(OpBuilder &builder, OperationState &result, ArrayRef outputShape) { auto [staticOutputShape, dynamicOutputShape] = decomposeMixedValues(SmallVector(outputShape)); - build(builder, result, resultType.cast(), src, + build(builder, result, cast(resultType), src, getReassociationIndicesAttribute(builder, reassociation), dynamicOutputShape, staticOutputShape); } @@ -1673,7 +1673,7 @@ void ExpandShapeOp::build(OpBuilder &builder, OperationState &result, ArrayRef reassociation) { SmallVector inputShape = getMixedSizes(builder, result.location, src); - auto tensorResultTy = resultType.cast(); + auto tensorResultTy = cast(resultType); FailureOr> outputShape = inferOutputShape( builder, result.location, tensorResultTy, reassociation, inputShape); // Failure of this assertion usually indicates presence of multiple diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 99b0db14c1427..b7394ad4c4bd9 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -220,7 +220,8 @@ static bool hasZeroDimension(ShapedType shapedType) { return false; } -template static LogicalResult verifyConvOp(T op) { +template +static LogicalResult verifyConvOp(T op) { // All TOSA conv ops have an input() and weight(). auto inputType = llvm::dyn_cast(op.getInput().getType()); auto weightType = llvm::dyn_cast(op.getWeight().getType()); @@ -962,7 +963,7 @@ mlir::LogicalResult tosa::ReshapeOp::verify() { return emitOpError() << "tensor has a dimension with size zero. Each " "dimension of a tensor must have size >= 1"; - if ((int64_t) getNewShape().size() != outputType.getRank()) + if ((int64_t)getNewShape().size() != outputType.getRank()) return emitOpError() << "new shape does not match result rank"; for (auto [newShapeDim, outputShapeDim] : @@ -1127,7 +1128,7 @@ LogicalResult TransposeOp::reifyResultShapes( return failure(); Value input = getInput1(); - auto inputType = input.getType().cast(); + auto inputType = cast(input.getType()); SmallVector returnedDims(inputType.getRank()); for (auto dim : transposePerms) { diff --git a/mlir/lib/Tools/PDLL/AST/Types.cpp b/mlir/lib/Tools/PDLL/AST/Types.cpp index fc4cb613dd22a..081f85d69a2f6 100644 --- a/mlir/lib/Tools/PDLL/AST/Types.cpp +++ b/mlir/lib/Tools/PDLL/AST/Types.cpp @@ -35,8 +35,8 @@ Type Type::refineWith(Type other) const { return *this; // Operation types are compatible if the operation names don't conflict. - if (auto opTy = dyn_cast()) { - auto otherOpTy = other.dyn_cast(); + if (auto opTy = mlir::dyn_cast(*this)) { + auto otherOpTy = mlir::dyn_cast(other); if (!otherOpTy) return nullptr; if (!otherOpTy.getName()) @@ -105,25 +105,26 @@ Type RangeType::getElementType() const { // TypeRangeType bool TypeRangeType::classof(Type type) { - RangeType range = type.dyn_cast(); - return range && range.getElementType().isa(); + RangeType range = mlir::dyn_cast(type); + return range && mlir::isa(range.getElementType()); } TypeRangeType TypeRangeType::get(Context &context) { - return RangeType::get(context, TypeType::get(context)).cast(); + return mlir::cast( + RangeType::get(context, TypeType::get(context))); } //===----------------------------------------------------------------------===// // ValueRangeType bool ValueRangeType::classof(Type type) { - RangeType range = type.dyn_cast(); - return range && range.getElementType().isa(); + RangeType range = mlir::dyn_cast(type); + return range && mlir::isa(range.getElementType()); } ValueRangeType ValueRangeType::get(Context &context) { - return RangeType::get(context, ValueType::get(context)) - .cast(); + return mlir::cast( + RangeType::get(context, ValueType::get(context))); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp b/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp index 16c3ccf0de269..964d94c9c0a46 100644 --- a/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp +++ b/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp @@ -337,13 +337,13 @@ Value CodeGen::genNonInitializerVar(const ast::VariableDecl *varDecl, // Generate a value based on the type of the variable. ast::Type type = varDecl->getType(); Type mlirType = genType(type); - if (type.isa()) + if (isa(type)) return builder.create(loc, mlirType, getTypeConstraint()); - if (type.isa()) + if (isa(type)) return builder.create(loc, mlirType, /*type=*/TypeAttr()); - if (type.isa()) + if (isa(type)) return builder.create(loc, getTypeConstraint()); - if (ast::OperationType opType = type.dyn_cast()) { + if (ast::OperationType opType = dyn_cast(type)) { Value operands = builder.create( loc, pdl::RangeType::get(builder.getType()), /*type=*/Value()); @@ -354,12 +354,12 @@ Value CodeGen::genNonInitializerVar(const ast::VariableDecl *varDecl, loc, opType.getName(), operands, std::nullopt, ValueRange(), results); } - if (ast::RangeType rangeTy = type.dyn_cast()) { + if (ast::RangeType rangeTy = dyn_cast(type)) { ast::Type eleTy = rangeTy.getElementType(); - if (eleTy.isa()) + if (isa(eleTy)) return builder.create(loc, mlirType, getTypeConstraint()); - if (eleTy.isa()) + if (isa(eleTy)) return builder.create(loc, mlirType, /*types=*/ArrayAttr()); } @@ -440,7 +440,7 @@ Value CodeGen::genExprImpl(const ast::MemberAccessExpr *expr) { ast::Type parentType = expr->getParentExpr()->getType(); // Handle operation based member access. - if (ast::OperationType opType = parentType.dyn_cast()) { + if (ast::OperationType opType = dyn_cast(parentType)) { if (isa(expr)) { Type mlirType = genType(expr->getType()); if (isa(mlirType)) @@ -480,7 +480,7 @@ Value CodeGen::genExprImpl(const ast::MemberAccessExpr *expr) { } // Handle tuple based member access. - if (auto tupleType = parentType.dyn_cast()) { + if (auto tupleType = dyn_cast(parentType)) { auto elementNames = tupleType.getElementNames(); // The index is either a numeric index, or a name. @@ -581,14 +581,14 @@ CodeGen::genConstraintOrRewriteCall(const T *decl, Location loc, if (!cstBody) { ast::Type declResultType = decl->getResultType(); SmallVector resultTypes; - if (ast::TupleType tupleType = declResultType.dyn_cast()) { + if (ast::TupleType tupleType = dyn_cast(declResultType)) { for (ast::Type type : tupleType.getElementTypes()) resultTypes.push_back(genType(type)); } else { resultTypes.push_back(genType(declResultType)); } - PDLOpT pdlOp = builder.create( - loc, resultTypes, decl->getName().getName(), inputs); + PDLOpT pdlOp = builder.create(loc, resultTypes, + decl->getName().getName(), inputs); if (isNegated && std::is_same_v) cast(pdlOp).setIsNegated(true); return pdlOp->getResults(); diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp index 9f931f4fce001..45f9e2f899a77 100644 --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -623,7 +623,7 @@ LogicalResult Parser::convertExpressionTo( return diag; }; - if (auto exprOpType = exprType.dyn_cast()) + if (auto exprOpType = dyn_cast(exprType)) return convertOpExpressionTo(expr, exprOpType, type, emitConvertError); // FIXME: Decide how to allow/support converting a single result to multiple, @@ -638,7 +638,7 @@ LogicalResult Parser::convertExpressionTo( return success(); // Handle tuple types. - if (auto exprTupleType = exprType.dyn_cast()) + if (auto exprTupleType = dyn_cast(exprType)) return convertTupleExpressionTo(expr, exprTupleType, type, emitConvertError, noteAttachFn); @@ -650,7 +650,7 @@ LogicalResult Parser::convertOpExpressionTo( function_ref emitErrorFn) { // Two operation types are compatible if they have the same name, or if the // expected type is more general. - if (auto opType = type.dyn_cast()) { + if (auto opType = dyn_cast(type)) { if (opType.getName()) return emitErrorFn(); return success(); @@ -702,7 +702,7 @@ LogicalResult Parser::convertTupleExpressionTo( function_ref emitErrorFn, function_ref noteAttachFn) { // Handle conversions between tuples. - if (auto tupleType = type.dyn_cast()) { + if (auto tupleType = dyn_cast(type)) { if (tupleType.size() != exprType.size()) return emitErrorFn(); @@ -2568,7 +2568,7 @@ Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer, } // Constraint types cannot be used when defining variables. - if (type.isa()) { + if (isa(type)) { return emitError( loc, llvm::formatv("unable to define variable of `{0}` type", type)); } @@ -2782,7 +2782,7 @@ Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, FailureOr Parser::validateMemberAccess(ast::Expr *parentExpr, StringRef name, SMRange loc) { ast::Type parentType = parentExpr->getType(); - if (ast::OperationType opType = parentType.dyn_cast()) { + if (ast::OperationType opType = dyn_cast(parentType)) { if (name == ast::AllResultsMemberAccessExpr::getMemberName()) return valueRangeTy; @@ -2808,7 +2808,7 @@ FailureOr Parser::validateMemberAccess(ast::Expr *parentExpr, // operations. It returns a single value. return valueTy; } - } else if (auto tupleType = parentType.dyn_cast()) { + } else if (auto tupleType = dyn_cast(parentType)) { // Handle indexed results. unsigned index = 0; if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) && @@ -2845,7 +2845,7 @@ FailureOr Parser::createOperationExpr( for (ast::NamedAttributeDecl *attr : attributes) { // Check for an attribute type, or a type awaiting resolution. ast::Type attrType = attr->getValue()->getType(); - if (!attrType.isa()) { + if (!isa(attrType)) { return emitError( attr->getValue()->getLoc(), llvm::formatv("expected `Attr` expression, but got `{0}`", attrType)); @@ -3024,7 +3024,7 @@ LogicalResult Parser::validateOperationOperandsOrResults( // ValueRange. This situations arises quite often with nested operation // expressions: `op(op)` if (singleTy == valueTy) { - if (valueExprType.isa()) { + if (isa(valueExprType)) { valueExpr = convertOpToValue(valueExpr); continue; } @@ -3048,7 +3048,7 @@ Parser::createTupleExpr(SMRange loc, ArrayRef elements, ArrayRef elementNames) { for (const ast::Expr *element : elements) { ast::Type eleTy = element->getType(); - if (eleTy.isa()) { + if (isa(eleTy)) { return emitError( element->getLoc(), llvm::formatv("unable to build a tuple with `{0}` element", eleTy)); @@ -3064,7 +3064,7 @@ FailureOr Parser::createEraseStmt(SMRange loc, ast::Expr *rootOp) { // Check that root is an Operation. ast::Type rootType = rootOp->getType(); - if (!rootType.isa()) + if (!isa(rootType)) return emitError(rootOp->getLoc(), "expected `Op` expression"); return ast::EraseStmt::create(ctx, loc, rootOp); @@ -3075,7 +3075,7 @@ Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp, MutableArrayRef replValues) { // Check that root is an Operation. ast::Type rootType = rootOp->getType(); - if (!rootType.isa()) { + if (!isa(rootType)) { return emitError( rootOp->getLoc(), llvm::formatv("expected `Op` expression, but got `{0}`", rootType)); @@ -3088,7 +3088,7 @@ Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp, ast::Type replType = replExpr->getType(); // Check that replExpr is an Operation, Value, or ValueRange. - if (replType.isa()) { + if (isa(replType)) { if (shouldConvertOpToValues) replExpr = convertOpToValue(replExpr); continue; @@ -3110,7 +3110,7 @@ Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp, ast::CompoundStmt *rewriteBody) { // Check that root is an Operation. ast::Type rootType = rootOp->getType(); - if (!rootType.isa()) { + if (!isa(rootType)) { return emitError( rootOp->getLoc(), llvm::formatv("expected `Op` expression, but got `{0}`", rootType)); @@ -3125,9 +3125,9 @@ Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp, LogicalResult Parser::codeCompleteMemberAccess(ast::Expr *parentExpr) { ast::Type parentType = parentExpr->getType(); - if (ast::OperationType opType = parentType.dyn_cast()) + if (ast::OperationType opType = dyn_cast(parentType)) codeCompleteContext->codeCompleteOperationMemberAccess(opType); - else if (ast::TupleType tupleType = parentType.dyn_cast()) + else if (ast::TupleType tupleType = dyn_cast(parentType)) codeCompleteContext->codeCompleteTupleMemberAccess(tupleType); return failure(); } diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp index d282ee8f61d8f..ae0961c62abad 100644 --- a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp +++ b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp @@ -137,7 +137,8 @@ struct PDLIndexSymbol { /// Return the location of the definition of this symbol. SMRange getDefLoc() const { - if (const ast::Decl *decl = llvm::dyn_cast_if_present(definition)) { + if (const ast::Decl *decl = + llvm::dyn_cast_if_present(definition)) { const ast::Name *declName = decl->getName(); return declName ? declName->getLoc() : decl->getLoc(); } @@ -466,7 +467,8 @@ PDLDocument::findHover(const lsp::URIForFile &uri, return std::nullopt; // Add hover for operation names. - if (const auto *op = llvm::dyn_cast_if_present(symbol->definition)) + if (const auto *op = + llvm::dyn_cast_if_present(symbol->definition)) return buildHoverForOpName(op, hoverRange); const auto *decl = symbol->definition.get(); return findHover(decl, hoverRange); @@ -587,7 +589,7 @@ lsp::Hover PDLDocument::buildHoverForUserConstraintOrRewrite( hoverOS << "***\n"; } ast::Type resultType = decl->getResultType(); - if (auto resultTupleTy = resultType.dyn_cast()) { + if (auto resultTupleTy = dyn_cast(resultType)) { if (!resultTupleTy.empty()) { hoverOS << "Results:\n"; for (auto it : llvm::zip(resultTupleTy.getElementNames(), @@ -795,13 +797,13 @@ class LSPCodeCompleteContext : public CodeCompleteContext { } if (allowInlineTypeConstraints) { /// Attr. - if (!currentType || currentType.isa()) + if (!currentType || isa(currentType)) addCoreConstraint("Attr", "mlir::Attribute", "Attr<$1>"); /// Value. - if (!currentType || currentType.isa()) + if (!currentType || isa(currentType)) addCoreConstraint("Value", "mlir::Value", "Value<$1>"); /// ValueRange. - if (!currentType || currentType.isa()) + if (!currentType || isa(currentType)) addCoreConstraint("ValueRange", "mlir::ValueRange", "ValueRange<$1>"); } @@ -1242,7 +1244,7 @@ void PDLDocument::getInlayHintsFor(const ast::OperationExpr *expr, const lsp::URIForFile &uri, std::vector &inlayHints) { // Check for ODS information. - ast::OperationType opType = expr->getType().dyn_cast(); + ast::OperationType opType = dyn_cast(expr->getType()); const auto *odsOp = opType ? opType.getODSOperation() : nullptr; auto addOpHint = [&](const ast::Expr *valueExpr, StringRef label) {