|
| 1 | +//===-- LLVMInsertChainFolder.cpp -----------------------------------------===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | + |
| 9 | +#include "flang/Optimizer/CodeGen/LLVMInsertChainFolder.h" |
| 10 | +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" |
| 11 | +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| 12 | +#include "mlir/IR/Builders.h" |
| 13 | +#include "llvm/Support/Debug.h" |
| 14 | + |
| 15 | +#define DEBUG_TYPE "flang-insert-folder" |
| 16 | + |
| 17 | +#include <deque> |
| 18 | + |
| 19 | +namespace { |
| 20 | +// Helper class to construct the attribute elements of an aggregate value being |
| 21 | +// folded without creating a full mlir::Attribute representation for each step |
| 22 | +// of the insert value chain, which would both be expensive in terms of |
| 23 | +// compilation time and memory (since the intermediate Attribute would survive, |
| 24 | +// unused, inside the mlir context). |
| 25 | +class InsertChainBackwardFolder { |
| 26 | + // Type for the current value of an element of the aggregate value being |
| 27 | + // constructed by the insert chain. |
| 28 | + // At any point of the insert chain, the value of an element is either: |
| 29 | + // - nullptr: not yet known, the insert has not yet been seen. |
| 30 | + // - an mlir::Attribute: the element is fully defined. |
| 31 | + // - a nested InsertChainBackwardFolder: the element is itself an aggregate |
| 32 | + // and its sub-elements have been partially defined (insert with mutliple |
| 33 | + // indices have been seen). |
| 34 | + |
| 35 | + // The insertion folder assumes backward walk of the insert chain. Once an |
| 36 | + // element or sub-element has been defined, it is not overriden by new |
| 37 | + // insertions (last insert wins). |
| 38 | + using InFlightValue = |
| 39 | + llvm::PointerUnion<mlir::Attribute, InsertChainBackwardFolder *>; |
| 40 | + |
| 41 | +public: |
| 42 | + InsertChainBackwardFolder( |
| 43 | + mlir::Type type, std::deque<InsertChainBackwardFolder> *folderStorage) |
| 44 | + : values(getNumElements(type), mlir::Attribute{}), |
| 45 | + folderStorage{folderStorage}, type{type} {} |
| 46 | + |
| 47 | + /// Push |
| 48 | + bool pushValue(mlir::Attribute val, llvm::ArrayRef<int64_t> at); |
| 49 | + |
| 50 | + mlir::Attribute finalize(mlir::Attribute defaultFieldValue); |
| 51 | + |
| 52 | +private: |
| 53 | + static int64_t getNumElements(mlir::Type type) { |
| 54 | + if (auto structTy = |
| 55 | + llvm::dyn_cast_if_present<mlir::LLVM::LLVMStructType>(type)) |
| 56 | + return structTy.getBody().size(); |
| 57 | + if (auto arrayTy = |
| 58 | + llvm::dyn_cast_if_present<mlir::LLVM::LLVMArrayType>(type)) |
| 59 | + return arrayTy.getNumElements(); |
| 60 | + return 0; |
| 61 | + } |
| 62 | + |
| 63 | + static mlir::Type getSubElementType(mlir::Type type, int64_t field) { |
| 64 | + if (auto arrayTy = |
| 65 | + llvm::dyn_cast_if_present<mlir::LLVM::LLVMArrayType>(type)) |
| 66 | + return arrayTy.getElementType(); |
| 67 | + if (auto structTy = |
| 68 | + llvm::dyn_cast_if_present<mlir::LLVM::LLVMStructType>(type)) |
| 69 | + return structTy.getBody()[field]; |
| 70 | + return nullptr; |
| 71 | + } |
| 72 | + |
| 73 | + // Current element value of the aggregate value being built. |
| 74 | + llvm::SmallVector<InFlightValue> values; |
| 75 | + // std::deque is used to allocate storage for nested list and guarantee the |
| 76 | + // stability of the InsertChainBackwardFolder* used as element value. |
| 77 | + std::deque<InsertChainBackwardFolder> *folderStorage; |
| 78 | + // Type of the aggregate value being built. |
| 79 | + mlir::Type type; |
| 80 | +}; |
| 81 | +} // namespace |
| 82 | + |
| 83 | +// Helper to fold the value being inserted by an llvm.insert_value. |
| 84 | +// This may call tryFoldingLLVMInsertChain if the value is an aggregate and |
| 85 | +// was itself constructed by a different insert chain. |
| 86 | +// Returns a nullptr Attribute if the value could not be folded. |
| 87 | +static mlir::Attribute getAttrIfConstant(mlir::Value val, |
| 88 | + mlir::OpBuilder &rewriter) { |
| 89 | + if (auto cst = val.getDefiningOp<mlir::LLVM::ConstantOp>()) |
| 90 | + return cst.getValue(); |
| 91 | + if (auto insert = val.getDefiningOp<mlir::LLVM::InsertValueOp>()) { |
| 92 | + llvm::FailureOr<mlir::Attribute> attr = |
| 93 | + fir::tryFoldingLLVMInsertChain(val, rewriter); |
| 94 | + if (succeeded(attr)) |
| 95 | + return *attr; |
| 96 | + return nullptr; |
| 97 | + } |
| 98 | + if (val.getDefiningOp<mlir::LLVM::ZeroOp>()) |
| 99 | + return mlir::LLVM::ZeroAttr::get(val.getContext()); |
| 100 | + if (val.getDefiningOp<mlir::LLVM::UndefOp>()) |
| 101 | + return mlir::LLVM::UndefAttr::get(val.getContext()); |
| 102 | + if (mlir::Operation *op = val.getDefiningOp()) { |
| 103 | + unsigned resNum = llvm::cast<mlir::OpResult>(val).getResultNumber(); |
| 104 | + llvm::SmallVector<mlir::Value> results; |
| 105 | + if (mlir::succeeded(rewriter.tryFold(op, results)) && |
| 106 | + results.size() > resNum) { |
| 107 | + if (auto cst = results[resNum].getDefiningOp<mlir::LLVM::ConstantOp>()) |
| 108 | + return cst.getValue(); |
| 109 | + } |
| 110 | + } |
| 111 | + if (auto trunc = val.getDefiningOp<mlir::LLVM::TruncOp>()) |
| 112 | + if (auto attr = getAttrIfConstant(trunc.getArg(), rewriter)) |
| 113 | + if (auto intAttr = llvm::dyn_cast<mlir::IntegerAttr>(attr)) |
| 114 | + return mlir::IntegerAttr::get(trunc.getType(), intAttr.getInt()); |
| 115 | + LLVM_DEBUG(llvm::dbgs() << "cannot fold insert value operand: " << val |
| 116 | + << "\n"); |
| 117 | + return nullptr; |
| 118 | +} |
| 119 | + |
| 120 | +mlir::Attribute |
| 121 | +InsertChainBackwardFolder::finalize(mlir::Attribute defaultFieldValue) { |
| 122 | + llvm::SmallVector<mlir::Attribute> attrs = llvm::map_to_vector( |
| 123 | + values, [&](InFlightValue inFlight) -> mlir::Attribute { |
| 124 | + if (!inFlight) |
| 125 | + return defaultFieldValue; |
| 126 | + if (auto attr = llvm::dyn_cast<mlir::Attribute>(inFlight)) |
| 127 | + return attr; |
| 128 | + return llvm::cast<InsertChainBackwardFolder *>(inFlight)->finalize( |
| 129 | + defaultFieldValue); |
| 130 | + }); |
| 131 | + return mlir::ArrayAttr::get(type.getContext(), attrs); |
| 132 | +} |
| 133 | + |
| 134 | +bool InsertChainBackwardFolder::pushValue(mlir::Attribute val, |
| 135 | + llvm::ArrayRef<int64_t> at) { |
| 136 | + if (at.size() == 0 || at[0] >= static_cast<int64_t>(values.size())) |
| 137 | + return false; |
| 138 | + InFlightValue &inFlight = values[at[0]]; |
| 139 | + if (!inFlight) { |
| 140 | + if (at.size() == 1) { |
| 141 | + inFlight = val; |
| 142 | + return true; |
| 143 | + } |
| 144 | + // This is the first insert to a nested field. Create a |
| 145 | + // InsertChainBackwardFolder for the current element value. |
| 146 | + mlir::Type subType = getSubElementType(type, at[0]); |
| 147 | + if (!subType) |
| 148 | + return false; |
| 149 | + InsertChainBackwardFolder &inFlightList = |
| 150 | + folderStorage->emplace_back(subType, folderStorage); |
| 151 | + inFlight = &inFlightList; |
| 152 | + return inFlightList.pushValue(val, at.drop_front()); |
| 153 | + } |
| 154 | + // Keep last inserted value if already set. |
| 155 | + if (llvm::isa<mlir::Attribute>(inFlight)) |
| 156 | + return true; |
| 157 | + auto *inFlightList = llvm::cast<InsertChainBackwardFolder *>(inFlight); |
| 158 | + if (at.size() == 1) { |
| 159 | + if (!llvm::isa<mlir::LLVM::ZeroAttr, mlir::LLVM::UndefAttr>(val)) { |
| 160 | + LLVM_DEBUG(llvm::dbgs() |
| 161 | + << "insert chain sub-element partially overwritten initial " |
| 162 | + "value is not zero or undef\n"); |
| 163 | + return false; |
| 164 | + } |
| 165 | + inFlight = inFlightList->finalize(val); |
| 166 | + return true; |
| 167 | + } |
| 168 | + return inFlightList->pushValue(val, at.drop_front()); |
| 169 | +} |
| 170 | + |
| 171 | +llvm::FailureOr<mlir::Attribute> |
| 172 | +fir::tryFoldingLLVMInsertChain(mlir::Value val, mlir::OpBuilder &rewriter) { |
| 173 | + if (auto cst = val.getDefiningOp<mlir::LLVM::ConstantOp>()) |
| 174 | + return cst.getValue(); |
| 175 | + if (auto insert = val.getDefiningOp<mlir::LLVM::InsertValueOp>()) { |
| 176 | + LLVM_DEBUG(llvm::dbgs() << "trying to fold insert chain:" << val << "\n"); |
| 177 | + if (auto structTy = |
| 178 | + llvm::dyn_cast<mlir::LLVM::LLVMStructType>(insert.getType())) { |
| 179 | + mlir::LLVM::InsertValueOp currentInsert = insert; |
| 180 | + mlir::LLVM::InsertValueOp lastInsert; |
| 181 | + std::deque<InsertChainBackwardFolder> folderStorage; |
| 182 | + InsertChainBackwardFolder inFlightList(structTy, &folderStorage); |
| 183 | + while (currentInsert) { |
| 184 | + mlir::Attribute attr = |
| 185 | + getAttrIfConstant(currentInsert.getValue(), rewriter); |
| 186 | + if (!attr) |
| 187 | + return llvm::failure(); |
| 188 | + if (!inFlightList.pushValue(attr, currentInsert.getPosition())) |
| 189 | + return llvm::failure(); |
| 190 | + lastInsert = currentInsert; |
| 191 | + currentInsert = currentInsert.getContainer() |
| 192 | + .getDefiningOp<mlir::LLVM::InsertValueOp>(); |
| 193 | + } |
| 194 | + mlir::Attribute defaultVal; |
| 195 | + if (lastInsert) { |
| 196 | + if (lastInsert.getContainer().getDefiningOp<mlir::LLVM::ZeroOp>()) |
| 197 | + defaultVal = mlir::LLVM::ZeroAttr::get(val.getContext()); |
| 198 | + else if (lastInsert.getContainer().getDefiningOp<mlir::LLVM::UndefOp>()) |
| 199 | + defaultVal = mlir::LLVM::UndefAttr::get(val.getContext()); |
| 200 | + } |
| 201 | + if (!defaultVal) { |
| 202 | + LLVM_DEBUG(llvm::dbgs() |
| 203 | + << "insert chain initial value is not Zero or Undef\n"); |
| 204 | + return llvm::failure(); |
| 205 | + } |
| 206 | + return inFlightList.finalize(defaultVal); |
| 207 | + } |
| 208 | + } |
| 209 | + return llvm::failure(); |
| 210 | +} |
0 commit comments