@@ -594,7 +594,7 @@ IndexCompute::IndexCompute(
594
594
std::move(extent_map),
595
595
std::move(zero_domains),
596
596
std::move(zero_merged_in),
597
- ContigIDs({}, {}, {}, {}, {} ),
597
+ ContigIDs::getNonContigIDs( ),
598
598
std::move(preferred_paths),
599
599
std::move(halo_extent_map)) {}
600
600
@@ -2326,103 +2326,71 @@ std::vector<PredicateDomainInfo> getPredicateContigIds(
2326
2326
2327
2327
const auto & consumer_root_domain = consumer_tv->getRootDomain ();
2328
2328
2329
- std::vector<IterDomain*> contiguous_ids = consumer_root_domain;
2330
-
2331
- if (contiguous_ids.empty ()) {
2329
+ if (consumer_root_domain.empty ()) {
2332
2330
return std::vector<PredicateDomainInfo>();
2333
2331
}
2334
2332
2335
- // If root IDs are partial, i.e., start is non-zero and stop is not
2336
- // equal to extent, predication can't be done with merged domains as
2337
- // start and stop information is only available with root
2338
- // domains. Similarly, merged domains don't have enough information
2339
- // about halo to do correct predication, so they must be excluded.
2340
- std::unordered_set<IterDomain*> excluded_ids;
2333
+ std::unordered_map<IterDomain*, Val*> concrete_index_map;
2334
+ for ( auto entry : consumer_index_map) {
2335
+ auto c_id = gpu_lower-> caMap ()-> getConcreteMappedID (
2336
+ entry. first , IdMappingMode::EXACT);
2337
+ concrete_index_map[c_id] = entry. second ;
2338
+ }
2341
2339
2342
- for (auto consumer_root_id : consumer_root_domain) {
2343
- if (gpu_lower->haloInfo ()->getRootAxisInfo (consumer_root_id).hasHalo ()) {
2344
- excluded_ids.insert (consumer_root_id);
2345
- continue ;
2346
- }
2347
- if (consumer_root_id->maybePartial ()) {
2348
- excluded_ids.insert (consumer_root_id);
2349
- continue ;
2350
- }
2351
- // When consumer_root_id is a broadcast domain, do not allow contig
2352
- // predication as the merged output is not mapped with the
2353
- // reference unless the concrete domain is also a broadcast
2354
- // domain.
2355
- if (consumer_root_id->isBroadcast () &&
2356
- !GpuLower::current ()
2357
- ->caMap ()
2358
- ->getConcreteMappedID (consumer_root_id, IdMappingMode::PERMISSIVE)
2359
- ->isBroadcast ()) {
2360
- excluded_ids.insert (consumer_root_id);
2340
+ std::vector<bool > predicate_contiguity (consumer_root_domain.size (), true );
2341
+ std::unordered_set<IterDomain*> final_ids;
2342
+ for (auto root_i : c10::irange (predicate_contiguity.size ())) {
2343
+ auto root_id = consumer_root_domain[root_i];
2344
+ if (root_id->maybePartial ()) {
2345
+ final_ids.insert (root_id);
2361
2346
continue ;
2362
2347
}
2363
2348
// Shifted or gathered axes need to be predicated at the root domain
2364
2349
auto shift_expr = dynamic_cast <ShiftOp*>(consumer_tv->definition ());
2365
2350
auto gather_expr = dynamic_cast <GatherOp*>(consumer_tv->definition ());
2366
- if (shift_expr == nullptr && gather_expr == nullptr ) {
2367
- continue ;
2368
- }
2369
- auto consumer_root_pos = consumer_tv->domain ()->rootPosOf (consumer_root_id);
2370
- if ((shift_expr && shift_expr->offset (consumer_root_pos) != 0 ) ||
2371
- (gather_expr && consumer_root_pos < gather_expr->windowShape ().size () &&
2372
- gather_expr->windowShape ().at (consumer_root_pos) != 1 )) {
2373
- excluded_ids.insert (consumer_root_id);
2351
+ if ((shift_expr && shift_expr->offset (root_i) != 0 ) ||
2352
+ (gather_expr && root_i < gather_expr->windowShape ().size () &&
2353
+ gather_expr->windowShape ().at (root_i) != 1 )) {
2354
+ final_ids.insert (root_id);
2374
2355
}
2375
2356
}
2376
2357
2377
- // Run through iteration domain history
2378
- auto exprs = StmtSort::getExprs (
2379
- consumer_tv->fusion (),
2380
- {consumer_tv->domain ()->domain ().begin (),
2381
- consumer_tv->domain ()->domain ().end ()});
2358
+ ContigIDs contig_finder (
2359
+ consumer_tv->domain ()->domain (),
2360
+ consumer_root_domain,
2361
+ predicate_contiguity,
2362
+ final_ids,
2363
+ concrete_index_map,
2364
+ GpuLower::current ()->divisbleSplitSet (),
2365
+ GpuLower::current ()->caMap (),
2366
+ GpuLower::current ()->haloInfo (),
2367
+ GpuLower::current ()->concretizedBroadcastDomains (),
2368
+ {},
2369
+ false ,
2370
+ true );
2382
2371
2383
- for (auto expr : exprs) {
2384
- // If not a merge, output is not contiguous
2385
- if (expr->isA <Merge>()) {
2386
- auto merge = expr->as <Merge>();
2387
- auto inner_contig_it = std::find (
2388
- contiguous_ids.begin (), contiguous_ids.end (), merge->inner ());
2389
- auto outer_contig_it = std::find (
2390
- contiguous_ids.begin (), contiguous_ids.end (), merge->outer ());
2372
+ std::vector<PredicateDomainInfo> contig_id_infos;
2373
+ std::unordered_set<IterDomain*> covered_roots;
2391
2374
2392
- if (excluded_ids.count (merge->inner ()) > 0 ||
2393
- excluded_ids.count (merge->outer ()) > 0 ) {
2394
- continue ;
2395
- }
2375
+ // Create entries and return them
2376
+ for (auto root_id : consumer_root_domain) {
2377
+ if (covered_roots.count (root_id) > 0 ) {
2378
+ continue ;
2379
+ }
2396
2380
2397
- // Do not try to predicate the merge output domain if the output
2398
- // domain has not a predicate that is mapped from the reference.
2399
- // See FusionContigPredicate_CUDA for a concrete example.
2400
- if (consumer_index_map.find (merge->out ()) == consumer_index_map.end ()) {
2401
- continue ;
2402
- }
2381
+ auto contig_id_it = contig_finder.rootToIndexedID ().find (root_id);
2403
2382
2404
- if (inner_contig_it != contiguous_ids.end () &&
2405
- outer_contig_it != contiguous_ids.end ()) {
2406
- // If inner and outer are contiguous, out must be contiguous. Remove
2407
- // inner and outer, and add out.
2408
- contiguous_ids.erase (outer_contig_it);
2409
- contiguous_ids.erase (std::find (
2410
- contiguous_ids.begin (), contiguous_ids.end (), merge->inner ()));
2411
- contiguous_ids.emplace_back (merge->out ());
2412
- }
2413
- }
2414
- }
2383
+ TORCH_INTERNAL_ASSERT (
2384
+ contig_id_it != contig_finder.rootToIndexedID ().end (),
2385
+ " Error in predicate contiguity analysis, missing index for root " ,
2386
+ root_id->toString ());
2415
2387
2416
- std::vector<PredicateDomainInfo> contig_id_infos ;
2388
+ auto contig_id = contig_id_it-> second ;
2417
2389
2418
- // Create entries and return them
2419
- for (auto contig_id : contiguous_ids) {
2420
2390
// Pick inputs from the starting domains, i.e.,
2421
2391
// reference_predicated_root_domain.
2422
- auto contig_root_vals = IterVisitor::getInputsTo (
2423
- {contig_id},
2424
- {consumer_root_domain.begin (), consumer_root_domain.end ()});
2425
- auto contig_root_ids = ir_utils::filterByType<IterDomain>(contig_root_vals);
2392
+ auto contig_root_ids = contig_finder.indexedRootIDs (contig_id);
2393
+ covered_roots.insert (contig_root_ids.begin (), contig_root_ids.end ());
2426
2394
PredicateDomainInfo contig_id_info;
2427
2395
contig_id_info.id = contig_id;
2428
2396
contig_id_info.covered_ids = std::unordered_set<IterDomain*>(
0 commit comments