Skip to content

[CIR] Upstream cir-canonicalize pass #131891

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions clang/include/clang/Basic/DiagnosticFrontendKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -386,4 +386,12 @@ def warn_hlsl_langstd_minimal :
Warning<"support for HLSL language version %0 is incomplete, "
"recommend using %1 instead">,
InGroup<HLSLDXCCompat>;

// ClangIR frontend errors
def err_cir_to_cir_transform_failed : Error<
"CIR-to-CIR transformation failed">, DefaultFatal;

def err_cir_verification_failed_pre_passes : Error<
"CIR module verification error before running CIR-to-CIR passes">,
DefaultFatal;
}
4 changes: 4 additions & 0 deletions clang/include/clang/CIR/CIRGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ class CIRGenerator : public clang::ASTConsumer {
void Initialize(clang::ASTContext &astContext) override;
bool HandleTopLevelDecl(clang::DeclGroupRef group) override;
mlir::ModuleOp getModule() const;
mlir::MLIRContext &getMLIRContext() { return *mlirContext; };
const mlir::MLIRContext &getMLIRContext() const { return *mlirContext; };

bool verifyModule() const;
};

} // namespace cir
Expand Down
39 changes: 39 additions & 0 deletions clang/include/clang/CIR/CIRToCIRPasses.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file declares an interface for running CIR-to-CIR passes.
//
//===----------------------------------------------------------------------===//

#ifndef CLANG_CIR_CIRTOCIRPASSES_H
#define CLANG_CIR_CIRTOCIRPASSES_H

#include "mlir/Pass/Pass.h"

#include <memory>

namespace clang {
class ASTContext;
}

namespace mlir {
class MLIRContext;
class ModuleOp;
} // namespace mlir

namespace cir {

// Run set of cleanup/prepare/etc passes CIR <-> CIR.
mlir::LogicalResult runCIRToCIRPasses(mlir::ModuleOp theModule,
mlir::MLIRContext &mlirCtx,
clang::ASTContext &astCtx,
bool enableVerifier);

} // namespace cir

#endif // CLANG_CIR_CIRTOCIRPASSES_H_
1 change: 1 addition & 0 deletions clang/include/clang/CIR/Dialect/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class ASTContext;
}
namespace mlir {

std::unique_ptr<Pass> createCIRCanonicalizePass();
std::unique_ptr<Pass> createCIRFlattenCFGPass();

void populateCIRPreLoweringPasses(mlir::OpPassManager &pm);
Expand Down
18 changes: 18 additions & 0 deletions clang/include/clang/CIR/Dialect/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,24 @@

include "mlir/Pass/PassBase.td"

def CIRCanonicalize : Pass<"cir-canonicalize"> {
let summary = "Performs CIR canonicalization";
let description = [{
Perform canonicalizations on CIR and removes some redundant operations.

This pass performs basic cleanup and canonicalization transformations that
are not intended to affect CIR-to-source fidelity and high-level code
analysis passes. Example transformations performed in this pass include
empty scope cleanup, trivial `try` cleanup, redundant branch cleanup, etc.
Those more "heavyweight" transformations and those transformations that
could significantly affect CIR-to-source fidelity are performed in the
`cir-simplify` pass.
}];

let constructor = "mlir::createCIRCanonicalizePass()";
let dependentDialects = ["cir::CIRDialect"];
}

def CIRFlattenCFG : Pass<"cir-flatten-cfg"> {
let summary = "Produces flatten CFG";
let description = [{
Expand Down
12 changes: 12 additions & 0 deletions clang/include/clang/CIR/MissingFeatures.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,18 @@ struct MissingFeatures {
static bool scalableVectors() { return false; }
static bool unsizedTypes() { return false; }
static bool vectorType() { return false; }

// Future CIR operations
static bool labelOp() { return false; }
static bool brCondOp() { return false; }
static bool switchOp() { return false; }
static bool tryOp() { return false; }
static bool unaryOp() { return false; }
static bool selectOp() { return false; }
static bool complexCreateOp() { return false; }
static bool complexRealOp() { return false; }
static bool complexImagOp() { return false; }
static bool callOp() { return false; }
};

} // namespace cir
Expand Down
14 changes: 12 additions & 2 deletions clang/include/clang/Driver/Options.td
Original file line number Diff line number Diff line change
Expand Up @@ -2978,6 +2978,15 @@ def fapple_link_rtlib : Flag<["-"], "fapple-link-rtlib">, Group<f_Group>,
HelpText<"Force linking the clang builtins runtime library">;

/// ClangIR-specific options - BEGIN
def clangir_disable_passes : Flag<["-"], "clangir-disable-passes">,
Visibility<[ClangOption, CC1Option]>,
HelpText<"Disable CIR transformations pipeline">,
MarshallingInfoFlag<FrontendOpts<"ClangIRDisablePasses">>;
def clangir_disable_verifier : Flag<["-"], "clangir-disable-verifier">,
Visibility<[ClangOption, CC1Option]>,
HelpText<"ClangIR: Disable MLIR module verifier">,
MarshallingInfoFlag<FrontendOpts<"ClangIRDisableCIRVerifier">>;

defm clangir : BoolFOption<"clangir",
FrontendOpts<"UseClangIRPipeline">, DefaultFalse,
PosFlag<SetTrue, [], [ClangOption, CC1Option], "Use the ClangIR pipeline to compile">,
Expand Down Expand Up @@ -4822,8 +4831,9 @@ def : Joined<["-"], "mllvm=">,
Visibility<[ClangOption, CLOption, DXCOption, FlangOption]>, Alias<mllvm>,
HelpText<"Alias for -mllvm">, MetaVarName<"<arg>">;
def mmlir : Separate<["-"], "mmlir">,
Visibility<[ClangOption, CLOption, FC1Option, FlangOption]>,
HelpText<"Additional arguments to forward to MLIR's option processing">;
Visibility<[ClangOption, CC1Option, FC1Option, FlangOption]>,
HelpText<"Additional arguments to forward to MLIR's option processing">,
MarshallingInfoStringVector<FrontendOpts<"MLIRArgs">>;
def ffuchsia_api_level_EQ : Joined<["-"], "ffuchsia-api-level=">,
Group<m_Group>, Visibility<[ClangOption, CC1Option]>,
HelpText<"Set Fuchsia API level">,
Expand Down
15 changes: 14 additions & 1 deletion clang/include/clang/Frontend/FrontendOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,14 @@ class FrontendOptions {
LLVM_PREFERRED_TYPE(bool)
unsigned UseClangIRPipeline : 1;

/// Disable Clang IR specific (CIR) passes
LLVM_PREFERRED_TYPE(bool)
unsigned ClangIRDisablePasses : 1;

/// Disable Clang IR (CIR) verifier
LLVM_PREFERRED_TYPE(bool)
unsigned ClangIRDisableCIRVerifier : 1;

CodeCompleteOptions CodeCompleteOpts;

/// Specifies the output format of the AST.
Expand Down Expand Up @@ -488,6 +496,10 @@ class FrontendOptions {
/// should only be used for debugging and experimental features.
std::vector<std::string> LLVMArgs;

/// A list of arguments to forward to MLIR's option processing; this
/// should only be used for debugging and experimental features.
std::vector<std::string> MLIRArgs;

/// File name of the file that will provide record layouts
/// (in the format produced by -fdump-record-layouts).
std::string OverrideRecordLayoutsFile;
Expand Down Expand Up @@ -533,7 +545,8 @@ class FrontendOptions {
EmitExtensionSymbolGraphs(false),
EmitSymbolGraphSymbolLabelsForTesting(false),
EmitPrettySymbolGraphs(false), GenReducedBMI(false),
UseClangIRPipeline(false), TimeTraceGranularity(500),
UseClangIRPipeline(false), ClangIRDisablePasses(false),
ClangIRDisableCIRVerifier(false), TimeTraceGranularity(500),
TimeTraceVerbose(false) {}

/// getInputKindForExtension - Return the appropriate input kind for a file
Expand Down
8 changes: 8 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Verifier.h"

using namespace clang;
using namespace clang::CIRGen;
Expand Down Expand Up @@ -488,6 +489,13 @@ mlir::Type CIRGenModule::convertType(QualType type) {
return genTypes.convertType(type);
}

bool CIRGenModule::verifyModule() const {
// Verify the module after we have finished constructing it, this will
// check the structural properties of the IR and invoke any specific
// verifiers we have on the CIR operations.
return mlir::verify(theModule).succeeded();
}

DiagnosticBuilder CIRGenModule::errorNYI(SourceLocation loc,
llvm::StringRef feature) {
unsigned diagID = diags.getCustomDiagID(
Expand Down
2 changes: 2 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ class CIRGenModule : public CIRGenTypeCache {

void emitTopLevelDecl(clang::Decl *decl);

bool verifyModule() const;

/// Return the address of the given function. If funcType is non-null, then
/// this function will use the specified type if it has to create it.
// TODO: this is a bit weird as `GetAddr` given we give back a FuncOp?
Expand Down
2 changes: 2 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ void CIRGenerator::Initialize(ASTContext &astContext) {
*mlirContext.get(), astContext, codeGenOpts, diags);
}

bool CIRGenerator::verifyModule() const { return cgm->verifyModule(); }

mlir::ModuleOp CIRGenerator::getModule() const { return cgm->getModule(); }

bool CIRGenerator::HandleTopLevelDecl(DeclGroupRef group) {
Expand Down
145 changes: 145 additions & 0 deletions clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements pass that canonicalizes CIR operations, eliminating
// redundant branches, empty scopes, and other unnecessary operations.
//
//===----------------------------------------------------------------------===//

#include "PassDetail.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Region.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"
#include "clang/CIR/Dialect/Passes.h"
#include "clang/CIR/MissingFeatures.h"

using namespace mlir;
using namespace cir;

namespace {

/// Removes branches between two blocks if it is the only branch.
///
/// From:
/// ^bb0:
/// cir.br ^bb1
/// ^bb1: // pred: ^bb0
/// cir.return
///
/// To:
/// ^bb0:
/// cir.return
struct RemoveRedundantBranches : public OpRewritePattern<BrOp> {
using OpRewritePattern<BrOp>::OpRewritePattern;

LogicalResult matchAndRewrite(BrOp op,
PatternRewriter &rewriter) const final {
Block *block = op.getOperation()->getBlock();
Block *dest = op.getDest();

assert(!cir::MissingFeatures::labelOp());

// Single edge between blocks: merge it.
if (block->getNumSuccessors() == 1 &&
dest->getSinglePredecessor() == block) {
rewriter.eraseOp(op);
rewriter.mergeBlocks(dest, block);
return success();
}

return failure();
}
};

struct RemoveEmptyScope
: public OpRewritePattern<ScopeOp>::SplitMatchAndRewrite {
using SplitMatchAndRewrite::SplitMatchAndRewrite;

LogicalResult match(ScopeOp op) const final {
// TODO: Remove this logic once CIR uses MLIR infrastructure to remove
// trivially dead operations
if (op.isEmpty())
return success();

Region &region = op.getScopeRegion();
if (region.getBlocks().front().getOperations().size() == 1)
return success(isa<YieldOp>(region.getBlocks().front().front()));

return failure();
}

void rewrite(ScopeOp op, PatternRewriter &rewriter) const final {
rewriter.eraseOp(op);
}
};

//===----------------------------------------------------------------------===//
// CIRCanonicalizePass
//===----------------------------------------------------------------------===//

struct CIRCanonicalizePass : public CIRCanonicalizeBase<CIRCanonicalizePass> {
using CIRCanonicalizeBase::CIRCanonicalizeBase;

// The same operation rewriting done here could have been performed
// by CanonicalizerPass (adding hasCanonicalizer for target Ops and
// implementing the same from above in CIRDialects.cpp). However, it's
// currently too aggressive for static analysis purposes, since it might
// remove things where a diagnostic can be generated.
//
// FIXME: perhaps we can add one more mode to GreedyRewriteConfig to
// disable this behavior.
void runOnOperation() override;
};

void populateCIRCanonicalizePatterns(RewritePatternSet &patterns) {
// clang-format off
patterns.add<
RemoveRedundantBranches,
RemoveEmptyScope
>(patterns.getContext());
// clang-format on
}

void CIRCanonicalizePass::runOnOperation() {
// Collect rewrite patterns.
RewritePatternSet patterns(&getContext());
populateCIRCanonicalizePatterns(patterns);

// Collect operations to apply patterns.
llvm::SmallVector<Operation *, 16> ops;
getOperation()->walk([&](Operation *op) {
assert(!cir::MissingFeatures::brCondOp());
assert(!cir::MissingFeatures::switchOp());
assert(!cir::MissingFeatures::tryOp());
assert(!cir::MissingFeatures::unaryOp());
assert(!cir::MissingFeatures::selectOp());
assert(!cir::MissingFeatures::complexCreateOp());
assert(!cir::MissingFeatures::complexRealOp());
assert(!cir::MissingFeatures::complexImagOp());
assert(!cir::MissingFeatures::callOp());
// CastOp here is to perform a manual `fold` in
// applyOpPatternsGreedily
if (isa<BrOp, ScopeOp, CastOp>(op))
ops.push_back(op);
});

// Apply patterns.
if (applyOpPatternsGreedily(ops, std::move(patterns)).failed())
signalPassFailure();
}

} // namespace

std::unique_ptr<Pass> mlir::createCIRCanonicalizePass() {
return std::make_unique<CIRCanonicalizePass>();
}
1 change: 1 addition & 0 deletions clang/lib/CIR/Dialect/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_clang_library(MLIRCIRTransforms
CIRCanonicalize.cpp
FlattenCFG.cpp

DEPENDS
Expand Down
Loading