@@ -384,20 +384,45 @@ func.func @tensor.reshape() -> tensor<2x2x5xf32> {
384
384
// -----
385
385
386
386
// CHECK-LABEL: @reshape_with_non_identity_layout(
387
- // CHECK-SAME: %[[INPUT:[a-zA-Z0-9]*]]: memref<2x2xf32, strided<[?, ?], offset: ?>>,
388
- // CHECK-SAME: %[[LAYOUT:[a-zA-Z0-9]*]]: memref<2xi32, strided<[?], offset: ?>>)
389
- func.func @reshape_with_non_identity_layout (%arg0: tensor <2 x2 xf32 >, %arg1: tensor <2 xi32 >) -> tensor <1 x2 xf32 > {
390
-
391
- // CHECK: %[[SUBVIEW:.+]] = memref.subview %[[INPUT]][1, 0] [1, 2] [1, 1] : memref<2x2xf32, strided<[?, ?], offset: ?>> to memref<2xf32, strided<[?], offset: ?>>
392
- %extracted_slice = tensor.extract_slice %arg0 [1 , 0 ] [1 , 2 ] [1 , 1 ] : tensor <2 x2 xf32 > to tensor <2 xf32 >
387
+ // CHECK-SAME: %[[INPUT:[a-zA-Z0-9]*]]: memref<2x2xf32, strided<[?, ?], offset: ?>, 3>,
388
+ // CHECK-SAME: %[[LAYOUT:[a-zA-Z0-9]*]]: memref<2xi32, strided<[?], offset: ?>>,
389
+ func.func @reshape_with_non_identity_layout (%arg0: memref <2 x2 xf32 , strided <[?, ?], offset : ?>, 3 >, %arg1: tensor <2 xi32 >, %idx: index ) -> f32 {
390
+ %t = bufferization.to_tensor %arg0 restrict : memref <2 x2 xf32 , strided <[?, ?], offset : ?>, 3 >
391
+
392
+ // CHECK: %[[SUBVIEW:.+]] = memref.subview %[[INPUT]][1, 0] [1, 2] [1, 1] : memref<2x2xf32, strided<[?, ?], offset: ?>, 3> to memref<2xf32, strided<[?], offset: ?>, 3>
393
+ %extracted_slice = tensor.extract_slice %t [1 , 0 ] [1 , 2 ] [1 , 1 ] : tensor <2 x2 xf32 > to tensor <2 xf32 >
394
+
395
+ // To satisify the constraints of memref.reshape, the subview must be
396
+ // reallocated a buffer with an identity layout.
397
+ // CHECK: %[[ALLOC:.+]] = memref.alloc() {{.*}} : memref<2xf32, 3>
398
+ // CHECK: memref.copy %[[SUBVIEW]], %[[ALLOC]]
399
+ // CHECK: %[[RESHAPED:.+]] = memref.reshape %[[ALLOC]](%[[LAYOUT]]) : (memref<2xf32, 3>, memref<2xi32, strided<[?], offset: ?>>) -> memref<1x2xf32, 3>
400
+ %reshape = tensor.reshape %extracted_slice (%arg1 ) : (tensor <2 xf32 >, tensor <2 xi32 >) -> tensor <1 x2 xf32 >
393
401
394
- // To satisify the constraints of memref.reshape, the subview must be cloned into
395
- // a buffer with an identity layout.
396
- // CHECK: %[[CLONED:.+]] = bufferization.clone %[[SUBVIEW]] : memref<2xf32, strided<[?], offset: ?>> to memref<2xf32>
397
- // CHECK: %[[RESHAPED:.+]] = memref.reshape %[[CLONED]](%[[LAYOUT]]) : (memref<2xf32>, memref<2xi32, strided<[?], offset: ?>>) -> memref<1x2xf32>
402
+ %r = tensor.extract %reshape [%idx , %idx ] : tensor <1 x2 xf32 >
403
+ return %r : f32
404
+ }
398
405
399
- %reshape = tensor.reshape %extracted_slice ( %arg1 ) : ( tensor < 2 x f32 >, tensor < 2 x i32 >) -> tensor < 1 x 2 x f32 >
406
+ // -----
400
407
401
- // CHECK: return %[[RESHAPED]] : memref<1x2xf32>
402
- return %reshape : tensor <1 x2 xf32 >
408
+ // CHECK-LABEL: func @collapse_shape_regression(
409
+ // CHECK-SAME: %[[t:.*]]: memref<10x20xf32,
410
+ func.func @collapse_shape_regression (
411
+ %t: tensor <10 x20 xf32 >, %f: f32 , %idx: index ) {
412
+ // CHECK: %[[subview:.*]] = memref.subview %[[t]]
413
+ %0 = tensor.extract_slice %t [2 , 3 ] [5 , 6 ] [1 , 1 ]
414
+ : tensor <10 x20 xf32 > to tensor <5 x6 xf32 >
415
+
416
+ // Insert a copy because the original %0 is read later.
417
+ // CHECK: %[[alloc1:.*]] = memref.alloc() {{.*}} : memref<5x6xf32>
418
+ // CHECK: memref.copy %[[subview]], %[[alloc1]]
419
+ // CHECK: memref.store {{.*}}, %[[alloc1]]
420
+ tensor.insert %f into %0 [%idx , %idx ] : tensor <5 x6 xf32 >
421
+
422
+ // Insert a copy because the shape cannot be collapsed.
423
+ // CHECK: %[[alloc2:.*]] = memref.alloc() {{.*}} : memref<5x6xf32>
424
+ // CHECK: memref.copy %[[subview]], %[[alloc2]]
425
+ // CHECK: memref.collapse_shape %[[alloc2]]
426
+ tensor.collapse_shape %0 [[0 , 1 ]] : tensor <5 x6 xf32 > into tensor <30 xf32 >
427
+ return
403
428
}
0 commit comments