@@ -21,20 +21,26 @@ namespace scheduler_utils {
21
21
22
22
// Returns number of "valid" dimensions. e.g. if tv has
23
23
// [I1, R2, I3, I4, R3{1}]
24
- // resulting domain should be:
25
- // [I1, I3*I4, R2* R3{1}] with return value 3
24
+ // where R3{1} is in dont_merge, resulting domain should be:
25
+ // [I1, I3*I4, R2, R3{1}] with return value 3
26
26
//
27
27
// if tv has
28
28
// [R1, I2, R3, I4, R4, R5{1}, R6{1}]
29
- // resulting domain should be:
30
- // [I2*I4, R1*R3, R4* R5{1}* R6{1}]
29
+ // where R5{1} and R6{1} are in dont_merge, resulting domain should be:
30
+ // [I2*I4, R1*R3, R4, R5{1}, R6{1}]
31
31
// with return value 3
32
- size_t merge_3d (TensorView* tv) {
32
+ size_t merge_3d (
33
+ TensorView* tv,
34
+ const std::unordered_set<IterDomain*>& dont_merge) {
33
35
bool active_is_reduction = false ;
34
36
bool first_dim = true ;
35
37
int prev_i = -1 ;
36
38
37
39
for (int i = static_cast <int >(tv->nDims ()) - 1 ; i >= 0 ; i--) {
40
+ if (dont_merge.count (tv->axis (i))) {
41
+ continue ;
42
+ }
43
+
38
44
if (first_dim) {
39
45
active_is_reduction = tv->axis (i)->isReduction ();
40
46
prev_i = i;
@@ -61,6 +67,10 @@ size_t merge_3d(TensorView* tv) {
61
67
62
68
for (int i = static_cast <int >(tv->nDims ()) - 2 ; i >= 0 ; i--) {
63
69
auto id = tv->axis (i);
70
+ if (dont_merge.count (id)) {
71
+ continue ;
72
+ }
73
+
64
74
if (first_dim) {
65
75
active_is_reduction = id->isReduction ();
66
76
prev_i = i;
@@ -86,6 +96,10 @@ size_t merge_3d(TensorView* tv) {
86
96
prev_i = -1 ;
87
97
88
98
for (int i = static_cast <int >(tv->nDims ()) - 3 ; i >= 0 ; i--) {
99
+ if (dont_merge.count (tv->axis (i))) {
100
+ continue ;
101
+ }
102
+
89
103
if (first_dim) {
90
104
active_is_reduction = tv->axis (i)->isReduction ();
91
105
prev_i = i;
@@ -100,7 +114,7 @@ size_t merge_3d(TensorView* tv) {
100
114
if (prev_i == -1 ) {
101
115
// Two dimensional, put merged dimensions first
102
116
tv->reorder ({{-1 , 0 }, {-2 , 1 }});
103
- // [outer, inner]
117
+ // [outer, inner, dont_merge... ]
104
118
if (tv->axis (0 )->isReduction ()) {
105
119
// put reductions as second axis
106
120
tv->reorder ({{0 , 1 }, {1 , 0 }});
@@ -181,11 +195,13 @@ c10::optional<size_t> mergeDims(
181
195
return left;
182
196
}
183
197
184
- size_t mergeReduction (TensorView* tv) {
198
+ size_t mergeReduction (
199
+ TensorView* tv,
200
+ const std::unordered_set<IterDomain*>& dont_merge) {
185
201
int prev_i = -1 ;
186
202
size_t num_merged = 0 ;
187
203
for (int i = static_cast <int >(tv->nDims ()) - 1 ; i >= 0 ; i--) {
188
- if (!tv->axis (i)->isReduction ()) {
204
+ if (!tv->axis (i)->isReduction () || dont_merge. count (tv-> axis (i)) ) {
189
205
continue ;
190
206
}
191
207
if (prev_i == -1 ) {
@@ -203,14 +219,16 @@ size_t mergeReduction(TensorView* tv) {
203
219
return prev_i == -1 ? 0 : num_merged + 1 ;
204
220
}
205
221
206
- size_t mergeNonReduction (TensorView* tv) {
222
+ size_t mergeNonReduction (
223
+ TensorView* tv,
224
+ const std::unordered_set<IterDomain*>& dont_merge) {
207
225
int prev_i = -1 ;
208
226
size_t num_merged = 0 ;
209
227
if (tv->nDims () == 0 ) {
210
228
return 0 ;
211
229
}
212
230
for (int i = static_cast <int >(tv->nDims ()) - 1 ; i >= 0 ; i--) {
213
- if (tv->axis (i)->isReduction ()) {
231
+ if (tv->axis (i)->isReduction () || dont_merge. count (tv-> axis (i)) ) {
214
232
continue ;
215
233
}
216
234
if (prev_i == -1 ) {
@@ -887,21 +905,63 @@ PersistentBufferSizeReturn persistentBufferSize(
887
905
return persistent_buffer_size;
888
906
}
889
907
908
+ std::unordered_set<IterDomain*> getTrivialReductionMap (Fusion* fusion) {
909
+ auto all_tvs = ir_utils::allTvs (fusion);
910
+ std::unordered_set<IterDomain*> mapped_to_trivial_reduction;
911
+ for (auto tv : all_tvs) {
912
+ // root domain vs domain shouldn't matter as at this point we shouldn't have
913
+ // any transformations.
914
+ for (auto id : tv->getRootDomain ()) {
915
+ if (id->isTrivialReduction ()) {
916
+ mapped_to_trivial_reduction.emplace (id);
917
+ }
918
+ }
919
+ }
920
+
921
+ if (!mapped_to_trivial_reduction.empty ()) {
922
+ // Use the loop map as that is the most permissive
923
+ auto ca_map = ComputeAtMap (fusion);
924
+ // Make a copy we need to check mappings of all
925
+ auto trivial_ids = mapped_to_trivial_reduction;
926
+ for (auto tv : all_tvs) {
927
+ for (auto id : tv->getRootDomain ()) {
928
+ if (!id->extent ()->isOneInt ()) {
929
+ continue ;
930
+ }
931
+ if (std::any_of (
932
+ trivial_ids.begin (),
933
+ trivial_ids.end (),
934
+ [&ca_map, &id](IterDomain* trivial_id) {
935
+ return ca_map.areMapped (
936
+ id, trivial_id, IdMappingMode::PERMISSIVE);
937
+ })) {
938
+ mapped_to_trivial_reduction.emplace (id);
939
+ }
940
+ }
941
+ }
942
+ }
943
+ return mapped_to_trivial_reduction;
944
+ }
945
+
890
946
std::pair<bool , bool > canonicalDimReduction (
891
947
Fusion* fusion,
892
948
TensorView* tv,
893
949
bool schedule_3D) {
950
+ std::unordered_set<IterDomain*> mapped_to_trivial_reduction =
951
+ getTrivialReductionMap (fusion);
952
+
894
953
TORCH_INTERNAL_ASSERT (tv != nullptr );
895
954
896
955
if (!schedule_3D) {
897
956
// We coalesce all reduction axes to the right;
898
- bool has_red_axis = mergeReduction (tv) > 0 ;
957
+ bool has_red_axis = mergeReduction (tv, mapped_to_trivial_reduction ) > 0 ;
899
958
900
- bool has_iter_axis = mergeNonReduction (tv) > 0 ;
959
+ bool has_iter_axis = mergeNonReduction (tv, mapped_to_trivial_reduction ) > 0 ;
901
960
return {has_iter_axis, has_red_axis};
902
961
} else {
903
962
TORCH_INTERNAL_ASSERT (
904
- merge_3d (tv) == 3 , " Tried 3D merge, but result is not 3D." );
963
+ merge_3d (tv, mapped_to_trivial_reduction) == 3 ,
964
+ " Tried 3D merge, but result is not 3D." );
905
965
return {true , true };
906
966
}
907
967
}
0 commit comments