@@ -92,6 +92,14 @@ void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
92
92
// BroadcastOp
93
93
// ===----------------------------------------------------------------------===//
94
94
95
+ LogicalResult BroadcastOp::inferReturnTypes (
96
+ MLIRContext *context, Optional<Location> location, ValueRange operands,
97
+ ArrayRef<NamedAttribute> attributes, RegionRange regions,
98
+ SmallVectorImpl<Type> &inferredReturnTypes) {
99
+ inferredReturnTypes.push_back (ShapeType::get (context));
100
+ return success ();
101
+ }
102
+
95
103
OpFoldResult BroadcastOp::fold (ArrayRef<Attribute> operands) {
96
104
if (!operands[0 ] || !operands[1 ])
97
105
return nullptr ;
@@ -175,6 +183,14 @@ LogicalResult ConstSizeOp::inferReturnTypes(
175
183
// ShapeOfOp
176
184
// ===----------------------------------------------------------------------===//
177
185
186
+ LogicalResult ShapeOfOp::inferReturnTypes (
187
+ MLIRContext *context, Optional<Location> location, ValueRange operands,
188
+ ArrayRef<NamedAttribute> attributes, RegionRange regions,
189
+ SmallVectorImpl<Type> &inferredReturnTypes) {
190
+ inferredReturnTypes.push_back (ShapeType::get (context));
191
+ return success ();
192
+ }
193
+
178
194
OpFoldResult ShapeOfOp::fold (ArrayRef<Attribute>) {
179
195
auto type = getOperand ().getType ().dyn_cast <ShapedType>();
180
196
if (!type || !type.hasStaticShape ())
0 commit comments