Skip to content

Commit 57a7cd7

Browse files
committed
[shape] Add inferReturnTypes to a couple ops.
- ShapeOfOp - BroadcastOp Differential Revision: https://reviews.llvm.org/D78822
1 parent 5fff169 commit 57a7cd7

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,8 @@ def Shape_AddOp : Shape_Op<"add", [SameOperandsAndResultType]> {
130130
let results = (outs Shape_SizeType:$result);
131131
}
132132

133-
def Shape_BroadcastOp : Shape_Op<"broadcast", []> {
133+
def Shape_BroadcastOp : Shape_Op<"broadcast",
134+
[DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
134135
let summary = "Returns the broadcasted output shape of two inputs";
135136
let description = [{
136137
Computes the broadcasted output shape following:
@@ -317,7 +318,8 @@ def Shape_ReduceOp : Shape_Op<"reduce", []> {
317318
let regions = (region SizedRegion<1>:$body);
318319
}
319320

320-
def Shape_ShapeOfOp : Shape_Op<"shape_of", []> {
321+
def Shape_ShapeOfOp : Shape_Op<"shape_of",
322+
[DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
321323
let summary = "Returns shape of a value or shaped type operand";
322324

323325
let arguments = (ins AnyTypeOf<[AnyShaped, Shape_ValueShapeType]>:$arg);

mlir/lib/Dialect/Shape/IR/Shape.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,14 @@ void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
9292
// BroadcastOp
9393
//===----------------------------------------------------------------------===//
9494

95+
LogicalResult BroadcastOp::inferReturnTypes(
96+
MLIRContext *context, Optional<Location> location, ValueRange operands,
97+
ArrayRef<NamedAttribute> attributes, RegionRange regions,
98+
SmallVectorImpl<Type> &inferredReturnTypes) {
99+
inferredReturnTypes.push_back(ShapeType::get(context));
100+
return success();
101+
}
102+
95103
OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
96104
if (!operands[0] || !operands[1])
97105
return nullptr;
@@ -175,6 +183,14 @@ LogicalResult ConstSizeOp::inferReturnTypes(
175183
// ShapeOfOp
176184
//===----------------------------------------------------------------------===//
177185

186+
LogicalResult ShapeOfOp::inferReturnTypes(
187+
MLIRContext *context, Optional<Location> location, ValueRange operands,
188+
ArrayRef<NamedAttribute> attributes, RegionRange regions,
189+
SmallVectorImpl<Type> &inferredReturnTypes) {
190+
inferredReturnTypes.push_back(ShapeType::get(context));
191+
return success();
192+
}
193+
178194
OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
179195
auto type = getOperand().getType().dyn_cast<ShapedType>();
180196
if (!type || !type.hasStaticShape())

0 commit comments

Comments
 (0)