@@ -16,7 +16,7 @@ func.func @eliminate_redundant_masks_through_insert_and_extracts(%tensor: tensor
16
16
%extracted_slice_0 = tensor.extract_slice %tensor [0 , 0 ] [1 , %c4_vscale ] [1 , 1 ] : tensor <1 x1000 xf32 > to tensor <1 x?xf32 >
17
17
%output_tensor = scf.for %i = %c0 to %c1000 step %c4_vscale iter_args (%arg = %extracted_slice_0 ) -> tensor <1 x?xf32 > {
18
18
// 1. Extract a slice.
19
- %extracted_slice_1 = tensor.extract_slice %arg [0 , 0 ] [1 , %c4_vscale ] [1 , 1 ] : tensor <1 x?xf32 > to tensor <?xf32 >
19
+ %extracted_slice_1 = tensor.extract_slice %arg [0 , %i ] [1 , %c4_vscale ] [1 , 1 ] : tensor <1 x?xf32 > to tensor <?xf32 >
20
20
21
21
// 2. Create a mask for the slice.
22
22
%dim_1 = tensor.dim %extracted_slice_1 , %c0 : tensor <?xf32 >
@@ -30,7 +30,7 @@ func.func @eliminate_redundant_masks_through_insert_and_extracts(%tensor: tensor
30
30
%write = vector.transfer_write %new_vec , %extracted_slice_1 [%c0 ], %mask {in_bounds = [true ]} : vector <[4 ]xf32 >, tensor <?xf32 >
31
31
32
32
// 5. Insert and yield the new tensor value.
33
- %result = tensor.insert_slice %write into %arg [0 , 0 ] [1 , %c4_vscale ] [1 , 1 ] : tensor <?xf32 > into tensor <1 x?xf32 >
33
+ %result = tensor.insert_slice %write into %arg [0 , %i ] [1 , %c4_vscale ] [1 , 1 ] : tensor <?xf32 > into tensor <1 x?xf32 >
34
34
scf.yield %result : tensor <1 x?xf32 >
35
35
}
36
36
" test.some_use" (%output_tensor ) : (tensor <1 x?xf32 >) -> ()
@@ -57,8 +57,8 @@ func.func @negative_extract_slice_size_shrink(%tensor: tensor<1000xf32>) {
57
57
%mask = vector.create_mask %dim : vector <[4 ]xi1 >
58
58
" test.some_use" (%mask ) : (vector <[4 ]xi1 >) -> ()
59
59
// !!! Here the size of the mask could shrink in the next iteration.
60
- %next_num_els = affine.min affine_map <(d0 )[s0 ] -> (-d0 + 1000 , s0 )>(%i )[%c4_vscale ]
61
- %new_extracted_slice = tensor.extract_slice %tensor [%c4_vscale ] [%next_num_els ] [1 ] : tensor <1000 xf32 > to tensor <?xf32 >
60
+ %next_num_elts = affine.min affine_map <(d0 )[s0 ] -> (-d0 + 1000 , s0 )>(%i )[%c4_vscale ]
61
+ %new_extracted_slice = tensor.extract_slice %tensor [%c4_vscale ] [%next_num_elts ] [1 ] : tensor <1000 xf32 > to tensor <?xf32 >
62
62
scf.yield %new_extracted_slice : tensor <?xf32 >
63
63
}
64
64
" test.some_use" (%slice ) : (tensor <?xf32 >) -> ()
@@ -110,8 +110,8 @@ func.func @negative_value_bounds_fixed_dim_not_all_true(%tensor: tensor<2x?xf32>
110
110
%c4 = arith.constant 4 : index
111
111
%vscale = vector.vscale
112
112
%c4_vscale = arith.muli %vscale , %c4 : index
113
- // This is _very_ simple but since addi is not a constant value bounds will
114
- // be used to resolve it.
113
+ // This is _very_ simple but since tensor.dim is not a constant value bounds
114
+ // will be used to resolve it.
115
115
%dim = tensor.dim %tensor , %c0 : tensor <2 x?xf32 >
116
116
%mask = vector.create_mask %dim , %c4_vscale : vector <3 x[4 ]xi1 >
117
117
" test.some_use" (%mask ) : (vector <3 x[4 ]xi1 >) -> ()
0 commit comments