|
9 | 9 | #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
|
10 | 10 | #include "mlir/Dialect/Affine/IR/AffineOps.h"
|
11 | 11 | #include "mlir/Dialect/Affine/LoopUtils.h"
|
| 12 | +#include "mlir/Dialect/Arith/IR/Arith.h" |
| 13 | +#include "mlir/Dialect/Arith/Utils/Utils.h" |
12 | 14 | #include "mlir/Dialect/Func/IR/FuncOps.h"
|
13 | 15 | #include "mlir/Dialect/SCF/IR/SCF.h"
|
14 | 16 | #include "mlir/Dialect/SCF/Transforms/Patterns.h"
|
|
17 | 19 | #include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
18 | 20 | #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
|
19 | 21 | #include "mlir/Dialect/Transform/IR/TransformOps.h"
|
| 22 | +#include "mlir/Dialect/Utils/StaticValueUtils.h" |
20 | 23 | #include "mlir/Dialect/Vector/IR/VectorOps.h"
|
| 24 | +#include "mlir/IR/BuiltinAttributes.h" |
21 | 25 | #include "mlir/IR/Dominance.h"
|
| 26 | +#include "mlir/IR/OpDefinition.h" |
22 | 27 |
|
23 | 28 | using namespace mlir;
|
24 | 29 | using namespace mlir::affine;
|
@@ -47,6 +52,7 @@ void transform::ApplySCFStructuralConversionPatternsOp::
|
47 | 52 | //===----------------------------------------------------------------------===//
|
48 | 53 | // GetParentForOp
|
49 | 54 | //===----------------------------------------------------------------------===//
|
| 55 | + |
50 | 56 | DiagnosedSilenceableFailure
|
51 | 57 | transform::GetParentForOp::apply(transform::TransformRewriter &rewriter,
|
52 | 58 | transform::TransformResults &results,
|
@@ -76,6 +82,72 @@ transform::GetParentForOp::apply(transform::TransformRewriter &rewriter,
|
76 | 82 | return DiagnosedSilenceableFailure::success();
|
77 | 83 | }
|
78 | 84 |
|
| 85 | +//===----------------------------------------------------------------------===// |
| 86 | +// ForallToForOp |
| 87 | +//===----------------------------------------------------------------------===// |
| 88 | + |
| 89 | +DiagnosedSilenceableFailure |
| 90 | +transform::ForallToForOp::apply(transform::TransformRewriter &rewriter, |
| 91 | + transform::TransformResults &results, |
| 92 | + transform::TransformState &state) { |
| 93 | + auto payload = state.getPayloadOps(getTarget()); |
| 94 | + if (!llvm::hasSingleElement(payload)) |
| 95 | + return emitSilenceableError() << "expected a single payload op"; |
| 96 | + |
| 97 | + auto target = dyn_cast<scf::ForallOp>(*payload.begin()); |
| 98 | + if (!target) { |
| 99 | + DiagnosedSilenceableFailure diag = |
| 100 | + emitSilenceableError() << "expected the payload to be scf.forall"; |
| 101 | + diag.attachNote((*payload.begin())->getLoc()) << "payload op"; |
| 102 | + return diag; |
| 103 | + } |
| 104 | + |
| 105 | + rewriter.setInsertionPoint(target); |
| 106 | + |
| 107 | + if (!target.getOutputs().empty()) { |
| 108 | + return emitSilenceableError() |
| 109 | + << "unsupported shared outputs (didn't bufferize?)"; |
| 110 | + } |
| 111 | + |
| 112 | + SmallVector<OpFoldResult> lbs = target.getMixedLowerBound(); |
| 113 | + SmallVector<OpFoldResult> ubs = target.getMixedUpperBound(); |
| 114 | + SmallVector<OpFoldResult> steps = target.getMixedStep(); |
| 115 | + |
| 116 | + if (getNumResults() != lbs.size()) { |
| 117 | + DiagnosedSilenceableFailure diag = |
| 118 | + emitSilenceableError() |
| 119 | + << "op expects as many results (" << getNumResults() |
| 120 | + << ") as payload has induction variables (" << lbs.size() << ")"; |
| 121 | + diag.attachNote(target.getLoc()) << "payload op"; |
| 122 | + return diag; |
| 123 | + } |
| 124 | + |
| 125 | + auto loc = target.getLoc(); |
| 126 | + SmallVector<Value> ivs; |
| 127 | + for (auto &&[lb, ub, step] : llvm::zip(lbs, ubs, steps)) { |
| 128 | + Value lbValue = getValueOrCreateConstantIndexOp(rewriter, loc, lb); |
| 129 | + Value ubValue = getValueOrCreateConstantIndexOp(rewriter, loc, ub); |
| 130 | + Value stepValue = getValueOrCreateConstantIndexOp(rewriter, loc, step); |
| 131 | + auto loop = rewriter.create<scf::ForOp>( |
| 132 | + loc, lbValue, ubValue, stepValue, ValueRange(), |
| 133 | + [](OpBuilder &, Location, Value, ValueRange) {}); |
| 134 | + ivs.push_back(loop.getInductionVar()); |
| 135 | + rewriter.setInsertionPointToStart(loop.getBody()); |
| 136 | + rewriter.create<scf::YieldOp>(loc); |
| 137 | + rewriter.setInsertionPointToStart(loop.getBody()); |
| 138 | + } |
| 139 | + rewriter.eraseOp(target.getBody()->getTerminator()); |
| 140 | + rewriter.inlineBlockBefore(target.getBody(), &*rewriter.getInsertionPoint(), |
| 141 | + ivs); |
| 142 | + rewriter.eraseOp(target); |
| 143 | + |
| 144 | + for (auto &&[i, iv] : llvm::enumerate(ivs)) { |
| 145 | + results.set(cast<OpResult>(getTransformed()[i]), |
| 146 | + {iv.getParentBlock()->getParentOp()}); |
| 147 | + } |
| 148 | + return DiagnosedSilenceableFailure::success(); |
| 149 | +} |
| 150 | + |
79 | 151 | //===----------------------------------------------------------------------===//
|
80 | 152 | // LoopOutlineOp
|
81 | 153 | //===----------------------------------------------------------------------===//
|
|
0 commit comments