Skip to content

Commit 1a0e355

Browse files
authored
Fix contiguity analysis of predicates to match updated contiguity. (#1991)
1 parent a4effa6 commit 1a0e355

File tree

8 files changed

+245
-116
lines changed

8 files changed

+245
-116
lines changed

torch/csrc/jit/codegen/cuda/contiguity.cpp

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -390,16 +390,20 @@ ContigIDs::ContigIDs(
390390
const std::vector<IterDomain*>& ids,
391391
const std::vector<IterDomain*>& root_domain,
392392
const std::vector<bool>& root_contiguity,
393-
std::unordered_map<IterDomain*, IterDomain*> concrete_to_ref,
393+
const std::unordered_set<IterDomain*>& final_ids,
394+
const std::unordered_map<IterDomain*, Val*>& index_map,
394395
const std::unordered_set<Split*>& divisible_splits,
395396
std::unordered_map<IterDomain*, IterDomain*> p2c_id_map,
396-
bool ignore_indexability)
397+
bool ignore_indexability,
398+
bool ignore_consistent_ordering)
397399
: root_domain_(root_domain),
398400
root_contiguity_(root_contiguity),
399-
concrete_to_ref_(std::move(concrete_to_ref)),
401+
final_ids_(final_ids),
402+
index_map_(index_map),
400403
divisible_splits_(divisible_splits),
401404
p2c_id_map_(std::move(p2c_id_map)),
402405
ignore_indexability_(ignore_indexability),
406+
ignore_consistent_ordering_(ignore_consistent_ordering),
403407
non_divisible_id_info_(ids, root_domain_, divisible_splits_) {
404408
if (ids.size() > 0) {
405409
// This constructor doesn't provide the following information so it needs to
@@ -419,22 +423,26 @@ ContigIDs::ContigIDs(
419423
const std::vector<IterDomain*>& ids,
420424
const std::vector<IterDomain*>& root_domain,
421425
const std::vector<bool>& root_contiguity,
422-
std::unordered_map<IterDomain*, IterDomain*> concrete_to_ref,
426+
const std::unordered_set<IterDomain*>& final_ids,
427+
const std::unordered_map<IterDomain*, Val*>& index_map,
423428
const std::unordered_set<Split*>& divisible_splits,
424429
std::shared_ptr<const ComputeAtMap> ca_map,
425430
std::shared_ptr<const HaloInfo> halo_info,
426431
std::shared_ptr<const ConcretizedBroadcastDomains> concrete_info,
427432
std::unordered_map<IterDomain*, IterDomain*> p2c_id_map,
428-
bool ignore_indexability)
433+
bool ignore_indexability,
434+
bool ignore_consistent_ordering)
429435
: root_domain_(root_domain),
430436
root_contiguity_(root_contiguity),
431-
concrete_to_ref_(std::move(concrete_to_ref)),
437+
final_ids_(final_ids),
438+
index_map_(index_map),
432439
divisible_splits_(divisible_splits),
433440
ca_map_(ca_map),
434441
halo_info_(halo_info),
435442
concrete_info_(concrete_info),
436443
p2c_id_map_(std::move(p2c_id_map)),
437444
ignore_indexability_(ignore_indexability),
445+
ignore_consistent_ordering_(ignore_consistent_ordering),
438446
consistent_transform_info_(std::make_unique<const OrderedIdInformation>(
439447
ids,
440448
root_domain,
@@ -443,6 +451,10 @@ ContigIDs::ContigIDs(
443451
build(ids);
444452
}
445453

454+
ContigIDs ContigIDs::getNonContigIDs() {
455+
return ContigIDs({}, {}, {}, {}, {}, {});
456+
}
457+
446458
void ContigIDs::build(const std::vector<IterDomain*>& ids) {
447459
if (ids.empty() || root_domain_.empty()) {
448460
return;
@@ -488,8 +500,12 @@ void ContigIDs::handle(Merge* merge) {
488500
// If output is not consistently ordered or doesn't solely consume all root
489501
// domains in its dependencies, then it can't be a contiguously indexable
490502
// iterdomain.
491-
if (!(consistent_transform_info_->isConsistentlyOrdered(merge->out()) &&
492-
consistent_transform_info_->exclusivelyConsumesRoots(merge->out()))) {
503+
if (!(ignore_consistent_ordering_ ||
504+
consistent_transform_info_->isConsistentlyOrdered(merge->out()))) {
505+
return;
506+
}
507+
508+
if (!consistent_transform_info_->exclusivelyConsumesRoots(merge->out())) {
493509
return;
494510
}
495511

@@ -499,6 +515,11 @@ void ContigIDs::handle(Merge* merge) {
499515
return;
500516
}
501517

518+
// If inputs are marked as final, stop
519+
if (final_ids_.count(merge->inner()) || final_ids_.count(merge->outer())) {
520+
return;
521+
}
522+
502523
// Check root domains for contiguity
503524
auto root_ids_it =
504525
consistent_transform_info_->idToRootIds().find(merge->out());
@@ -512,17 +533,25 @@ void ContigIDs::handle(Merge* merge) {
512533

513534
VectorOfUniqueEntries<IterDomain*> root_ids = root_ids_it->second;
514535

536+
bool is_indexing_pass = !ignore_consistent_ordering_;
537+
515538
IterDomain* last_root = nullptr;
516539
for (auto root_id_i : c10::irange(root_domain_.size())) {
517540
auto root_id = root_domain_[root_id_i];
518541
if (root_ids.has(root_id)) {
519542
// ID found, remove it
520543
root_ids.erase(root_id);
521-
// If the last id isn't contiguous that's fine, we can use the stride of
522-
// the last iter domain to multiply the contig index.
523-
if (!root_contiguity_[root_id_i] && !root_ids.empty()) {
524-
// Otherwise this merge->out() isn't a contiguously indexable ID
525-
return;
544+
// If we're indexing:
545+
// we could still potentially consider this ID linearly indexable, as we
546+
// could multiple the index by the last root's stride.
547+
//
548+
// If we're computing predicates (ignore_consistent_ordering_==true),
549+
// then we don't have this same constraint, we can just ignore
550+
// contiguity of the roots all together.
551+
if (!root_contiguity_[root_id_i] && is_indexing_pass) {
552+
if (!root_ids.empty()) {
553+
return;
554+
}
526555
}
527556
last_root = root_id;
528557
}
@@ -581,7 +610,7 @@ bool ContigIDs::isIndexable(IterDomain* id) const {
581610
}
582611
auto c_id =
583612
ca_map_->getConcreteMappedID(getMappedId(id), IdMappingMode::EXACT);
584-
return concrete_to_ref_.find(c_id) != concrete_to_ref_.end();
613+
return index_map_.find(c_id) != index_map_.end();
585614
}
586615

587616
} // namespace cuda

torch/csrc/jit/codegen/cuda/contiguity.h

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -161,22 +161,49 @@ class ContigIDs : public OptInDispatch {
161161
const std::vector<IterDomain*>& ids,
162162
const std::vector<IterDomain*>& root_domain,
163163
const std::vector<bool>& root_contiguity,
164-
std::unordered_map<IterDomain*, IterDomain*> concrete_to_ref,
164+
const std::unordered_set<IterDomain*>& final_ids,
165+
const std::unordered_map<IterDomain*, Val*>& index_map,
165166
const std::unordered_set<Split*>& divisible_splits,
166167
std::unordered_map<IterDomain*, IterDomain*> p2c_id_map = {},
167-
bool ignore_indexability = false);
168-
168+
bool ignore_indexability = false,
169+
bool ignore_consistent_ordering = false);
170+
171+
//! \param ids IterDomains on the leaves of the domain we're looking for
172+
//! contiguous indexing into.
173+
//! \param root_domain the root domain of the domain we're looking for
174+
//! contiguous indexing into.
175+
//! \param root_contiguity the contiguity of the root_domain.
176+
//! \param concrete_to_ref concrete ids of the exact map that the reference
177+
//! index is using for indexing.
178+
//! \param divisible_splits a set of all splits in the fusion that are
179+
//! divisible.
180+
//! \param ca_map compute at map of the fusion.
181+
//! \param halo_info halo information of the fusion.
182+
//! \param concrete_info concretized broadcast information of the fusion.
183+
//! \param p2c_id_map map from producer to consumer ids used for indexing
184+
//! producer tensors.
185+
//! \param ignore_consistent_ordering true for actual indexing into tensors
186+
//! but false for predicate analysis. Ordering of merges don't matter for
187+
//! predicate generation as they don't map to a physical address.
188+
//! \param ignore_indexability can only be true if providing a real
189+
//! concrete_to_ref map. As what it's checking is if the index is actually
190+
//! indexable based on the reference.
169191
ContigIDs(
170192
const std::vector<IterDomain*>& ids,
171193
const std::vector<IterDomain*>& root_domain,
172194
const std::vector<bool>& root_contiguity,
173-
std::unordered_map<IterDomain*, IterDomain*> concrete_to_ref,
195+
const std::unordered_set<IterDomain*>& final_ids,
196+
const std::unordered_map<IterDomain*, Val*>& index_map,
174197
const std::unordered_set<Split*>& divisible_splits,
175198
std::shared_ptr<const ComputeAtMap> ca_map,
176199
std::shared_ptr<const HaloInfo> halo_info,
177200
std::shared_ptr<const ConcretizedBroadcastDomains> concrete_info,
178201
std::unordered_map<IterDomain*, IterDomain*> p2c_id_map = {},
179-
bool ignore_indexability = false);
202+
bool ignore_indexability = false,
203+
bool ignore_consistent_ordering = false);
204+
205+
//! Return an empty ContigIDs with no contiguous ID
206+
static ContigIDs getNonContigIDs();
180207

181208
const std::unordered_set<IterDomain*>& contigIDs() const {
182209
return contig_ids_;
@@ -191,6 +218,14 @@ class ContigIDs : public OptInDispatch {
191218
return root_to_indexed_id_;
192219
}
193220

221+
VectorOfUniqueEntries<IterDomain*> indexedRootIDs(IterDomain* id) const {
222+
auto root_ids_it = consistent_transform_info_->idToRootIds().find(id);
223+
if (root_ids_it == consistent_transform_info_->idToRootIds().end()) {
224+
return {};
225+
}
226+
return root_ids_it->second;
227+
}
228+
194229
private:
195230
using OptInDispatch::handle;
196231

@@ -233,9 +268,12 @@ class ContigIDs : public OptInDispatch {
233268
const std::vector<IterDomain*>& root_domain_;
234269
//! Contiguity of root_domain_
235270
const std::vector<bool>& root_contiguity_;
236-
//! Mapping of concrete to reference domains. If a concrete domain
237-
//! is not mapped, it is not indexable as there's no mapped index.
238-
const std::unordered_map<IterDomain*, IterDomain*> concrete_to_ref_;
271+
//! Domains where indexing/predicates cannot be done with their
272+
//! consumers domains
273+
const std::unordered_set<IterDomain*>& final_ids_;
274+
//! Mapping of concrete domains to indices. Just used to check if
275+
//! there's an index for an IterDomain.
276+
const std::unordered_map<IterDomain*, Val*> index_map_;
239277
// Divisible split information as we can still consider iter domains
240278
// contiguous through divisible splits.
241279
const std::unordered_set<Split*>& divisible_splits_;
@@ -247,7 +285,9 @@ class ContigIDs : public OptInDispatch {
247285
//! Producer-to-consumer index map in the case of analyzing replayed
248286
//! producer tensors
249287
const std::unordered_map<IterDomain*, IterDomain*> p2c_id_map_;
288+
250289
const bool ignore_indexability_ = false;
290+
const bool ignore_consistent_ordering_ = false;
251291

252292
//! Mapping of root domain to bool indicating contiguity
253293
std::unordered_map<IterDomain*, bool> is_contig_root_;

torch/csrc/jit/codegen/cuda/index_compute.cpp

Lines changed: 46 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,7 @@ IndexCompute::IndexCompute(
594594
std::move(extent_map),
595595
std::move(zero_domains),
596596
std::move(zero_merged_in),
597-
ContigIDs({}, {}, {}, {}, {}),
597+
ContigIDs::getNonContigIDs(),
598598
std::move(preferred_paths),
599599
std::move(halo_extent_map)) {}
600600

@@ -2326,103 +2326,71 @@ std::vector<PredicateDomainInfo> getPredicateContigIds(
23262326

23272327
const auto& consumer_root_domain = consumer_tv->getRootDomain();
23282328

2329-
std::vector<IterDomain*> contiguous_ids = consumer_root_domain;
2330-
2331-
if (contiguous_ids.empty()) {
2329+
if (consumer_root_domain.empty()) {
23322330
return std::vector<PredicateDomainInfo>();
23332331
}
23342332

2335-
// If root IDs are partial, i.e., start is non-zero and stop is not
2336-
// equal to extent, predication can't be done with merged domains as
2337-
// start and stop information is only available with root
2338-
// domains. Similarly, merged domains don't have enough information
2339-
// about halo to do correct predication, so they must be excluded.
2340-
std::unordered_set<IterDomain*> excluded_ids;
2333+
std::unordered_map<IterDomain*, Val*> concrete_index_map;
2334+
for (auto entry : consumer_index_map) {
2335+
auto c_id = gpu_lower->caMap()->getConcreteMappedID(
2336+
entry.first, IdMappingMode::EXACT);
2337+
concrete_index_map[c_id] = entry.second;
2338+
}
23412339

2342-
for (auto consumer_root_id : consumer_root_domain) {
2343-
if (gpu_lower->haloInfo()->getRootAxisInfo(consumer_root_id).hasHalo()) {
2344-
excluded_ids.insert(consumer_root_id);
2345-
continue;
2346-
}
2347-
if (consumer_root_id->maybePartial()) {
2348-
excluded_ids.insert(consumer_root_id);
2349-
continue;
2350-
}
2351-
// When consumer_root_id is a broadcast domain, do not allow contig
2352-
// predication as the merged output is not mapped with the
2353-
// reference unless the concrete domain is also a broadcast
2354-
// domain.
2355-
if (consumer_root_id->isBroadcast() &&
2356-
!GpuLower::current()
2357-
->caMap()
2358-
->getConcreteMappedID(consumer_root_id, IdMappingMode::PERMISSIVE)
2359-
->isBroadcast()) {
2360-
excluded_ids.insert(consumer_root_id);
2340+
std::vector<bool> predicate_contiguity(consumer_root_domain.size(), true);
2341+
std::unordered_set<IterDomain*> final_ids;
2342+
for (auto root_i : c10::irange(predicate_contiguity.size())) {
2343+
auto root_id = consumer_root_domain[root_i];
2344+
if (root_id->maybePartial()) {
2345+
final_ids.insert(root_id);
23612346
continue;
23622347
}
23632348
// Shifted or gathered axes need to be predicated at the root domain
23642349
auto shift_expr = dynamic_cast<ShiftOp*>(consumer_tv->definition());
23652350
auto gather_expr = dynamic_cast<GatherOp*>(consumer_tv->definition());
2366-
if (shift_expr == nullptr && gather_expr == nullptr) {
2367-
continue;
2368-
}
2369-
auto consumer_root_pos = consumer_tv->domain()->rootPosOf(consumer_root_id);
2370-
if ((shift_expr && shift_expr->offset(consumer_root_pos) != 0) ||
2371-
(gather_expr && consumer_root_pos < gather_expr->windowShape().size() &&
2372-
gather_expr->windowShape().at(consumer_root_pos) != 1)) {
2373-
excluded_ids.insert(consumer_root_id);
2351+
if ((shift_expr && shift_expr->offset(root_i) != 0) ||
2352+
(gather_expr && root_i < gather_expr->windowShape().size() &&
2353+
gather_expr->windowShape().at(root_i) != 1)) {
2354+
final_ids.insert(root_id);
23742355
}
23752356
}
23762357

2377-
// Run through iteration domain history
2378-
auto exprs = StmtSort::getExprs(
2379-
consumer_tv->fusion(),
2380-
{consumer_tv->domain()->domain().begin(),
2381-
consumer_tv->domain()->domain().end()});
2358+
ContigIDs contig_finder(
2359+
consumer_tv->domain()->domain(),
2360+
consumer_root_domain,
2361+
predicate_contiguity,
2362+
final_ids,
2363+
concrete_index_map,
2364+
GpuLower::current()->divisbleSplitSet(),
2365+
GpuLower::current()->caMap(),
2366+
GpuLower::current()->haloInfo(),
2367+
GpuLower::current()->concretizedBroadcastDomains(),
2368+
{},
2369+
false,
2370+
true);
23822371

2383-
for (auto expr : exprs) {
2384-
// If not a merge, output is not contiguous
2385-
if (expr->isA<Merge>()) {
2386-
auto merge = expr->as<Merge>();
2387-
auto inner_contig_it = std::find(
2388-
contiguous_ids.begin(), contiguous_ids.end(), merge->inner());
2389-
auto outer_contig_it = std::find(
2390-
contiguous_ids.begin(), contiguous_ids.end(), merge->outer());
2372+
std::vector<PredicateDomainInfo> contig_id_infos;
2373+
std::unordered_set<IterDomain*> covered_roots;
23912374

2392-
if (excluded_ids.count(merge->inner()) > 0 ||
2393-
excluded_ids.count(merge->outer()) > 0) {
2394-
continue;
2395-
}
2375+
// Create entries and return them
2376+
for (auto root_id : consumer_root_domain) {
2377+
if (covered_roots.count(root_id) > 0) {
2378+
continue;
2379+
}
23962380

2397-
// Do not try to predicate the merge output domain if the output
2398-
// domain has not a predicate that is mapped from the reference.
2399-
// See FusionContigPredicate_CUDA for a concrete example.
2400-
if (consumer_index_map.find(merge->out()) == consumer_index_map.end()) {
2401-
continue;
2402-
}
2381+
auto contig_id_it = contig_finder.rootToIndexedID().find(root_id);
24032382

2404-
if (inner_contig_it != contiguous_ids.end() &&
2405-
outer_contig_it != contiguous_ids.end()) {
2406-
// If inner and outer are contiguous, out must be contiguous. Remove
2407-
// inner and outer, and add out.
2408-
contiguous_ids.erase(outer_contig_it);
2409-
contiguous_ids.erase(std::find(
2410-
contiguous_ids.begin(), contiguous_ids.end(), merge->inner()));
2411-
contiguous_ids.emplace_back(merge->out());
2412-
}
2413-
}
2414-
}
2383+
TORCH_INTERNAL_ASSERT(
2384+
contig_id_it != contig_finder.rootToIndexedID().end(),
2385+
"Error in predicate contiguity analysis, missing index for root ",
2386+
root_id->toString());
24152387

2416-
std::vector<PredicateDomainInfo> contig_id_infos;
2388+
auto contig_id = contig_id_it->second;
24172389

2418-
// Create entries and return them
2419-
for (auto contig_id : contiguous_ids) {
24202390
// Pick inputs from the starting domains, i.e.,
24212391
// reference_predicated_root_domain.
2422-
auto contig_root_vals = IterVisitor::getInputsTo(
2423-
{contig_id},
2424-
{consumer_root_domain.begin(), consumer_root_domain.end()});
2425-
auto contig_root_ids = ir_utils::filterByType<IterDomain>(contig_root_vals);
2392+
auto contig_root_ids = contig_finder.indexedRootIDs(contig_id);
2393+
covered_roots.insert(contig_root_ids.begin(), contig_root_ids.end());
24262394
PredicateDomainInfo contig_id_info;
24272395
contig_id_info.id = contig_id;
24282396
contig_id_info.covered_ids = std::unordered_set<IterDomain*>(

torch/csrc/jit/codegen/cuda/iter_visitor.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,10 +214,10 @@ void IterVisitor::traverseBetween(
214214

215215
void IterVisitor::traverseTo(
216216
Fusion* fusion,
217-
const std::vector<Val*>& To,
217+
const std::vector<Val*>& to,
218218
bool traverse_all_paths,
219219
bool traverse_into_members) {
220-
traverseBetween(fusion, {}, To, traverse_all_paths, traverse_into_members);
220+
traverseBetween(fusion, {}, to, traverse_all_paths, traverse_into_members);
221221
}
222222

223223
void IterVisitor::traverseHelper(Fusion* fusion, bool traverse_all_paths) {

0 commit comments

Comments
 (0)