Skip to content

[Libomptarget] Rework Record & Replay to be a plugin member (#88928) #89097

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion offload/plugins-nextgen/common/include/PluginInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ namespace plugin {
struct GenericPluginTy;
struct GenericKernelTy;
struct GenericDeviceTy;
struct RecordReplayTy;

/// Class that wraps the __tgt_async_info to simply its usage. In case the
/// object is constructed without a valid __tgt_async_info, the object will use
Expand Down Expand Up @@ -958,7 +959,8 @@ struct GenericPluginTy {

/// Construct a plugin instance.
GenericPluginTy(Triple::ArchType TA)
: GlobalHandler(nullptr), JIT(TA), RPCServer(nullptr) {}
: GlobalHandler(nullptr), JIT(TA), RPCServer(nullptr),
RecordReplay(nullptr) {}

virtual ~GenericPluginTy() {}

Expand Down Expand Up @@ -1027,6 +1029,12 @@ struct GenericPluginTy {
return *RPCServer;
}

/// Get a reference to the record and replay interface for the plugin.
RecordReplayTy &getRecordReplay() {
assert(RecordReplay && "RR interface not initialized");
return *RecordReplay;
}

/// Initialize a device within the plugin.
Error initDevice(int32_t DeviceId);

Expand Down Expand Up @@ -1204,6 +1212,9 @@ struct GenericPluginTy {

/// The interface between the plugin and the GPU for host services.
RPCServerTy *RPCServer;

/// The interface between the plugin and the GPU for host services.
RecordReplayTy *RecordReplay;
};

namespace Plugin {
Expand Down
34 changes: 23 additions & 11 deletions offload/plugins-nextgen/common/src/PluginInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ using namespace target;
using namespace plugin;

// TODO: Fix any thread safety issues for multi-threaded kernel recording.
namespace llvm::omp::target::plugin {
struct RecordReplayTy {

// Describes the state of the record replay mechanism.
Expand Down Expand Up @@ -358,8 +359,7 @@ struct RecordReplayTy {
}
}
};

static RecordReplayTy RecordReplay;
} // namespace llvm::omp::target::plugin

// Extract the mapping of host function pointers to device function pointers
// from the entry table. Functions marked as 'indirect' in OpenMP will have
Expand Down Expand Up @@ -470,7 +470,7 @@ GenericKernelTy::getKernelLaunchEnvironment(
// Ctor/Dtor have no arguments, replaying uses the original kernel launch
// environment. Older versions of the compiler do not generate a kernel
// launch environment.
if (RecordReplay.isReplaying() ||
if (GenericDevice.Plugin.getRecordReplay().isReplaying() ||
Version < OMP_KERNEL_ARG_MIN_VERSION_WITH_DYN_PTR)
return nullptr;

Expand Down Expand Up @@ -559,6 +559,7 @@ Error GenericKernelTy::launch(GenericDeviceTy &GenericDevice, void **ArgPtrs,

// Record the kernel description after we modified the argument count and num
// blocks/threads.
RecordReplayTy &RecordReplay = GenericDevice.Plugin.getRecordReplay();
if (RecordReplay.isRecording()) {
RecordReplay.saveImage(getName(), getImage());
RecordReplay.saveKernelInput(getName(), getImage());
Expand Down Expand Up @@ -833,6 +834,7 @@ Error GenericDeviceTy::deinit(GenericPluginTy &Plugin) {
delete MemoryManager;
MemoryManager = nullptr;

RecordReplayTy &RecordReplay = Plugin.getRecordReplay();
if (RecordReplay.isRecordingOrReplaying())
RecordReplay.deinit();

Expand Down Expand Up @@ -886,7 +888,8 @@ GenericDeviceTy::loadBinary(GenericPluginTy &Plugin,
return std::move(Err);

// Setup the global device memory pool if needed.
if (!RecordReplay.isReplaying() && shouldSetupDeviceMemoryPool()) {
if (!Plugin.getRecordReplay().isReplaying() &&
shouldSetupDeviceMemoryPool()) {
uint64_t HeapSize;
auto SizeOrErr = getDeviceHeapSize(HeapSize);
if (SizeOrErr) {
Expand Down Expand Up @@ -1301,8 +1304,8 @@ Expected<void *> GenericDeviceTy::dataAlloc(int64_t Size, void *HostPtr,
TargetAllocTy Kind) {
void *Alloc = nullptr;

if (RecordReplay.isRecordingOrReplaying())
return RecordReplay.alloc(Size);
if (Plugin.getRecordReplay().isRecordingOrReplaying())
return Plugin.getRecordReplay().alloc(Size);

switch (Kind) {
case TARGET_ALLOC_DEFAULT:
Expand Down Expand Up @@ -1338,7 +1341,7 @@ Expected<void *> GenericDeviceTy::dataAlloc(int64_t Size, void *HostPtr,

Error GenericDeviceTy::dataDelete(void *TgtPtr, TargetAllocTy Kind) {
// Free is a noop when recording or replaying.
if (RecordReplay.isRecordingOrReplaying())
if (Plugin.getRecordReplay().isRecordingOrReplaying())
return Plugin::success();

int Res;
Expand Down Expand Up @@ -1405,7 +1408,8 @@ Error GenericDeviceTy::launchKernel(void *EntryPtr, void **ArgPtrs,
KernelArgsTy &KernelArgs,
__tgt_async_info *AsyncInfo) {
AsyncInfoWrapperTy AsyncInfoWrapper(
*this, RecordReplay.isRecordingOrReplaying() ? nullptr : AsyncInfo);
*this,
Plugin.getRecordReplay().isRecordingOrReplaying() ? nullptr : AsyncInfo);

GenericKernelTy &GenericKernel =
*reinterpret_cast<GenericKernelTy *>(EntryPtr);
Expand All @@ -1416,6 +1420,7 @@ Error GenericDeviceTy::launchKernel(void *EntryPtr, void **ArgPtrs,
// 'finalize' here to guarantee next record-replay actions are in-sync
AsyncInfoWrapper.finalize(Err);

RecordReplayTy &RecordReplay = Plugin.getRecordReplay();
if (RecordReplay.isRecordingOrReplaying() &&
RecordReplay.isSaveOutputEnabled())
RecordReplay.saveKernelOutputInfo(GenericKernel.getName());
Expand Down Expand Up @@ -1503,6 +1508,9 @@ Error GenericPluginTy::init() {
RPCServer = new RPCServerTy(*this);
assert(RPCServer && "Invalid RPC server");

RecordReplay = new RecordReplayTy();
assert(RecordReplay && "Invalid RR interface");

return Plugin::success();
}

Expand All @@ -1523,6 +1531,9 @@ Error GenericPluginTy::deinit() {
if (RPCServer)
delete RPCServer;

if (RecordReplay)
delete RecordReplay;

// Perform last deinitializations on the plugin.
return deinitImpl();
}
Expand Down Expand Up @@ -1633,12 +1644,12 @@ int32_t GenericPluginTy::initialize_record_replay(int32_t DeviceId,
isRecord ? RecordReplayTy::RRStatusTy::RRRecording
: RecordReplayTy::RRStatusTy::RRReplaying;

if (auto Err = RecordReplay.init(&Device, MemorySize, VAddr, Status,
SaveOutput, ReqPtrArgOffset)) {
if (auto Err = RecordReplay->init(&Device, MemorySize, VAddr, Status,
SaveOutput, ReqPtrArgOffset)) {
REPORT("WARNING RR did not intialize RR-properly with %lu bytes"
"(Error: %s)\n",
MemorySize, toString(std::move(Err)).data());
RecordReplay.setStatus(RecordReplayTy::RRStatusTy::RRDeactivated);
RecordReplay->setStatus(RecordReplayTy::RRStatusTy::RRDeactivated);

if (!isRecord) {
return OFFLOAD_FAIL;
Expand Down Expand Up @@ -1982,6 +1993,7 @@ int32_t GenericPluginTy::get_global(__tgt_device_binary Binary, uint64_t Size,
assert(DevicePtr && "Invalid device global's address");

// Save the loaded globals if we are recording.
RecordReplayTy &RecordReplay = Device.Plugin.getRecordReplay();
if (RecordReplay.isRecording())
RecordReplay.addEntry(Name, Size, *DevicePtr);

Expand Down
Loading