@@ -266,6 +266,127 @@ func.func @fold_vector_transfer_write_with_inner_rank_reduced_subview(
266
266
267
267
// -----
268
268
269
+ func.func @fold_masked_vector_transfer_read_with_subview (
270
+ %arg0 : memref <?x?xf32 , strided <[?, ?], offset : ?>>,
271
+ %arg1: index , %arg2 : index , %arg3 : index , %arg4: index , %arg5 : index ,
272
+ %arg6 : index , %mask : vector <4 xi1 >) -> vector <4 xf32 > {
273
+ %cst = arith.constant 0.0 : f32
274
+ %0 = memref.subview %arg0 [%arg1 , %arg2 ] [%arg3 , %arg4 ] [1 , 1 ]
275
+ : memref <?x?xf32 , strided <[?, ?], offset : ?>> to
276
+ memref <?x?xf32 , strided <[?, ?], offset : ?>>
277
+ %1 = vector.transfer_read %0 [%arg5 , %arg6 ], %cst , %mask {in_bounds = [true ]}
278
+ : memref <?x?xf32 , strided <[?, ?], offset : ?>>, vector <4 xf32 >
279
+ return %1 : vector <4 xf32 >
280
+ }
281
+ // CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
282
+ // CHECK: func @fold_masked_vector_transfer_read_with_subview
283
+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32, strided<[?, ?], offset: ?>>
284
+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
285
+ // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
286
+ // CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
287
+ // CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
288
+ // CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
289
+ // CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
290
+ // CHECK-SAME: %[[MASK:[a-zA-Z0-9]+]]: vector<4xi1>
291
+ // CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP1]]()[%[[ARG1]], %[[ARG5]]]
292
+ // CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG6]]]
293
+ // CHECK: vector.transfer_read %[[ARG0]][%[[IDX0]], %[[IDX1]]], %{{.*}}, %[[MASK]] {{.*}} : memref<?x?xf32
294
+
295
+ // -----
296
+
297
+ func.func @fold_masked_vector_transfer_read_with_rank_reducing_subview (
298
+ %arg0 : memref <?x?x?x?xf32 , strided <[?, ?, ?, ?], offset : ?>>,
299
+ %arg1: index , %arg2 : index , %arg3 : index , %arg4: index , %arg5 : index ,
300
+ %arg6 : index , %mask : vector <4 x3 xi1 >) -> vector <3 x4 xf32 > {
301
+ %cst = arith.constant 0.0 : f32
302
+ %0 = memref.subview %arg0 [0 , %arg1 , 0 , %arg2 ] [1 , %arg3 , 1 , %arg4 ] [1 , 1 , 1 , 1 ]
303
+ : memref <?x?x?x?xf32 , strided <[?, ?, ?, ?], offset : ?>> to
304
+ memref <?x?xf32 , strided <[?, ?], offset : ?>>
305
+ %1 = vector.transfer_read %0 [%arg5 , %arg6 ], %cst , %mask {
306
+ permutation_map = affine_map <(d0 , d1 ) -> (d1 , d0 )>, in_bounds = [true , true ]}
307
+ : memref <?x?xf32 , strided <[?, ?], offset : ?>>, vector <3 x4 xf32 >
308
+ return %1 : vector <3 x4 xf32 >
309
+ }
310
+ // CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
311
+ // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d1)>
312
+ // CHECK: func @fold_masked_vector_transfer_read_with_rank_reducing_subview
313
+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32, strided<[?, ?, ?, ?], offset: ?>>
314
+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
315
+ // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
316
+ // CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
317
+ // CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
318
+ // CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
319
+ // CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
320
+ // CHECK-SAME: %[[MASK:[a-zA-Z0-9]+]]: vector<4x3xi1>
321
+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
322
+ // CHECK-DAG: %[[PAD:.+]] = arith.constant 0.000000e+00 : f32
323
+ // CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]], %[[ARG5]]]
324
+ // CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP0]]()[%[[ARG2]], %[[ARG6]]]
325
+ // CHECK: vector.transfer_read %[[ARG0]][%[[C0]], %[[IDX0]], %[[C0]], %[[IDX1]]], %[[PAD]], %[[MASK]] {{.*}} permutation_map = #[[MAP1]]} : memref<?x?x?x?xf32
326
+
327
+ // -----
328
+
329
+ func.func @fold_masked_vector_transfer_write_with_subview (
330
+ %arg0 : memref <?x?xf32 , strided <[?, ?], offset : ?>>,
331
+ %arg1 : vector <4 xf32 >, %arg2: index , %arg3 : index , %arg4 : index ,
332
+ %arg5: index , %arg6 : index , %arg7 : index , %mask : vector <4 xi1 >) {
333
+ %cst = arith.constant 0.0 : f32
334
+ %0 = memref.subview %arg0 [%arg2 , %arg3 ] [%arg4 , %arg5 ] [1 , 1 ]
335
+ : memref <?x?xf32 , strided <[?, ?], offset : ?>> to
336
+ memref <?x?xf32 , strided <[?, ?], offset : ?>>
337
+ vector.transfer_write %arg1 , %0 [%arg6 , %arg7 ], %mask {in_bounds = [true ]}
338
+ : vector <4 xf32 >, memref <?x?xf32 , strided <[?, ?], offset : ?>>
339
+ return
340
+ }
341
+ // CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
342
+ // CHECK: func @fold_masked_vector_transfer_write_with_subview
343
+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32, strided<[?, ?], offset: ?>>
344
+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: vector<4xf32>
345
+ // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
346
+ // CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
347
+ // CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
348
+ // CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
349
+ // CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
350
+ // CHECK-SAME: %[[ARG7:[a-zA-Z0-9]+]]: index
351
+ // CHECK-SAME: %[[MASK:[a-zA-Z0-9]+]]: vector<4xi1>
352
+ // CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG6]]]
353
+ // CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG7]]]
354
+ // CHECK-DAG: vector.transfer_write %[[ARG1]], %[[ARG0]][%[[IDX0]], %[[IDX1]]], %[[MASK]] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32
355
+
356
+ // -----
357
+
358
+ func.func @fold_masked_vector_transfer_write_with_rank_reducing_subview (
359
+ %arg0 : memref <?x?x?x?xf32 , strided <[?, ?, ?, ?], offset : ?>>,
360
+ %arg1 : vector <3 x4 xf32 >, %arg2: index , %arg3 : index , %arg4 : index ,
361
+ %arg5: index , %arg6 : index , %arg7 : index , %mask : vector <4 x3 xi1 >) {
362
+ %cst = arith.constant 0.0 : f32
363
+ %0 = memref.subview %arg0 [0 , %arg2 , 0 , %arg3 ] [1 , %arg4 , 1 , %arg5 ] [1 , 1 , 1 , 1 ]
364
+ : memref <?x?x?x?xf32 , strided <[?, ?, ?, ?], offset : ?>> to
365
+ memref <?x?xf32 , strided <[?, ?], offset : ?>>
366
+ vector.transfer_write %arg1 , %0 [%arg6 , %arg7 ], %mask {
367
+ permutation_map = affine_map <(d0 , d1 ) -> (d1 , d0 )>, in_bounds = [true , true ]}
368
+ : vector <3 x4 xf32 >, memref <?x?xf32 , strided <[?, ?], offset : ?>>
369
+ return
370
+ }
371
+ // CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
372
+ // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d1)>
373
+ // CHECK: func @fold_masked_vector_transfer_write_with_rank_reducing_subview
374
+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32, strided<[?, ?, ?, ?], offset: ?>>
375
+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: vector<3x4xf32>
376
+ // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
377
+ // CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
378
+ // CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
379
+ // CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
380
+ // CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
381
+ // CHECK-SAME: %[[ARG7:[a-zA-Z0-9]+]]: index
382
+ // CHECK-SAME: %[[MASK:[a-zA-Z0-9]+]]: vector<4x3xi1>
383
+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
384
+ // CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP0]]()[%[[ARG2]], %[[ARG6]]]
385
+ // CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP0]]()[%[[ARG3]], %[[ARG7]]]
386
+ // CHECK-DAG: vector.transfer_write %[[ARG1]], %[[ARG0]][%[[C0]], %[[IDX0]], %[[C0]], %[[IDX1]]], %[[ARG8]] {in_bounds = [true, true], permutation_map = #[[MAP1]]} : vector<3x4xf32>, memref<?x?x?x?xf32
387
+
388
+ // -----
389
+
269
390
// Test with affine.load/store ops. We only do a basic test here since the
270
391
// logic is identical to that with memref.load/store ops. The same affine.apply
271
392
// ops would be generated.
0 commit comments