diff --git a/include/npcomp/Dialect/Torch/IR/TorchOps.td b/include/npcomp/Dialect/Torch/IR/TorchOps.td index bd36149a3866..053b69f751e0 100644 --- a/include/npcomp/Dialect/Torch/IR/TorchOps.td +++ b/include/npcomp/Dialect/Torch/IR/TorchOps.td @@ -461,6 +461,7 @@ def Torch_PrimLoopOp : Torch_Op<"prim.Loop", [ } def Torch_PrimLoopConditionOp : Torch_Op<"prim.Loop.condition", [ + DeclareOpInterfaceMethods, Terminator, HasParent<"::mlir::NPCOMP::Torch::PrimLoopOp">]> { let summary = "yield-like terminator for torch.prim.Loop"; diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 2e47cb443239..4dc2801cf748 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -252,6 +252,17 @@ void PrimLoopOp::getSuccessorRegions( regions.emplace_back(getResults()); } +//===----------------------------------------------------------------------===// +// PrimLoopConditionOp +//===----------------------------------------------------------------------===// + +MutableOperandRange +PrimLoopConditionOp::getMutableSuccessorOperands(Optional index) { + // Pass all operands except the condition to the successor which is the + // parent loop op. + return iterArgsMutable(); +} + //===----------------------------------------------------------------------===// // PrimIfOp //===----------------------------------------------------------------------===// @@ -357,10 +368,11 @@ bool DerefineOp::areCastCompatible(mlir::TypeRange inputs, void DerefineOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](DerefineOp op, PatternRewriter &rewriter) { - // TODO: Extend RefineTypes for this case and delete this canonicalization, - // since we don't want control flow or calls to randomly block this fold - // (this canonicalization pattern makes the compiler brittle to control flow - // and calls). + // TODO: This pattern should be removed because type refine does a better + // job dealing with control flow. However, removing this would expose an + // issue with ReduceOpVariants. DerefineOp doesn't have value semantics and + // if not removed eagerly by canonicalizer would prevent ReduceOpVariants + // from converting certain tensors value semantics. bool allAllowRefinement = llvm::all_of(op.getResult().getUsers(), allowsTypeRefinement); if (!allAllowRefinement) diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 0c85164b0b99..063f6abf7b75 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -37,20 +37,38 @@ static Type joinElementTypes(Type lhs, Type rhs) { return Type(); } +static Type getTypeFromDTypeInteger(MLIRContext *context, int64_t dtypeInt) { + // TODO: include c10/core/ScalarType.h to make this cleaner. + switch (dtypeInt) { + case 6: + return Float32Type::get(context); + break; + default: + return Type(); + } +} + namespace { // Statically known information for a particular Value. // -// This struct currently tracks only information relevant for tensor/array-like -// shaped types. It is fine to associate a `ValueKnowledge` with a non-shaped -// type as long as it is in the default "no knowledge" state returned by -// `getPessimisticValueState`. The important invariant is that we cannot -// claim to know something about a value which is false. -// +// This struct currently tracks information relevant for tensor/array-like +// shaped types as well as whether an object is None or not, namely +// !torch.optional. It is fine to associate a `ValueKnowledge` with a non-shaped +// type or non OptionalType as long as it is in the default "no knowledge" +// state returned by `getPessimisticValueState`. The important invariant is that +// we cannot claim to know something about a value which is false. // This class could also be called "dataflow facts", "lattice value", etc. struct ValueKnowledge { + enum class OptionalKnowledge { + unKnown, + isNone, + notNone, + }; ValueKnowledge() = delete; - ValueKnowledge(bool hasSizes, std::vector sizes, Type dtype) - : hasSizes(hasSizes), sizes(sizes), dtype(dtype) { + ValueKnowledge(bool hasSizes, std::vector sizes, Type dtype, + OptionalKnowledge optionalKnowledge) + : hasSizes(hasSizes), sizes(sizes), dtype(dtype), + optional(optionalKnowledge) { assert(sizes.size() == 0 || hasSizes); } @@ -63,6 +81,11 @@ struct ValueKnowledge { result.sizes = tensorType.getSizes().vec(); } result.dtype = tensorType.getOptionalDtype(); + result.optional = OptionalKnowledge::notNone; + } else if (auto optionalType = type.dyn_cast()) { + result.optional = OptionalKnowledge::isNone; + } else if (!type.isa()) { + result.optional = OptionalKnowledge::notNone; } return result; } @@ -70,7 +93,7 @@ struct ValueKnowledge { // Return a pessimistic/conservative value state without assuming any knowlege // about the IR. static ValueKnowledge getPessimisticValueState(MLIRContext *context) { - return ValueKnowledge(false, {}, Type()); + return ValueKnowledge(false, {}, Type(), OptionalKnowledge::unKnown); } // Return a pessimistic/conservative value state only using knowlege already // recorded in the IR. @@ -79,8 +102,8 @@ struct ValueKnowledge { } bool operator==(const ValueKnowledge &rhs) const { - return std::make_tuple(hasSizes, sizes, dtype) == - std::make_tuple(rhs.hasSizes, rhs.sizes, rhs.dtype); + return std::make_tuple(hasSizes, sizes, dtype, optional) == + std::make_tuple(rhs.hasSizes, rhs.sizes, rhs.dtype, rhs.optional); } // Given two pieces of static knowledge, calculate conservatively the @@ -92,6 +115,11 @@ struct ValueKnowledge { // consistent with lhs and rhs. ValueKnowledge result = getPessimisticValueState(nullptr); + // If lhs and rhs are not equal, the knowledge state must be the + // pessimistic state. + if (lhs.optional == rhs.optional) + result.optional = lhs.optional; + if (lhs.hasSizes && !rhs.hasSizes) { result.hasSizes = true; result.sizes = lhs.sizes; @@ -129,6 +157,7 @@ struct ValueKnowledge { // This is equal to nullptr if we don't know that it is a specific concrete // type. Type dtype; + OptionalKnowledge optional; }; // Forward intraprocedural dataflow for type information. @@ -142,184 +171,455 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis { visitOperation(Operation *op, ArrayRef *> operands) final { if (isa(op)) { + AtenTanhOp, AtenBatchNormOp, AtenReluOp, AtenAddScalarOp, + AtenSubScalarOp, AtenMulScalarOp, AtenDivScalarOp, AtenFmodScalarOp, + AtenFloorDivideScalarOp, AtenEqScalarOp, AtenGeScalarOp, + AtenNeScalarOp, AtenBitwiseNotOp, AtenToDtypeOp, AtenExpOp, + AtenSinOp, AtenCosOp, DerefineOp>(op)) { return getLatticeElement(op->getResult(0)).join(*operands[0]); } - if (isa(op)) { - auto &lhs = operands[0]->getValue(); - auto &rhs = operands[1]->getValue(); - auto knowledge = - ValueKnowledge::getPessimisticValueState(op->getContext()); - knowledge.hasSizes = true; - // WARNING: We could be more precise here by calculating the output - // shape as "(lhs.shape[0], rhs.shape[1])". However, that is really tricky - // at this stage in the compiler because we don't really have many static - // guarantees about the input ranks because `aten` ops do dynamic error - // checking and safely abort the program. There is nothing preventing us - // from (correctly!) statically inferring the shapes of the operands to - // shapes that are guaranteed to cause an error at runtime. - // - // Example: Suppose a user program calls `aten.mm` with two rank-0 - // operands. The program emits an error when invoked, but when running - // this pass, we will (correctly!) infer `lhs.hasSizes && lhs.sizes.size() - // == 0 && rhs.hasSizes && rhs.sizes.size() == 0` -- it's not safe to - // access `lhs.sizes[0]` / `rhs.sizes[1]`! So when writing this transfer - // function, it's not as simple as taking `lhs.sizes[0]` and - // `rhs.sizes[1]`, as both of those might read out of bounds of the array. - // It would require more complicated logic. - // - // Just knowing dtypes and ranks is sufficient at this stage - // in the compiler. The precise per-dimension size propagation is best - // done lower in the stack, such as at the linalg level, where we have - // more static guarantees and more structure. - knowledge.sizes.resize(2, kUnknownSize); - // TODO: Investigate promotion rules if element types mismatch. - // This is conservatively correct, assuming that if both element types are - // the same, then the result is of that same element type. - knowledge.dtype = joinElementTypes(lhs.dtype, rhs.dtype); - return getLatticeElement(op->getResult(0)).join(knowledge); - } else if (isa(op)) { - // The output shape is the input shape with the last dimension changed - // to the weight's output dimension. - auto knowledge = operands[0]->getValue(); - if (knowledge.hasSizes && knowledge.sizes.size() > 0) - knowledge.sizes[knowledge.sizes.size() - 1] = kUnknownSize; - // TODO: Handle case of bias being None gracefully. Requires a lattice - // that tracks "None" (torch.optional). See also - // DerefineOp::getCanonicalizationPatterns for more refinement that needs - // to be done in this pass. - knowledge.dtype = joinElementTypes( - knowledge.dtype, joinElementTypes(operands[1]->getValue().dtype, - operands[2]->getValue().dtype)); - return getLatticeElement(op->getResult(0)).join(knowledge); - } else if (isa(op)) { - auto knowledge = - ValueKnowledge::getPessimisticValueState(op->getContext()); - knowledge.hasSizes = true; - knowledge.sizes.resize(4, kUnknownSize); - // Running some experiments in PyTorch, the bias doesn't seem to - // contribute to the final element type. - knowledge.dtype = joinElementTypes(operands[0]->getValue().dtype, - operands[1]->getValue().dtype); - return getLatticeElement(op->getResult(0)).join(knowledge); - } else if (isa(op)) { - auto knowledge = - ValueKnowledge::getPessimisticValueState(op->getContext()); - knowledge.hasSizes = true; - knowledge.sizes.resize(4, kUnknownSize); - knowledge.dtype = operands[0]->getValue().dtype; - return getLatticeElement(op->getResult(0)).join(knowledge); - } else if (isa(op)) { + + if (isa(op)) { auto input = operands[0]->getValue(); auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext()); - if (input.hasSizes) { - knowledge.hasSizes = true; - knowledge.sizes.resize(input.sizes.size(), kUnknownSize); - } - knowledge.dtype = input.dtype; + knowledge.hasSizes = true; + knowledge.sizes.resize(1, 1); + knowledge.dtype = IntegerType::get(op->getContext(), 1); return getLatticeElement(op->getResult(0)).join(knowledge); + } + + if (auto mm = llvm::dyn_cast(op)) { + return visitAtenMmOp(mm, operands); + } else if (auto linear = llvm::dyn_cast(op)) { + return visitAtenLinearOp(linear, operands); + } else if (auto conv2d = llvm::dyn_cast(op)) { + return visitAtenConv2dOp(conv2d, operands); + } else if (auto maxPool2d = llvm::dyn_cast(op)) { + return visitAtenMaxPool2dOp(maxPool2d, operands); + } else if (auto avgPool2d = llvm::dyn_cast(op)) { + return visitAtenAdaptiveAvgPool2dOp(avgPool2d, operands); } else if (isa(op)) { - // This is a general binary broadcasting shape transfer function. - // We currently don't track "size 1" in our lattice, but we might want to. - // We could make this more precise as well. But again, as with the other - // shape transfer functions, handling the statically-invalid case is - // tricky, so we defer that until we need it. - auto lhs = operands[0]->getValue(); - auto rhs = operands[1]->getValue(); - auto knowledge = - ValueKnowledge::getPessimisticValueState(op->getContext()); - if (lhs.hasSizes && rhs.hasSizes) { - knowledge.hasSizes = true; - knowledge.sizes.resize(std::max(lhs.sizes.size(), rhs.sizes.size()), - kUnknownSize); - } - knowledge.dtype = joinElementTypes(lhs.dtype, rhs.dtype); - return getLatticeElement(op->getResult(0)).join(knowledge); - } else if (isa(op)) { - // This is a general broadcasting shape transfer function. - // We currently don't track "size 1" in our lattice, but we might want to. - // We could make this more precise as well. But again, as with the other - // shape transfer functions, handling the statically-invalid case is - // tricky, so we defer that until we need it. - auto a = operands[0]->getValue(); - auto b = operands[1]->getValue(); - auto c = operands[1]->getValue(); - auto knowledge = - ValueKnowledge::getPessimisticValueState(op->getContext()); - if (a.hasSizes && b.hasSizes && c.hasSizes) { - knowledge.hasSizes = true; - knowledge.sizes.resize( - std::max(std::max(a.sizes.size(), b.sizes.size()), c.sizes.size()), - kUnknownSize); - } - knowledge.dtype = - joinElementTypes(joinElementTypes(a.dtype, b.dtype), c.dtype); - return getLatticeElement(op->getResult(0)).join(knowledge); + AtenDivTensorOp, Aten__And__TensorOp, AtenEqTensorOp>(op)) { + return visitBinaryBroadcastingOp(op, operands); + } else if (auto lerpTensor = llvm::dyn_cast(op)) { + return visitAtenLerpTensorOp(lerpTensor, operands); } else if (auto flatten = dyn_cast(op)) { - int64_t startDim; - int64_t endDim; - auto operand = operands[0]->getValue(); - auto knowledge = - ValueKnowledge::getPessimisticValueState(op->getContext()); - knowledge.dtype = operand.dtype; - if (operand.hasSizes && operand.sizes.size() == 0) { - // Rank 0 is special and flattens to rank 1 with size 1. - knowledge.hasSizes = true; - knowledge.sizes.push_back(1); - } else if (operand.hasSizes && - matchPattern(flatten.start_dim(), - m_TorchConstantInt(&startDim)) && - matchPattern(flatten.end_dim(), m_TorchConstantInt(&endDim))) { - int64_t inputRank = operand.sizes.size(); - if (startDim < 0) - startDim += inputRank; - if (endDim < 0) - endDim += inputRank; - // Careful: dimension numbers might be out of bounds. - if (0 <= startDim && startDim <= (inputRank - 1) && 0 <= endDim && - endDim <= (inputRank - 1) && startDim <= endDim) { - knowledge.hasSizes = true; - for (auto i = 0; i < startDim; i++) - knowledge.sizes.push_back(operand.sizes[i]); - knowledge.sizes.push_back(kUnknownSize); - for (auto i = endDim + 1; i < inputRank; i++) - knowledge.sizes.push_back(operand.sizes[i]); - } - } - return getLatticeElement(op->getResult(0)).join(knowledge); + return visitAtenFlattenUsingIntsOp(flatten, operands); } else if (auto unsqueeze = dyn_cast(op)) { - auto operand = operands[0]->getValue(); - auto knowledge = - ValueKnowledge::getPessimisticValueState(op->getContext()); - knowledge.dtype = operand.dtype; - int64_t dim; - if (operand.hasSizes && - matchPattern(unsqueeze.dim(), m_TorchConstantInt(&dim))) { - int64_t inputRank = operand.sizes.size(); - // Careful, it's easy to be off by one here for negative values. - // The dim value is allowed to be in the range - // `[-inputRank - 1, inputRank]`. - // And negative values have `inputRank + 1` added to them rather - // than the more typical `inputRank`. - if (dim < 0) - dim += inputRank + 1; - if (0 <= dim && dim <= inputRank) { - knowledge.hasSizes = true; - knowledge.sizes = operand.sizes; - knowledge.sizes.insert(knowledge.sizes.begin() + dim, 1); - } - } - return getLatticeElement(op->getResult(0)).join(knowledge); + return visitAtenUnsqueezeOp(unsqueeze, operands); + } else if (auto arange = dyn_cast(op)) { + return visitAtenArangeOp(arange); + } else if (auto arangeStart = dyn_cast(op)) { + return visitAtenArangeStartOp(arangeStart); + } else if (auto sumDimIntList = dyn_cast(op)) { + return visitCalculationAlongDimIntListOp( + sumDimIntList, sumDimIntList.dim(), sumDimIntList.keepdim(), + operands); + } else if (auto meanDim = dyn_cast(op)) { + return visitCalculationAlongDimIntListOp(meanDim, meanDim.dim(), + meanDim.keepdim(), operands); + } else if (auto anyDim = dyn_cast(op)) { + return visitAtenAnyDimOp(anyDim, operands); + } else if (auto view = dyn_cast(op)) { + return visitAtenViewOp(view, operands); + } else if (auto transposeInt = dyn_cast(op)) { + return visitAtenTransposeIntOp(transposeInt, operands); } + // Otherwise, this is an unknown operation. Just mark all results as having // reached a pessimistic fixpoint. return markAllPessimisticFixpoint(op->getResults()); } + +private: + ChangeResult + visitAtenMmOp(AtenMmOp op, + ArrayRef *> operands); + ChangeResult + visitAtenLinearOp(AtenLinearOp op, + ArrayRef *> operands); + ChangeResult + visitAtenConv2dOp(AtenConv2dOp op, + ArrayRef *> operands); + ChangeResult + visitAtenMaxPool2dOp(AtenMaxPool2dOp op, + ArrayRef *> operands); + ChangeResult visitAtenAdaptiveAvgPool2dOp( + AtenAdaptiveAvgPool2dOp op, + ArrayRef *> operands); + ChangeResult visitBinaryBroadcastingOp( + Operation *op, ArrayRef *> operands); + ChangeResult + visitAtenLerpTensorOp(AtenLerpTensorOp op, + ArrayRef *> operands); + ChangeResult visitAtenFlattenUsingIntsOp( + AtenFlattenUsingIntsOp op, + ArrayRef *> operands); + ChangeResult + visitAtenUnsqueezeOp(AtenUnsqueezeOp op, + ArrayRef *> operands); + + ChangeResult visitAtenArangeLikeOpHelper(Operation *op, + llvm::Optional start, + Value end, Value dtype); + ChangeResult visitAtenArangeStartOp(AtenArangeStartOp op); + ChangeResult visitAtenArangeOp(AtenArangeOp op); + ChangeResult visitCalculationAlongDimIntListOp( + Operation *op, Value dim, Value keepdim, + ArrayRef *> operands); + ChangeResult + visitAtenAnyDimOp(AtenAnyDimOp op, + ArrayRef *> operands); + ChangeResult + visitAtenViewOp(AtenViewOp op, + ArrayRef *> operands); + ChangeResult + visitAtenTransposeIntOp(AtenTransposeIntOp op, + ArrayRef *> operands); }; } // namespace +static int64_t toPositiveDim(int64_t dim, int64_t inputRank) { + return dim >= 0 ? dim : dim + inputRank; +} + +static bool isValidDim(int64_t dim, int64_t inputRank) { + return dim >= 0 && dim < inputRank; +} + +ChangeResult TypeAnalyzer::visitAtenMmOp( + AtenMmOp op, ArrayRef *> operands) { + auto &lhs = operands[0]->getValue(); + auto &rhs = operands[1]->getValue(); + auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext()); + knowledge.hasSizes = true; + // WARNING: We could be more precise here by calculating the output + // shape as "(lhs.shape[0], rhs.shape[1])". However, that is really tricky + // at this stage in the compiler because we don't really have many static + // guarantees about the input ranks because `aten` ops do dynamic error + // checking and safely abort the program. There is nothing preventing us + // from (correctly!) statically inferring the shapes of the operands to + // shapes that are guaranteed to cause an error at runtime. + // + // Example: Suppose a user program calls `aten.mm` with two rank-0 + // operands. The program emits an error when invoked, but when running + // this pass, we will (correctly!) infer `lhs.hasSizes && lhs.sizes.size() + // == 0 && rhs.hasSizes && rhs.sizes.size() == 0` -- it's not safe to + // access `lhs.sizes[0]` / `rhs.sizes[1]`! So when writing this transfer + // function, it's not as simple as taking `lhs.sizes[0]` and + // `rhs.sizes[1]`, as both of those might read out of bounds of the array. + // It would require more complicated logic. + // + // Just knowing dtypes and ranks is sufficient at this stage + // in the compiler. The precise per-dimension size propagation is best + // done lower in the stack, such as at the linalg level, where we have + // more static guarantees and more structure. + knowledge.sizes.resize(2, kUnknownSize); + // TODO: Investigate promotion rules if element types mismatch. + // This is conservatively correct, assuming that if both element types are + // the same, then the result is of that same element type. + knowledge.dtype = joinElementTypes(lhs.dtype, rhs.dtype); + return getLatticeElement(op->getResult(0)).join(knowledge); +} + +ChangeResult TypeAnalyzer::visitAtenLinearOp( + AtenLinearOp op, ArrayRef *> operands) { + // The output shape is the input shape with the last dimension changed + // to the weight's output dimension. + auto knowledge = operands[0]->getValue(); + if (knowledge.hasSizes && knowledge.sizes.size() > 0) + knowledge.sizes[knowledge.sizes.size() - 1] = kUnknownSize; + // TODO: Handle case of bias being None gracefully. Requires a lattice + // that tracks "None" (torch.optional). See also + // DerefineOp::getCanonicalizationPatterns for more refinement that needs + // to be done in this pass. + knowledge.dtype = joinElementTypes( + knowledge.dtype, joinElementTypes(operands[1]->getValue().dtype, + operands[2]->getValue().dtype)); + return getLatticeElement(op->getResult(0)).join(knowledge); +} + +ChangeResult TypeAnalyzer::visitAtenConv2dOp( + AtenConv2dOp op, ArrayRef *> operands) { + auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext()); + knowledge.hasSizes = true; + knowledge.sizes.resize(4, kUnknownSize); + // Running some experiments in PyTorch, the bias doesn't seem to + // contribute to the final element type. + knowledge.dtype = joinElementTypes(operands[0]->getValue().dtype, + operands[1]->getValue().dtype); + return getLatticeElement(op->getResult(0)).join(knowledge); +} + +ChangeResult TypeAnalyzer::visitAtenMaxPool2dOp( + AtenMaxPool2dOp op, ArrayRef *> operands) { + auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext()); + knowledge.hasSizes = true; + knowledge.sizes.resize(4, kUnknownSize); + knowledge.dtype = operands[0]->getValue().dtype; + return getLatticeElement(op->getResult(0)).join(knowledge); +} + +ChangeResult TypeAnalyzer::visitAtenAdaptiveAvgPool2dOp( + AtenAdaptiveAvgPool2dOp op, + ArrayRef *> operands) { + auto input = operands[0]->getValue(); + auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext()); + if (input.hasSizes) { + knowledge.hasSizes = true; + knowledge.sizes.resize(input.sizes.size(), kUnknownSize); + } + knowledge.dtype = input.dtype; + return getLatticeElement(op->getResult(0)).join(knowledge); +} + +ChangeResult TypeAnalyzer::visitBinaryBroadcastingOp( + Operation *op, ArrayRef *> operands) { + // This is a general binary broadcasting shape transfer function. + // We currently don't track "size 1" in our lattice, but we might want to. + // We could make this more precise as well. But again, as with the other + // shape transfer functions, handling the statically-invalid case is + // tricky, so we defer that until we need it. + auto lhs = operands[0]->getValue(); + auto rhs = operands[1]->getValue(); + auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext()); + if (lhs.hasSizes && rhs.hasSizes) { + knowledge.hasSizes = true; + knowledge.sizes.resize(std::max(lhs.sizes.size(), rhs.sizes.size()), + kUnknownSize); + } + knowledge.dtype = joinElementTypes(lhs.dtype, rhs.dtype); + return getLatticeElement(op->getResult(0)).join(knowledge); +} + +ChangeResult TypeAnalyzer::visitAtenLerpTensorOp( + AtenLerpTensorOp op, ArrayRef *> operands) { + // This is a general broadcasting shape transfer function. + // We currently don't track "size 1" in our lattice, but we might want to. + // We could make this more precise as well. But again, as with the other + // shape transfer functions, handling the statically-invalid case is + // tricky, so we defer that until we need it. + auto a = operands[0]->getValue(); + auto b = operands[1]->getValue(); + auto c = operands[1]->getValue(); + auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext()); + if (a.hasSizes && b.hasSizes && c.hasSizes) { + knowledge.hasSizes = true; + knowledge.sizes.resize( + std::max(std::max(a.sizes.size(), b.sizes.size()), c.sizes.size()), + kUnknownSize); + } + knowledge.dtype = + joinElementTypes(joinElementTypes(a.dtype, b.dtype), c.dtype); + return getLatticeElement(op->getResult(0)).join(knowledge); +} + +ChangeResult TypeAnalyzer::visitAtenFlattenUsingIntsOp( + AtenFlattenUsingIntsOp op, + ArrayRef *> operands) { + int64_t startDim; + int64_t endDim; + auto operand = operands[0]->getValue(); + auto knowledge = ValueKnowledge::getPessimisticValueState(op.getContext()); + knowledge.dtype = operand.dtype; + if (operand.hasSizes && operand.sizes.size() == 0) { + // Rank 0 is special and flattens to rank 1 with size 1. + knowledge.hasSizes = true; + knowledge.sizes.push_back(1); + } else if (operand.hasSizes && + matchPattern(op.start_dim(), m_TorchConstantInt(&startDim)) && + matchPattern(op.end_dim(), m_TorchConstantInt(&endDim))) { + int64_t inputRank = operand.sizes.size(); + if (startDim < 0) + startDim += inputRank; + if (endDim < 0) + endDim += inputRank; + // Careful: dimension numbers might be out of bounds. + if (0 <= startDim && startDim <= (inputRank - 1) && 0 <= endDim && + endDim <= (inputRank - 1) && startDim <= endDim) { + knowledge.hasSizes = true; + for (auto i = 0; i < startDim; i++) + knowledge.sizes.push_back(operand.sizes[i]); + knowledge.sizes.push_back(kUnknownSize); + for (auto i = endDim + 1; i < inputRank; i++) + knowledge.sizes.push_back(operand.sizes[i]); + } + } + return getLatticeElement(op.getResult()).join(knowledge); +} + +ChangeResult TypeAnalyzer::visitAtenUnsqueezeOp( + AtenUnsqueezeOp op, ArrayRef *> operands) { + auto operand = operands[0]->getValue(); + auto knowledge = ValueKnowledge::getPessimisticValueState(op.getContext()); + knowledge.dtype = operand.dtype; + int64_t dim; + if (operand.hasSizes && matchPattern(op.dim(), m_TorchConstantInt(&dim))) { + int64_t inputRank = operand.sizes.size(); + // Careful, it's easy to be off by one here for negative values. + // The dim value is allowed to be in the range + // `[-inputRank - 1, inputRank]`. + // And negative values have `inputRank + 1` added to them rather + // than the more typical `inputRank`. + if (dim < 0) + dim += inputRank + 1; + if (0 <= dim && dim <= inputRank) { + knowledge.hasSizes = true; + knowledge.sizes = operand.sizes; + knowledge.sizes.insert(knowledge.sizes.begin() + dim, 1); + } + } + return getLatticeElement(op.getResult()).join(knowledge); +} +// Arange like ops returns a 1-D tensor of size ceil(end - start). +ChangeResult TypeAnalyzer::visitAtenArangeLikeOpHelper( + Operation *op, llvm::Optional start, Value end, Value dtype) { + auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext()); + knowledge.sizes.resize(1, kUnknownSize); + knowledge.hasSizes = true; + int64_t dtypeInt; + if (matchPattern(dtype, m_TorchConstantInt(&dtypeInt))) { + knowledge.dtype = getTypeFromDTypeInteger(op->getContext(), dtypeInt); + } else if (dtype.getType().isa()) { + // From torch/_torch_docs.py: + // If `dtype` is not given, infer the data type from the other input + // arguments. If any of `start`, `end`, or `stop` are floating-point, the + // `dtype` is inferred to be the default dtype, see + // `torch.get_default_dtype`. Otherwise, the `dtype` is inferred to + // be `torch.int64` + if ((start.hasValue() && (*start).getType().isa()) || + end.getType().isa()) { + // TODO: Should get the dtype from torch.get_default_dtype(). + // For now, use float32 which is the initial default dtype. + knowledge.dtype = Float32Type::get(op->getContext()); + } else + knowledge.dtype = + IntegerType::get(op->getContext(), 64, IntegerType::Signed); + } + return getLatticeElement(op->getResult(0)).join(knowledge); +} + +ChangeResult TypeAnalyzer::visitAtenArangeStartOp(AtenArangeStartOp op) { + return visitAtenArangeLikeOpHelper(op, op.start(), op.end(), op.dtype()); +} + +ChangeResult TypeAnalyzer::visitAtenArangeOp(AtenArangeOp op) { + return visitAtenArangeLikeOpHelper(op, {}, op.end(), op.dtype()); +} + +// These ops do caculation along the dims given by the integer list and reduce +// each dim to size one. If \p keepdim is false, the dims are squeezed. +ChangeResult TypeAnalyzer::visitCalculationAlongDimIntListOp( + Operation *op, Value dim, Value keepdim, + ArrayRef *> operands) { + auto input = operands[0]->getValue(); + auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext()); + knowledge.dtype = input.dtype; + llvm::SmallVector dimList; + bool keepdimBool; + if (matchPattern(keepdim, m_TorchConstantBool(&keepdimBool))) { + knowledge.hasSizes = true; + int64_t inputRank = input.sizes.size(); + // TODO: This is not safe. Need to check the list users and use aliasing + // info to safely detect the list is not modified. + if (matchPattern(dim, m_TorchConstantIntList(dimList))) { + llvm::for_each( + dimList, [&](int64_t &dim) { dim = toPositiveDim(dim, inputRank); }); + DenseSet dimSet(dimList.begin(), dimList.end()); + for (auto en : llvm::enumerate(input.sizes)) { + if (dimSet.contains(en.index())) { + if (keepdimBool) + knowledge.sizes.push_back(1); + } else { + knowledge.sizes.push_back(en.value()); + } + } + } else if (auto listConstruct = dim.getDefiningOp()) { + auto sizes = listConstruct.elements(); + knowledge.sizes.resize(keepdimBool ? inputRank : inputRank - sizes.size(), + kUnknownSize); + } + } + return getLatticeElement(op->getResult(0)).join(knowledge); +} + +ChangeResult TypeAnalyzer::visitAtenAnyDimOp( + AtenAnyDimOp op, ArrayRef *> operands) { + auto input = operands[0]->getValue(); + auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext()); + knowledge.dtype = input.dtype; + int64_t dim; + bool keepdimBool; + if (matchPattern(op.keepdim(), m_TorchConstantBool(&keepdimBool))) { + int64_t inputRank = input.sizes.size(); + knowledge.hasSizes = true; + if (matchPattern(op.dim(), m_TorchConstantInt(&dim))) { + knowledge.sizes = input.sizes; + dim = toPositiveDim(dim, inputRank); + if (isValidDim(dim, inputRank)) { + if (keepdimBool) + knowledge.sizes[dim] = 1; + else + knowledge.sizes.erase(knowledge.sizes.begin() + dim); + } + } else { + knowledge.sizes.resize(keepdimBool ? inputRank : inputRank - 1, + kUnknownSize); + } + } + return getLatticeElement(op->getResult(0)).join(knowledge); +} + +ChangeResult TypeAnalyzer::visitAtenViewOp( + AtenViewOp op, ArrayRef *> operands) { + auto input = operands[0]->getValue(); + auto knowledge = ValueKnowledge::getPessimisticValueState(op.getContext()); + knowledge.dtype = input.dtype; + + // TODO: This is not safe. Need to check the list users and use aliasing + // info to safely detect the list is not modified. + if (auto listConstruct = op.size().getDefiningOp()) { + knowledge.hasSizes = true; + auto sizes = listConstruct.elements(); + int64_t size; + for (auto sizeValue : sizes) { + if (matchPattern(sizeValue, m_TorchConstantInt(&size))) + knowledge.sizes.push_back(size); + else + knowledge.sizes.push_back(kUnknownSize); + } + } + return getLatticeElement(op.getResult()).join(knowledge); +} + +ChangeResult TypeAnalyzer::visitAtenTransposeIntOp( + AtenTransposeIntOp op, + ArrayRef *> operands) { + auto input = operands[0]->getValue(); + auto knowledge = ValueKnowledge::getPessimisticValueState(op.getContext()); + knowledge.dtype = input.dtype; + knowledge.hasSizes = input.hasSizes; + auto dim0 = op.dim0(); + auto dim1 = op.dim1(); + int64_t dim0Int; + int64_t dim1Int; + if (matchPattern(dim0, m_TorchConstantInt(&dim0Int)) && + matchPattern(dim1, m_TorchConstantInt(&dim1Int))) { + knowledge.sizes = input.sizes; + int64_t inputRank = input.sizes.size(); + dim0Int = toPositiveDim(dim0Int, inputRank); + dim1Int = toPositiveDim(dim1Int, inputRank); + if (isValidDim(dim0Int, inputRank) && isValidDim(dim1Int, inputRank)) { + std::swap(knowledge.sizes[dim0Int], knowledge.sizes[dim1Int]); + return getLatticeElement(op.getResult()).join(knowledge); + } + } + + knowledge.sizes.resize(input.sizes.size(), kUnknownSize); + return getLatticeElement(op.getResult()).join(knowledge); +} + // ----------------------------------------------------------------------------- // Transforms. // ----------------------------------------------------------------------------- @@ -327,16 +627,36 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis { // Get a the most refined type compatible with ValueKnowledge, or null if that // is not possible. static Type getMostRefinedStaticType(Value v, TypeAnalyzer &analyzer) { + auto getRefinedTensorType = [](BaseTensorType tensorType, + ValueKnowledge const &knowledge) { + return tensorType.getWithSizesAndDtype( + knowledge.hasSizes ? llvm::makeArrayRef(knowledge.sizes) + : Optional>(), + knowledge.dtype); + }; + if (auto tensorType = v.getType().dyn_cast()) { LatticeElement *latticeElement = analyzer.lookupLatticeElement(v); if (!latticeElement) return nullptr; const ValueKnowledge &knowledge = latticeElement->getValue(); - return tensorType.getWithSizesAndDtype( - knowledge.hasSizes ? llvm::makeArrayRef(knowledge.sizes) - : Optional>(), - knowledge.dtype); + return getRefinedTensorType(tensorType, knowledge); + } else if (auto optionalType = v.getType().dyn_cast()) { + LatticeElement *latticeElement = + analyzer.lookupLatticeElement(v); + if (!latticeElement) + return nullptr; + const ValueKnowledge &knowledge = latticeElement->getValue(); + if (knowledge.optional == ValueKnowledge::OptionalKnowledge::isNone) + return Torch::NoneType::get(v.getContext()); + else if (knowledge.optional == ValueKnowledge::OptionalKnowledge::notNone) { + auto containedType = optionalType.getContainedType(); + if (auto tensorType = containedType.dyn_cast()) + return getRefinedTensorType(tensorType, knowledge); + else + return containedType; + } } return nullptr; } @@ -351,74 +671,123 @@ static Type getMostRefinedStaticType(Value v, TypeAnalyzer &analyzer) { // latter case, since their operand and result types must have the same shape // and dtype -- we know that our transfer functions and updating logic will do // the right thing forthose ops. -static bool allowsTypeRefinementOrWillBeOtherwiseSafelyRefined(Operation *op) { +// +static bool allowsTypeRefinementOrIsSafeToRefine(Operation *op) { return allowsTypeRefinement(op) || isa(op); } +// Some operations have extra verification logic regarding the relationship +// between the input types and output types. Adding more refined type info to +// the operand might change a valid instruction to be invalid. +static bool operationIsValidWithRefinedType(OpOperand *use, Type newType) { + Operation *op = use->getOwner(); + if (auto uncheckedCast = llvm::dyn_cast(op)) + return uncheckedCast.areCastCompatible(newType, uncheckedCast.getType()); + return true; +} + +static bool isSafeToRefineOperandInPlace(OpOperand *use, Type newOperandType) { + Operation *op = use->getOwner(); + if (!allowsTypeRefinementOrIsSafeToRefine(op)) + return false; + return operationIsValidWithRefinedType(use, newOperandType); +} + void optimize(FuncOp func, TypeAnalyzer &analyzer) { func.walk([&](Operation *op) { - for (Value v : op->getResults()) { - Type refinedType = getMostRefinedStaticType(v, analyzer); - Type originalType = v.getType(); - // No type? Nothing to do. - if (!refinedType) - continue; - // Type is same as existing one? Nothing to do. - if (refinedType == originalType) - continue; - // If we have an op that allows adding/removing static information from - // this type, then we can rewrite. We make sure to always embed the static - // information in the IR, and insert the minimal number of casts needed to - // do so. - // TODO: For some types, we will need 2 ops here: one to add static - // information, and the other to remove static information. - // (for example, torch.unchecked_cast / torch.derefine for torch.optional - // types). - std::function createStaticInfoCast; - OpBuilder b(op->getBlock(), std::next(op->getIterator())); - if (originalType.isa()) { - createStaticInfoCast = [&](Location loc, Type newType, - Value v) -> Value { - return b.create(loc, newType, v); - }; - } - if (createStaticInfoCast) { - // Save off the original uses to avoid iterator invalidation issues - // or other unexpected behavior since we are creating new ops here that - // use the value. - auto originalUses = llvm::to_vector<6>( - llvm::map_range(v.getUses(), [](OpOperand &use) { return &use; })); - OpBuilder b(op->getBlock(), std::next(op->getIterator())); - Value newTypedValue; - // Always make sure that the new static information is reflected in the - // IR, either by updating the type in place, or inserting a static info - // cast. - if (allowsTypeRefinementOrWillBeOtherwiseSafelyRefined(op)) { - newTypedValue = v; - v.setType(refinedType); - } else { - newTypedValue = createStaticInfoCast(op->getLoc(), refinedType, v); + auto convertValuesToMostRefinedType = [&](ValueRange values, OpBuilder &b) { + for (Value v : values) { + Type refinedType = getMostRefinedStaticType(v, analyzer); + Type originalType = v.getType(); + // No type? Nothing to do. + if (!refinedType) + continue; + // Type is same as existing one? Nothing to do. + if (refinedType == originalType) + continue; + // If we have an op that allows adding/removing static information from + // this type, then we can rewrite. We make sure to always embed the + // static information in the IR, and insert the minimal number of casts + // needed to do so. + using CreateStaticInfoCastFn = + std::function; + CreateStaticInfoCastFn createStaticInfoDownCast; + CreateStaticInfoCastFn createStaticInfoUpCast; + if (originalType.isa()) { + createStaticInfoDownCast = [&](Location loc, Type newType, + Value v) -> Value { + return b.create(loc, newType, v); + }; + createStaticInfoUpCast = createStaticInfoDownCast; + } else if (originalType.isa()) { + createStaticInfoDownCast = [&](Location loc, Type newType, + Value v) -> Value { + return b.create(loc, newType, v); + }; + createStaticInfoUpCast = [&](Location loc, Type newType, + Value v) -> Value { + return b.create(loc, newType, v); + }; } - Value oldTypedValue; - for (OpOperand *use : originalUses) { - // If the use can be updated to the new type directly, do it! - if (allowsTypeRefinementOrWillBeOtherwiseSafelyRefined( - use->getOwner())) { - use->set(newTypedValue); - continue; + if (createStaticInfoUpCast) { + assert(createStaticInfoDownCast && + "createStaticInfoDownCast and createStaticInfoUpCast must be " + "defined in pairs"); + // Save off the original uses to avoid iterator invalidation issues + // or other unexpected behavior since we are creating new ops here + // that use the value. + auto originalUses = llvm::to_vector<6>(llvm::map_range( + v.getUses(), [](OpOperand &use) { return &use; })); + OpBuilder b(op->getBlock(), std::next(op->getIterator())); + Value newTypedValue; + // Always make sure that the new static information is reflected in + // the IR, either by updating the type in place, or inserting a static + // info cast. + if (allowsTypeRefinementOrIsSafeToRefine(op)) { + newTypedValue = v; + v.setType(refinedType); + } else { + if (auto derefineOp = llvm::dyn_cast(op)) { + newTypedValue = derefineOp.operand(); + } else { + newTypedValue = + createStaticInfoDownCast(op->getLoc(), refinedType, v); + } } - // If needed, create a value of the original type to appease users - // that cannot accept the new type. - if (!oldTypedValue) { - oldTypedValue = - createStaticInfoCast(op->getLoc(), originalType, newTypedValue); + + Value oldTypedValue; + for (OpOperand *use : originalUses) { + // If the use can be updated to the new type directly, do it! + if (isSafeToRefineOperandInPlace(use, refinedType)) { + use->set(newTypedValue); + continue; + } + // If needed, create a value of the original type to appease users + // that cannot accept the new type. + if (!oldTypedValue) { + if (auto derefineOp = llvm::dyn_cast(op)) { + oldTypedValue = derefineOp.result(); + } else { + oldTypedValue = createStaticInfoUpCast( + op->getLoc(), originalType, newTypedValue); + } + } + use->set(oldTypedValue); } - use->set(oldTypedValue); } } + }; + + if (auto branch = dyn_cast(op)) { + for (auto ®ion : branch->getRegions()) { + OpBuilder b(region); + convertValuesToMostRefinedType(region.front().getArguments(), b); + } } + OpBuilder b(op->getBlock(), std::next(op->getIterator())); + convertValuesToMostRefinedType(op->getResults(), b); }); } diff --git a/test/Dialect/Torch/refine-types-branch.mlir b/test/Dialect/Torch/refine-types-branch.mlir new file mode 100644 index 000000000000..962a946ce27f --- /dev/null +++ b/test/Dialect/Torch/refine-types-branch.mlir @@ -0,0 +1,91 @@ +// RUN: npcomp-opt -torch-refine-types -split-input-file %s | FileCheck %s + +// ----- + +// CHECK-LABEL: builtin.func @prim.if$branch_merge_type_tensor( +// CHECK-SAME: %[[PRED:.*]]: !torch.bool, +// CHECK-SAME: %[[T1:.*]]: !torch.tensor, +// CHECK-SAME: %[[T2:.*]]: !torch.tensor) -> !torch.bool { +// CHECK: %[[MERGED:.*]] = torch.prim.If %[[PRED]] -> (!torch.optional) { +// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[T1]] : !torch.tensor to !torch.optional +// CHECK: torch.prim.If.yield %[[OPTIONAL]] : !torch.optional +// CHECK: } else { +// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[T2]] : !torch.tensor to !torch.optional +// CHECK: torch.prim.If.yield %[[OPTIONAL]] : !torch.optional +// CHECK: } +// CHECK: %[[REFINED:.*]] = torch.prim.unchecked_cast %[[MERGED:.*]] : !torch.optional -> !torch.tensor +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[RET:.*]] = torch.aten.__isnot__ %[[REFINED]], %[[NONE]] : !torch.tensor, !torch.none -> !torch.bool +// CHECK: return %[[RET]] : !torch.bool + +func @prim.if$branch_merge_type_tensor(%pred: !torch.bool, %t0: !torch.tensor, %t1: !torch.tensor) -> !torch.bool { + %res = torch.prim.If %pred -> (!torch.optional) { + %optional0 = torch.derefine %t0: !torch.tensor to !torch.optional + torch.prim.If.yield %optional0: !torch.optional + } else { + %optional1 = torch.derefine %t1: !torch.tensor to !torch.optional + torch.prim.If.yield %optional1: !torch.optional + } + %none = torch.constant.none + %cmp = torch.aten.__isnot__ %res, %none : !torch.optional, !torch.none -> !torch.bool + return %cmp : !torch.bool +} + +// ----- + +// CHECK-LABEL: builtin.func @prim.if$branch_merge_type_optional( +// CHECK-SAME: %[[PRED:.*]]: !torch.bool, +// CHECK-SAME: %[[T:.*]]: !torch.tensor) -> !torch.optional { +// CHECK: %[[MERGED:.*]] = torch.prim.If %[[PRED]] -> (!torch.optional) { +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional +// CHECK: torch.prim.If.yield %[[OPTIONAL]] : !torch.optional +// CHECK: } else { +// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[T]] : !torch.tensor to !torch.optional +// CHECK: torch.prim.If.yield %[[OPTIONAL]] : !torch.optional +// CHECK: } +// CHECK: return %[[MERGED:.*]] : !torch.optional + +func @prim.if$branch_merge_type_optional(%pred: !torch.bool, %t1: !torch.tensor) -> !torch.optional { + %res = torch.prim.If %pred -> (!torch.optional) { + %none = torch.constant.none + %optional0 = torch.derefine %none: !torch.none to !torch.optional + torch.prim.If.yield %optional0: !torch.optional + } else { + %optional1 = torch.derefine %t1: !torch.tensor to !torch.optional + torch.prim.If.yield %optional1: !torch.optional + } + return %res: !torch.optional +} + +// ----- + +// CHECK-LABEL: builtin.func @prim.loop$region_arg_to_internal( +// CHECK-SAME: %[[ARG_NONE:.*]]: !torch.none) -> !torch.optional { +// CHECK: %[[INT10:.*]] = torch.constant.int 10 +// CHECK: %[[INDV:.*]] = torch.constant.int 0 +// CHECK: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[ARG_NONE]] : !torch.none to !torch.optional +// CHECK: %[[LOOP_RET:.*]] = torch.prim.Loop %[[INT10]], %[[TRUE]], init(%[[OPTIONAL]]) { +// CHECK: ^bb0(%[[INDV:.*]]: !torch.int, %[[IT:.*]]: !torch.optional): +// CHECK: %[[NONE:.*]] = torch.prim.unchecked_cast %[[IT]] : !torch.optional -> !torch.none +// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional +// CHECK: %[[COND:.*]] = torch.aten.__isnot__ %[[NONE]], %[[ARG_NONE]] : !torch.none, !torch.none -> !torch.bool +// CHECK: torch.prim.Loop.condition %[[COND]], iter(%[[OPTIONAL]] : !torch.optional) +// CHECK: } : (!torch.int, !torch.bool, !torch.optional) -> !torch.optional +// CHECK: %[[NONE:.*]] = torch.prim.unchecked_cast %[[LOOP_RET:.*]] : !torch.optional -> !torch.none +// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional +// CHECK: return %[[OPTIONAL]] : !torch.optional + +func @prim.loop$region_arg_to_internal(%none: !torch.none) -> !torch.optional { + %int10 = torch.constant.int 10 + %int0 = torch.constant.int 0 + %true = torch.constant.bool true + %optional = torch.derefine %none: !torch.none to !torch.optional + %ret = torch.prim.Loop %int10, %true, init(%optional) { + ^bb0(%arg2: !torch.int, %arg3: !torch.optional): // no predecessors + %cond = torch.aten.__isnot__ %arg3, %none : !torch.optional, !torch.none -> !torch.bool + torch.prim.Loop.condition %cond, iter(%arg3: !torch.optional) + } : (!torch.int, !torch.bool, !torch.optional) -> (!torch.optional) + return %ret: !torch.optional +} diff --git a/test/Dialect/Torch/refine-types.mlir b/test/Dialect/Torch/refine-types.mlir index 10f3d2628b49..39213bccc6f1 100644 --- a/test/Dialect/Torch/refine-types.mlir +++ b/test/Dialect/Torch/refine-types.mlir @@ -292,3 +292,297 @@ func @f() { } return } + +// ----- + +// CHECK-LABEL: builtin.func @f( +// CHECK-SAME: %[[TENSOR:.*]]: !torch.tensor) -> !torch.bool { +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[TENSOR]] : !torch.tensor to !torch.optional +// CHECK: %[[RET:.*]] = torch.aten.__isnot__ %[[TENSOR]], %[[NONE]] : !torch.tensor, !torch.none -> !torch.bool +// CHECK: return %[[RET]] : !torch.bool + +func @f(%arg : !torch.tensor) -> !torch.bool { + %none = torch.constant.none + %optional = "torch.derefine"(%arg) : (!torch.tensor) -> !torch.optional + %ret = "torch.aten.__isnot__"(%optional, %none) : (!torch.optional, !torch.none) -> !torch.bool + return %ret: !torch.bool +} + +// ----- + +// CHECK-LABEL: builtin.func @aten.arange.start$int64_dtype( +// CHECK-SAME: %[[START:.*]]: !torch.int, +// CHECK-SAME: %[[END:.*]]: !torch.int) -> !torch.vtensor { +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[T:.*]] = torch.aten.arange.start +// CHECK-SAME: %[[START]], %[[END]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : +// CHECK-SAME: !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none +// CHECK-SAME: -> !torch.vtensor<[?],si64> +// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[T]] : !torch.vtensor<[?],si64> to !torch.vtensor +// CHECK: return %[[RET]] : !torch.vtensor + +func @aten.arange.start$int64_dtype(%start: !torch.int, %end: !torch.int) -> !torch.vtensor { + %none = torch.constant.none + %ret = torch.aten.arange.start %start, %end, %none, %none, %none, %none: !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor + return %ret : !torch.vtensor +} + +// ----- + +// CHECK-LABEL: builtin.func @aten.arange.start$float32_dtype( +// CHECK-SAME: %[[START:.*]]: !torch.float, +// CHECK-SAME: %[[END:.*]]: !torch.int) -> !torch.vtensor { +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[T:.*]] = torch.aten.arange.start +// CHECK-SAME: %[[START]], %[[END]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : +// CHECK-SAME: !torch.float, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none +// CHECK-SAME: -> !torch.vtensor<[?],f32> +// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[T]] : !torch.vtensor<[?],f32> to !torch.vtensor +// CHECK: return %[[RET]] : !torch.vtensor + +func @aten.arange.start$float32_dtype(%start: !torch.float, %end: !torch.int) -> !torch.vtensor { + %none = torch.constant.none + %ret = torch.aten.arange.start %start, %end, %none, %none, %none, %none: !torch.float, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor + return %ret : !torch.vtensor +} + +// ----- + +// CHECK-LABEL: builtin.func @aten.arange.start$specified_dtype( +// CHECK-SAME: %[[END:.*]]: !torch.int) -> !torch.vtensor { +// CHECK: %[[CST6:.*]] = torch.constant.int 6 +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[T:.*]] = torch.aten.arange +// CHECK-SAME: %[[END]], %[[CST6]], %[[NONE]], %[[NONE]], %[[NONE]] : +// CHECK-SAME: !torch.int, !torch.int, !torch.none, !torch.none, !torch.none +// CHECK-SAME: -> !torch.vtensor<[?],f32> +// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[T]] : !torch.vtensor<[?],f32> to !torch.vtensor +// CHECK: return %[[RET]] : !torch.vtensor + +func @aten.arange.start$specified_dtype(%end: !torch.int) -> !torch.vtensor { + %int6 = torch.constant.int 6 + %none = torch.constant.none + %ret = torch.aten.arange %end, %int6, %none, %none, %none: !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor + return %ret : !torch.vtensor +} + +// ----- + +// CHECK-LABEL: builtin.func @aten.sum.dim_IntList( +// CHECK-SAME: %[[T:.*]]: !torch.vtensor<[2,3,?],si64>) -> !torch.vtensor { +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[INT_NEG1:.*]] = torch.constant.int -1 +// CHECK: %[[DIMLIST:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT_NEG1]] +// CHECK-SAME: : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[RET:.*]] = torch.aten.sum.dim_IntList %[[T]], %[[DIMLIST]], %[[FALSE]], %[[NONE]] +// CHECK-SAME: : !torch.vtensor<[2,3,?],si64>, !torch.list, !torch.bool, !torch.none +// CHECK-SAME: -> !torch.vtensor<[3],si64> +// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.vtensor<[3],si64> to !torch.vtensor +// CHECK: return %[[CAST]] : !torch.vtensor + +func @aten.sum.dim_IntList(%t: !torch.vtensor<[2,3,?],si64>) -> !torch.vtensor { + %false = torch.constant.bool false + %none = torch.constant.none + %int0 = torch.constant.int 0 + %int-1 = torch.constant.int -1 + %dimList = torch.prim.ListConstruct %int0, %int-1 : (!torch.int, !torch.int) -> !torch.list + %ret = torch.aten.sum.dim_IntList %t, %dimList, %false, %none : !torch.vtensor<[2,3,?],si64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor + return %ret : !torch.vtensor +} + +// ----- + +// CHECK-LABEL: builtin.func @aten.sum.dim_IntList$keepdim( +// CHECK-SAME: %[[T:.*]]: !torch.vtensor<[2,3,?],si64>) -> !torch.vtensor { +// CHECK: %[[KEEPDIM:.*]] = torch.constant.bool true +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[INT_NEG1:.*]] = torch.constant.int -1 +// CHECK: %[[DIMLIST:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT_NEG1]] : +// CHECK-SAME: (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[RET:.*]] = torch.aten.sum.dim_IntList +// CHECK-SAME: %[[T]], %[[DIMLIST]], %[[KEEPDIM]], %[[NONE]] +// CHECK-SAME: : !torch.vtensor<[2,3,?],si64>, !torch.list, !torch.bool, !torch.none +// CHECK-SAME: -> !torch.vtensor<[1,3,1],si64> +// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : +// CHECK-SAME: !torch.vtensor<[1,3,1],si64> to !torch.vtensor +// CHECK: return %[[CAST]] : !torch.vtensor + +func @aten.sum.dim_IntList$keepdim(%t: !torch.vtensor<[2,3,?],si64>) -> !torch.vtensor { + %true = torch.constant.bool true + %none = torch.constant.none + %int0 = torch.constant.int 0 + %int-1 = torch.constant.int -1 + %dimList = torch.prim.ListConstruct %int0, %int-1 : (!torch.int, !torch.int) -> !torch.list + %ret = torch.aten.sum.dim_IntList %t, %dimList, %true, %none : !torch.vtensor<[2,3,?],si64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor + return %ret : !torch.vtensor +} + +// ----- +// CHECK-LABEL: builtin.func @aten.sum.dim_IntList$unknown_position( +// CHECK-SAME: %[[T:.*]]: !torch.vtensor<[2,3,?],si64>, +// CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.vtensor { +// CHECK: %[[KEEPDIM:.*]] = torch.constant.bool false +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[INT_NEG1:.*]] = torch.constant.int -1 +// CHECK: %[[DIMLIST:.*]] = torch.prim.ListConstruct %[[DIM]], %[[INT_NEG1]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[RET:.*]] = torch.aten.sum.dim_IntList %[[T]], %[[DIMLIST]], %[[KEEPDIM]], %[[NONE]] : !torch.vtensor<[2,3,?],si64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[?],si64> +// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.vtensor<[?],si64> to !torch.vtensor +// CHECK: return %[[CAST]] : !torch.vtensor + +func @aten.sum.dim_IntList$unknown_position(%t: !torch.vtensor<[2,3,?],si64>, %dim0: !torch.int) -> !torch.vtensor { + %false = torch.constant.bool false + %none = torch.constant.none + %int-1 = torch.constant.int -1 + %dimList = torch.prim.ListConstruct %dim0, %int-1 : (!torch.int, !torch.int) -> !torch.list + %ret = torch.aten.sum.dim_IntList %t, %dimList, %false, %none : !torch.vtensor<[2,3,?],si64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor + return %ret : !torch.vtensor +} + +// ----- + +// CHECK-LABEL: builtin.func @aten.any.dim( +// CHECK-SAME: %[[T:.*]]: !torch.vtensor<[2,3,?],i1>) -> !torch.vtensor { +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[INT_NEG1:.*]] = torch.constant.int -1 +// CHECK: %[[RET:.*]] = torch.aten.any.dim %[[T]], %[[INT_NEG1]], %[[FALSE]] : !torch.vtensor<[2,3,?],i1>, !torch.int, !torch.bool -> !torch.vtensor<[2,3],i1> +// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.vtensor<[2,3],i1> to !torch.vtensor +// CHECK: return %[[CAST]] : !torch.vtensor + +func @aten.any.dim(%t: !torch.vtensor<[2,3,?],i1>) -> !torch.vtensor { + %false = torch.constant.bool false + %int-1 = torch.constant.int -1 + %ret = torch.aten.any.dim %t, %int-1, %false : !torch.vtensor<[2,3,?],i1>, !torch.int, !torch.bool -> !torch.vtensor + return %ret : !torch.vtensor +} + +// ----- + +// CHECK-LABEL: builtin.func @aten.any.dim$keepdim( +// CHECK-SAME: %[[T:.*]]: !torch.vtensor<[2,3,?],i1>) -> !torch.vtensor { +// CHECK: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: %[[INT_NEG1:.*]] = torch.constant.int -1 +// CHECK: %[[RET:.*]] = torch.aten.any.dim %[[T]], %[[INT_NEG1]], %[[TRUE]] : !torch.vtensor<[2,3,?],i1>, !torch.int, !torch.bool -> !torch.vtensor<[2,3,1],i1> +// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.vtensor<[2,3,1],i1> to !torch.vtensor +// CHECK: return %[[CAST]] : !torch.vtensor + +func @aten.any.dim$keepdim(%t: !torch.vtensor<[2,3,?],i1>) -> !torch.vtensor { + %true = torch.constant.bool true + %int-1 = torch.constant.int -1 + %ret = torch.aten.any.dim %t, %int-1, %true : !torch.vtensor<[2,3,?],i1>, !torch.int, !torch.bool -> !torch.vtensor + return %ret : !torch.vtensor +} + +// ----- + +// CHECK-LABEL: builtin.func @aten.any.dim$unknown_position( +// CHECK-SAME: %[[T:.*]]: !torch.vtensor<[2,3,?],i1>, +// CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.vtensor { +// CHECK: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: %[[RET:.*]] = torch.aten.any.dim %[[T]], %[[DIM]], %[[TRUE]] : !torch.vtensor<[2,3,?],i1>, !torch.int, !torch.bool -> !torch.vtensor<[?,?,?],i1> +// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.vtensor<[?,?,?],i1> to !torch.vtensor +// CHECK: return %[[CAST]] : !torch.vtensor + +func @aten.any.dim$unknown_position(%t: !torch.vtensor<[2,3,?],i1>, %dim: !torch.int) -> !torch.vtensor { + %true = torch.constant.bool true + %ret = torch.aten.any.dim %t, %dim, %true : !torch.vtensor<[2,3,?],i1>, !torch.int, !torch.bool -> !torch.vtensor + return %ret : !torch.vtensor +} + +// ----- + +// CHECK-LABEL: builtin.func @aten.any( +// CHECK-SAME: %[[T:.*]]: !torch.vtensor<[2,3,?],i1>) -> !torch.vtensor { +// CHECK: %[[RET:.*]] = torch.aten.any %[[T]] : !torch.vtensor<[2,3,?],i1> -> !torch.vtensor<[1],i1> +// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.vtensor<[1],i1> to !torch.vtensor +// CHECK: return %[[CAST]] : !torch.vtensor + +func @aten.any(%t: !torch.vtensor<[2,3,?],i1>) -> !torch.vtensor { + %ret = torch.aten.any %t: !torch.vtensor<[2,3,?],i1> -> !torch.vtensor + return %ret : !torch.vtensor +} + +// ----- + +// CHECK-LABEL: builtin.func @aten.transpose.int( +// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3,4,5],si64>) -> !torch.tensor { +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[INT_NEG1:.*]] = torch.constant.int -1 +// CHECK: %[[RET:.*]] = torch.aten.transpose.int %[[T]], %[[INT1]], %[[INT_NEG1]] : !torch.tensor<[2,3,4,5],si64>, !torch.int, !torch.int -> !torch.tensor<[2,5,4,3],si64> +// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<[2,5,4,3],si64> to !torch.tensor +// CHECK: return %[[CAST]] : !torch.tensor + +func @aten.transpose.int(%t: !torch.tensor<[2,3,4,5],si64>) -> !torch.tensor { + %int1 = torch.constant.int 1 + %int-1 = torch.constant.int -1 + %ret = torch.aten.transpose.int %t, %int1, %int-1 : !torch.tensor<[2,3,4,5],si64>, !torch.int, !torch.int -> !torch.tensor + return %ret: !torch.tensor +} + +// ----- + +// CHECK-LABEL: builtin.func @aten.transpose.int$unknown_position( +// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3,4,5],si64>, +// CHECK-SAME: %[[DIM0:.*]]: !torch.int) -> !torch.tensor { +// CHECK: %[[INT_NEG1:.*]] = torch.constant.int -1 +// CHECK: %[[RET:.*]] = torch.aten.transpose.int %[[T]], %[[DIM0]], %[[INT_NEG1]] : !torch.tensor<[2,3,4,5],si64>, !torch.int, !torch.int -> !torch.tensor<[?,?,?,?],si64> +// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<[?,?,?,?],si64> to !torch.tensor +// CHECK: return %[[CAST]] : !torch.tensor + +func @aten.transpose.int$unknown_position(%t: !torch.tensor<[2,3,4,5],si64>, %dim0: !torch.int) -> !torch.tensor { + %int-1 = torch.constant.int -1 + %ret = torch.aten.transpose.int %t, %dim0, %int-1 : !torch.tensor<[2,3,4,5],si64>, !torch.int, !torch.int -> !torch.tensor + return %ret: !torch.tensor +} + +// ----- + +// CHECK-LABEL: builtin.func @aten.view( +// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3,4,5],si64>) -> !torch.tensor { +// CHECK: %[[INT2:.*]] = torch.constant.int 2 +// CHECK: %[[INT_NEG1:.*]] = torch.constant.int -1 +// CHECK: %[[SIZES:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT_NEG1]] +// CHECK-SAME: : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[RET:.*]] = torch.aten.view %[[T]], %[[SIZES]] : +// CHECK-SAME: !torch.tensor<[2,3,4,5],si64>, !torch.list -> !torch.tensor<[2,?],si64> +// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<[2,?],si64> to !torch.tensor +// CHECK: return %[[CAST]] : !torch.tensor + +func @aten.view(%t: !torch.tensor<[2,3,4,5],si64>) -> !torch.tensor { + %int2 = torch.constant.int 2 + %int-1 = torch.constant.int -1 + %sizes = torch.prim.ListConstruct %int2, %int-1 : (!torch.int, !torch.int) -> !torch.list + %ret = torch.aten.view %t, %sizes: !torch.tensor<[2,3,4,5],si64>, !torch.list -> !torch.tensor + return %ret: !torch.tensor +} + +// ----- + +// CHECK-LABEL: builtin.func @prim.if$refined_type_conflicting( +// CHECK-SAME: %[[NONE:.*]]: !torch.none) -> !torch.tensor { +// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional +// CHECK: %[[NOT_NONE:.*]] = torch.aten.__isnot__ %[[NONE]], %[[NONE]] : !torch.none, !torch.none -> !torch.bool +// CHECK: %[[PRED:.*]] = torch.prim.If %[[NOT_NONE]] -> (!torch.tensor) { +// CHECK: %[[T:.*]] = torch.prim.unchecked_cast %[[OPTIONAL]] : !torch.optional -> !torch.tensor +// CHECK: torch.prim.If.yield %[[T]] : !torch.tensor +// CHECK: } else { +// CHECK: %[[LITERAL:.*]] = torch.tensor.literal(dense<0.000000e+00> : tensor<3x5xf32>) : !torch.tensor +// CHECK: torch.prim.If.yield %[[LITERAL]] : !torch.tensor +// CHECK: } +// CHECK: return %[[PRED:.*]] : !torch.tensor + +func @prim.if$refined_type_conflicting(%none: !torch.none) -> !torch.tensor { + %optional = torch.derefine %none: !torch.none to !torch.optional + %pred = torch.aten.__isnot__ %optional, %none : !torch.optional, !torch.none -> !torch.bool + %res = torch.prim.If %pred -> (!torch.tensor) { + %t = torch.prim.unchecked_cast %optional: !torch.optional -> !torch.tensor + torch.prim.If.yield %t: !torch.tensor + } else { + %t_cst = torch.tensor.literal(dense<0.0> : tensor<3x5xf32>) : !torch.tensor + torch.prim.If.yield %t_cst: !torch.tensor + } + return %res: !torch.tensor +}