@@ -340,6 +340,42 @@ BasicBlock *llvm::splitBBWithSuffix(IRBuilderBase &Builder, bool CreateBranch,
340
340
return splitBB (Builder, CreateBranch, Old->getName () + Suffix);
341
341
}
342
342
343
+ // This function creates a fake integer value and a fake use for the integer
344
+ // value. It returns the fake value created. This is useful in modeling the
345
+ // extra arguments to the outlined functions.
346
+ Value *createFakeIntVal (IRBuilder<> &Builder,
347
+ OpenMPIRBuilder::InsertPointTy OuterAllocaIP,
348
+ std::stack<Instruction *> &ToBeDeleted,
349
+ OpenMPIRBuilder::InsertPointTy InnerAllocaIP,
350
+ const Twine &Name = " " , bool AsPtr = true ) {
351
+ Builder.restoreIP (OuterAllocaIP);
352
+ Instruction *FakeVal;
353
+ AllocaInst *FakeValAddr =
354
+ Builder.CreateAlloca (Builder.getInt32Ty (), nullptr , Name + " .addr" );
355
+ ToBeDeleted.push (FakeValAddr);
356
+
357
+ if (AsPtr) {
358
+ FakeVal = FakeValAddr;
359
+ } else {
360
+ FakeVal =
361
+ Builder.CreateLoad (Builder.getInt32Ty (), FakeValAddr, Name + " .val" );
362
+ ToBeDeleted.push (FakeVal);
363
+ }
364
+
365
+ // Generate a fake use of this value
366
+ Builder.restoreIP (InnerAllocaIP);
367
+ Instruction *UseFakeVal;
368
+ if (AsPtr) {
369
+ UseFakeVal =
370
+ Builder.CreateLoad (Builder.getInt32Ty (), FakeVal, Name + " .use" );
371
+ } else {
372
+ UseFakeVal =
373
+ cast<BinaryOperator>(Builder.CreateAdd (FakeVal, Builder.getInt32 (10 )));
374
+ }
375
+ ToBeDeleted.push (UseFakeVal);
376
+ return FakeVal;
377
+ }
378
+
343
379
// ===----------------------------------------------------------------------===//
344
380
// OpenMPIRBuilderConfig
345
381
// ===----------------------------------------------------------------------===//
@@ -1496,6 +1532,7 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
1496
1532
InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB,
1497
1533
bool Tied, Value *Final, Value *IfCondition,
1498
1534
SmallVector<DependData> Dependencies) {
1535
+
1499
1536
if (!updateToLocation (Loc))
1500
1537
return InsertPointTy ();
1501
1538
@@ -1523,41 +1560,31 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
1523
1560
BasicBlock *TaskAllocaBB =
1524
1561
splitBB (Builder, /* CreateBranch=*/ true , " task.alloca" );
1525
1562
1563
+ InsertPointTy TaskAllocaIP =
1564
+ InsertPointTy (TaskAllocaBB, TaskAllocaBB->begin ());
1565
+ InsertPointTy TaskBodyIP = InsertPointTy (TaskBodyBB, TaskBodyBB->begin ());
1566
+ BodyGenCB (TaskAllocaIP, TaskBodyIP);
1567
+
1526
1568
OutlineInfo OI;
1527
1569
OI.EntryBB = TaskAllocaBB;
1528
1570
OI.OuterAllocaBB = AllocaIP.getBlock ();
1529
1571
OI.ExitBB = TaskExitBB;
1530
- OI.PostOutlineCB = [this , Ident, Tied, Final, IfCondition,
1531
- Dependencies](Function &OutlinedFn) {
1532
- // The input IR here looks like the following-
1533
- // ```
1534
- // func @current_fn() {
1535
- // outlined_fn(%args)
1536
- // }
1537
- // func @outlined_fn(%args) { ... }
1538
- // ```
1539
- //
1540
- // This is changed to the following-
1541
- //
1542
- // ```
1543
- // func @current_fn() {
1544
- // runtime_call(..., wrapper_fn, ...)
1545
- // }
1546
- // func @wrapper_fn(..., %args) {
1547
- // outlined_fn(%args)
1548
- // }
1549
- // func @outlined_fn(%args) { ... }
1550
- // ```
1551
1572
1552
- // The stale call instruction will be replaced with a new call instruction
1553
- // for runtime call with a wrapper function.
1573
+ // Add the thread ID argument.
1574
+ std::stack<Instruction *> ToBeDeleted;
1575
+ OI.ExcludeArgsFromAggregate .push_back (createFakeIntVal (
1576
+ Builder, AllocaIP, ToBeDeleted, TaskAllocaIP, " global.tid" , false ));
1577
+
1578
+ OI.PostOutlineCB = [this , Ident, Tied, Final, IfCondition, Dependencies,
1579
+ TaskAllocaBB, ToBeDeleted](Function &OutlinedFn) mutable {
1580
+ // Replace the Stale CI by appropriate RTL function call.
1554
1581
assert (OutlinedFn.getNumUses () == 1 &&
1555
1582
" there must be a single user for the outlined function" );
1556
1583
CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back ());
1557
1584
1558
1585
// HasShareds is true if any variables are captured in the outlined region,
1559
1586
// false otherwise.
1560
- bool HasShareds = StaleCI->arg_size () > 0 ;
1587
+ bool HasShareds = StaleCI->arg_size () > 1 ;
1561
1588
Builder.SetInsertPoint (StaleCI);
1562
1589
1563
1590
// Gather the arguments for emitting the runtime call for
@@ -1595,7 +1622,7 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
1595
1622
Value *SharedsSize = Builder.getInt64 (0 );
1596
1623
if (HasShareds) {
1597
1624
AllocaInst *ArgStructAlloca =
1598
- dyn_cast<AllocaInst>(StaleCI->getArgOperand (0 ));
1625
+ dyn_cast<AllocaInst>(StaleCI->getArgOperand (1 ));
1599
1626
assert (ArgStructAlloca &&
1600
1627
" Unable to find the alloca instruction corresponding to arguments "
1601
1628
" for extracted function" );
@@ -1606,31 +1633,17 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
1606
1633
SharedsSize =
1607
1634
Builder.getInt64 (M.getDataLayout ().getTypeStoreSize (ArgStructType));
1608
1635
}
1609
-
1610
- // Argument - task_entry (the wrapper function)
1611
- // If the outlined function has some captured variables (i.e. HasShareds is
1612
- // true), then the wrapper function will have an additional argument (the
1613
- // struct containing captured variables). Otherwise, no such argument will
1614
- // be present.
1615
- SmallVector<Type *> WrapperArgTys{Builder.getInt32Ty ()};
1616
- if (HasShareds)
1617
- WrapperArgTys.push_back (OutlinedFn.getArg (0 )->getType ());
1618
- FunctionCallee WrapperFuncVal = M.getOrInsertFunction (
1619
- (Twine (OutlinedFn.getName ()) + " .wrapper" ).str (),
1620
- FunctionType::get (Builder.getInt32Ty (), WrapperArgTys, false ));
1621
- Function *WrapperFunc = dyn_cast<Function>(WrapperFuncVal.getCallee ());
1622
-
1623
1636
// Emit the @__kmpc_omp_task_alloc runtime call
1624
1637
// The runtime call returns a pointer to an area where the task captured
1625
1638
// variables must be copied before the task is run (TaskData)
1626
1639
CallInst *TaskData = Builder.CreateCall (
1627
1640
TaskAllocFn, {/* loc_ref=*/ Ident, /* gtid=*/ ThreadID, /* flags=*/ Flags,
1628
1641
/* sizeof_task=*/ TaskSize, /* sizeof_shared=*/ SharedsSize,
1629
- /* task_func=*/ WrapperFunc });
1642
+ /* task_func=*/ &OutlinedFn });
1630
1643
1631
1644
// Copy the arguments for outlined function
1632
1645
if (HasShareds) {
1633
- Value *Shareds = StaleCI->getArgOperand (0 );
1646
+ Value *Shareds = StaleCI->getArgOperand (1 );
1634
1647
Align Alignment = TaskData->getPointerAlignment (M.getDataLayout ());
1635
1648
Value *TaskShareds = Builder.CreateLoad (VoidPtr, TaskData);
1636
1649
Builder.CreateMemCpy (TaskShareds, Alignment, Shareds, Alignment,
@@ -1689,18 +1702,17 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
1689
1702
// br label %exit
1690
1703
// else:
1691
1704
// call @__kmpc_omp_task_begin_if0(...)
1692
- // call @wrapper_fn (...)
1705
+ // call @outlined_fn (...)
1693
1706
// call @__kmpc_omp_task_complete_if0(...)
1694
1707
// br label %exit
1695
1708
// exit:
1696
1709
// ...
1697
1710
if (IfCondition) {
1698
1711
// `SplitBlockAndInsertIfThenElse` requires the block to have a
1699
1712
// terminator.
1700
- BasicBlock *NewBasicBlock =
1701
- splitBB (Builder, /* CreateBranch=*/ true , " if.end" );
1713
+ splitBB (Builder, /* CreateBranch=*/ true , " if.end" );
1702
1714
Instruction *IfTerminator =
1703
- NewBasicBlock-> getSinglePredecessor ()->getTerminator ();
1715
+ Builder. GetInsertPoint ()-> getParent ()->getTerminator ();
1704
1716
Instruction *ThenTI = IfTerminator, *ElseTI = nullptr ;
1705
1717
Builder.SetInsertPoint (IfTerminator);
1706
1718
SplitBlockAndInsertIfThenElse (IfCondition, IfTerminator, &ThenTI,
@@ -1711,10 +1723,12 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
1711
1723
Function *TaskCompleteFn =
1712
1724
getOrCreateRuntimeFunctionPtr (OMPRTL___kmpc_omp_task_complete_if0);
1713
1725
Builder.CreateCall (TaskBeginFn, {Ident, ThreadID, TaskData});
1726
+ CallInst *CI = nullptr ;
1714
1727
if (HasShareds)
1715
- Builder.CreateCall (WrapperFunc , {ThreadID, TaskData});
1728
+ CI = Builder.CreateCall (&OutlinedFn , {ThreadID, TaskData});
1716
1729
else
1717
- Builder.CreateCall (WrapperFunc, {ThreadID});
1730
+ CI = Builder.CreateCall (&OutlinedFn, {ThreadID});
1731
+ CI->setDebugLoc (StaleCI->getDebugLoc ());
1718
1732
Builder.CreateCall (TaskCompleteFn, {Ident, ThreadID, TaskData});
1719
1733
Builder.SetInsertPoint (ThenTI);
1720
1734
}
@@ -1736,26 +1750,20 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
1736
1750
1737
1751
StaleCI->eraseFromParent ();
1738
1752
1739
- // Emit the body for wrapper function
1740
- BasicBlock *WrapperEntryBB =
1741
- BasicBlock::Create (M.getContext (), " " , WrapperFunc);
1742
- Builder.SetInsertPoint (WrapperEntryBB);
1753
+ Builder.SetInsertPoint (TaskAllocaBB, TaskAllocaBB->begin ());
1743
1754
if (HasShareds) {
1744
- llvm::Value *Shareds =
1745
- Builder.CreateLoad (VoidPtr, WrapperFunc->getArg (1 ));
1746
- Builder.CreateCall (&OutlinedFn, {Shareds});
1747
- } else {
1748
- Builder.CreateCall (&OutlinedFn);
1755
+ LoadInst *Shareds = Builder.CreateLoad (VoidPtr, OutlinedFn.getArg (1 ));
1756
+ OutlinedFn.getArg (1 )->replaceUsesWithIf (
1757
+ Shareds, [Shareds](Use &U) { return U.getUser () != Shareds; });
1758
+ }
1759
+
1760
+ while (!ToBeDeleted.empty ()) {
1761
+ ToBeDeleted.top ()->eraseFromParent ();
1762
+ ToBeDeleted.pop ();
1749
1763
}
1750
- Builder.CreateRet (Builder.getInt32 (0 ));
1751
1764
};
1752
1765
1753
1766
addOutlineInfo (std::move (OI));
1754
-
1755
- InsertPointTy TaskAllocaIP =
1756
- InsertPointTy (TaskAllocaBB, TaskAllocaBB->begin ());
1757
- InsertPointTy TaskBodyIP = InsertPointTy (TaskBodyBB, TaskBodyBB->begin ());
1758
- BodyGenCB (TaskAllocaIP, TaskBodyIP);
1759
1767
Builder.SetInsertPoint (TaskExitBB, TaskExitBB->begin ());
1760
1768
1761
1769
return Builder.saveIP ();
@@ -5763,84 +5771,63 @@ OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
5763
5771
BasicBlock *AllocaBB =
5764
5772
splitBB (Builder, /* CreateBranch=*/ true , " teams.alloca" );
5765
5773
5774
+ // Generate the body of teams.
5775
+ InsertPointTy AllocaIP (AllocaBB, AllocaBB->begin ());
5776
+ InsertPointTy CodeGenIP (BodyBB, BodyBB->begin ());
5777
+ BodyGenCB (AllocaIP, CodeGenIP);
5778
+
5766
5779
OutlineInfo OI;
5767
5780
OI.EntryBB = AllocaBB;
5768
5781
OI.ExitBB = ExitBB;
5769
5782
OI.OuterAllocaBB = &OuterAllocaBB;
5770
- OI.PostOutlineCB = [this , Ident](Function &OutlinedFn) {
5771
- // The input IR here looks like the following-
5772
- // ```
5773
- // func @current_fn() {
5774
- // outlined_fn(%args)
5775
- // }
5776
- // func @outlined_fn(%args) { ... }
5777
- // ```
5778
- //
5779
- // This is changed to the following-
5780
- //
5781
- // ```
5782
- // func @current_fn() {
5783
- // runtime_call(..., wrapper_fn, ...)
5784
- // }
5785
- // func @wrapper_fn(..., %args) {
5786
- // outlined_fn(%args)
5787
- // }
5788
- // func @outlined_fn(%args) { ... }
5789
- // ```
5790
5783
5784
+ // Insert fake values for global tid and bound tid.
5785
+ std::stack<Instruction *> ToBeDeleted;
5786
+ InsertPointTy OuterAllocaIP (&OuterAllocaBB, OuterAllocaBB.begin ());
5787
+ OI.ExcludeArgsFromAggregate .push_back (createFakeIntVal (
5788
+ Builder, OuterAllocaIP, ToBeDeleted, AllocaIP, " gid" , true ));
5789
+ OI.ExcludeArgsFromAggregate .push_back (createFakeIntVal (
5790
+ Builder, OuterAllocaIP, ToBeDeleted, AllocaIP, " tid" , true ));
5791
+
5792
+ OI.PostOutlineCB = [this , Ident, ToBeDeleted](Function &OutlinedFn) mutable {
5791
5793
// The stale call instruction will be replaced with a new call instruction
5792
- // for runtime call with a wrapper function.
5794
+ // for runtime call with the outlined function.
5793
5795
5794
5796
assert (OutlinedFn.getNumUses () == 1 &&
5795
5797
" there must be a single user for the outlined function" );
5796
5798
CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back ());
5799
+ ToBeDeleted.push (StaleCI);
5800
+
5801
+ assert ((OutlinedFn.arg_size () == 2 || OutlinedFn.arg_size () == 3 ) &&
5802
+ " Outlined function must have two or three arguments only" );
5803
+
5804
+ bool HasShared = OutlinedFn.arg_size () == 3 ;
5797
5805
5798
- // Create the wrapper function.
5799
- SmallVector<Type *> WrapperArgTys{Builder.getPtrTy (), Builder.getPtrTy ()};
5800
- for (auto &Arg : OutlinedFn.args ())
5801
- WrapperArgTys.push_back (Arg.getType ());
5802
- FunctionCallee WrapperFuncVal = M.getOrInsertFunction (
5803
- (Twine (OutlinedFn.getName ()) + " .teams" ).str (),
5804
- FunctionType::get (Builder.getVoidTy (), WrapperArgTys, false ));
5805
- Function *WrapperFunc = dyn_cast<Function>(WrapperFuncVal.getCallee ());
5806
- WrapperFunc->getArg (0 )->setName (" global_tid" );
5807
- WrapperFunc->getArg (1 )->setName (" bound_tid" );
5808
- if (WrapperFunc->arg_size () > 2 )
5809
- WrapperFunc->getArg (2 )->setName (" data" );
5810
-
5811
- // Emit the body of the wrapper function - just a call to outlined function
5812
- // and return statement.
5813
- BasicBlock *WrapperEntryBB =
5814
- BasicBlock::Create (M.getContext (), " entrybb" , WrapperFunc);
5815
- Builder.SetInsertPoint (WrapperEntryBB);
5816
- SmallVector<Value *> Args;
5817
- for (size_t ArgIndex = 2 ; ArgIndex < WrapperFunc->arg_size (); ArgIndex++)
5818
- Args.push_back (WrapperFunc->getArg (ArgIndex));
5819
- Builder.CreateCall (&OutlinedFn, Args);
5820
- Builder.CreateRetVoid ();
5821
-
5822
- OutlinedFn.addFnAttr (Attribute::AttrKind::AlwaysInline);
5806
+ OutlinedFn.getArg (0 )->setName (" global.tid.ptr" );
5807
+ OutlinedFn.getArg (1 )->setName (" bound.tid.ptr" );
5808
+ if (HasShared)
5809
+ OutlinedFn.getArg (2 )->setName (" data" );
5823
5810
5824
5811
// Call to the runtime function for teams in the current function.
5825
5812
assert (StaleCI && " Error while outlining - no CallInst user found for the "
5826
5813
" outlined function." );
5827
5814
Builder.SetInsertPoint (StaleCI);
5828
- Args = {Ident, Builder.getInt32 (StaleCI->arg_size ()), WrapperFunc};
5829
- for (Use &Arg : StaleCI->args ())
5830
- Args.push_back (Arg);
5815
+ SmallVector<Value *> Args = {
5816
+ Ident, Builder.getInt32 (StaleCI->arg_size () - 2 ), &OutlinedFn};
5817
+ if (HasShared)
5818
+ Args.push_back (StaleCI->getArgOperand (2 ));
5831
5819
Builder.CreateCall (getOrCreateRuntimeFunctionPtr (
5832
5820
omp::RuntimeFunction::OMPRTL___kmpc_fork_teams),
5833
5821
Args);
5834
- StaleCI->eraseFromParent ();
5822
+
5823
+ while (!ToBeDeleted.empty ()) {
5824
+ ToBeDeleted.top ()->eraseFromParent ();
5825
+ ToBeDeleted.pop ();
5826
+ }
5835
5827
};
5836
5828
5837
5829
addOutlineInfo (std::move (OI));
5838
5830
5839
- // Generate the body of teams.
5840
- InsertPointTy AllocaIP (AllocaBB, AllocaBB->begin ());
5841
- InsertPointTy CodeGenIP (BodyBB, BodyBB->begin ());
5842
- BodyGenCB (AllocaIP, CodeGenIP);
5843
-
5844
5831
Builder.SetInsertPoint (ExitBB, ExitBB->begin ());
5845
5832
5846
5833
return Builder.saveIP ();
0 commit comments