@@ -141,46 +141,24 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
141
141
}
142
142
143
143
void CreateNdDescOp::build (OpBuilder &builder, OperationState &state,
144
- Type tdesc, TypedValue<MemRefType> source,
144
+ Type tdesc, Value source,
145
145
llvm::ArrayRef<OpFoldResult> offsets,
146
146
llvm::ArrayRef<OpFoldResult> shape,
147
147
llvm::ArrayRef<OpFoldResult> strides) {
148
148
assert (shape.size () && offsets.size () && strides.size () &&
149
149
shape.size () == strides.size () && shape.size () == offsets.size ());
150
150
151
- llvm::SmallVector<int64_t > staticOffsets;
152
- llvm::SmallVector<int64_t > staticShape;
153
- llvm::SmallVector<int64_t > staticStrides;
151
+ Type srcTy = source.getType ();
152
+ assert (isa<IntegerType>(srcTy) ||
153
+ isa<MemRefType>(srcTy) && " Source has to be either int or memref." );
154
+
154
155
llvm::SmallVector<Value> dynamicOffsets;
155
156
llvm::SmallVector<Value> dynamicShape;
156
157
llvm::SmallVector<Value> dynamicStrides;
157
158
158
- dispatchIndexOpFoldResults (offsets, dynamicOffsets, staticOffsets);
159
- dispatchIndexOpFoldResults (shape, dynamicShape, staticShape);
160
- dispatchIndexOpFoldResults (strides, dynamicStrides, staticStrides);
161
-
162
- auto staticOffsetsAttr = builder.getDenseI64ArrayAttr (staticOffsets);
163
- auto staticShapeAttr = builder.getDenseI64ArrayAttr (staticShape);
164
- auto staticStridesAttr = builder.getDenseI64ArrayAttr (staticStrides);
165
-
166
- build (builder, state, tdesc, source, dynamicOffsets, dynamicShape,
167
- dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
168
- }
169
-
170
- void CreateNdDescOp::build (OpBuilder &builder, OperationState &state,
171
- Type tdesc, TypedValue<IntegerType> source,
172
- llvm::ArrayRef<OpFoldResult> offsets,
173
- llvm::ArrayRef<OpFoldResult> shape,
174
- llvm::ArrayRef<OpFoldResult> strides) {
175
- assert (shape.size () && offsets.size () && strides.size () &&
176
- shape.size () == strides.size () && shape.size () == offsets.size ());
177
-
178
159
llvm::SmallVector<int64_t > staticOffsets;
179
160
llvm::SmallVector<int64_t > staticShape;
180
161
llvm::SmallVector<int64_t > staticStrides;
181
- llvm::SmallVector<Value> dynamicOffsets;
182
- llvm::SmallVector<Value> dynamicShape;
183
- llvm::SmallVector<Value> dynamicStrides;
184
162
185
163
dispatchIndexOpFoldResults (offsets, dynamicOffsets, staticOffsets);
186
164
dispatchIndexOpFoldResults (shape, dynamicShape, staticShape);
@@ -190,6 +168,18 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
190
168
auto staticShapeAttr = builder.getDenseI64ArrayAttr (staticShape);
191
169
auto staticStridesAttr = builder.getDenseI64ArrayAttr (staticStrides);
192
170
171
+ if (auto memrefTy = dyn_cast<MemRefType>(srcTy)) {
172
+ auto memrefShape = memrefTy.getShape ();
173
+ auto [memrefStrides, _] = memrefTy.getStridesAndOffset ();
174
+
175
+ // if shape and strides are from Memref, we don't need attributes for them
176
+ // to keep the IR print clean.
177
+ if (staticShape == memrefShape && staticStrides == memrefStrides) {
178
+ staticShapeAttr = DenseI64ArrayAttr ();
179
+ staticStridesAttr = DenseI64ArrayAttr ();
180
+ }
181
+ }
182
+
193
183
build (builder, state, tdesc, source, dynamicOffsets, dynamicShape,
194
184
dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
195
185
}
0 commit comments