diff --git a/comms/ctran/backends/ib/CtranIb.cc b/comms/ctran/backends/ib/CtranIb.cc index 3c5807a2..6bdd77da 100644 --- a/comms/ctran/backends/ib/CtranIb.cc +++ b/comms/ctran/backends/ib/CtranIb.cc @@ -67,7 +67,7 @@ CtranIbSingleton::CtranIbSingleton() { auto ibvInitResult = ibverbx::ibvInit(); FOLLY_EXPECTED_CHECKTHROW(ibvInitResult); auto maybeDeviceList = ibverbx::IbvDevice::ibvGetDeviceList( - NCCL_IB_HCA, NCCL_IB_HCA_PREFIX, CTRAN_IB_ANY_PORT); + NCCL_IB_HCA, NCCL_IB_HCA_PREFIX, CTRAN_IB_ANY_PORT, NCCL_IB_DATA_DIRECT); FOLLY_EXPECTED_CHECKTHROW(maybeDeviceList); ibvDevices = std::move(*maybeDeviceList); diff --git a/comms/ctran/backends/ib/IbvWrap.cc b/comms/ctran/backends/ib/IbvWrap.cc index 7cbe3d60..02bd9860 100644 --- a/comms/ctran/backends/ib/IbvWrap.cc +++ b/comms/ctran/backends/ib/IbvWrap.cc @@ -3,7 +3,7 @@ #include #include "comms/ctran/backends/ib/IbvWrap.h" -#include "comms/ctran/ibverbx/Ibverbx.h" +#include "comms/ctran/ibverbx/IbverbxSymbols.h" #include "comms/utils/logger/LogUtils.h" #include "comms/ctran/utils/Checks.h" @@ -128,8 +128,9 @@ commResult_t wrap_ibv_get_device_list( struct ibv_device*** ret, int* num_devices) { *ret = ibvSymbols.ibv_internal_get_device_list(num_devices); - if (*ret == nullptr) + if (*ret == nullptr) { *num_devices = 0; + } return commSuccess; } @@ -480,8 +481,9 @@ static void ibvModifyQpLog( remoteGidRes = ibvGetGidStr(remoteGid, remoteGidName, sizeof(remoteGidName)); // we need pd->context to retrieve local GID, skip if not there - if (!qp->pd->context) + if (!qp->pd->context) { goto print; + } gidIndex = avAttr->ah_attr.grh.sgid_index; union ibv_gid localGid; FB_COMMCHECKGOTO( diff --git a/comms/ctran/ibverbx/Coordinator.cc b/comms/ctran/ibverbx/Coordinator.cc new file mode 100644 index 00000000..271ff8f5 --- /dev/null +++ b/comms/ctran/ibverbx/Coordinator.cc @@ -0,0 +1,179 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. + +#include "comms/ctran/ibverbx/Coordinator.h" + +#include +#include "comms/ctran/ibverbx/IbvVirtualQp.h" + +namespace ibverbx { + +namespace { +folly::Singleton coordinatorSingleton{}; +} + +/*** Coordinator ***/ + +std::shared_ptr Coordinator::getCoordinator() { + return coordinatorSingleton.try_get(); +} + +// Register APIs for mapping management +void Coordinator::registerVirtualQp( + uint32_t virtualQpNum, + IbvVirtualQp* virtualQp) { + virtualQpNumToVirtualQp_[virtualQpNum] = virtualQp; +} + +void Coordinator::registerVirtualCq( + uint32_t virtualCqNum, + IbvVirtualCq* virtualCq) { + virtualCqNumToVirtualCq_[virtualCqNum] = virtualCq; +} + +void Coordinator::registerPhysicalQpToVirtualQp( + int physicalQpNum, + uint32_t virtualQpNum) { + physicalQpNumToVirtualQpNum_[physicalQpNum] = virtualQpNum; +} + +void Coordinator::registerVirtualQpToVirtualSendCq( + uint32_t virtualQpNum, + uint32_t virtualSendCqNum) { + virtualQpNumToVirtualSendCqNum_[virtualQpNum] = virtualSendCqNum; +} + +void Coordinator::registerVirtualQpToVirtualRecvCq( + uint32_t virtualQpNum, + uint32_t virtualRecvCqNum) { + virtualQpNumToVirtualRecvCqNum_[virtualQpNum] = virtualRecvCqNum; +} + +void Coordinator::registerVirtualQpWithVirtualCqMappings( + IbvVirtualQp* virtualQp, + uint32_t virtualSendCqNum, + uint32_t virtualRecvCqNum) { + // Extract virtual QP number from the virtual QP object + uint32_t virtualQpNum = virtualQp->getVirtualQpNum(); + + // Register the virtual QP + registerVirtualQp(virtualQpNum, virtualQp); + + // Register all physical QP to virtual QP mappings + for (const auto& qp : virtualQp->getQpsRef()) { + registerPhysicalQpToVirtualQp(qp.qp()->qp_num, virtualQpNum); + } + // Register notify QP + registerPhysicalQpToVirtualQp( + virtualQp->getNotifyQpRef().qp()->qp_num, virtualQpNum); + + // Register virtual QP to virtual CQ relationships + registerVirtualQpToVirtualSendCq(virtualQpNum, virtualSendCqNum); + registerVirtualQpToVirtualRecvCq(virtualQpNum, virtualRecvCqNum); +} + +// Access APIs for testing and internal use +const std::unordered_map& +Coordinator::getVirtualQpMap() const { + return virtualQpNumToVirtualQp_; +} + +const std::unordered_map& +Coordinator::getVirtualCqMap() const { + return virtualCqNumToVirtualCq_; +} + +const std::unordered_map& +Coordinator::getPhysicalQpToVirtualQpMap() const { + return physicalQpNumToVirtualQpNum_; +} + +const std::unordered_map& +Coordinator::getVirtualQpToVirtualSendCqMap() const { + return virtualQpNumToVirtualSendCqNum_; +} + +const std::unordered_map& +Coordinator::getVirtualQpToVirtualRecvCqMap() const { + return virtualQpNumToVirtualRecvCqNum_; +} + +// Update API for move operations - only need to update pointer maps +void Coordinator::updateVirtualQpPointer( + uint32_t virtualQpNum, + IbvVirtualQp* newPtr) { + virtualQpNumToVirtualQp_[virtualQpNum] = newPtr; +} + +void Coordinator::updateVirtualCqPointer( + uint32_t virtualCqNum, + IbvVirtualCq* newPtr) { + virtualCqNumToVirtualCq_[virtualCqNum] = newPtr; +} + +void Coordinator::unregisterVirtualQp( + uint32_t virtualQpNum, + IbvVirtualQp* ptr) { + // Only unregister if the pointer in the map matches the object being + // destroyed. This handles the case where the object was moved and the map was + // already updated with the new pointer. + auto it = virtualQpNumToVirtualQp_.find(virtualQpNum); + if (it == virtualQpNumToVirtualQp_.end() || it->second != ptr) { + // Object was moved, map already updated, nothing to do + return; + } + + // Remove entries from all maps related to this virtual QP + virtualQpNumToVirtualQp_.erase(virtualQpNum); + virtualQpNumToVirtualSendCqNum_.erase(virtualQpNum); + virtualQpNumToVirtualRecvCqNum_.erase(virtualQpNum); + + // Remove all physical QP to virtual QP mappings that point to this virtual QP + for (auto it = physicalQpNumToVirtualQpNum_.begin(); + it != physicalQpNumToVirtualQpNum_.end();) { + if (it->second == virtualQpNum) { + it = physicalQpNumToVirtualQpNum_.erase(it); + } else { + ++it; + } + } +} + +void Coordinator::unregisterVirtualCq( + uint32_t virtualCqNum, + IbvVirtualCq* ptr) { + // Only unregister if the pointer in the map matches the object being + // destroyed. This handles the case where the object was moved and the map was + // already updated with the new pointer. + auto it = virtualCqNumToVirtualCq_.find(virtualCqNum); + if (it == virtualCqNumToVirtualCq_.end() || it->second != ptr) { + // Object was moved, map already updated, nothing to do + return; + } + + // Remove the virtual CQ from the pointer map + virtualCqNumToVirtualCq_.erase(virtualCqNum); + + // Remove all virtual QP to virtual send CQ mappings that point to this + // virtual CQ + for (auto it = virtualQpNumToVirtualSendCqNum_.begin(); + it != virtualQpNumToVirtualSendCqNum_.end();) { + if (it->second == virtualCqNum) { + it = virtualQpNumToVirtualSendCqNum_.erase(it); + } else { + ++it; + } + } + + // Remove all virtual QP to virtual recv CQ mappings that point to this + // virtual CQ + for (auto it = virtualQpNumToVirtualRecvCqNum_.begin(); + it != virtualQpNumToVirtualRecvCqNum_.end();) { + if (it->second == virtualCqNum) { + it = virtualQpNumToVirtualRecvCqNum_.erase(it); + } else { + ++it; + } + } +} + +} // namespace ibverbx diff --git a/comms/ctran/ibverbx/Coordinator.h b/comms/ctran/ibverbx/Coordinator.h new file mode 100644 index 00000000..06dd3166 --- /dev/null +++ b/comms/ctran/ibverbx/Coordinator.h @@ -0,0 +1,148 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include +#include + +#include +#include "comms/ctran/ibverbx/IbvCommon.h" + +namespace ibverbx { + +class IbvVirtualQp; +class IbvVirtualCq; + +// Coordinator class responsible for routing commands and responses between +// IbvVirtualQp and IbvVirtualCq. Maintains mappings from physical QP numbers to +// IbvVirtualQp pointers, and from virtual CQ numbers to IbvVirtualCq pointers. +// Acts as a router to forward requests between these two classes. +// +// NOTE: The Coordinator APIs are NOT thread-safe. Users must ensure proper +// synchronization when accessing Coordinator methods from multiple threads. +// Thread-safe support can be added in the future if needed. +class Coordinator { + public: + Coordinator() = default; + ~Coordinator() = default; + + // Disable copy constructor and assignment operator + Coordinator(const Coordinator&) = delete; + Coordinator& operator=(const Coordinator&) = delete; + + // Allow default move constructor and assignment operator + Coordinator(Coordinator&&) = default; + Coordinator& operator=(Coordinator&&) = default; + + inline void submitRequestToVirtualCq(VirtualCqRequest&& request); + inline folly::Expected submitRequestToVirtualQp( + VirtualQpRequest&& request); + + // Register APIs for mapping management + void registerVirtualQp(uint32_t virtualQpNum, IbvVirtualQp* virtualQp); + void registerVirtualCq(uint32_t virtualCqNum, IbvVirtualCq* virtualCq); + void registerPhysicalQpToVirtualQp(int physicalQpNum, uint32_t virtualQpNum); + void registerVirtualQpToVirtualSendCq( + uint32_t virtualQpNum, + uint32_t virtualSendCqNum); + void registerVirtualQpToVirtualRecvCq( + uint32_t virtualQpNum, + uint32_t virtualRecvCqNum); + + // Consolidated registration API for IbvVirtualQp - registers the virtual QP + // along with all its physical QPs and CQ relationships in one call + void registerVirtualQpWithVirtualCqMappings( + IbvVirtualQp* virtualQp, + uint32_t virtualSendCqNum, + uint32_t virtualRecvCqNum); + + // Getter APIs for accessing mappings + inline IbvVirtualCq* getVirtualSendCq(uint32_t virtualQpNum) const; + inline IbvVirtualCq* getVirtualRecvCq(uint32_t virtualQpNum) const; + inline IbvVirtualQp* getVirtualQpByPhysicalQpNum(int physicalQpNum) const; + inline IbvVirtualQp* getVirtualQpById(uint32_t virtualQpNum) const; + inline IbvVirtualCq* getVirtualCqById(uint32_t virtualCqNum) const; + + // Access APIs for testing and internal use + const std::unordered_map& getVirtualQpMap() const; + const std::unordered_map& getVirtualCqMap() const; + const std::unordered_map& getPhysicalQpToVirtualQpMap() const; + const std::unordered_map& getVirtualQpToVirtualSendCqMap() + const; + const std::unordered_map& getVirtualQpToVirtualRecvCqMap() + const; + + // Update API for move operations - only need to update pointer maps + void updateVirtualQpPointer(uint32_t virtualQpNum, IbvVirtualQp* newPtr); + void updateVirtualCqPointer(uint32_t virtualCqNum, IbvVirtualCq* newPtr); + + // Unregister API for cleanup during destruction + void unregisterVirtualQp(uint32_t virtualQpNum, IbvVirtualQp* ptr); + void unregisterVirtualCq(uint32_t virtualCqNum, IbvVirtualCq* ptr); + + static std::shared_ptr getCoordinator(); + + private: + // Map 1: Virtual QP Num -> Virtual QP pointer + std::unordered_map virtualQpNumToVirtualQp_; + + // Map 2: Virtual CQ Num -> Virtual CQ pointer + std::unordered_map virtualCqNumToVirtualCq_; + + // Map 3: Virtual QP Num -> Virtual Send CQ Num (relationship) + std::unordered_map virtualQpNumToVirtualSendCqNum_; + + // Map 4: Virtual QP Num -> Virtual Recv CQ Num (relationship) + std::unordered_map virtualQpNumToVirtualRecvCqNum_; + + // Map 5: Physical QP number -> Virtual QP Num (for routing) + std::unordered_map physicalQpNumToVirtualQpNum_; +}; + +// Coordinator inline functions +inline IbvVirtualCq* Coordinator::getVirtualSendCq( + uint32_t virtualQpNum) const { + auto it = virtualQpNumToVirtualSendCqNum_.find(virtualQpNum); + if (it == virtualQpNumToVirtualSendCqNum_.end()) { + return nullptr; + } + return getVirtualCqById(it->second); +} + +inline IbvVirtualCq* Coordinator::getVirtualRecvCq( + uint32_t virtualQpNum) const { + auto it = virtualQpNumToVirtualRecvCqNum_.find(virtualQpNum); + if (it == virtualQpNumToVirtualRecvCqNum_.end()) { + return nullptr; + } + return getVirtualCqById(it->second); +} + +inline IbvVirtualQp* Coordinator::getVirtualQpByPhysicalQpNum( + int physicalQpNum) const { + auto it = physicalQpNumToVirtualQpNum_.find(physicalQpNum); + if (it == physicalQpNumToVirtualQpNum_.end()) { + return nullptr; + } + return getVirtualQpById(it->second); +} + +inline IbvVirtualQp* Coordinator::getVirtualQpById( + uint32_t virtualQpNum) const { + auto it = virtualQpNumToVirtualQp_.find(virtualQpNum); + if (it == virtualQpNumToVirtualQp_.end()) { + return nullptr; + } + return it->second; +} + +inline IbvVirtualCq* Coordinator::getVirtualCqById( + uint32_t virtualCqNum) const { + auto it = virtualCqNumToVirtualCq_.find(virtualCqNum); + if (it == virtualCqNumToVirtualCq_.end()) { + return nullptr; + } + return it->second; +} + +} // namespace ibverbx diff --git a/comms/ctran/ibverbx/DqplbSeqTracker.h b/comms/ctran/ibverbx/DqplbSeqTracker.h new file mode 100644 index 00000000..cdd712d1 --- /dev/null +++ b/comms/ctran/ibverbx/DqplbSeqTracker.h @@ -0,0 +1,57 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include +#include "comms/ctran/ibverbx/IbvCommon.h" + +namespace ibverbx { + +class DqplbSeqTracker { + public: + DqplbSeqTracker() = default; + ~DqplbSeqTracker() = default; + + // Explicitly default move constructor and move assignment operator + DqplbSeqTracker(DqplbSeqTracker&&) = default; + DqplbSeqTracker& operator=(DqplbSeqTracker&&) = default; + + // This helper function calculates sender IMM message in DQPLB mode. + inline uint32_t getSendImm(int remainingMsgCnt); + // This helper function processes received IMM message and update + // receivedSeqNums_ map and receiveNext_ field. + inline int processReceivedImm(uint32_t receivedImm); + + private: + int sendNext_{0}; + int receiveNext_{0}; + std::unordered_map receivedSeqNums_; +}; + +// DqplbSeqTracker inline functions +inline uint32_t DqplbSeqTracker::getSendImm(int remainingMsgCnt) { + uint32_t immData = sendNext_; + sendNext_ = (sendNext_ + 1) % kSeqNumMask; + if (remainingMsgCnt == 1) { + immData |= (1 << kNotifyBit); + } + return immData; +} + +inline int DqplbSeqTracker::processReceivedImm(uint32_t immData) { + int notifyCount = 0; + receivedSeqNums_[immData & kSeqNumMask] = immData & (1U << kNotifyBit); + auto it = receivedSeqNums_.find(receiveNext_); + + while (it != receivedSeqNums_.end()) { + if (it->second) { + notifyCount++; + } + receivedSeqNums_.erase(it); + receiveNext_ = (receiveNext_ + 1) % kSeqNumMask; + it = receivedSeqNums_.find(receiveNext_); + } + return notifyCount; +} + +} // namespace ibverbx diff --git a/comms/ctran/ibverbx/IbvCommon.cc b/comms/ctran/ibverbx/IbvCommon.cc new file mode 100644 index 00000000..30540294 --- /dev/null +++ b/comms/ctran/ibverbx/IbvCommon.cc @@ -0,0 +1,19 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. + +#include "comms/ctran/ibverbx/IbvCommon.h" +#include +#include + +namespace ibverbx { + +Error::Error() : errNum(errno), errStr(folly::errnoStr(errno)) {} +Error::Error(int errNum) : errNum(errNum), errStr(folly::errnoStr(errNum)) {} +Error::Error(int errNum, std::string errStr) + : errNum(errNum), errStr(std::move(errStr)) {} + +std::ostream& operator<<(std::ostream& out, Error const& err) { + out << err.errStr << " (errno=" << err.errNum << ")"; + return out; +} + +} // namespace ibverbx diff --git a/comms/ctran/ibverbx/IbvCommon.h b/comms/ctran/ibverbx/IbvCommon.h new file mode 100644 index 00000000..c735a87b --- /dev/null +++ b/comms/ctran/ibverbx/IbvCommon.h @@ -0,0 +1,62 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include +#include + +#include "comms/ctran/ibverbx/Ibvcore.h" + +namespace ibverbx { + +// Default HCA prefix +constexpr std::string_view kDefaultHcaPrefix = ""; +// Default HCA list +const std::vector kDefaultHcaList{}; +// Default port +constexpr int kIbAnyPort = -1; +constexpr int kDefaultIbDataDirect = 1; +constexpr int kIbMaxMsgCntPerQp = 100; +constexpr int kIbMaxMsgSizeByte = 100; +constexpr int kIbMaxCqe_ = 100; +constexpr int kNotifyBit = 31; +constexpr uint32_t kSeqNumMask = 0xFFFFFF; // 24 bits + +// Command types for coordinator routing and operations +enum class RequestType { SEND = 0, RECV = 1, SEND_NOTIFY = 2 }; +enum class LoadBalancingScheme { SPRAY = 0, DQPLB = 1 }; + +struct Error { + Error(); + explicit Error(int errNum); + Error(int errNum, std::string errStr); + + const int errNum{0}; + const std::string errStr; +}; + +std::ostream& operator<<(std::ostream&, Error const&); + +struct VirtualQpRequest { + RequestType type{RequestType::SEND}; + uint64_t wrId{0}; + uint32_t physicalQpNum{0}; + uint32_t immData{0}; +}; + +struct VirtualQpResponse { + uint64_t virtualWrId{0}; + bool useDqplb{false}; + int notifyCount{0}; +}; + +struct VirtualCqRequest { + RequestType type{RequestType::SEND}; + int virtualQpNum{-1}; + int expectedMsgCnt{-1}; + ibv_send_wr* sendWr{nullptr}; + ibv_recv_wr* recvWr{nullptr}; + bool sendExtraNotifyImm{false}; +}; + +} // namespace ibverbx diff --git a/comms/ctran/ibverbx/IbvCq.cc b/comms/ctran/ibverbx/IbvCq.cc new file mode 100644 index 00000000..5b3ad551 --- /dev/null +++ b/comms/ctran/ibverbx/IbvCq.cc @@ -0,0 +1,48 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. + +#include "comms/ctran/ibverbx/IbvCq.h" +#include +#include "comms/ctran/ibverbx/IbverbxSymbols.h" + +namespace ibverbx { + +extern IbvSymbols ibvSymbols; + +/*** IbvCq ***/ + +IbvCq::IbvCq(ibv_cq* cq) : cq_(cq) {} + +IbvCq::~IbvCq() { + if (cq_) { + int rc = ibvSymbols.ibv_internal_destroy_cq(cq_); + if (rc != 0) { + XLOGF(ERR, "Failed to destroy cq rc: {}, {}", rc, strerror(errno)); + } + } +} + +IbvCq::IbvCq(IbvCq&& other) noexcept { + cq_ = other.cq_; + other.cq_ = nullptr; +} + +IbvCq& IbvCq::operator=(IbvCq&& other) noexcept { + cq_ = other.cq_; + other.cq_ = nullptr; + return *this; +} + +ibv_cq* IbvCq::cq() const { + return cq_; +} + +folly::Expected IbvCq::reqNotifyCq( + int solicited_only) const { + int rc = cq_->context->ops.req_notify_cq(cq_, solicited_only); + if (rc != 0) { + return folly::makeUnexpected(Error(rc)); + } + return folly::unit; +} + +} // namespace ibverbx diff --git a/comms/ctran/ibverbx/IbvCq.h b/comms/ctran/ibverbx/IbvCq.h new file mode 100644 index 00000000..6d260f1a --- /dev/null +++ b/comms/ctran/ibverbx/IbvCq.h @@ -0,0 +1,54 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include +#include "comms/ctran/ibverbx/IbvCommon.h" +#include "comms/ctran/ibverbx/Ibvcore.h" + +namespace ibverbx { + +// Ibv CompletionQueue(CQ) +class IbvCq { + public: + IbvCq() = default; + ~IbvCq(); + + // disable copy constructor + IbvCq(const IbvCq&) = delete; + IbvCq& operator=(const IbvCq&) = delete; + + // move constructor + IbvCq(IbvCq&& other) noexcept; + IbvCq& operator=(IbvCq&& other) noexcept; + + ibv_cq* cq() const; + inline folly::Expected, Error> pollCq(int numEntries); + + // Request notification when the next completion is added to this CQ + folly::Expected reqNotifyCq(int solicited_only) const; + + private: + friend class IbvDevice; + + explicit IbvCq(ibv_cq* cq); + + ibv_cq* cq_{nullptr}; +}; + +// IbvCq inline functions +inline folly::Expected, Error> IbvCq::pollCq( + int numEntries) { + std::vector wcs(numEntries); + int numPolled = cq_->context->ops.poll_cq(cq_, numEntries, wcs.data()); + if (numPolled < 0) { + wcs.clear(); + return folly::makeUnexpected( + Error(EINVAL, fmt::format("Call to pollCq() returned {}", numPolled))); + } else { + wcs.resize(numPolled); + } + return wcs; +} + +} // namespace ibverbx diff --git a/comms/ctran/ibverbx/IbvDevice.cc b/comms/ctran/ibverbx/IbvDevice.cc new file mode 100644 index 00000000..4acc5df3 --- /dev/null +++ b/comms/ctran/ibverbx/IbvDevice.cc @@ -0,0 +1,417 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. + +#include "comms/ctran/ibverbx/IbvDevice.h" +#include "comms/ctran/ibverbx/IbverbxSymbols.h" + +namespace ibverbx { + +extern IbvSymbols ibvSymbols; + +namespace { + +class RoceHca { + public: + RoceHca(std::string hcaStr, int defaultPort) { + std::string s = std::move(hcaStr); + std::string delim = ":"; + + std::vector hcaStrPair; + folly::split(':', s, hcaStrPair); + if (hcaStrPair.size() == 1) { + this->name = s; + this->port = defaultPort; + } else if (hcaStrPair.size() == 2) { + this->name = hcaStrPair.at(0); + this->port = std::stoi(hcaStrPair.at(1)); + } + } + std::string name; + int port{-1}; +}; + +bool mlx5dvDmaBufDataDirectLinkCapable( + ibv_device* device, + ibv_context* context) { + if (ibvSymbols.mlx5dv_internal_is_supported == nullptr || + ibvSymbols.mlx5dv_internal_reg_dmabuf_mr == nullptr || + ibvSymbols.mlx5dv_internal_get_data_direct_sysfs_path == nullptr) { + return false; + } + + if (!ibvSymbols.mlx5dv_internal_is_supported(device)) { + return false; + } + int dev_fail = 0; + ibv_pd* pd = nullptr; + pd = ibvSymbols.ibv_internal_alloc_pd(context); + if (!pd) { + XLOG(ERR) << "ibv_alloc_pd failed: " << folly::errnoStr(errno); + return false; + } + + // Test kernel DMA-BUF support with a dummy call (fd=-1) + (void)ibvSymbols.ibv_internal_reg_dmabuf_mr( + pd, 0ULL /*offset*/, 0ULL /*len*/, 0ULL /*iova*/, -1 /*fd*/, 0 /*flags*/); + // ibv_reg_dmabuf_mr() will fail with EOPNOTSUPP/EPROTONOSUPPORT if not + // supported (EBADF otherwise) + (void)ibvSymbols.mlx5dv_internal_reg_dmabuf_mr( + pd, + 0ULL /*offset*/, + 0ULL /*len*/, + 0ULL /*iova*/, + -1 /*fd*/, + 0 /*flags*/, + 0 /* mlx5 flags*/); + // mlx5dv_reg_dmabuf_mr() will fail with EOPNOTSUPP/EPROTONOSUPPORT if not + // supported (EBADF otherwise) + dev_fail |= (errno == EOPNOTSUPP) || (errno == EPROTONOSUPPORT); + if (ibvSymbols.ibv_internal_dealloc_pd(pd) != 0) { + XLOGF( + WARN, + "ibv_dealloc_pd failed: {} DMA-BUF support status: {}", + folly::errnoStr(errno), + dev_fail); + return false; + } + if (dev_fail) { + XLOGF( + INFO, + "MLX5DV Kernel DMA-BUF is not supported on device {}", + device->name); + return false; + } + + char dataDirectDevicePath[PATH_MAX]; + snprintf(dataDirectDevicePath, PATH_MAX, "/sys"); + return ibvSymbols.mlx5dv_internal_get_data_direct_sysfs_path( + context, dataDirectDevicePath + 4, PATH_MAX - 4) == 0; +} + +} // namespace + +/*** IbvDevice ***/ + +// hcaList format examples: +// - Without port: "mlx5_0,mlx5_1,mlx5_2" +// - With port: "mlx5_0:1,mlx5_1:0,mlx5_2:1" +// - Prefix match: "mlx5" +// hcaPrefix: use "=" for exact match, "^" for exclude match, "" for prefix +// match. See guidelines: +// https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#nccl-ib-hca +folly::Expected, Error> IbvDevice::ibvGetDeviceList( + const std::vector& hcaList, + const std::string& hcaPrefix, + int defaultPort, + int ibDataDirect) { + // Get device list + ibv_device** devs{nullptr}; + int numDevs; + devs = ibvSymbols.ibv_internal_get_device_list(&numDevs); + if (!devs) { + return folly::makeUnexpected(Error(errno)); + } + auto devices = ibvFilterDeviceList( + numDevs, devs, hcaList, hcaPrefix, defaultPort, ibDataDirect); + // Free device list + ibvSymbols.ibv_internal_free_device_list(devs); + return devices; +} + +std::vector IbvDevice::ibvFilterDeviceList( + int numDevs, + ibv_device** devs, + const std::vector& hcaList, + const std::string& hcaPrefix, + int defaultPort, + int ibDataDirect) { + std::vector devices; + bool dataDirect = ibDataDirect == 1; + + if (hcaList.empty()) { + devices.reserve(numDevs); + for (int i = 0; i < numDevs; i++) { + devices.emplace_back(devs[i], defaultPort, dataDirect); + } + return devices; + } + + // Convert the provided list of HCA strings into a vector of RoceHca + // objects, which enables efficient device filter operation + std::vector hcas; + // Avoid copy triggered by resize + hcas.reserve(hcaList.size()); + for (const auto& hca : hcaList) { + // Copy value to each vector element so it can be freed automatically + hcas.emplace_back(hca, defaultPort); + } + + // Filter devices + if (hcaPrefix == "=") { + for (const auto& hca : hcas) { + for (int i = 0; i < numDevs; i++) { + if (hca.name == devs[i]->name) { + devices.emplace_back(devs[i], hca.port, dataDirect); + break; + } + } + } + return devices; + } else if (hcaPrefix == "^") { + for (const auto& hca : hcas) { + for (int i = 0; i < numDevs; i++) { + if (hca.name != devs[i]->name) { + devices.emplace_back(devs[i], defaultPort, dataDirect); + break; + } + } + } + return devices; + } else { + // Prefix match + for (const auto& hca : hcas) { + for (int i = 0; i < numDevs; i++) { + if (strncmp(devs[i]->name, hca.name.c_str(), hca.name.length()) == 0) { + devices.emplace_back(devs[i], hca.port, dataDirect); + break; + } + } + } + return devices; + } +} + +IbvDevice::IbvDevice(ibv_device* ibvDevice, int port, bool dataDirect) + : device_(ibvDevice) { + port_ = port; + context_ = ibvSymbols.ibv_internal_open_device(device_); + if (!context_) { + XLOGF(ERR, "Failed to open device {}", device_->name); + throw std::runtime_error( + fmt::format("Failed to open device {}", device_->name)); + } + if (dataDirect && (mlx5dvDmaBufDataDirectLinkCapable(device_, context_))) { + dataDirect_ = true; + XLOGF( + INFO, + "NET/IB: Data Direct DMA Interface is detected for device: {} dataDirect: {}", + device_->name, + dataDirect_); + } +} + +IbvDevice::~IbvDevice() { + if (context_) { + int rc = ibvSymbols.ibv_internal_close_device(context_); + if (rc != 0) { + XLOGF(ERR, "Failed to close device rc: {}, {}", rc, strerror(errno)); + } + } +} + +IbvDevice::IbvDevice(IbvDevice&& other) noexcept { + device_ = other.device_; + context_ = other.context_; + port_ = other.port_; + dataDirect_ = other.dataDirect_; + + other.device_ = nullptr; + other.context_ = nullptr; +} + +IbvDevice& IbvDevice::operator=(IbvDevice&& other) noexcept { + device_ = other.device_; + context_ = other.context_; + port_ = other.port_; + dataDirect_ = other.dataDirect_; + + other.device_ = nullptr; + other.context_ = nullptr; + return *this; +} + +ibv_device* IbvDevice::device() const { + return device_; +} + +ibv_context* IbvDevice::context() const { + return context_; +} + +int IbvDevice::port() const { + return port_; +} + +folly::Expected IbvDevice::allocPd() { + ibv_pd* pd; + pd = ibvSymbols.ibv_internal_alloc_pd(context_); + if (!pd) { + return folly::makeUnexpected(Error(errno)); + } + return IbvPd(pd, dataDirect_); +} + +folly::Expected IbvDevice::allocParentDomain( + ibv_parent_domain_init_attr* attr) { + ibv_pd* pd; + + if (ibvSymbols.ibv_internal_alloc_parent_domain == nullptr) { + return folly::makeUnexpected(Error(ENOSYS)); + } + + pd = ibvSymbols.ibv_internal_alloc_parent_domain(context_, attr); + + if (!pd) { + return folly::makeUnexpected(Error(errno)); + } + return IbvPd(pd, dataDirect_); +} + +folly::Expected IbvDevice::queryDevice() const { + ibv_device_attr deviceAttr{}; + int rc = ibvSymbols.ibv_internal_query_device(context_, &deviceAttr); + if (rc != 0) { + return folly::makeUnexpected(Error(rc)); + } + return deviceAttr; +} + +folly::Expected IbvDevice::queryPort( + uint8_t portNum) const { + ibv_port_attr portAttr{}; + int rc = ibvSymbols.ibv_internal_query_port(context_, portNum, &portAttr); + if (rc != 0) { + return folly::makeUnexpected(Error(rc)); + } + return portAttr; +} + +folly::Expected IbvDevice::queryGid( + uint8_t portNum, + int gidIndex) const { + ibv_gid gid{}; + int rc = ibvSymbols.ibv_internal_query_gid(context_, portNum, gidIndex, &gid); + if (rc != 0) { + return folly::makeUnexpected(Error(rc)); + } + return gid; +} + +folly::Expected IbvDevice::createCq( + int cqe, + void* cq_context, + ibv_comp_channel* channel, + int comp_vector) const { + ibv_cq* cq; + cq = ibvSymbols.ibv_internal_create_cq( + context_, cqe, cq_context, channel, comp_vector); + if (!cq) { + return folly::makeUnexpected(Error(errno)); + } + return IbvCq(cq); +} + +folly::Expected IbvDevice::createVirtualCq( + int cqe, + void* cq_context, + ibv_comp_channel* channel, + int comp_vector) { + auto maybeCq = createCq(cqe, cq_context, channel, comp_vector); + if (maybeCq.hasError()) { + return folly::makeUnexpected(maybeCq.error()); + } + return IbvVirtualCq(std::move(*maybeCq), cqe); +} + +folly::Expected IbvDevice::createCq( + ibv_cq_init_attr_ex* attr) const { + ibv_cq_ex* cqEx; + cqEx = ibvSymbols.ibv_internal_create_cq_ex(context_, attr); + if (!cqEx) { + return folly::makeUnexpected(Error(errno)); + } + ibv_cq* cq = ibv_cq_ex_to_cq(cqEx); + return IbvCq(cq); +} + +folly::Expected IbvDevice::createCompChannel() const { + ibv_comp_channel* channel; + channel = ibvSymbols.ibv_internal_create_comp_channel(context_); + if (!channel) { + return folly::makeUnexpected(Error(errno)); + } + return channel; +} + +folly::Expected IbvDevice::destroyCompChannel( + ibv_comp_channel* channel) const { + int rc = ibvSymbols.ibv_internal_destroy_comp_channel(channel); + if (rc != 0) { + return folly::makeUnexpected(Error(rc)); + } + return folly::unit; +} + +folly::Expected IbvDevice::isPortActive( + uint8_t portNum, + std::unordered_set linkLayers) const { + auto maybePortAttr = queryPort(portNum); + if (maybePortAttr.hasError()) { + return folly::makeUnexpected(maybePortAttr.error()); + } + + auto portAttr = maybePortAttr.value(); + + // Check if port is active + if (portAttr.state != IBV_PORT_ACTIVE) { + return false; + } + + // Check if link layer matches (if specified) + if (!linkLayers.empty() && + linkLayers.find(portAttr.link_layer) == linkLayers.end()) { + return false; + } + + return true; +} + +folly::Expected IbvDevice::findActivePort( + std::unordered_set const& linkLayers) const { + // If specific port requested, check if it is active + if (port_ != kIbAnyPort) { + auto maybeActive = isPortActive(port_, linkLayers); + if (maybeActive.hasError()) { + return folly::makeUnexpected(maybeActive.error()); + } + + if (!maybeActive.value()) { + return folly::makeUnexpected(Error( + EINVAL, + fmt::format( + "Port {} is not active on device {}", port_, device_->name))); + } + return port_; + } + + // No specific port requested, find any active port + auto maybeDeviceAttr = queryDevice(); + if (maybeDeviceAttr.hasError()) { + return folly::makeUnexpected(maybeDeviceAttr.error()); + } + + for (uint8_t port = 1; port <= maybeDeviceAttr->phys_port_cnt; port++) { + auto maybeActive = isPortActive(port, linkLayers); + if (maybeActive.hasError()) { + continue; // Skip ports we can't query + } + + if (maybeActive.value()) { + return port; + } + } + + return folly::makeUnexpected(Error( + ENODEV, fmt::format("No active port found on device {}", device_->name))); +} + +} // namespace ibverbx diff --git a/comms/ctran/ibverbx/IbvDevice.h b/comms/ctran/ibverbx/IbvDevice.h new file mode 100644 index 00000000..c53fa64c --- /dev/null +++ b/comms/ctran/ibverbx/IbvDevice.h @@ -0,0 +1,94 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include +#include "comms/ctran/ibverbx/IbvCommon.h" +#include "comms/ctran/ibverbx/IbvCq.h" +#include "comms/ctran/ibverbx/IbvPd.h" +#include "comms/ctran/ibverbx/IbvVirtualCq.h" +#include "comms/ctran/ibverbx/Ibvcore.h" + +namespace ibverbx { + +// IbvDevice +class IbvDevice { + public: + static folly::Expected, Error> ibvGetDeviceList( + const std::vector& hcaList = kDefaultHcaList, + const std::string& hcaPrefix = std::string(kDefaultHcaPrefix), + int defaultPort = kIbAnyPort, + int ibDataDirect = kDefaultIbDataDirect); + IbvDevice(ibv_device* ibvDevice, int port, bool dataDirect = false); + ~IbvDevice(); + + // disable copy constructor + IbvDevice(const IbvDevice&) = delete; + IbvDevice& operator=(const IbvDevice&) = delete; + + // move constructor + IbvDevice(IbvDevice&& other) noexcept; + IbvDevice& operator=(IbvDevice&& other) noexcept; + + ibv_device* device() const; + ibv_context* context() const; + int port() const; + + folly::Expected allocPd(); + folly::Expected allocParentDomain( + ibv_parent_domain_init_attr* attr); + folly::Expected queryDevice() const; + folly::Expected queryPort(uint8_t portNum) const; + folly::Expected queryGid(uint8_t portNum, int gidIndex) const; + + folly::Expected createCq( + int cqe, + void* cq_context, + ibv_comp_channel* channel, + int comp_vector) const; + + // create Cq with attributes + folly::Expected createCq(ibv_cq_init_attr_ex* attr) const; + + // Create a completion channel for event-driven completion handling + folly::Expected createCompChannel() const; + + // Destroy a completion channel + folly::Expected destroyCompChannel( + ibv_comp_channel* channel) const; + + // When creating an IbvVirtualCq for an IbvVirtualQp, ensure that cqe >= + // (number of QPs * capacity per QP). If send queue and recv queue intend to + // share the same cqe, then ensure cqe >= (2 * number of QPs * capacity per + // QP). Failing to meet this condition may result in lost CQEs. TODO: Enforce + // this requirement in the low-level API. If a higher-level API is introduced + // in the future, ensure this guarantee is handled within Ibverbx when + // creating a IbvVirtualCq for the user. + folly::Expected createVirtualCq( + int cqe, + void* cq_context, + ibv_comp_channel* channel, + int comp_vector); + + folly::Expected isPortActive( + uint8_t portNum, + std::unordered_set linkLayers) const; + folly::Expected findActivePort( + std::unordered_set const& linkLayers) const; + + private: + ibv_device* device_{nullptr}; + ibv_context* context_{nullptr}; + int port_{-1}; + bool dataDirect_{false}; // Relevant only to mlx5 + + static std::vector ibvFilterDeviceList( + int numDevs, + ibv_device** devs, + const std::vector& hcaList = kDefaultHcaList, + const std::string& hcaPrefix = std::string(kDefaultHcaPrefix), + int defaultPort = kIbAnyPort, + int ibDataDirect = kDefaultIbDataDirect); +}; + +} // namespace ibverbx diff --git a/comms/ctran/ibverbx/IbvMr.cc b/comms/ctran/ibverbx/IbvMr.cc new file mode 100644 index 00000000..f7a8baae --- /dev/null +++ b/comms/ctran/ibverbx/IbvMr.cc @@ -0,0 +1,39 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +#include "comms/ctran/ibverbx/IbvMr.h" + +#include +#include "comms/ctran/ibverbx/IbverbxSymbols.h" + +namespace ibverbx { + +extern IbvSymbols ibvSymbols; + +/*** IbvMr ***/ + +IbvMr::IbvMr(ibv_mr* mr) : mr_(mr) {} + +IbvMr::IbvMr(IbvMr&& other) noexcept { + mr_ = other.mr_; + other.mr_ = nullptr; +} + +IbvMr& IbvMr::operator=(IbvMr&& other) noexcept { + mr_ = other.mr_; + other.mr_ = nullptr; + return *this; +} + +IbvMr::~IbvMr() { + if (mr_) { + int rc = ibvSymbols.ibv_internal_dereg_mr(mr_); + if (rc != 0) { + XLOGF(ERR, "Failed to deregister mr rc: {}, {}", rc, strerror(errno)); + } + } +} + +ibv_mr* IbvMr::mr() const { + return mr_; +} + +} // namespace ibverbx diff --git a/comms/ctran/ibverbx/IbvMr.h b/comms/ctran/ibverbx/IbvMr.h new file mode 100644 index 00000000..d0e8f9d0 --- /dev/null +++ b/comms/ctran/ibverbx/IbvMr.h @@ -0,0 +1,32 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include "comms/ctran/ibverbx/Ibvcore.h" + +namespace ibverbx { + +// IbvMr: Memory Region +class IbvMr { + public: + ~IbvMr(); + + // disable copy constructor + IbvMr(const IbvMr&) = delete; + IbvMr& operator=(const IbvMr&) = delete; + + // move constructor + IbvMr(IbvMr&& other) noexcept; + IbvMr& operator=(IbvMr&& other) noexcept; + + ibv_mr* mr() const; + + private: + friend class IbvPd; + + explicit IbvMr(ibv_mr* mr); + + ibv_mr* mr_{nullptr}; +}; + +} // namespace ibverbx diff --git a/comms/ctran/ibverbx/IbvPd.cc b/comms/ctran/ibverbx/IbvPd.cc new file mode 100644 index 00000000..6ec5978f --- /dev/null +++ b/comms/ctran/ibverbx/IbvPd.cc @@ -0,0 +1,142 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. + +#include "comms/ctran/ibverbx/IbvPd.h" +#include "comms/ctran/ibverbx/IbverbxSymbols.h" + +namespace ibverbx { + +extern IbvSymbols ibvSymbols; + +/*** IbvPd ***/ + +IbvPd::IbvPd(ibv_pd* pd, bool dataDirect) : pd_(pd), dataDirect_(dataDirect) {} + +IbvPd::IbvPd(IbvPd&& other) noexcept { + pd_ = other.pd_; + dataDirect_ = other.dataDirect_; + other.pd_ = nullptr; +} + +IbvPd& IbvPd::operator=(IbvPd&& other) noexcept { + pd_ = other.pd_; + dataDirect_ = other.dataDirect_; + other.pd_ = nullptr; + return *this; +} + +IbvPd::~IbvPd() { + if (pd_) { + int rc = ibvSymbols.ibv_internal_dealloc_pd(pd_); + if (rc != 0) { + XLOGF(ERR, "Failed to deallocate pd rc: {}, {}", rc, strerror(errno)); + } + } +} + +ibv_pd* IbvPd::pd() const { + return pd_; +} + +bool IbvPd::useDataDirect() const { + return dataDirect_; +} + +folly::Expected +IbvPd::regMr(void* addr, size_t length, ibv_access_flags access) const { + ibv_mr* mr; + mr = ibvSymbols.ibv_internal_reg_mr(pd_, addr, length, access); + if (!mr) { + return folly::makeUnexpected(Error(errno)); + } + return IbvMr(mr); +} + +folly::Expected IbvPd::regDmabufMr( + uint64_t offset, + size_t length, + uint64_t iova, + int fd, + ibv_access_flags access) const { + ibv_mr* mr; + if (dataDirect_) { + mr = ibvSymbols.mlx5dv_internal_reg_dmabuf_mr( + pd_, + offset, + length, + iova, + fd, + access, + MLX5DV_REG_DMABUF_ACCESS_DATA_DIRECT); + } else { + mr = ibvSymbols.ibv_internal_reg_dmabuf_mr( + pd_, offset, length, iova, fd, access); + } + if (!mr) { + return folly::makeUnexpected(Error(errno)); + } + return IbvMr(mr); +} + +folly::Expected IbvPd::createQp( + ibv_qp_init_attr* initAttr) const { + ibv_qp* qp; + qp = ibvSymbols.ibv_internal_create_qp(pd_, initAttr); + if (!qp) { + return folly::makeUnexpected(Error(errno)); + } + return IbvQp(qp); +} + +folly::Expected IbvPd::createVirtualQp( + int totalQps, + ibv_qp_init_attr* initAttr, + IbvVirtualCq* sendCq, + IbvVirtualCq* recvCq, + int maxMsgCntPerQp, + int maxMsgSize, + LoadBalancingScheme loadBalancingScheme) const { + std::vector qps; + qps.reserve(totalQps); + + if (sendCq == nullptr) { + return folly::makeUnexpected( + Error(EINVAL, "Empty sendCq being provided to createVirtualQp")); + } + + if (recvCq == nullptr) { + return folly::makeUnexpected( + Error(EINVAL, "Empty recvCq being provided to createVirtualQp")); + } + + // Overwrite the CQs in the initAttr to point to the virtual CQ + initAttr->send_cq = sendCq->getPhysicalCqRef().cq(); + initAttr->recv_cq = recvCq->getPhysicalCqRef().cq(); + + // First create all the data QPs + for (int i = 0; i < totalQps; i++) { + auto maybeQp = createQp(initAttr); + if (maybeQp.hasError()) { + return folly::makeUnexpected(maybeQp.error()); + } + qps.emplace_back(std::move(*maybeQp)); + } + + // Create notify QP + auto maybeNotifyQp = createQp(initAttr); + if (maybeNotifyQp.hasError()) { + return folly::makeUnexpected(maybeNotifyQp.error()); + } + + // Create the IbvVirtualQp instance, with coordinator registartion happens + // within IbvVirtualQp constructor + return IbvVirtualQp( + std::move(qps), + std::move(*maybeNotifyQp), + sendCq, + recvCq, + maxMsgCntPerQp, + maxMsgSize, + loadBalancingScheme); +} + +} // namespace ibverbx diff --git a/comms/ctran/ibverbx/IbvPd.h b/comms/ctran/ibverbx/IbvPd.h new file mode 100644 index 00000000..10d664fe --- /dev/null +++ b/comms/ctran/ibverbx/IbvPd.h @@ -0,0 +1,66 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include +#include "comms/ctran/ibverbx/IbvCommon.h" +#include "comms/ctran/ibverbx/IbvMr.h" +#include "comms/ctran/ibverbx/IbvQp.h" +#include "comms/ctran/ibverbx/IbvVirtualQp.h" +#include "comms/ctran/ibverbx/Ibvcore.h" + +namespace ibverbx { + +class IbvVirtualCq; + +// IbvPd: Protection Domain +class IbvPd { + public: + ~IbvPd(); + + // disable copy constructor + IbvPd(const IbvPd&) = delete; + IbvPd& operator=(const IbvPd&) = delete; + + // move constructor + IbvPd(IbvPd&& other) noexcept; + IbvPd& operator=(IbvPd&& other) noexcept; + + ibv_pd* pd() const; + bool useDataDirect() const; + + folly::Expected + regMr(void* addr, size_t length, ibv_access_flags access) const; + + folly::Expected regDmabufMr( + uint64_t offset, + size_t length, + uint64_t iova, + int fd, + ibv_access_flags access) const; + + folly::Expected createQp(ibv_qp_init_attr* initAttr) const; + + // The send_cq and recv_cq fields in initAttr are ignored. + // Instead, initAttr.send_cq and initAttr.recv_cq will be set to the physical + // CQs contained within sendCq and recvCq, respectively. + folly::Expected createVirtualQp( + int totalQps, + ibv_qp_init_attr* initAttr, + IbvVirtualCq* sendCq, + IbvVirtualCq* recvCq, + int maxMsgCntPerQp = kIbMaxMsgCntPerQp, + int maxMsgSize = kIbMaxMsgSizeByte, + LoadBalancingScheme loadBalancingScheme = + LoadBalancingScheme::SPRAY) const; + + private: + friend class IbvDevice; + + IbvPd(ibv_pd* pd, bool dataDirect = false); + + ibv_pd* pd_{nullptr}; + bool dataDirect_{false}; // Relevant only to mlx5 +}; + +} // namespace ibverbx diff --git a/comms/ctran/ibverbx/IbvQp.cc b/comms/ctran/ibverbx/IbvQp.cc new file mode 100644 index 00000000..48cba1ca --- /dev/null +++ b/comms/ctran/ibverbx/IbvQp.cc @@ -0,0 +1,95 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. + +#include "comms/ctran/ibverbx/IbvQp.h" + +#include +#include "comms/ctran/ibverbx/Ibvcore.h" +#include "comms/ctran/ibverbx/IbverbxSymbols.h" + +namespace ibverbx { + +extern IbvSymbols ibvSymbols; + +/*** IbvQp ***/ +IbvQp::IbvQp(ibv_qp* qp) : qp_(qp) {} + +IbvQp::~IbvQp() { + if (qp_) { + int rc = ibvSymbols.ibv_internal_destroy_qp(qp_); + if (rc != 0) { + XLOGF(ERR, "Failed to destroy qp rc: {}, {}", rc, strerror(errno)); + } + } +} + +IbvQp::IbvQp(IbvQp&& other) noexcept { + qp_ = other.qp_; + physicalSendWrStatus_ = std::move(other.physicalSendWrStatus_); + physicalRecvWrStatus_ = std::move(other.physicalRecvWrStatus_); + other.qp_ = nullptr; +} + +IbvQp& IbvQp::operator=(IbvQp&& other) noexcept { + qp_ = other.qp_; + physicalSendWrStatus_ = std::move(other.physicalSendWrStatus_); + physicalRecvWrStatus_ = std::move(other.physicalRecvWrStatus_); + other.qp_ = nullptr; + return *this; +} + +ibv_qp* IbvQp::qp() const { + return qp_; +} + +folly::Expected IbvQp::modifyQp( + ibv_qp_attr* attr, + int attrMask) { + int rc = ibvSymbols.ibv_internal_modify_qp(qp_, attr, attrMask); + if (rc != 0) { + return folly::makeUnexpected(Error(rc)); + } + return folly::unit; +} + +folly::Expected, Error> IbvQp::queryQp( + int attrMask) const { + ibv_qp_attr qpAttr{}; + ibv_qp_init_attr initAttr{}; + int rc = ibvSymbols.ibv_internal_query_qp(qp_, &qpAttr, attrMask, &initAttr); + if (rc != 0) { + return folly::makeUnexpected(Error(rc)); + } + return std::make_pair(qpAttr, initAttr); +} + +void IbvQp::enquePhysicalSendWrStatus(int physicalWrId, int virtualWrId) { + physicalSendWrStatus_.emplace_back(physicalWrId, virtualWrId); +} + +void IbvQp::dequePhysicalSendWrStatus() { + physicalSendWrStatus_.pop_front(); +} + +void IbvQp::dequePhysicalRecvWrStatus() { + physicalRecvWrStatus_.pop_front(); +} + +void IbvQp::enquePhysicalRecvWrStatus(int physicalWrId, int virtualWrId) { + physicalRecvWrStatus_.emplace_back(physicalWrId, virtualWrId); +} + +bool IbvQp::isSendQueueAvailable(int maxMsgCntPerQp) const { + if (maxMsgCntPerQp < 0) { + return true; + } + return physicalSendWrStatus_.size() < maxMsgCntPerQp; +} + +bool IbvQp::isRecvQueueAvailable(int maxMsgCntPerQp) const { + if (maxMsgCntPerQp < 0) { + return true; + } + return physicalRecvWrStatus_.size() < maxMsgCntPerQp; +} + +} // namespace ibverbx diff --git a/comms/ctran/ibverbx/IbvQp.h b/comms/ctran/ibverbx/IbvQp.h new file mode 100644 index 00000000..6c3f7b9e --- /dev/null +++ b/comms/ctran/ibverbx/IbvQp.h @@ -0,0 +1,97 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include +#include +#include +#include "comms/ctran/ibverbx/IbvCommon.h" +#include "comms/ctran/ibverbx/Ibvcore.h" + +namespace ibverbx { + +// Ibv Queue Pair +class IbvQp { + public: + ~IbvQp(); + + // disable copy constructor + IbvQp(const IbvQp&) = delete; + IbvQp& operator=(const IbvQp&) = delete; + + // move constructor + IbvQp(IbvQp&& other) noexcept; + IbvQp& operator=(IbvQp&& other) noexcept; + + ibv_qp* qp() const; + + folly::Expected modifyQp(ibv_qp_attr* attr, int attrMask); + folly::Expected, Error> queryQp( + int attrMask) const; + + inline uint32_t getQpNum() const; + inline folly::Expected postRecv( + ibv_recv_wr* recvWr, + ibv_recv_wr* recvWrBad); + inline folly::Expected postSend( + ibv_send_wr* sendWr, + ibv_send_wr* sendWrBad); + + void enquePhysicalSendWrStatus(int physicalWrId, int virtualWrId); + void enquePhysicalRecvWrStatus(int physicalWrId, int virtualWrId); + void dequePhysicalSendWrStatus(); + void dequePhysicalRecvWrStatus(); + bool isSendQueueAvailable(int maxMsgCntPerQp) const; + bool isRecvQueueAvailable(int maxMsgCntPerQp) const; + + private: + friend class IbvPd; + friend class IbvVirtualQp; + friend class IbvVirtualCq; + + struct PhysicalSendWrStatus { + PhysicalSendWrStatus(uint64_t physicalWrId, uint64_t virtualWrId) + : physicalWrId(physicalWrId), virtualWrId(virtualWrId) {} + uint64_t physicalWrId{0}; + uint64_t virtualWrId{0}; + }; + struct PhysicalRecvWrStatus { + PhysicalRecvWrStatus(uint64_t physicalWrId, uint64_t virtualWrId) + : physicalWrId(physicalWrId), virtualWrId(virtualWrId) {} + uint64_t physicalWrId{0}; + uint64_t virtualWrId{0}; + }; + explicit IbvQp(ibv_qp* qp); + + ibv_qp* qp_{nullptr}; + std::deque physicalSendWrStatus_; + std::deque physicalRecvWrStatus_; +}; + +// IbvQp inline functions +inline uint32_t IbvQp::getQpNum() const { + XCHECK_NE(qp_, nullptr); + return qp_->qp_num; +} + +inline folly::Expected IbvQp::postRecv( + ibv_recv_wr* recvWr, + ibv_recv_wr* recvWrBad) { + int rc = qp_->context->ops.post_recv(qp_, recvWr, &recvWrBad); + if (rc != 0) { + return folly::makeUnexpected(Error(rc)); + } + return folly::unit; +} + +inline folly::Expected IbvQp::postSend( + ibv_send_wr* sendWr, + ibv_send_wr* sendWrBad) { + int rc = qp_->context->ops.post_send(qp_, sendWr, &sendWrBad); + if (rc != 0) { + return folly::makeUnexpected(Error(rc)); + } + return folly::unit; +} + +} // namespace ibverbx diff --git a/comms/ctran/ibverbx/IbvVirtualCq.cc b/comms/ctran/ibverbx/IbvVirtualCq.cc new file mode 100644 index 00000000..1eddec9e --- /dev/null +++ b/comms/ctran/ibverbx/IbvVirtualCq.cc @@ -0,0 +1,79 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. + +#include "comms/ctran/ibverbx/IbvVirtualCq.h" + +namespace ibverbx { + +/*** IbvVirtualCq ***/ + +IbvVirtualCq::IbvVirtualCq(IbvCq&& physicalCq, int maxCqe) + : physicalCq_(std::move(physicalCq)), maxCqe_(maxCqe) { + virtualCqNum_ = + nextVirtualCqNum_.fetch_add(1); // Assign unique virtual CQ number + + // Register the virtual CQ with Coordinator + auto coordinator = Coordinator::getCoordinator(); + CHECK(coordinator) + << "Coordinator should not be nullptr during IbvVirtualCq construction!"; + coordinator->registerVirtualCq(virtualCqNum_, this); +} + +IbvVirtualCq::IbvVirtualCq(IbvVirtualCq&& other) noexcept { + physicalCq_ = std::move(other.physicalCq_); + pendingSendVirtualWcQue_ = std::move(other.pendingSendVirtualWcQue_); + pendingRecvVirtualWcQue_ = std::move(other.pendingRecvVirtualWcQue_); + maxCqe_ = other.maxCqe_; + virtualWrIdToVirtualWc_ = std::move(other.virtualWrIdToVirtualWc_); + virtualCqNum_ = other.virtualCqNum_; + + // Update coordinator pointer mapping for this virtual CQ after move + auto coordinator = Coordinator::getCoordinator(); + CHECK(coordinator) + << "Coordinator should not be nullptr during IbvVirtualCq move construction!"; + coordinator->updateVirtualCqPointer(virtualCqNum_, this); +} + +IbvVirtualCq& IbvVirtualCq::operator=(IbvVirtualCq&& other) noexcept { + if (this != &other) { + physicalCq_ = std::move(other.physicalCq_); + pendingSendVirtualWcQue_ = std::move(other.pendingSendVirtualWcQue_); + pendingRecvVirtualWcQue_ = std::move(other.pendingRecvVirtualWcQue_); + maxCqe_ = other.maxCqe_; + virtualWrIdToVirtualWc_ = std::move(other.virtualWrIdToVirtualWc_); + virtualCqNum_ = other.virtualCqNum_; + + // Update coordinator pointer mapping for this virtual CQ after move + auto coordinator = Coordinator::getCoordinator(); + CHECK(coordinator) + << "Coordinator should not be nullptr during IbvVirtualCq move construction!"; + coordinator->updateVirtualCqPointer(virtualCqNum_, this); + } + return *this; +} + +IbvCq& IbvVirtualCq::getPhysicalCqRef() { + return physicalCq_; +} + +uint32_t IbvVirtualCq::getVirtualCqNum() const { + return virtualCqNum_; +} + +void IbvVirtualCq::enqueSendCq(VirtualWc virtualWc) { + pendingSendVirtualWcQue_.push_back(std::move(virtualWc)); +} + +void IbvVirtualCq::enqueRecvCq(VirtualWc virtualWc) { + pendingRecvVirtualWcQue_.push_back(std::move(virtualWc)); +} + +IbvVirtualCq::~IbvVirtualCq() { + // Always call unregister - the coordinator will check if the pointer matches + // and do nothing if the object was moved + auto coordinator = Coordinator::getCoordinator(); + CHECK(coordinator) + << "Coordinator should not be nullptr during IbvVirtualCq destruction!"; + coordinator->unregisterVirtualCq(virtualCqNum_, this); +} + +} // namespace ibverbx diff --git a/comms/ctran/ibverbx/IbvVirtualCq.h b/comms/ctran/ibverbx/IbvVirtualCq.h new file mode 100644 index 00000000..bbbf3058 --- /dev/null +++ b/comms/ctran/ibverbx/IbvVirtualCq.h @@ -0,0 +1,312 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include +#include +#include +#include +#include + +#include "comms/ctran/ibverbx/Coordinator.h" +#include "comms/ctran/ibverbx/IbvCommon.h" +#include "comms/ctran/ibverbx/IbvCq.h" +#include "comms/ctran/ibverbx/IbvVirtualQp.h" +#include "comms/ctran/ibverbx/Ibvcore.h" + +namespace ibverbx { + +struct VirtualWc { + VirtualWc() = default; + ~VirtualWc() = default; + + struct ibv_wc wc{}; + int expectedMsgCnt{0}; + int remainingMsgCnt{0}; + bool sendExtraNotifyImm{ + false}; // Whether to expect an extra notify IMM + // message to be sent for the current virtualWc +}; + +// Ibv Virtual Completion Queue (CQ): Provides a virtual CQ abstraction for the +// user. When the user calls IbvVirtualQp::postSend() or +// IbvVirtualQp::postRecv(), they can track the completion of messages posted on +// the Virtual QP through this virtual CQ. +class IbvVirtualCq { + public: + IbvVirtualCq(IbvCq&& cq, int maxCqe); + ~IbvVirtualCq(); + + // disable copy constructor + IbvVirtualCq(const IbvVirtualCq&) = delete; + IbvVirtualCq& operator=(const IbvVirtualCq&) = delete; + + // move constructor + IbvVirtualCq(IbvVirtualCq&& other) noexcept; + IbvVirtualCq& operator=(IbvVirtualCq&& other) noexcept; + + inline folly::Expected, Error> pollCq(int numEntries); + + IbvCq& getPhysicalCqRef(); + uint32_t getVirtualCqNum() const; + + void enqueSendCq(VirtualWc virtualWc); + void enqueRecvCq(VirtualWc virtualWc); + + inline void processRequest(VirtualCqRequest&& request); + + private: + friend class IbvPd; + friend class IbvVirtualQp; + + inline static std::atomic nextVirtualCqNum_{ + 0}; // Static counter for assigning unique virtual CQ numbers + uint32_t virtualCqNum_{ + 0}; // The unique virtual CQ number assigned to instance of IbvVirtualCq + + IbvCq physicalCq_; + int maxCqe_{0}; + std::deque pendingSendVirtualWcQue_; + std::deque pendingRecvVirtualWcQue_; + inline void updateVirtualWcFromPhysicalWc( + const ibv_wc& physicalWc, + VirtualWc* virtualWc); + std::unordered_map virtualWrIdToVirtualWc_; + + // Helper function for IbvVirtualCq::pollCq. + // Continuously polls the underlying physical Completion Queue (CQ) in a loop, + // retrieving all available Completion Queue Entries (CQEs) until none remain. + // For each physical CQE polled, the corresponding virtual CQE entries in the + // virtual CQ are also updated. This function ensures that all ready physical + // CQEs are polled, processed, and reflected in the virtual CQ state. + inline folly::Expected loopPollPhysicalCqUntilEmpty(); + + // Helper function for IbvVirtualCq::pollCq. + // Continuously polls the underlying virtual Completion Queues (CQs) in a + // loop. The function collects up to numEntries virtual Completion Queue + // Entries (CQEs), or stops early if there are no more virtual CQEs available + // to poll. Returns a vector containing the polled virtual CQEs. + inline std::vector loopPollVirtualCqUntil(int numEntries); +}; + +inline void Coordinator::submitRequestToVirtualCq(VirtualCqRequest&& request) { + if (request.type == RequestType::SEND) { + auto virtualCq = getVirtualSendCq(request.virtualQpNum); + virtualCq->processRequest(std::move(request)); + } else { + auto virtualCq = getVirtualRecvCq(request.virtualQpNum); + virtualCq->processRequest(std::move(request)); + } +} + +// IbvVirtualCq inline functions +inline folly::Expected, Error> IbvVirtualCq::pollCq( + int numEntries) { + auto maybeLoopPollPhysicalCq = loopPollPhysicalCqUntilEmpty(); + if (maybeLoopPollPhysicalCq.hasError()) { + return folly::makeUnexpected(maybeLoopPollPhysicalCq.error()); + } + + return loopPollVirtualCqUntil(numEntries); +} + +inline folly::Expected +IbvVirtualCq::loopPollPhysicalCqUntilEmpty() { + // Poll from physical CQ one by one and process immediately + while (true) { + // Poll one completion at a time + auto maybePhysicalWcsVector = physicalCq_.pollCq(1); + if (maybePhysicalWcsVector.hasError()) { + return folly::makeUnexpected(maybePhysicalWcsVector.error()); + } + + // If no completions available, break the loop + if (maybePhysicalWcsVector->empty()) { + break; + } + + // Process the single completion immediately + const auto& physicalWc = maybePhysicalWcsVector->front(); + + if (physicalWc.opcode == IBV_WC_RECV || + physicalWc.opcode == IBV_WC_RECV_RDMA_WITH_IMM) { + VirtualQpRequest request = { + .type = RequestType::RECV, + .wrId = physicalWc.wr_id, + .physicalQpNum = physicalWc.qp_num}; + if (physicalWc.opcode == IBV_WC_RECV_RDMA_WITH_IMM) { + request.immData = physicalWc.imm_data; + } + auto coordinator = Coordinator::getCoordinator(); + CHECK(coordinator) << "Coordinator should not be nullptr during pollCq!"; + auto response = coordinator->submitRequestToVirtualQp(std::move(request)); + if (response.hasError()) { + return folly::makeUnexpected(response.error()); + } + + if (response->useDqplb) { + int processedCount = 0; + for (int i = 0; i < pendingRecvVirtualWcQue_.size() && + processedCount < response->notifyCount; + i++) { + if (pendingRecvVirtualWcQue_.at(i).remainingMsgCnt != 0) { + pendingRecvVirtualWcQue_.at(i).remainingMsgCnt = 0; + processedCount++; + } + } + } else { + auto virtualWc = virtualWrIdToVirtualWc_.at(response->virtualWrId); + virtualWc->remainingMsgCnt--; + updateVirtualWcFromPhysicalWc(physicalWc, virtualWc); + } + } else { + // Except for the above two conditions, all other conditions indicate a + // send message, and we should poll from send queue + VirtualQpRequest request = { + .type = RequestType::SEND, + .wrId = physicalWc.wr_id, + .physicalQpNum = physicalWc.qp_num}; + auto coordinator = Coordinator::getCoordinator(); + CHECK(coordinator) << "Coordinator should not be nullptr during pollCq!"; + auto response = coordinator->submitRequestToVirtualQp(std::move(request)); + if (response.hasError()) { + return folly::makeUnexpected(response.error()); + } + + auto virtualWc = virtualWrIdToVirtualWc_.at(response->virtualWrId); + virtualWc->remainingMsgCnt--; + updateVirtualWcFromPhysicalWc(physicalWc, virtualWc); + if (virtualWc->remainingMsgCnt == 1 && virtualWc->sendExtraNotifyImm) { + VirtualQpRequest request = { + .type = RequestType::SEND_NOTIFY, + .wrId = response->virtualWrId, + .physicalQpNum = physicalWc.qp_num}; + + auto coordinator = Coordinator::getCoordinator(); + CHECK(coordinator) + << "Coordinator should not be nullptr during pollCq!"; + auto response = + coordinator->submitRequestToVirtualQp(std::move(request)); + if (response.hasError()) { + return folly::makeUnexpected(response.error()); + } + } + } + } + + return folly::unit; +} + +inline std::vector IbvVirtualCq::loopPollVirtualCqUntil( + int numEntries) { + std::vector wcs; + wcs.reserve(numEntries); + bool virtualSendCqPollComplete = false; + bool virtualRecvCqPollComplete = false; + while (wcs.size() < static_cast(numEntries) && + (!virtualSendCqPollComplete || !virtualRecvCqPollComplete)) { + if (!virtualSendCqPollComplete) { + if (pendingSendVirtualWcQue_.empty() || + pendingSendVirtualWcQue_.front().remainingMsgCnt > 0) { + virtualSendCqPollComplete = true; + } else { + auto vSendCqHead = pendingSendVirtualWcQue_.front(); + virtualWrIdToVirtualWc_.erase(vSendCqHead.wc.wr_id); + wcs.push_back(std::move(vSendCqHead.wc)); + pendingSendVirtualWcQue_.pop_front(); + } + } + + if (!virtualRecvCqPollComplete) { + if (pendingRecvVirtualWcQue_.empty() || + pendingRecvVirtualWcQue_.front().remainingMsgCnt > 0) { + virtualRecvCqPollComplete = true; + } else { + auto vRecvCqHead = pendingRecvVirtualWcQue_.front(); + virtualWrIdToVirtualWc_.erase(vRecvCqHead.wc.wr_id); + wcs.push_back(std::move(vRecvCqHead.wc)); + pendingRecvVirtualWcQue_.pop_front(); + } + } + } + + return wcs; +} + +inline void IbvVirtualCq::updateVirtualWcFromPhysicalWc( + const ibv_wc& physicalWc, + VirtualWc* virtualWc) { + // Updates the vWc status field based on the statuses of all pWc instances. + // If all physicalWc statuses indicate success, returns success. + // If any of the physicalWc statuses indicate an error, return the first + // encountered error code. + // Additionally, log all error statuses for debug purposes. + if (physicalWc.status != IBV_WC_SUCCESS) { + if (virtualWc->wc.status == IBV_WC_SUCCESS) { + virtualWc->wc.status = physicalWc.status; + } + + // Log the error + XLOGF( + ERR, + "Physical WC error: status={}, vendor_err={}, qp_num={}, wr_id={}", + physicalWc.status, + physicalWc.vendor_err, + physicalWc.qp_num, + physicalWc.wr_id); + } + + // Update the OP code in virtualWc. Note that for the same user message, the + // opcode must remain consistent, because all sub-messages within that user + // message will be postSend using the same opcode. + virtualWc->wc.opcode = physicalWc.opcode; + + // Update the vendor error in virtualWc. For now, assume that all pWc + // instances will report the same vendor_error across all sub-messages + // within a single user message. + virtualWc->wc.vendor_err = physicalWc.vendor_err; + + virtualWc->wc.src_qp = physicalWc.src_qp; + virtualWc->wc.byte_len += physicalWc.byte_len; + virtualWc->wc.imm_data = physicalWc.imm_data; + virtualWc->wc.wc_flags = physicalWc.wc_flags; + virtualWc->wc.pkey_index = physicalWc.pkey_index; + virtualWc->wc.slid = physicalWc.slid; + virtualWc->wc.sl = physicalWc.sl; + virtualWc->wc.dlid_path_bits = physicalWc.dlid_path_bits; +} + +inline void IbvVirtualCq::processRequest(VirtualCqRequest&& request) { + VirtualWc* virtualWcPtr = nullptr; + uint64_t wrId; + if (request.type == RequestType::SEND) { + wrId = request.sendWr->wr_id; + if (request.sendWr->send_flags & IBV_SEND_SIGNALED || + request.sendWr->opcode == IBV_WR_RDMA_WRITE_WITH_IMM) { + VirtualWc virtualWc{}; + virtualWc.wc.wr_id = request.sendWr->wr_id; + virtualWc.wc.qp_num = request.virtualQpNum; + virtualWc.wc.status = IBV_WC_SUCCESS; + virtualWc.wc.byte_len = 0; + virtualWc.expectedMsgCnt = request.expectedMsgCnt; + virtualWc.remainingMsgCnt = request.expectedMsgCnt; + virtualWc.sendExtraNotifyImm = request.sendExtraNotifyImm; + pendingSendVirtualWcQue_.push_back(std::move(virtualWc)); + virtualWcPtr = &pendingSendVirtualWcQue_.back(); + } + } else { + wrId = request.recvWr->wr_id; + VirtualWc virtualWc{}; + virtualWc.wc.wr_id = request.recvWr->wr_id; + virtualWc.wc.qp_num = request.virtualQpNum; + virtualWc.wc.status = IBV_WC_SUCCESS; + virtualWc.wc.byte_len = 0; + virtualWc.expectedMsgCnt = request.expectedMsgCnt; + virtualWc.remainingMsgCnt = request.expectedMsgCnt; + pendingRecvVirtualWcQue_.push_back(std::move(virtualWc)); + virtualWcPtr = &pendingRecvVirtualWcQue_.back(); + } + virtualWrIdToVirtualWc_[wrId] = virtualWcPtr; +} + +} // namespace ibverbx diff --git a/comms/ctran/ibverbx/IbvVirtualQp.cc b/comms/ctran/ibverbx/IbvVirtualQp.cc new file mode 100644 index 00000000..32da6a30 --- /dev/null +++ b/comms/ctran/ibverbx/IbvVirtualQp.cc @@ -0,0 +1,265 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. + +#include "comms/ctran/ibverbx/IbvVirtualQp.h" + +#include +#include "comms/ctran/ibverbx/IbvVirtualCq.h" +#include "comms/ctran/ibverbx/Ibvcore.h" + +namespace ibverbx { + +/*** IbvVirtualQp ***/ + +IbvVirtualQp::IbvVirtualQp( + std::vector&& qps, + IbvQp&& notifyQp, + IbvVirtualCq* sendCq, + IbvVirtualCq* recvCq, + int maxMsgCntPerQp, + int maxMsgSize, + LoadBalancingScheme loadBalancingScheme) + : physicalQps_(std::move(qps)), + maxMsgCntPerQp_(maxMsgCntPerQp), + maxMsgSize_(maxMsgSize), + loadBalancingScheme_(loadBalancingScheme), + notifyQp_(std::move(notifyQp)) { + virtualQpNum_ = + nextVirtualQpNum_.fetch_add(1); // Assign unique virtual QP number + + for (int i = 0; i < physicalQps_.size(); i++) { + qpNumToIdx_[physicalQps_.at(i).qp()->qp_num] = i; + } + + // Register the virtual QP and all its mappings with the coordinator + auto coordinator = Coordinator::getCoordinator(); + CHECK(coordinator) + << "Coordinator should not be nullptr during IbvVirtualQp construction!"; + + // Use the consolidated registration API + coordinator->registerVirtualQpWithVirtualCqMappings( + this, sendCq->getVirtualCqNum(), recvCq->getVirtualCqNum()); +} + +size_t IbvVirtualQp::getTotalQps() const { + return physicalQps_.size(); +} + +const std::vector& IbvVirtualQp::getQpsRef() const { + return physicalQps_; +} + +std::vector& IbvVirtualQp::getQpsRef() { + return physicalQps_; +} + +const IbvQp& IbvVirtualQp::getNotifyQpRef() const { + return notifyQp_; +} + +uint32_t IbvVirtualQp::getVirtualQpNum() const { + return virtualQpNum_; +} + +IbvVirtualQp::IbvVirtualQp(IbvVirtualQp&& other) noexcept + : pendingSendVirtualWrQue_(std::move(other.pendingSendVirtualWrQue_)), + pendingRecvVirtualWrQue_(std::move(other.pendingRecvVirtualWrQue_)), + virtualQpNum_(std::move(other.virtualQpNum_)), + physicalQps_(std::move(other.physicalQps_)), + qpNumToIdx_(std::move(other.qpNumToIdx_)), + nextSendPhysicalQpIdx_(std::move(other.nextSendPhysicalQpIdx_)), + nextRecvPhysicalQpIdx_(std::move(other.nextRecvPhysicalQpIdx_)), + maxMsgCntPerQp_(std::move(other.maxMsgCntPerQp_)), + maxMsgSize_(std::move(other.maxMsgSize_)), + nextPhysicalWrId_(std::move(other.nextPhysicalWrId_)), + loadBalancingScheme_(std::move(other.loadBalancingScheme_)), + pendingSendNotifyVirtualWrQue_( + std::move(other.pendingSendNotifyVirtualWrQue_)), + notifyQp_(std::move(other.notifyQp_)), + dqplbSeqTracker(std::move(other.dqplbSeqTracker)), + dqplbReceiverInitialized_(std::move(other.dqplbReceiverInitialized_)) { + // Update coordinator pointer mapping for this virtual QP after move + auto coordinator = Coordinator::getCoordinator(); + CHECK(coordinator) + << "Coordinator should not be nullptr during IbvVirtualQp move construction!"; + coordinator->updateVirtualQpPointer(virtualQpNum_, this); +} + +IbvVirtualQp& IbvVirtualQp::operator=(IbvVirtualQp&& other) noexcept { + if (this != &other) { + physicalQps_ = std::move(other.physicalQps_); + notifyQp_ = std::move(other.notifyQp_); + nextSendPhysicalQpIdx_ = std::move(other.nextSendPhysicalQpIdx_); + nextRecvPhysicalQpIdx_ = std::move(other.nextRecvPhysicalQpIdx_); + qpNumToIdx_ = std::move(other.qpNumToIdx_); + maxMsgCntPerQp_ = std::move(other.maxMsgCntPerQp_); + maxMsgSize_ = std::move(other.maxMsgSize_); + loadBalancingScheme_ = std::move(other.loadBalancingScheme_); + pendingSendVirtualWrQue_ = std::move(other.pendingSendVirtualWrQue_); + pendingRecvVirtualWrQue_ = std::move(other.pendingRecvVirtualWrQue_); + virtualQpNum_ = std::move(other.virtualQpNum_); + nextPhysicalWrId_ = std::move(other.nextPhysicalWrId_); + pendingSendNotifyVirtualWrQue_ = + std::move(other.pendingSendNotifyVirtualWrQue_); + dqplbSeqTracker = std::move(other.dqplbSeqTracker); + dqplbReceiverInitialized_ = std::move(other.dqplbReceiverInitialized_); + + // Update coordinator pointer mapping for this virtual QP after move + auto coordinator = Coordinator::getCoordinator(); + CHECK(coordinator) + << "Coordinator should not be nullptr during IbvVirtualQp move construction!"; + coordinator->updateVirtualQpPointer(virtualQpNum_, this); + } + return *this; +} + +IbvVirtualQp::~IbvVirtualQp() { + // Always call unregister - the coordinator will check if the pointer matches + // and do nothing if the object was moved + auto coordinator = Coordinator::getCoordinator(); + CHECK(coordinator) + << "Coordinator should not be nullptr during IbvVirtualQp destruction!"; + coordinator->unregisterVirtualQp(virtualQpNum_, this); +} + +folly::Expected IbvVirtualQp::modifyVirtualQp( + ibv_qp_attr* attr, + int attrMask, + const IbvVirtualQpBusinessCard& businessCard) { + // If businessCard is not empty, use it to modify QPs with specific + // dest_qp_num values + if (!businessCard.qpNums_.empty()) { + // Make sure the businessCard has the same number of QPs as physicalQps_ + if (businessCard.qpNums_.size() != physicalQps_.size()) { + return folly::makeUnexpected(Error( + EINVAL, "BusinessCard QP count doesn't match physical QP count")); + } + + // Modify each QP with its corresponding dest_qp_num from the businessCard + for (auto i = 0; i < physicalQps_.size(); i++) { + attr->dest_qp_num = businessCard.qpNums_.at(i); + auto maybeModifyQp = physicalQps_.at(i).modifyQp(attr, attrMask); + if (maybeModifyQp.hasError()) { + return folly::makeUnexpected(maybeModifyQp.error()); + } + } + attr->dest_qp_num = businessCard.notifyQpNum_; + auto maybeModifyQp = notifyQp_.modifyQp(attr, attrMask); + if (maybeModifyQp.hasError()) { + return folly::makeUnexpected(maybeModifyQp.error()); + } + } else { + // If no businessCard provided, modify all QPs with the same attributes + for (auto& qp : physicalQps_) { + auto maybeModifyQp = qp.modifyQp(attr, attrMask); + if (maybeModifyQp.hasError()) { + return folly::makeUnexpected(maybeModifyQp.error()); + } + } + auto maybeModifyQp = notifyQp_.modifyQp(attr, attrMask); + if (maybeModifyQp.hasError()) { + return folly::makeUnexpected(maybeModifyQp.error()); + } + } + return folly::unit; +} + +IbvVirtualQpBusinessCard IbvVirtualQp::getVirtualQpBusinessCard() const { + std::vector qpNums; + qpNums.reserve(physicalQps_.size()); + for (auto& qp : physicalQps_) { + qpNums.push_back(qp.qp()->qp_num); + } + return IbvVirtualQpBusinessCard(std::move(qpNums), notifyQp_.qp()->qp_num); +} + +LoadBalancingScheme IbvVirtualQp::getLoadBalancingScheme() const { + return loadBalancingScheme_; +} + +/*** IbvVirtualQpBusinessCard ***/ + +IbvVirtualQpBusinessCard::IbvVirtualQpBusinessCard( + std::vector qpNums, + uint32_t notifyQpNum) + : qpNums_(std::move(qpNums)), notifyQpNum_(notifyQpNum) {} + +folly::dynamic IbvVirtualQpBusinessCard::toDynamic() const { + folly::dynamic obj = folly::dynamic::object; + folly::dynamic qpNumsArray = folly::dynamic::array; + + // Use fixed-width string formatting to ensure consistent size + // All uint32_t values will be formatted as 10-digit zero-padded strings + for (const auto& qpNum : qpNums_) { + std::string paddedQpNum = fmt::format("{:010d}", qpNum); + qpNumsArray.push_back(paddedQpNum); + } + + obj["qpNums"] = std::move(qpNumsArray); + obj["notifyQpNum"] = fmt::format("{:010d}", notifyQpNum_); + return obj; +} + +folly::Expected +IbvVirtualQpBusinessCard::fromDynamic(const folly::dynamic& obj) { + std::vector qpNums; + + if (obj.count("qpNums") > 0 && obj["qpNums"].isArray()) { + const auto& qpNumsArray = obj["qpNums"]; + qpNums.reserve(qpNumsArray.size()); + + for (const auto& qpNum : qpNumsArray) { + CHECK(qpNum.isString()) << "qp num is not string!"; + try { + uint32_t qpNumValue = + static_cast(std::stoul(qpNum.asString())); + qpNums.push_back(qpNumValue); + } catch (const std::exception& e) { + return folly::makeUnexpected(Error( + EINVAL, + fmt::format( + "Invalid QP number string format: {}. Exception: {}", + qpNum.asString(), + e.what()))); + } + } + } else { + return folly::makeUnexpected( + Error(EINVAL, "Invalid qpNums array received from remote side")); + } + + uint32_t notifyQpNum = 0; // Default value for backwards compatibility + if (obj.count("notifyQpNum") > 0 && obj["notifyQpNum"].isString()) { + try { + notifyQpNum = + static_cast(std::stoul(obj["notifyQpNum"].asString())); + } catch (const std::exception& e) { + return folly::makeUnexpected(Error( + EINVAL, + fmt::format( + "Invalid notifyQpNum string format: {}. Exception: {}", + obj["notifyQpNum"].asString(), + e.what()))); + } + } + + return IbvVirtualQpBusinessCard(std::move(qpNums), notifyQpNum); +} + +std::string IbvVirtualQpBusinessCard::serialize() const { + return folly::toJson(toDynamic()); +} + +folly::Expected +IbvVirtualQpBusinessCard::deserialize(const std::string& jsonStr) { + try { + folly::dynamic obj = folly::parseJson(jsonStr); + return fromDynamic(obj); + } catch (const std::exception& e) { + return folly::makeUnexpected(Error( + EINVAL, + fmt::format( + "Failed to parse JSON in IbvVirtualQpBusinessCard Deserialize. Exception: {}", + e.what()))); + } +} +} // namespace ibverbx diff --git a/comms/ctran/ibverbx/IbvVirtualQp.h b/comms/ctran/ibverbx/IbvVirtualQp.h new file mode 100644 index 00000000..fd06b5ca --- /dev/null +++ b/comms/ctran/ibverbx/IbvVirtualQp.h @@ -0,0 +1,694 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include +#include +#include + +#include "comms/ctran/ibverbx/Coordinator.h" +#include "comms/ctran/ibverbx/DqplbSeqTracker.h" +#include "comms/ctran/ibverbx/IbvCommon.h" +#include "comms/ctran/ibverbx/IbvQp.h" +#include "comms/ctran/ibverbx/IbvVirtualCq.h" +#include "comms/ctran/ibverbx/IbvVirtualWr.h" +#include "comms/ctran/ibverbx/Ibvcore.h" + +namespace ibverbx { + +class IbvVirtualCq; + +// IbvVirtualQpBusinessCard +struct IbvVirtualQpBusinessCard { + explicit IbvVirtualQpBusinessCard( + std::vector qpNums, + uint32_t notifyQpNum = 0); + IbvVirtualQpBusinessCard() = default; + ~IbvVirtualQpBusinessCard() = default; + + // Default copy constructor and assignment operator + IbvVirtualQpBusinessCard(const IbvVirtualQpBusinessCard& other) = default; + IbvVirtualQpBusinessCard& operator=(const IbvVirtualQpBusinessCard& other) = + default; + + // Default move constructor and assignment operator + IbvVirtualQpBusinessCard(IbvVirtualQpBusinessCard&& other) = default; + IbvVirtualQpBusinessCard& operator=(IbvVirtualQpBusinessCard&& other) = + default; + + // Convert to/from folly::dynamic for serialization + folly::dynamic toDynamic() const; + static folly::Expected fromDynamic( + const folly::dynamic& obj); + + // JSON serialization methods + std::string serialize() const; + static folly::Expected deserialize( + const std::string& jsonStr); + + // The qpNums_ vector is ordered: the ith QP in qpNums_ will be + // connected to the ith QP in the remote side's qpNums_ vector. + std::vector qpNums_; + uint32_t notifyQpNum_{0}; +}; + +// Ibv Virtual Queue Pair +class IbvVirtualQp { + public: + ~IbvVirtualQp(); + + // disable copy constructor + IbvVirtualQp(const IbvVirtualQp&) = delete; + IbvVirtualQp& operator=(const IbvVirtualQp&) = delete; + + // move constructor + IbvVirtualQp(IbvVirtualQp&& other) noexcept; + IbvVirtualQp& operator=(IbvVirtualQp&& other) noexcept; + + size_t getTotalQps() const; + const std::vector& getQpsRef() const; + std::vector& getQpsRef(); + const IbvQp& getNotifyQpRef() const; + uint32_t getVirtualQpNum() const; + // If businessCard is not provided, all physical QPs will be updated with the + // universal attributes specified in attr. This is typically used for changing + // the state to INIT or RTS. + // If businessCard is provided, attr.qp_num for each physical QP will be set + // individually to the corresponding qpNum stored in qpNums_ within + // businessCard. This is typically used for changing the state to RTR. + folly::Expected modifyVirtualQp( + ibv_qp_attr* attr, + int attrMask, + const IbvVirtualQpBusinessCard& businessCard = + IbvVirtualQpBusinessCard()); + IbvVirtualQpBusinessCard getVirtualQpBusinessCard() const; + LoadBalancingScheme getLoadBalancingScheme() const; + + inline folly::Expected postSend( + ibv_send_wr* sendWr, + ibv_send_wr* sendWrBad); + + inline folly::Expected postRecv( + ibv_recv_wr* ibvRecvWr, + ibv_recv_wr* badIbvRecvWr); + + inline int findAvailableSendQp(); + inline int findAvailableRecvQp(); + + inline folly::Expected processRequest( + VirtualQpRequest&& request); + + private: +#ifdef IBVERBX_TEST_FRIENDS + IBVERBX_TEST_FRIENDS +#endif + + // updatePhysicalSendWrFromVirtualSendWr is a helper function to update + // physical send work request (ibv_send_wr) from virtual send work request + inline void updatePhysicalSendWrFromVirtualSendWr( + VirtualSendWr& virtualSendWr, + ibv_send_wr* sendWr, + ibv_sge* sendSg); + + friend class IbvPd; + friend class IbvVirtualCq; + + std::deque pendingSendVirtualWrQue_; + std::deque pendingRecvVirtualWrQue_; + + inline static std::atomic nextVirtualQpNum_{ + 0}; // Static counter for assigning unique virtual QP numbers + uint32_t virtualQpNum_{0}; // The unique virtual QP number assigned to + // instance of IbvVirtualQp. + + std::vector physicalQps_; + std::unordered_map qpNumToIdx_; + + int nextSendPhysicalQpIdx_{0}; + int nextRecvPhysicalQpIdx_{0}; + + int maxMsgCntPerQp_{ + -1}; // Maximum number of messages that can be sent on each physical QP. A + // value of -1 indicates there is no limit. + int maxMsgSize_{0}; + + uint64_t nextPhysicalWrId_{0}; // ID of the next physical work request to + // be posted on the physical QP + + LoadBalancingScheme loadBalancingScheme_{ + LoadBalancingScheme::SPRAY}; // Load balancing scheme for this virtual QP + + // Spray mode specific fields + std::deque pendingSendNotifyVirtualWrQue_; + IbvQp notifyQp_; + + // DQPLB mode specific fields and functions + DqplbSeqTracker dqplbSeqTracker; + bool dqplbReceiverInitialized_{ + false}; // flag to indicate if dqplb receiver is initialized + inline folly::Expected initializeDqplbReceiver(); + + IbvVirtualQp( + std::vector&& qps, + IbvQp&& notifyQp, + IbvVirtualCq* sendCq, + IbvVirtualCq* recvCq, + int maxMsgCntPerQp = kIbMaxMsgCntPerQp, + int maxMsgSize = kIbMaxMsgSizeByte, + LoadBalancingScheme loadBalancingScheme = LoadBalancingScheme::SPRAY); + + // mapPendingSendQueToPhysicalQp is a helper function to iterate through + // virtualSendWr in the pendingSendVirtualWrQue_, construct physical wrs and + // call postSend on physical QP. If qpIdx is provided, this function will + // postSend physicalWr on qpIdx. If qpIdx is not provided, then the function + // will find an available Qp to postSend the physical work request on. + inline folly::Expected mapPendingSendQueToPhysicalQp( + int qpIdx = -1); + + // postSendNotifyImm is a helper function to send IMM notification message + // after all previous messages are sent in a large message + inline folly::Expected postSendNotifyImm(); + inline folly::Expected mapPendingRecvQueToPhysicalQp( + int qpIdx = -1); + inline folly::Expected postRecvNotifyImm(int qpIdx = -1); +}; + +inline folly::Expected +Coordinator::submitRequestToVirtualQp(VirtualQpRequest&& request) { + auto virtualQp = getVirtualQpByPhysicalQpNum(request.physicalQpNum); + return virtualQp->processRequest(std::move(request)); +} + +// IbvVirtualQp inline functions +inline folly::Expected +IbvVirtualQp::mapPendingSendQueToPhysicalQp(int qpIdx) { + while (!pendingSendVirtualWrQue_.empty()) { + // Get the front of vSendQ_ and obtain the send information + VirtualSendWr& virtualSendWr = pendingSendVirtualWrQue_.front(); + + // For Send opcodes related to RDMA_WRITE operations, use user selected load + // balancing scheme specified in loadBalancingScheme_. For all other + // opcodes, default to using physical QP 0. + auto availableQpIdx = -1; + if (virtualSendWr.wr.opcode == IBV_WR_RDMA_WRITE || + virtualSendWr.wr.opcode == IBV_WR_RDMA_WRITE_WITH_IMM || + virtualSendWr.wr.opcode == IBV_WR_RDMA_READ) { + // Find an available Qp to send + availableQpIdx = qpIdx == -1 ? findAvailableSendQp() : qpIdx; + qpIdx = -1; // If qpIdx is provided, it indicates that one slot has been + // freed for the corresponding qpIdx. After using this slot, + // reset qpIdx to -1. + } else if ( + physicalQps_.at(0).physicalSendWrStatus_.size() < maxMsgCntPerQp_) { + availableQpIdx = 0; + } + if (availableQpIdx == -1) { + break; + } + + // Update the physical send work request with virtual one + ibv_send_wr sendWr_{}; + ibv_sge sendSg_{}; + updatePhysicalSendWrFromVirtualSendWr(virtualSendWr, &sendWr_, &sendSg_); + + // Call ibv_post_send to send the message + ibv_send_wr badSendWr_{}; + auto maybeSend = + physicalQps_.at(availableQpIdx).postSend(&sendWr_, &badSendWr_); + if (maybeSend.hasError()) { + return folly::makeUnexpected(maybeSend.error()); + } + + // Enqueue the send information to physicalQps_ + physicalQps_.at(availableQpIdx) + .physicalSendWrStatus_.emplace_back( + sendWr_.wr_id, virtualSendWr.wr.wr_id); + + // Decide if need to deque the front of vSendQ_ + virtualSendWr.offset += sendWr_.sg_list->length; + virtualSendWr.remainingMsgCnt--; + if (virtualSendWr.remainingMsgCnt == 0) { + pendingSendVirtualWrQue_.pop_front(); + } else if ( + virtualSendWr.remainingMsgCnt == 1 && + virtualSendWr.sendExtraNotifyImm) { + // Move front entry from pendingSendVirtualWrQue_ to + // pendingSendNotifyVirtualWrQue_ + pendingSendNotifyVirtualWrQue_.push_back( + std::move(pendingSendVirtualWrQue_.front())); + pendingSendVirtualWrQue_.pop_front(); + } + } + return folly::unit; +} + +inline int IbvVirtualQp::findAvailableSendQp() { + // maxMsgCntPerQp_ with a value of -1 indicates there is no limit. + if (maxMsgCntPerQp_ == -1) { + auto availableQpIdx = nextSendPhysicalQpIdx_; + nextSendPhysicalQpIdx_ = (nextSendPhysicalQpIdx_ + 1) % physicalQps_.size(); + return availableQpIdx; + } + + for (int i = 0; i < physicalQps_.size(); i++) { + if (physicalQps_.at(nextSendPhysicalQpIdx_).physicalSendWrStatus_.size() < + maxMsgCntPerQp_) { + auto availableQpIdx = nextSendPhysicalQpIdx_; + nextSendPhysicalQpIdx_ = + (nextSendPhysicalQpIdx_ + 1) % physicalQps_.size(); + return availableQpIdx; + } + nextSendPhysicalQpIdx_ = (nextSendPhysicalQpIdx_ + 1) % physicalQps_.size(); + } + return -1; +} + +inline folly::Expected IbvVirtualQp::postSendNotifyImm() { + auto virtualSendWr = pendingSendNotifyVirtualWrQue_.front(); + ibv_send_wr sendWr_{}; + ibv_send_wr badSendWr_{}; + ibv_sge sendSg_{}; + sendWr_.next = nullptr; + sendWr_.sg_list = &sendSg_; + sendWr_.num_sge = 0; + sendWr_.opcode = IBV_WR_RDMA_WRITE_WITH_IMM; + sendWr_.send_flags = IBV_SEND_SIGNALED; + sendWr_.wr.rdma.remote_addr = virtualSendWr.wr.wr.rdma.remote_addr; + sendWr_.wr.rdma.rkey = virtualSendWr.wr.wr.rdma.rkey; + sendWr_.imm_data = virtualSendWr.wr.imm_data; + sendWr_.wr_id = nextPhysicalWrId_++; + auto maybeSend = notifyQp_.postSend(&sendWr_, &badSendWr_); + if (maybeSend.hasError()) { + return folly::makeUnexpected(maybeSend.error()); + } + notifyQp_.physicalSendWrStatus_.emplace_back( + sendWr_.wr_id, virtualSendWr.wr.wr_id); + virtualSendWr.remainingMsgCnt = 0; + pendingSendNotifyVirtualWrQue_.pop_front(); + return folly::unit; +} + +inline void IbvVirtualQp::updatePhysicalSendWrFromVirtualSendWr( + VirtualSendWr& virtualSendWr, + ibv_send_wr* sendWr, + ibv_sge* sendSg) { + sendWr->wr_id = nextPhysicalWrId_++; + + auto lenToSend = std::min( + int(virtualSendWr.wr.sg_list->length - virtualSendWr.offset), + maxMsgSize_); + sendSg->addr = virtualSendWr.wr.sg_list->addr + virtualSendWr.offset; + sendSg->length = lenToSend; + sendSg->lkey = virtualSendWr.wr.sg_list->lkey; + sendWr->next = nullptr; + sendWr->sg_list = sendSg; + sendWr->num_sge = 1; + + // Set the opcode to the same as virtual wr, except for RDMA_WRITE_WITH_IMM, + // we'll handle the notification message separately + switch (virtualSendWr.wr.opcode) { + case IBV_WR_RDMA_WRITE: + case IBV_WR_RDMA_READ: + sendWr->opcode = virtualSendWr.wr.opcode; + sendWr->send_flags = virtualSendWr.wr.send_flags; + sendWr->wr.rdma.remote_addr = + virtualSendWr.wr.wr.rdma.remote_addr + virtualSendWr.offset; + sendWr->wr.rdma.rkey = virtualSendWr.wr.wr.rdma.rkey; + break; + case IBV_WR_RDMA_WRITE_WITH_IMM: + sendWr->opcode = (loadBalancingScheme_ == LoadBalancingScheme::SPRAY) + ? IBV_WR_RDMA_WRITE + : IBV_WR_RDMA_WRITE_WITH_IMM; + sendWr->send_flags = IBV_SEND_SIGNALED; + sendWr->wr.rdma.remote_addr = + virtualSendWr.wr.wr.rdma.remote_addr + virtualSendWr.offset; + sendWr->wr.rdma.rkey = virtualSendWr.wr.wr.rdma.rkey; + break; + case IBV_WR_SEND: + sendWr->opcode = virtualSendWr.wr.opcode; + sendWr->send_flags = virtualSendWr.wr.send_flags; + break; + + default: + break; + } + + if (sendWr->opcode == IBV_WR_RDMA_WRITE_WITH_IMM && + loadBalancingScheme_ == LoadBalancingScheme::DQPLB) { + sendWr->imm_data = + dqplbSeqTracker.getSendImm(virtualSendWr.remainingMsgCnt); + } +} + +inline folly::Expected IbvVirtualQp::postSend( + ibv_send_wr* sendWr, + ibv_send_wr* sendWrBad) { + // Report error if num_sge is more than 1 + if (sendWr->num_sge > 1) { + return folly::makeUnexpected(Error( + EINVAL, "In IbvVirtualQp::postSend, num_sge > 1 is not supported")); + } + + // Report error if opcode is not supported by Ibverbx virtualQp + switch (sendWr->opcode) { + case IBV_WR_SEND_WITH_IMM: + case IBV_WR_ATOMIC_CMP_AND_SWP: + case IBV_WR_ATOMIC_FETCH_AND_ADD: + return folly::makeUnexpected(Error( + EINVAL, + "In IbvVirtualQp::postSend, opcode IBV_WR_SEND_WITH_IMM, IBV_WR_ATOMIC_CMP_AND_SWP, IBV_WR_ATOMIC_FETCH_AND_ADD are not supported")); + + default: + break; + } + + // Calculate the chunk number for the current message and update sendWqe + bool sendExtraNotifyImm = + (sendWr->opcode == IBV_WR_RDMA_WRITE_WITH_IMM && + loadBalancingScheme_ == LoadBalancingScheme::SPRAY); + int expectedMsgCnt = + (sendWr->sg_list->length + maxMsgSize_ - 1) / maxMsgSize_; + if (sendExtraNotifyImm) { + expectedMsgCnt += 1; // After post send all data messages, will post send + // 1 more notification message on QP 0 + } + + // Submit request to virtualCq to enqueue VirtualWc + VirtualCqRequest request = { + .type = RequestType::SEND, + .virtualQpNum = (int)virtualQpNum_, + .expectedMsgCnt = expectedMsgCnt, + .sendWr = sendWr, + .sendExtraNotifyImm = sendExtraNotifyImm}; + auto coordinator = Coordinator::getCoordinator(); + CHECK(coordinator) << "Coordinator should not be nullptr during postSend!"; + coordinator->submitRequestToVirtualCq(std::move(request)); + + // Set up the send work request with the completion queue entry and enqueue + // Note: virtualWcPtr can be nullptr - this is intentional and supported + // The VirtualSendWr constructor will handle deep copying of sendWr and + // sg_list + pendingSendVirtualWrQue_.emplace_back( + *sendWr, expectedMsgCnt, expectedMsgCnt, sendExtraNotifyImm); + + // Map large messages from vSendQ_ to pQps_ + if (mapPendingSendQueToPhysicalQp().hasError()) { + *sendWrBad = *sendWr; + return folly::makeUnexpected(Error(errno)); + } + + return folly::unit; +} + +inline folly::Expected IbvVirtualQp::processRequest( + VirtualQpRequest&& request) { + VirtualQpResponse response; + // If request.physicalQpNum differs from notifyQpNum, locate the corresponding + // physical qpIdx to process this request. + auto qpIdx = request.physicalQpNum == notifyQp_.getQpNum() + ? -1 + : qpNumToIdx_.at(request.physicalQpNum); + // If qpIdx is -1, physicalQp is notifyQp; otherwise, physicalQp is the qpIdx + // entry of physicalQps_ + auto& physicalQp = qpIdx == -1 ? notifyQp_ : physicalQps_.at(qpIdx); + + if (request.type == RequestType::RECV) { + if (physicalQp.physicalRecvWrStatus_.empty()) { + return folly::makeUnexpected(Error( + EINVAL, + fmt::format( + "In pollCq, after calling submit command to IbvVirtualQp, \ + physicalRecvWrStatus_ at physicalQp {} is empty!", + request.physicalQpNum))); + } + + auto& physicalRecvWrStatus = physicalQp.physicalRecvWrStatus_.front(); + + if (physicalRecvWrStatus.physicalWrId != request.wrId) { + return folly::makeUnexpected(Error( + EINVAL, + fmt::format( + "In pollCq, after calling submit command to IbvVirtualQp, \ + physicalRecvWrStatus.physicalWrId({}) != request.wrId({})", + physicalRecvWrStatus.physicalWrId, + request.wrId))); + } + + response.virtualWrId = physicalRecvWrStatus.virtualWrId; + physicalQp.physicalRecvWrStatus_.pop_front(); + if (loadBalancingScheme_ == LoadBalancingScheme::DQPLB) { + if (postRecvNotifyImm(qpIdx).hasError()) { + return folly::makeUnexpected( + Error(errno, fmt::format("postRecvNotifyImm() failed!"))); + } + response.notifyCount = + dqplbSeqTracker.processReceivedImm(request.immData); + response.useDqplb = true; + } else if (qpIdx != -1) { + if (mapPendingRecvQueToPhysicalQp(qpIdx).hasError()) { + return folly::makeUnexpected(Error( + errno, + fmt::format("mapPendingRecvQueToPhysicalQp({}) failed!", qpIdx))); + } + } + } else if (request.type == RequestType::SEND) { + if (physicalQp.physicalSendWrStatus_.empty()) { + return folly::makeUnexpected(Error( + EINVAL, + fmt::format( + "In pollCq, after calling submit command to IbvVirtualQp, \ + physicalSendWrStatus_ at physicalQp {} is empty!", + request.physicalQpNum))); + } + + auto physicalSendWrStatus = physicalQp.physicalSendWrStatus_.front(); + + if (physicalSendWrStatus.physicalWrId != request.wrId) { + return folly::makeUnexpected(Error( + EINVAL, + fmt::format( + "In pollCq, after calling submit command to IbvVirtualQp, \ + physicalSendWrStatus.physicalWrId({}) != request.wrId({})", + physicalSendWrStatus.physicalWrId, + request.wrId))); + } + + response.virtualWrId = physicalSendWrStatus.virtualWrId; + physicalQp.physicalSendWrStatus_.pop_front(); + if (qpIdx != -1) { + if (mapPendingSendQueToPhysicalQp(qpIdx).hasError()) { + return folly::makeUnexpected(Error( + errno, + fmt::format("mapPendingSendQueToPhysicalQp({}) failed!", qpIdx))); + } + } + } else if (request.type == RequestType::SEND_NOTIFY) { + if (pendingSendNotifyVirtualWrQue_.empty()) { + return folly::makeUnexpected(Error( + EINVAL, + fmt::format( + "Tried to post send notify IMM message for wrId {} when pendingSendNotifyVirtualWrQue_ is empty", + request.wrId))); + } + + if (pendingSendNotifyVirtualWrQue_.front().wr.wr_id == request.wrId) { + if (postSendNotifyImm().hasError()) { + return folly::makeUnexpected( + Error(errno, fmt::format("postSendNotifyImm() failed!"))); + } + } + } + return response; +} + +// Currently, this function is only invoked to receive messages with opcode +// IBV_WR_SEND. Therefore, we restrict its usage to physical QP 0. +// Note: If Dynamic QP Load Balancing (DQPLB) or other load balancing techniques +// are required in the future, this function can be updated to support more +// advanced usage. +inline int IbvVirtualQp::findAvailableRecvQp() { + // maxMsgCntPerQp_ with a value of -1 indicates there is no limit. + auto availableQpIdx = -1; + if (maxMsgCntPerQp_ == -1 || + physicalQps_.at(0).physicalRecvWrStatus_.size() < maxMsgCntPerQp_) { + availableQpIdx = 0; + } + + return availableQpIdx; +} + +inline folly::Expected IbvVirtualQp::postRecvNotifyImm( + int qpIdx) { + auto& qp = qpIdx == -1 ? notifyQp_ : physicalQps_.at(qpIdx); + auto virtualRecvWrId = loadBalancingScheme_ == LoadBalancingScheme::SPRAY + ? pendingRecvVirtualWrQue_.front().wr.wr_id + : -1; + ibv_recv_wr recvWr_{}; + ibv_recv_wr badRecvWr_{}; + ibv_sge recvSg_{}; + recvWr_.next = nullptr; + recvWr_.sg_list = &recvSg_; + recvWr_.num_sge = 0; + recvWr_.wr_id = nextPhysicalWrId_++; + auto maybeRecv = qp.postRecv(&recvWr_, &badRecvWr_); + if (maybeRecv.hasError()) { + return folly::makeUnexpected(maybeRecv.error()); + } + qp.physicalRecvWrStatus_.emplace_back(recvWr_.wr_id, virtualRecvWrId); + + if (loadBalancingScheme_ == LoadBalancingScheme::SPRAY) { + pendingRecvVirtualWrQue_.pop_front(); + } + return folly::unit; +} + +inline folly::Expected +IbvVirtualQp::initializeDqplbReceiver() { + ibv_recv_wr recvWr_{}; + ibv_recv_wr badRecvWr_{}; + ibv_sge recvSg_{}; + recvWr_.next = nullptr; + recvWr_.sg_list = &recvSg_; + recvWr_.num_sge = 0; + for (int i = 0; i < maxMsgCntPerQp_; i++) { + for (int j = 0; j < physicalQps_.size(); j++) { + recvWr_.wr_id = nextPhysicalWrId_++; + auto maybeRecv = physicalQps_.at(j).postRecv(&recvWr_, &badRecvWr_); + if (maybeRecv.hasError()) { + return folly::makeUnexpected(maybeRecv.error()); + } + physicalQps_.at(j).physicalRecvWrStatus_.emplace_back(recvWr_.wr_id, -1); + } + } + + dqplbReceiverInitialized_ = true; + return folly::unit; +} + +inline folly::Expected +IbvVirtualQp::mapPendingRecvQueToPhysicalQp(int qpIdx) { + while (!pendingRecvVirtualWrQue_.empty()) { + VirtualRecvWr& virtualRecvWr = pendingRecvVirtualWrQue_.front(); + + if (virtualRecvWr.wr.num_sge == 0) { + auto maybeRecvNotifyImm = postRecvNotifyImm(); + if (maybeRecvNotifyImm.hasError()) { + return folly::makeUnexpected(maybeRecvNotifyImm.error()); + } + continue; + } + + // If num_sge is > 0, then the receive work request is used to receive + // messages with opcode IBV_WR_SEND. In this scenario, we restrict usage to + // physical QP 0 only. The reason behind is that, IBV_WR_SEND requires a + // strict one-to-one correspondence between send and receive WRs. If Dynamic + // QP Load Balancing (DQPLB) is applied, send and receive WRs may be posted + // to different physical QPs within the QP list. This mismatch can result in + // data being delivered to the wrong address, causing data integrity issues. + auto availableQpIdx = qpIdx != 0 ? findAvailableRecvQp() : qpIdx; + qpIdx = -1; // If qpIdx is provided, it indicates that one slot has been + // freed for the corresponding qpIdx. After using this slot, + // reset qpIdx to -1. + if (availableQpIdx == -1) { + break; + } + + // Get the front of vRecvQ_ and obtain the receive information + ibv_recv_wr recvWr_{}; + ibv_recv_wr badRecvWr_{}; + ibv_sge recvSg_{}; + int lenToRecv = 0; + if (virtualRecvWr.wr.num_sge == 1) { + lenToRecv = std::min( + int(virtualRecvWr.wr.sg_list->length - virtualRecvWr.offset), + maxMsgSize_); + recvSg_.addr = virtualRecvWr.wr.sg_list->addr + virtualRecvWr.offset; + recvSg_.length = lenToRecv; + recvSg_.lkey = virtualRecvWr.wr.sg_list->lkey; + + recvWr_.sg_list = &recvSg_; + recvWr_.num_sge = 1; + } else { + recvWr_.sg_list = nullptr; + recvWr_.num_sge = 0; + } + recvWr_.wr_id = nextPhysicalWrId_++; + recvWr_.next = nullptr; + + // Call ibv_post_recv to receive the message + auto maybeRecv = + physicalQps_.at(availableQpIdx).postRecv(&recvWr_, &badRecvWr_); + if (maybeRecv.hasError()) { + return folly::makeUnexpected(maybeRecv.error()); + } + + // Enqueue the receive information to physicalQps_ + physicalQps_.at(availableQpIdx) + .physicalRecvWrStatus_.emplace_back( + recvWr_.wr_id, virtualRecvWr.wr.wr_id); + + // Decide if need to deque the front of vRecvQ_ + if (virtualRecvWr.wr.num_sge == 1) { + virtualRecvWr.offset += lenToRecv; + } + virtualRecvWr.remainingMsgCnt--; + if (virtualRecvWr.remainingMsgCnt == 0) { + pendingRecvVirtualWrQue_.pop_front(); + } + } + return folly::unit; +} + +inline folly::Expected IbvVirtualQp::postRecv( + ibv_recv_wr* recvWr, + ibv_recv_wr* recvWrBad) { + // Report error if num_sge is more than 1 + if (recvWr->num_sge > 1) { + return folly::makeUnexpected(Error(EINVAL)); + } + + int expectedMsgCnt = 1; + + if (recvWr->num_sge == 0) { // recvWr->num_sge == 0 mean it's receiving a + // IMM notification message + expectedMsgCnt = 1; + } else if (recvWr->num_sge == 1) { // Calculate the chunk number for the + // current message and update recvWqe if + // num_sge is 1 + expectedMsgCnt = (recvWr->sg_list->length + maxMsgSize_ - 1) / maxMsgSize_; + } + + // Submit request to virtualCq to enqueue VirtualWc + VirtualCqRequest request = { + .type = RequestType::RECV, + .virtualQpNum = (int)virtualQpNum_, + .expectedMsgCnt = expectedMsgCnt, + .recvWr = recvWr}; + auto coordinator = Coordinator::getCoordinator(); + CHECK(coordinator) << "Coordinator should not be nullptr during postRecv!"; + coordinator->submitRequestToVirtualCq(std::move(request)); + + // Set up the recv work request with the completion queue entry and enqueue + pendingRecvVirtualWrQue_.emplace_back( + *recvWr, expectedMsgCnt, expectedMsgCnt); + + if (loadBalancingScheme_ != LoadBalancingScheme::DQPLB) { + if (mapPendingRecvQueToPhysicalQp().hasError()) { + // For non-DQPLB modes: map messages from pendingRecvVirtualWrQue_ to + // physicalQps_. In DQPLB mode, this mapping is unnecessary because all + // receive notify IMM operations are pre-posted to the QPs before postRecv + // is called. + *recvWrBad = *recvWr; + return folly::makeUnexpected(Error(errno)); + } + } else if (dqplbReceiverInitialized_ == false) { + if (initializeDqplbReceiver().hasError()) { + *recvWrBad = *recvWr; + return folly::makeUnexpected(Error(errno)); + } + } + + return folly::unit; +} + +} // namespace ibverbx diff --git a/comms/ctran/ibverbx/IbvVirtualWr.h b/comms/ctran/ibverbx/IbvVirtualWr.h new file mode 100644 index 00000000..b174dc78 --- /dev/null +++ b/comms/ctran/ibverbx/IbvVirtualWr.h @@ -0,0 +1,83 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include +#include "comms/ctran/ibverbx/Ibvcore.h" + +namespace ibverbx { + +struct VirtualSendWr { + VirtualSendWr( + const ibv_send_wr& wr, + int expectedMsgCnt, + int remainingMsgCnt, + bool sendExtraNotifyImm) + : expectedMsgCnt(expectedMsgCnt), + remainingMsgCnt(remainingMsgCnt), + sendExtraNotifyImm(sendExtraNotifyImm) { + // Make an explicit copy of the ibv_send_wr structure + this->wr = wr; + + // Deep copy the scatter-gather list + if (wr.sg_list != nullptr && wr.num_sge > 0) { + sgList.resize(wr.num_sge); + std::copy(wr.sg_list, wr.sg_list + wr.num_sge, sgList.begin()); + // Update the copied work request to point to our own scatter-gather list + this->wr.sg_list = sgList.data(); + } else { + // Handle case where there's no scatter-gather list + this->wr.sg_list = nullptr; + this->wr.num_sge = 0; + } + } + VirtualSendWr() = default; + ~VirtualSendWr() = default; + + ibv_send_wr wr{}; // Copy of the work request being posted by the user + std::vector sgList; // Copy of the scatter-gather list + int expectedMsgCnt{0}; // Expected message count resulting from splitting a + // large user message into multiple parts + int remainingMsgCnt{0}; // Number of message segments left to transmit after + // splitting a large user messaget + int offset{ + 0}; // Address offset to be used for the next message send operation + bool sendExtraNotifyImm{false}; // Whether to send an extra notify IMM message + // for the current VirtualSendWr +}; + +struct VirtualRecvWr { + inline VirtualRecvWr( + const ibv_recv_wr& wr, + int expectedMsgCnt, + int remainingMsgCnt) + : expectedMsgCnt(expectedMsgCnt), remainingMsgCnt(remainingMsgCnt) { + // Make an explicit copy of the ibv_recv_wr structure + this->wr = wr; + + // Deep copy the scatter-gather list + if (wr.sg_list != nullptr && wr.num_sge > 0) { + sgList.resize(wr.num_sge); + std::copy(wr.sg_list, wr.sg_list + wr.num_sge, sgList.begin()); + // Update the copied work request to point to our own scatter-gather list + this->wr.sg_list = sgList.data(); + } else { + // Handle case where there's no scatter-gather list + this->wr.sg_list = nullptr; + this->wr.num_sge = 0; + } + }; + VirtualRecvWr() = default; + ~VirtualRecvWr() = default; + + ibv_recv_wr wr{}; // Copy of the work request being posted by the user + std::vector sgList; // Copy of the scatter-gather list + int expectedMsgCnt{0}; // Expected message count resulting from splitting a + // large user message into multiple parts + int remainingMsgCnt{0}; // Number of message segments left to transmit after + // splitting a large user messaget + int offset{ + 0}; // Address offset to be used for the next message send operation +}; + +} // namespace ibverbx diff --git a/comms/ctran/ibverbx/Ibverbx.cc b/comms/ctran/ibverbx/Ibverbx.cc index 7d888e6c..0365e591 100644 --- a/comms/ctran/ibverbx/Ibverbx.cc +++ b/comms/ctran/ibverbx/Ibverbx.cc @@ -1,6 +1,7 @@ // Copyright (c) Meta Platforms, Inc. and affiliates. #include "comms/ctran/ibverbx/Ibverbx.h" +#include "comms/ctran/ibverbx/IbverbxSymbols.h" #ifdef IBVERBX_BUILD_RDMA_CORE #include @@ -13,548 +14,23 @@ #include #include #include -#include #include "comms/utils/cvars/nccl_cvars.h" namespace ibverbx { -namespace { +extern IbvSymbols ibvSymbols; -folly::Singleton coordinatorSingleton{}; +namespace { -IbvSymbols ibvSymbols; folly::once_flag initIbvSymbolOnce; -#define IBVERBS_VERSION "IBVERBS_1.1" - -#define MLX5DV_VERSION "MLX5_1.8" - -#ifdef IBVERBX_BUILD_RDMA_CORE -// Wrapper functions to handle type conversions between custom and real types -struct ibv_device** linked_get_device_list(int* num_devices) { - return reinterpret_cast( - ibv_get_device_list(num_devices)); -} - -void linked_free_device_list(struct ibv_device** list) { - ibv_free_device_list(reinterpret_cast<::ibv_device**>(list)); -} - -const char* linked_get_device_name(struct ibv_device* device) { - return ibv_get_device_name(reinterpret_cast<::ibv_device*>(device)); -} - -struct ibv_context* linked_open_device(struct ibv_device* device) { - return reinterpret_cast( - ibv_open_device(reinterpret_cast<::ibv_device*>(device))); -} - -int linked_close_device(struct ibv_context* context) { - return ibv_close_device(reinterpret_cast<::ibv_context*>(context)); -} - -int linked_query_device( - struct ibv_context* context, - struct ibv_device_attr* device_attr) { - return ibv_query_device( - reinterpret_cast<::ibv_context*>(context), - reinterpret_cast<::ibv_device_attr*>(device_attr)); -} - -int linked_query_port( - struct ibv_context* context, - uint8_t port_num, - struct ibv_port_attr* port_attr) { - return ibv_query_port( - reinterpret_cast<::ibv_context*>(context), - port_num, - reinterpret_cast<::ibv_port_attr*>(port_attr)); -} - -int linked_query_gid( - struct ibv_context* context, - uint8_t port_num, - int index, - union ibv_gid* gid) { - return ibv_query_gid( - reinterpret_cast<::ibv_context*>(context), - port_num, - index, - reinterpret_cast<::ibv_gid*>(gid)); -} - -struct ibv_pd* linked_alloc_pd(struct ibv_context* context) { - return reinterpret_cast( - ibv_alloc_pd(reinterpret_cast<::ibv_context*>(context))); -} - -struct ibv_pd* linked_alloc_parent_domain( - struct ibv_context* context, - struct ibv_parent_domain_init_attr* attr) { - return reinterpret_cast(ibv_alloc_parent_domain( - reinterpret_cast<::ibv_context*>(context), - reinterpret_cast<::ibv_parent_domain_init_attr*>(attr))); -} - -int linked_dealloc_pd(struct ibv_pd* pd) { - return ibv_dealloc_pd(reinterpret_cast<::ibv_pd*>(pd)); -} - -struct ibv_mr* -linked_reg_mr(struct ibv_pd* pd, void* addr, size_t length, int access) { - return reinterpret_cast( - ibv_reg_mr(reinterpret_cast<::ibv_pd*>(pd), addr, length, access)); -} - -int linked_dereg_mr(struct ibv_mr* mr) { - return ibv_dereg_mr(reinterpret_cast<::ibv_mr*>(mr)); -} - -struct ibv_cq* linked_create_cq( - struct ibv_context* context, - int cqe, - void* cq_context, - struct ibv_comp_channel* channel, - int comp_vector) { - return reinterpret_cast(ibv_create_cq( - reinterpret_cast<::ibv_context*>(context), - cqe, - cq_context, - reinterpret_cast<::ibv_comp_channel*>(channel), - comp_vector)); -} - -struct ibv_cq_ex* linked_create_cq_ex( - struct ibv_context* context, - struct ibv_cq_init_attr_ex* attr) { - return reinterpret_cast(ibv_create_cq_ex( - reinterpret_cast<::ibv_context*>(context), - reinterpret_cast<::ibv_cq_init_attr_ex*>(attr))); -} - -int linked_destroy_cq(struct ibv_cq* cq) { - return ibv_destroy_cq(reinterpret_cast<::ibv_cq*>(cq)); -} - -struct ibv_qp* linked_create_qp( - struct ibv_pd* pd, - struct ibv_qp_init_attr* qp_init_attr) { - return reinterpret_cast(ibv_create_qp( - reinterpret_cast<::ibv_pd*>(pd), - reinterpret_cast<::ibv_qp_init_attr*>(qp_init_attr))); -} - -int linked_modify_qp( - struct ibv_qp* qp, - struct ibv_qp_attr* attr, - int attr_mask) { - return ibv_modify_qp( - reinterpret_cast<::ibv_qp*>(qp), - reinterpret_cast<::ibv_qp_attr*>(attr), - attr_mask); -} - -int linked_destroy_qp(struct ibv_qp* qp) { - return ibv_destroy_qp(reinterpret_cast<::ibv_qp*>(qp)); -} - -const char* linked_event_type_str(enum ibv_event_type event) { - return ibv_event_type_str(static_cast<::ibv_event_type>(event)); -} - -int linked_get_async_event( - struct ibv_context* context, - struct ibv_async_event* event) { - return ibv_get_async_event( - reinterpret_cast<::ibv_context*>(context), - reinterpret_cast<::ibv_async_event*>(event)); -} - -void linked_ack_async_event(struct ibv_async_event* event) { - ibv_ack_async_event(reinterpret_cast<::ibv_async_event*>(event)); -} - -int linked_query_qp( - struct ibv_qp* qp, - struct ibv_qp_attr* attr, - int attr_mask, - struct ibv_qp_init_attr* init_attr) { - return ibv_query_qp( - reinterpret_cast<::ibv_qp*>(qp), - reinterpret_cast<::ibv_qp_attr*>(attr), - attr_mask, - reinterpret_cast<::ibv_qp_init_attr*>(init_attr)); -} - -struct ibv_mr* linked_reg_mr_iova2( - struct ibv_pd* pd, - void* addr, - size_t length, - uint64_t iova, - unsigned int access) { - return reinterpret_cast(ibv_reg_mr_iova2( - reinterpret_cast<::ibv_pd*>(pd), addr, length, iova, access)); -} - -struct ibv_mr* linked_reg_dmabuf_mr( - struct ibv_pd* pd, - uint64_t offset, - size_t length, - uint64_t iova, - int fd, - int access) { - return reinterpret_cast(ibv_reg_dmabuf_mr( - reinterpret_cast<::ibv_pd*>(pd), offset, length, iova, fd, access)); -} - -int linked_query_ece(struct ibv_qp* qp, struct ibv_ece* ece) { - return ibv_query_ece( - reinterpret_cast<::ibv_qp*>(qp), reinterpret_cast<::ibv_ece*>(ece)); -} - -int linked_set_ece(struct ibv_qp* qp, struct ibv_ece* ece) { - return ibv_set_ece( - reinterpret_cast<::ibv_qp*>(qp), reinterpret_cast<::ibv_ece*>(ece)); -} - -enum ibv_fork_status linked_is_fork_initialized() { - return static_cast(ibv_is_fork_initialized()); -} - -struct ibv_comp_channel* linked_create_comp_channel( - struct ibv_context* context) { - return reinterpret_cast( - ibv_create_comp_channel(reinterpret_cast<::ibv_context*>(context))); -} - -int linked_destroy_comp_channel(struct ibv_comp_channel* channel) { - return ibv_destroy_comp_channel( - reinterpret_cast<::ibv_comp_channel*>(channel)); -} - -int linked_req_notify_cq(struct ibv_cq* cq, int solicited_only) { - return ibv_req_notify_cq(reinterpret_cast<::ibv_cq*>(cq), solicited_only); -} - -int linked_get_cq_event( - struct ibv_comp_channel* channel, - struct ibv_cq** cq, - void** cq_context) { - return ibv_get_cq_event( - reinterpret_cast<::ibv_comp_channel*>(channel), - reinterpret_cast<::ibv_cq**>(cq), - cq_context); -} - -void linked_ack_cq_events(struct ibv_cq* cq, unsigned int nevents) { - ibv_ack_cq_events(reinterpret_cast<::ibv_cq*>(cq), nevents); -} - -bool linked_mlx5dv_is_supported(struct ibv_device* device) { - return mlx5dv_is_supported(reinterpret_cast<::ibv_device*>(device)); -} - -int linked_mlx5dv_init_obj(mlx5dv_obj* obj, uint64_t obj_type) { - return mlx5dv_init_obj(reinterpret_cast<::mlx5dv_obj*>(obj), obj_type); -} - -int linked_mlx5dv_get_data_direct_sysfs_path( - struct ibv_context* context, - char* buf, - size_t buf_len) { - return mlx5dv_get_data_direct_sysfs_path( - reinterpret_cast<::ibv_context*>(context), buf, buf_len); -} - -struct ibv_mr* linked_mlx5dv_reg_dmabuf_mr( - struct ibv_pd* pd, - uint64_t offset, - size_t length, - uint64_t iova, - int fd, - int access, - int mlx5_access) { - return reinterpret_cast(mlx5dv_reg_dmabuf_mr( - reinterpret_cast<::ibv_pd*>(pd), - offset, - length, - iova, - fd, - access, - mlx5_access)); -} -#endif - -bool mlx5dvDmaBufDataDirectLinkCapable( - ibv_device* device, - ibv_context* context) { - if (ibvSymbols.mlx5dv_internal_is_supported == nullptr || - ibvSymbols.mlx5dv_internal_reg_dmabuf_mr == nullptr || - ibvSymbols.mlx5dv_internal_get_data_direct_sysfs_path == nullptr) { - return false; - } - - if (!ibvSymbols.mlx5dv_internal_is_supported(device)) { - return false; - } - int dev_fail = 0; - ibv_pd* pd = nullptr; - pd = ibvSymbols.ibv_internal_alloc_pd(context); - if (!pd) { - XLOG(ERR) << "ibv_alloc_pd failed: " << folly::errnoStr(errno); - return false; - } - - // Test kernel DMA-BUF support with a dummy call (fd=-1) - (void)ibvSymbols.ibv_internal_reg_dmabuf_mr( - pd, 0ULL /*offset*/, 0ULL /*len*/, 0ULL /*iova*/, -1 /*fd*/, 0 /*flags*/); - // ibv_reg_dmabuf_mr() will fail with EOPNOTSUPP/EPROTONOSUPPORT if not - // supported (EBADF otherwise) - (void)ibvSymbols.mlx5dv_internal_reg_dmabuf_mr( - pd, - 0ULL /*offset*/, - 0ULL /*len*/, - 0ULL /*iova*/, - -1 /*fd*/, - 0 /*flags*/, - 0 /* mlx5 flags*/); - // mlx5dv_reg_dmabuf_mr() will fail with EOPNOTSUPP/EPROTONOSUPPORT if not - // supported (EBADF otherwise) - dev_fail |= (errno == EOPNOTSUPP) || (errno == EPROTONOSUPPORT); - if (ibvSymbols.ibv_internal_dealloc_pd(pd) != 0) { - XLOGF( - WARN, - "ibv_dealloc_pd failed: {} DMA-BUF support status: {}", - folly::errnoStr(errno), - dev_fail); - return false; - } - if (dev_fail) { - XLOGF(INFO, "Kernel DMA-BUF is not supported on device {}", device->name); - return false; - } - - char dataDirectDevicePath[PATH_MAX]; - snprintf(dataDirectDevicePath, PATH_MAX, "/sys"); - return ibvSymbols.mlx5dv_internal_get_data_direct_sysfs_path( - context, dataDirectDevicePath + 4, PATH_MAX - 4) == 0; -} - } // namespace -int buildIbvSymbols(IbvSymbols& symbols) { -#ifdef IBVERBX_BUILD_RDMA_CORE - // Direct linking mode - use wrapper functions to handle type conversions - symbols.ibv_internal_get_device_list = &linked_get_device_list; - symbols.ibv_internal_free_device_list = &linked_free_device_list; - symbols.ibv_internal_get_device_name = &linked_get_device_name; - symbols.ibv_internal_open_device = &linked_open_device; - symbols.ibv_internal_close_device = &linked_close_device; - symbols.ibv_internal_get_async_event = &linked_get_async_event; - symbols.ibv_internal_ack_async_event = &linked_ack_async_event; - symbols.ibv_internal_query_device = &linked_query_device; - symbols.ibv_internal_query_port = &linked_query_port; - symbols.ibv_internal_query_gid = &linked_query_gid; - symbols.ibv_internal_query_qp = &linked_query_qp; - symbols.ibv_internal_alloc_pd = &linked_alloc_pd; - symbols.ibv_internal_alloc_parent_domain = &linked_alloc_parent_domain; - symbols.ibv_internal_dealloc_pd = &linked_dealloc_pd; - symbols.ibv_internal_reg_mr = &linked_reg_mr; - - symbols.ibv_internal_reg_mr_iova2 = &linked_reg_mr_iova2; - symbols.ibv_internal_reg_dmabuf_mr = &linked_reg_dmabuf_mr; - symbols.ibv_internal_query_ece = &linked_query_ece; - symbols.ibv_internal_set_ece = &linked_set_ece; - symbols.ibv_internal_is_fork_initialized = &linked_is_fork_initialized; - - symbols.ibv_internal_dereg_mr = &linked_dereg_mr; - symbols.ibv_internal_create_cq = &linked_create_cq; - symbols.ibv_internal_create_cq_ex = &linked_create_cq_ex; - symbols.ibv_internal_destroy_cq = &linked_destroy_cq; - symbols.ibv_internal_create_comp_channel = &linked_create_comp_channel; - symbols.ibv_internal_destroy_comp_channel = &linked_destroy_comp_channel; - symbols.ibv_internal_req_notify_cq = &linked_req_notify_cq; - symbols.ibv_internal_get_cq_event = &linked_get_cq_event; - symbols.ibv_internal_ack_cq_events = &linked_ack_cq_events; - symbols.ibv_internal_create_qp = &linked_create_qp; - symbols.ibv_internal_modify_qp = &linked_modify_qp; - symbols.ibv_internal_destroy_qp = &linked_destroy_qp; - symbols.ibv_internal_fork_init = &ibv_fork_init; - symbols.ibv_internal_event_type_str = &linked_event_type_str; - - // mlx5dv symbols - symbols.mlx5dv_internal_is_supported = &linked_mlx5dv_is_supported; - symbols.mlx5dv_internal_init_obj = &linked_mlx5dv_init_obj; - symbols.mlx5dv_internal_get_data_direct_sysfs_path = - &linked_mlx5dv_get_data_direct_sysfs_path; - symbols.mlx5dv_internal_reg_dmabuf_mr = &linked_mlx5dv_reg_dmabuf_mr; - return 0; -#else - // Dynamic loading mode - use dlopen/dlsym - static void* ibvhandle = nullptr; - static void* mlx5dvhandle = nullptr; - void* tmp; - void** cast; - - // Use folly::ScopedGuard to ensure resources are cleaned up upon failure - auto guard = folly::makeGuard([&]() { - if (ibvhandle != nullptr) { - dlclose(ibvhandle); - } - if (mlx5dvhandle != nullptr) { - dlclose(mlx5dvhandle); - } - symbols = {}; // Reset all function pointers to nullptr - }); - - if (!NCCL_IBVERBS_PATH.empty()) { - ibvhandle = dlopen(NCCL_IBVERBS_PATH.c_str(), RTLD_NOW); - } - if (!ibvhandle) { - ibvhandle = dlopen("libibverbs.so.1", RTLD_NOW); - if (!ibvhandle) { - XLOG(ERR) << "Failed to open libibverbs.so.1"; - return 1; - } - } - - // Load mlx5dv symbols if available, do not abort if failed - mlx5dvhandle = dlopen("libmlx5.so", RTLD_NOW); - if (!mlx5dvhandle) { - mlx5dvhandle = dlopen("libmlx5.so.1", RTLD_NOW); - if (!mlx5dvhandle) { - XLOG(WARN) - << "Failed to open libmlx5.so[.1]. Advance features like CX-8 Direct-NIC will be disabled."; - } - } - -#define LOAD_SYM(handle, symbol, funcptr, version) \ - { \ - cast = (void**)&funcptr; \ - tmp = dlvsym(handle, symbol, version); \ - if (tmp == nullptr) { \ - XLOG(ERR) << fmt::format( \ - "dlvsym failed on {} - {} version {}", symbol, dlerror(), version); \ - return 1; \ - } \ - *cast = tmp; \ - } - -#define LOAD_SYM_WARN_ONLY(handle, symbol, funcptr, version) \ - { \ - cast = (void**)&funcptr; \ - tmp = dlvsym(handle, symbol, version); \ - if (tmp == nullptr) { \ - XLOG(WARN) << fmt::format( \ - "dlvsym failed on {} - {} version {}, set null", \ - symbol, \ - dlerror(), \ - version); \ - } \ - *cast = tmp; \ - } - -#define LOAD_IBVERBS_SYM(symbol, funcptr) \ - LOAD_SYM(ibvhandle, symbol, funcptr, IBVERBS_VERSION) - -#define LOAD_IBVERBS_SYM_VERSION(symbol, funcptr, version) \ - LOAD_SYM_WARN_ONLY(ibvhandle, symbol, funcptr, version) - -#define LOAD_IBVERBS_SYM_WARN_ONLY(symbol, funcptr) \ - LOAD_SYM_WARN_ONLY(ibvhandle, symbol, funcptr, IBVERBS_VERSION) - -// mlx5 -#define LOAD_MLX5DV_SYM(symbol, funcptr) \ - if (mlx5dvhandle != nullptr) { \ - LOAD_SYM_WARN_ONLY(mlx5dvhandle, symbol, funcptr, MLX5DV_VERSION) \ - } - -#define LOAD_MLX5DV_SYM_VERSION(symbol, funcptr, version) \ - if (mlx5dvhandle != nullptr) { \ - LOAD_SYM_WARN_ONLY(mlx5dvhandle, symbol, funcptr, version) \ - } - - LOAD_IBVERBS_SYM("ibv_get_device_list", symbols.ibv_internal_get_device_list); - LOAD_IBVERBS_SYM( - "ibv_free_device_list", symbols.ibv_internal_free_device_list); - LOAD_IBVERBS_SYM("ibv_get_device_name", symbols.ibv_internal_get_device_name); - LOAD_IBVERBS_SYM("ibv_open_device", symbols.ibv_internal_open_device); - LOAD_IBVERBS_SYM("ibv_close_device", symbols.ibv_internal_close_device); - LOAD_IBVERBS_SYM("ibv_get_async_event", symbols.ibv_internal_get_async_event); - LOAD_IBVERBS_SYM("ibv_ack_async_event", symbols.ibv_internal_ack_async_event); - LOAD_IBVERBS_SYM("ibv_query_device", symbols.ibv_internal_query_device); - LOAD_IBVERBS_SYM("ibv_query_port", symbols.ibv_internal_query_port); - LOAD_IBVERBS_SYM("ibv_query_gid", symbols.ibv_internal_query_gid); - LOAD_IBVERBS_SYM("ibv_query_qp", symbols.ibv_internal_query_qp); - LOAD_IBVERBS_SYM("ibv_alloc_pd", symbols.ibv_internal_alloc_pd); - LOAD_IBVERBS_SYM_WARN_ONLY( - "ibv_alloc_parent_domain", symbols.ibv_internal_alloc_parent_domain); - LOAD_IBVERBS_SYM("ibv_dealloc_pd", symbols.ibv_internal_dealloc_pd); - LOAD_IBVERBS_SYM("ibv_reg_mr", symbols.ibv_internal_reg_mr); - // Cherry-pick the ibv_reg_mr_iova2 API from IBVERBS 1.8 - LOAD_IBVERBS_SYM_VERSION( - "ibv_reg_mr_iova2", symbols.ibv_internal_reg_mr_iova2, "IBVERBS_1.8"); - // Cherry-pick the ibv_reg_dmabuf_mr API from IBVERBS 1.12 - LOAD_IBVERBS_SYM_VERSION( - "ibv_reg_dmabuf_mr", symbols.ibv_internal_reg_dmabuf_mr, "IBVERBS_1.12"); - LOAD_IBVERBS_SYM("ibv_dereg_mr", symbols.ibv_internal_dereg_mr); - LOAD_IBVERBS_SYM("ibv_create_cq", symbols.ibv_internal_create_cq); - LOAD_IBVERBS_SYM("ibv_destroy_cq", symbols.ibv_internal_destroy_cq); - LOAD_IBVERBS_SYM("ibv_create_qp", symbols.ibv_internal_create_qp); - LOAD_IBVERBS_SYM("ibv_modify_qp", symbols.ibv_internal_modify_qp); - LOAD_IBVERBS_SYM("ibv_destroy_qp", symbols.ibv_internal_destroy_qp); - LOAD_IBVERBS_SYM("ibv_fork_init", symbols.ibv_internal_fork_init); - LOAD_IBVERBS_SYM("ibv_event_type_str", symbols.ibv_internal_event_type_str); - - LOAD_IBVERBS_SYM_VERSION( - "ibv_create_comp_channel", - symbols.ibv_internal_create_comp_channel, - "IBVERBS_1.0"); - LOAD_IBVERBS_SYM_VERSION( - "ibv_destroy_comp_channel", - symbols.ibv_internal_destroy_comp_channel, - "IBVERBS_1.0"); - LOAD_IBVERBS_SYM_VERSION( - "ibv_get_cq_event", symbols.ibv_internal_get_cq_event, "IBVERBS_1.0"); - LOAD_IBVERBS_SYM_VERSION( - "ibv_ack_cq_events", symbols.ibv_internal_ack_cq_events, "IBVERBS_1.0"); - // TODO: ibv_req_notify_cq is found not in any version of IBVERBS - LOAD_IBVERBS_SYM_VERSION( - "ibv_req_notify_cq", symbols.ibv_internal_req_notify_cq, "IBVERBS_1.0"); - LOAD_IBVERBS_SYM_VERSION( - "ibv_query_ece", symbols.ibv_internal_query_ece, "IBVERBS_1.10"); - LOAD_IBVERBS_SYM_VERSION( - "ibv_set_ece", symbols.ibv_internal_set_ece, "IBVERBS_1.10"); - LOAD_IBVERBS_SYM_VERSION( - "ibv_is_fork_initialized", - symbols.ibv_internal_is_fork_initialized, - "IBVERBS_1.13"); - - LOAD_MLX5DV_SYM("mlx5dv_is_supported", symbols.mlx5dv_internal_is_supported); - // Cherry-pick the mlx5dv_get_data_direct_sysfs_path API from MLX5 1.2 - LOAD_MLX5DV_SYM_VERSION( - "mlx5dv_init_obj", symbols.mlx5dv_internal_init_obj, "MLX5_1.2"); - // Cherry-pick the mlx5dv_get_data_direct_sysfs_path API from MLX5 1.25 - LOAD_MLX5DV_SYM_VERSION( - "mlx5dv_get_data_direct_sysfs_path", - symbols.mlx5dv_internal_get_data_direct_sysfs_path, - "MLX5_1.25"); - // Cherry-pick the ibv_reg_dmabuf_mr API from MLX5 1.25 - LOAD_MLX5DV_SYM_VERSION( - "mlx5dv_reg_dmabuf_mr", - symbols.mlx5dv_internal_reg_dmabuf_mr, - "MLX5_1.25"); - - // all symbols were loaded successfully, dismiss guard - guard.dismiss(); - return 0; -#endif -} - folly::Expected ibvInit() { static std::atomic errNum{1}; - folly::call_once( - initIbvSymbolOnce, [&]() { errNum = buildIbvSymbols(ibvSymbols); }); + folly::call_once(initIbvSymbolOnce, [&]() { + errNum = buildIbvSymbols(ibvSymbols, NCCL_IBVERBS_PATH); + }); if (errNum != 0) { return folly::makeUnexpected(Error(errNum)); } @@ -574,1128 +50,6 @@ void ibvAckCqEvents(ibv_cq* cq, unsigned int nevents) { ibvSymbols.ibv_internal_ack_cq_events(cq, nevents); } -/*** Error ***/ - -Error::Error() : errNum(errno), errStr(folly::errnoStr(errno)) {} -Error::Error(int errNum) : errNum(errNum), errStr(folly::errnoStr(errNum)) {} -Error::Error(int errNum, std::string errStr) - : errNum(errNum), errStr(std::move(errStr)) {} - -std::ostream& operator<<(std::ostream& out, Error const& err) { - out << err.errStr << " (errno=" << err.errNum << ")"; - return out; -} - -/*** IbvMr ***/ - -IbvMr::IbvMr(ibv_mr* mr) : mr_(mr) {} - -IbvMr::IbvMr(IbvMr&& other) noexcept { - mr_ = other.mr_; - other.mr_ = nullptr; -} - -IbvMr& IbvMr::operator=(IbvMr&& other) noexcept { - mr_ = other.mr_; - other.mr_ = nullptr; - return *this; -} - -IbvMr::~IbvMr() { - if (mr_) { - int rc = ibvSymbols.ibv_internal_dereg_mr(mr_); - if (rc != 0) { - XLOGF(ERR, "Failed to deregister mr rc: {}, {}", rc, strerror(errno)); - } - } -} - -ibv_mr* IbvMr::mr() const { - return mr_; -} - -/*** IbvPd ***/ - -IbvPd::IbvPd(ibv_pd* pd, bool dataDirect) : pd_(pd), dataDirect_(dataDirect) {} - -IbvPd::IbvPd(IbvPd&& other) noexcept { - pd_ = other.pd_; - dataDirect_ = other.dataDirect_; - other.pd_ = nullptr; -} - -IbvPd& IbvPd::operator=(IbvPd&& other) noexcept { - pd_ = other.pd_; - dataDirect_ = other.dataDirect_; - other.pd_ = nullptr; - return *this; -} - -IbvPd::~IbvPd() { - if (pd_) { - int rc = ibvSymbols.ibv_internal_dealloc_pd(pd_); - if (rc != 0) { - XLOGF(ERR, "Failed to deallocate pd rc: {}, {}", rc, strerror(errno)); - } - } -} - -ibv_pd* IbvPd::pd() const { - return pd_; -} - -bool IbvPd::useDataDirect() const { - return dataDirect_; -} - -folly::Expected -IbvPd::regMr(void* addr, size_t length, ibv_access_flags access) const { - ibv_mr* mr; - mr = ibvSymbols.ibv_internal_reg_mr(pd_, addr, length, access); - if (!mr) { - return folly::makeUnexpected(Error(errno)); - } - return IbvMr(mr); -} - -folly::Expected IbvPd::regDmabufMr( - uint64_t offset, - size_t length, - uint64_t iova, - int fd, - ibv_access_flags access) const { - ibv_mr* mr; - if (dataDirect_) { - mr = ibvSymbols.mlx5dv_internal_reg_dmabuf_mr( - pd_, - offset, - length, - iova, - fd, - access, - MLX5DV_REG_DMABUF_ACCESS_DATA_DIRECT); - } else { - mr = ibvSymbols.ibv_internal_reg_dmabuf_mr( - pd_, offset, length, iova, fd, access); - } - if (!mr) { - return folly::makeUnexpected(Error(errno)); - } - return IbvMr(mr); -} - -folly::Expected IbvPd::createQp( - ibv_qp_init_attr* initAttr) const { - ibv_qp* qp; - qp = ibvSymbols.ibv_internal_create_qp(pd_, initAttr); - if (!qp) { - return folly::makeUnexpected(Error(errno)); - } - return IbvQp(qp); -} - -folly::Expected IbvPd::createVirtualQp( - int totalQps, - ibv_qp_init_attr* initAttr, - IbvVirtualCq* sendCq, - IbvVirtualCq* recvCq, - int maxMsgCntPerQp, - int maxMsgSize, - LoadBalancingScheme loadBalancingScheme) const { - std::vector qps; - qps.reserve(totalQps); - - if (sendCq == nullptr) { - return folly::makeUnexpected( - Error(EINVAL, "Empty sendCq being provided to createVirtualQp")); - } - - if (recvCq == nullptr) { - return folly::makeUnexpected( - Error(EINVAL, "Empty recvCq being provided to createVirtualQp")); - } - - // Overwrite the CQs in the initAttr to point to the virtual CQ - initAttr->send_cq = sendCq->getPhysicalCqRef().cq(); - initAttr->recv_cq = recvCq->getPhysicalCqRef().cq(); - - // First create all the data QPs - for (int i = 0; i < totalQps; i++) { - auto maybeQp = createQp(initAttr); - if (maybeQp.hasError()) { - return folly::makeUnexpected(maybeQp.error()); - } - qps.emplace_back(std::move(*maybeQp)); - } - - // Create notify QP - auto maybeNotifyQp = createQp(initAttr); - if (maybeNotifyQp.hasError()) { - return folly::makeUnexpected(maybeNotifyQp.error()); - } - - // Create the IbvVirtualQp instance, with coordinator registartion happens - // within IbvVirtualQp constructor - return IbvVirtualQp( - std::move(qps), - std::move(*maybeNotifyQp), - sendCq, - recvCq, - maxMsgCntPerQp, - maxMsgSize, - loadBalancingScheme); -} - -/*** IbvCq ***/ - -IbvCq::IbvCq(ibv_cq* cq) : cq_(cq) {} - -IbvCq::~IbvCq() { - if (cq_) { - int rc = ibvSymbols.ibv_internal_destroy_cq(cq_); - if (rc != 0) { - XLOGF(ERR, "Failed to destroy cq rc: {}, {}", rc, strerror(errno)); - } - } -} - -IbvCq::IbvCq(IbvCq&& other) noexcept { - cq_ = other.cq_; - other.cq_ = nullptr; -} - -IbvCq& IbvCq::operator=(IbvCq&& other) noexcept { - cq_ = other.cq_; - other.cq_ = nullptr; - return *this; -} - -ibv_cq* IbvCq::cq() const { - return cq_; -} - -folly::Expected IbvCq::reqNotifyCq( - int solicited_only) const { - int rc = ibvSymbols.ibv_internal_req_notify_cq(cq_, solicited_only); - if (rc != 0) { - return folly::makeUnexpected(Error(rc)); - } - return folly::unit; -} - -/*** IbvDevice ***/ - -// hcaList format examples: -// - Without port: "mlx5_0,mlx5_1,mlx5_2" -// - With port: "mlx5_0:1,mlx5_1:0,mlx5_2:1" -// - Prefix match: "mlx5" -// hcaPrefix: use "=" for exact match, "^" for exclude match, "" for prefix -// match. See guidelines: -// https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#nccl-ib-hca -folly::Expected, Error> IbvDevice::ibvGetDeviceList( - const std::vector& hcaList, - const std::string& hcaPrefix, - int defaultPort) { - // Get device list - ibv_device** devs{nullptr}; - int numDevs; - devs = ibvSymbols.ibv_internal_get_device_list(&numDevs); - if (!devs) { - return folly::makeUnexpected(Error(errno)); - } - auto devices = - ibvFilterDeviceList(numDevs, devs, hcaList, hcaPrefix, defaultPort); - // Free device list - ibvSymbols.ibv_internal_free_device_list(devs); - return devices; -} - -std::vector IbvDevice::ibvFilterDeviceList( - int numDevs, - ibv_device** devs, - const std::vector& hcaList, - const std::string& hcaPrefix, - int defaultPort) { - std::vector devices; - - if (hcaList.empty()) { - devices.reserve(numDevs); - for (int i = 0; i < numDevs; i++) { - devices.emplace_back(devs[i], defaultPort); - } - return devices; - } - - // Convert the provided list of HCA strings into a vector of RoceHca - // objects, which enables efficient device filter operation - std::vector hcas; - // Avoid copy triggered by resize - hcas.reserve(hcaList.size()); - for (const auto& hca : hcaList) { - // Copy value to each vector element so it can be freed automatically - hcas.emplace_back(hca, defaultPort); - } - - // Filter devices - if (hcaPrefix == "=") { - for (const auto& hca : hcas) { - for (int i = 0; i < numDevs; i++) { - if (hca.name == devs[i]->name) { - devices.emplace_back(devs[i], hca.port); - break; - } - } - } - return devices; - } else if (hcaPrefix == "^") { - for (const auto& hca : hcas) { - for (int i = 0; i < numDevs; i++) { - if (hca.name != devs[i]->name) { - devices.emplace_back(devs[i], defaultPort); - break; - } - } - } - return devices; - } else { - // Prefix match - for (const auto& hca : hcas) { - for (int i = 0; i < numDevs; i++) { - if (strncmp(devs[i]->name, hca.name.c_str(), hca.name.length()) == 0) { - devices.emplace_back(devs[i], hca.port); - break; - } - } - } - return devices; - } -} - -IbvDevice::IbvDevice(ibv_device* ibvDevice, int port) : device_(ibvDevice) { - port_ = port; - context_ = ibvSymbols.ibv_internal_open_device(device_); - if (!context_) { - XLOGF(ERR, "Failed to open device {}", device_->name); - throw std::runtime_error( - fmt::format("Failed to open device {}", device_->name)); - } - if ((mlx5dvDmaBufDataDirectLinkCapable(device_, context_))) { - // Now check whether Data Direct has been disabled by the user - dataDirect_ = NCCL_IB_DATA_DIRECT == 1; - XLOGF( - INFO, - "NET/IB: Data Direct DMA Interface is detected for device: {} dataDirect: {}", - device_->name, - dataDirect_); - } -} - -IbvDevice::~IbvDevice() { - if (context_) { - int rc = ibvSymbols.ibv_internal_close_device(context_); - if (rc != 0) { - XLOGF(ERR, "Failed to close device rc: {}, {}", rc, strerror(errno)); - } - } -} - -IbvDevice::IbvDevice(IbvDevice&& other) noexcept { - device_ = other.device_; - context_ = other.context_; - port_ = other.port_; - dataDirect_ = other.dataDirect_; - - other.device_ = nullptr; - other.context_ = nullptr; -} - -IbvDevice& IbvDevice::operator=(IbvDevice&& other) noexcept { - device_ = other.device_; - context_ = other.context_; - port_ = other.port_; - dataDirect_ = other.dataDirect_; - - other.device_ = nullptr; - other.context_ = nullptr; - return *this; -} - -ibv_device* IbvDevice::device() const { - return device_; -} - -ibv_context* IbvDevice::context() const { - return context_; -} - -int IbvDevice::port() const { - return port_; -} - -folly::Expected IbvDevice::allocPd() { - ibv_pd* pd; - pd = ibvSymbols.ibv_internal_alloc_pd(context_); - if (!pd) { - return folly::makeUnexpected(Error(errno)); - } - return IbvPd(pd, dataDirect_); -} - -folly::Expected IbvDevice::allocParentDomain( - ibv_parent_domain_init_attr* attr) { - ibv_pd* pd; - - if (ibvSymbols.ibv_internal_alloc_parent_domain == nullptr) { - return folly::makeUnexpected(Error(ENOSYS)); - } - - pd = ibvSymbols.ibv_internal_alloc_parent_domain(context_, attr); - - if (!pd) { - return folly::makeUnexpected(Error(errno)); - } - return IbvPd(pd, dataDirect_); -} - -folly::Expected IbvDevice::queryDevice() const { - ibv_device_attr deviceAttr{}; - int rc = ibvSymbols.ibv_internal_query_device(context_, &deviceAttr); - if (rc != 0) { - return folly::makeUnexpected(Error(rc)); - } - return deviceAttr; -} - -folly::Expected IbvDevice::queryPort( - uint8_t portNum) const { - ibv_port_attr portAttr{}; - int rc = ibvSymbols.ibv_internal_query_port(context_, portNum, &portAttr); - if (rc != 0) { - return folly::makeUnexpected(Error(rc)); - } - return portAttr; -} - -folly::Expected IbvDevice::queryGid( - uint8_t portNum, - int gidIndex) const { - ibv_gid gid{}; - int rc = ibvSymbols.ibv_internal_query_gid(context_, portNum, gidIndex, &gid); - if (rc != 0) { - return folly::makeUnexpected(Error(rc)); - } - return gid; -} - -folly::Expected IbvDevice::createCq( - int cqe, - void* cq_context, - ibv_comp_channel* channel, - int comp_vector) const { - ibv_cq* cq; - cq = ibvSymbols.ibv_internal_create_cq( - context_, cqe, cq_context, channel, comp_vector); - if (!cq) { - return folly::makeUnexpected(Error(errno)); - } - return IbvCq(cq); -} - -folly::Expected IbvDevice::createVirtualCq( - int cqe, - void* cq_context, - ibv_comp_channel* channel, - int comp_vector) { - auto maybeCq = createCq(cqe, cq_context, channel, comp_vector); - if (maybeCq.hasError()) { - return folly::makeUnexpected(maybeCq.error()); - } - return IbvVirtualCq(std::move(*maybeCq), cqe); -} - -folly::Expected IbvDevice::createCq( - ibv_cq_init_attr_ex* attr) const { - ibv_cq_ex* cqEx; - cqEx = ibvSymbols.ibv_internal_create_cq_ex(context_, attr); - if (!cqEx) { - return folly::makeUnexpected(Error(errno)); - } - ibv_cq* cq = ibv_cq_ex_to_cq(cqEx); - return IbvCq(cq); -} - -folly::Expected IbvDevice::createCompChannel() const { - ibv_comp_channel* channel; - channel = ibvSymbols.ibv_internal_create_comp_channel(context_); - if (!channel) { - return folly::makeUnexpected(Error(errno)); - } - return channel; -} - -folly::Expected IbvDevice::destroyCompChannel( - ibv_comp_channel* channel) const { - int rc = ibvSymbols.ibv_internal_destroy_comp_channel(channel); - if (rc != 0) { - return folly::makeUnexpected(Error(rc)); - } - return folly::unit; -} - -folly::Expected IbvDevice::isPortActive( - uint8_t portNum, - std::unordered_set linkLayers) const { - auto maybePortAttr = queryPort(portNum); - if (maybePortAttr.hasError()) { - return folly::makeUnexpected(maybePortAttr.error()); - } - - auto portAttr = maybePortAttr.value(); - - // Check if port is active - if (portAttr.state != IBV_PORT_ACTIVE) { - return false; - } - - // Check if link layer matches (if specified) - if (!linkLayers.empty() && - linkLayers.find(portAttr.link_layer) == linkLayers.end()) { - return false; - } - - return true; -} - -folly::Expected IbvDevice::findActivePort( - std::unordered_set const& linkLayers) const { - // If specific port requested, check if it is active - if (port_ != kIbAnyPort) { - auto maybeActive = isPortActive(port_, linkLayers); - if (maybeActive.hasError()) { - return folly::makeUnexpected(maybeActive.error()); - } - - if (!maybeActive.value()) { - return folly::makeUnexpected(Error( - EINVAL, - fmt::format( - "Port {} is not active on device {}", port_, device_->name))); - } - return port_; - } - - // No specific port requested, find any active port - auto maybeDeviceAttr = queryDevice(); - if (maybeDeviceAttr.hasError()) { - return folly::makeUnexpected(maybeDeviceAttr.error()); - } - - for (uint8_t port = 1; port <= maybeDeviceAttr->phys_port_cnt; port++) { - auto maybeActive = isPortActive(port, linkLayers); - if (maybeActive.hasError()) { - continue; // Skip ports we can't query - } - - if (maybeActive.value()) { - return port; - } - } - - return folly::makeUnexpected(Error( - ENODEV, fmt::format("No active port found on device {}", device_->name))); -} - -/*** IbvQp ***/ -IbvQp::IbvQp(ibv_qp* qp) : qp_(qp) {} - -IbvQp::~IbvQp() { - if (qp_) { - int rc = ibvSymbols.ibv_internal_destroy_qp(qp_); - if (rc != 0) { - XLOGF(ERR, "Failed to destroy qp rc: {}, {}", rc, strerror(errno)); - } - } -} - -IbvQp::IbvQp(IbvQp&& other) noexcept { - qp_ = other.qp_; - physicalSendWrStatus_ = std::move(other.physicalSendWrStatus_); - physicalRecvWrStatus_ = std::move(other.physicalRecvWrStatus_); - other.qp_ = nullptr; -} - -IbvQp& IbvQp::operator=(IbvQp&& other) noexcept { - qp_ = other.qp_; - physicalSendWrStatus_ = std::move(other.physicalSendWrStatus_); - physicalRecvWrStatus_ = std::move(other.physicalRecvWrStatus_); - other.qp_ = nullptr; - return *this; -} - -ibv_qp* IbvQp::qp() const { - return qp_; -} - -folly::Expected IbvQp::modifyQp( - ibv_qp_attr* attr, - int attrMask) { - int rc = ibvSymbols.ibv_internal_modify_qp(qp_, attr, attrMask); - if (rc != 0) { - return folly::makeUnexpected(Error(rc)); - } - return folly::unit; -} - -folly::Expected, Error> IbvQp::queryQp( - int attrMask) const { - ibv_qp_attr qpAttr{}; - ibv_qp_init_attr initAttr{}; - int rc = ibvSymbols.ibv_internal_query_qp(qp_, &qpAttr, attrMask, &initAttr); - if (rc != 0) { - return folly::makeUnexpected(Error(rc)); - } - return std::make_pair(qpAttr, initAttr); -} - -void IbvQp::enquePhysicalSendWrStatus(int physicalWrId, int virtualWrId) { - physicalSendWrStatus_.emplace_back(physicalWrId, virtualWrId); -} - -void IbvQp::dequePhysicalSendWrStatus() { - physicalSendWrStatus_.pop_front(); -} - -void IbvQp::dequePhysicalRecvWrStatus() { - physicalRecvWrStatus_.pop_front(); -} - -void IbvQp::enquePhysicalRecvWrStatus(int physicalWrId, int virtualWrId) { - physicalRecvWrStatus_.emplace_back(physicalWrId, virtualWrId); -} - -bool IbvQp::isSendQueueAvailable(int maxMsgCntPerQp) const { - if (maxMsgCntPerQp < 0) { - return true; - } - return physicalSendWrStatus_.size() < maxMsgCntPerQp; -} - -bool IbvQp::isRecvQueueAvailable(int maxMsgCntPerQp) const { - if (maxMsgCntPerQp < 0) { - return true; - } - return physicalRecvWrStatus_.size() < maxMsgCntPerQp; -} - -/*** IbvVirtualQp ***/ - -IbvVirtualQp::IbvVirtualQp( - std::vector&& qps, - IbvQp&& notifyQp, - IbvVirtualCq* sendCq, - IbvVirtualCq* recvCq, - int maxMsgCntPerQp, - int maxMsgSize, - LoadBalancingScheme loadBalancingScheme) - : physicalQps_(std::move(qps)), - maxMsgCntPerQp_(maxMsgCntPerQp), - maxMsgSize_(maxMsgSize), - loadBalancingScheme_(loadBalancingScheme), - notifyQp_(std::move(notifyQp)) { - virtualQpNum_ = - nextVirtualQpNum_.fetch_add(1); // Assign unique virtual QP number - - for (int i = 0; i < physicalQps_.size(); i++) { - qpNumToIdx_[physicalQps_.at(i).qp()->qp_num] = i; - } - - // Register the virtual QP and all its mappings with the coordinator - auto coordinator = Coordinator::getCoordinator(); - CHECK(coordinator) - << "Coordinator should not be nullptr during IbvVirtualQp construction!"; - - // Use the consolidated registration API - coordinator->registerVirtualQpWithVirtualCqMappings( - this, sendCq->getVirtualCqNum(), recvCq->getVirtualCqNum()); -} - -size_t IbvVirtualQp::getTotalQps() const { - return physicalQps_.size(); -} - -const std::vector& IbvVirtualQp::getQpsRef() const { - return physicalQps_; -} - -std::vector& IbvVirtualQp::getQpsRef() { - return physicalQps_; -} - -const IbvQp& IbvVirtualQp::getNotifyQpRef() const { - return notifyQp_; -} - -uint32_t IbvVirtualQp::getVirtualQpNum() const { - return virtualQpNum_; -} - -IbvVirtualQp::IbvVirtualQp(IbvVirtualQp&& other) noexcept - : pendingSendVirtualWrQue_(std::move(other.pendingSendVirtualWrQue_)), - pendingRecvVirtualWrQue_(std::move(other.pendingRecvVirtualWrQue_)), - virtualQpNum_(std::move(other.virtualQpNum_)), - physicalQps_(std::move(other.physicalQps_)), - qpNumToIdx_(std::move(other.qpNumToIdx_)), - nextSendPhysicalQpIdx_(std::move(other.nextSendPhysicalQpIdx_)), - nextRecvPhysicalQpIdx_(std::move(other.nextRecvPhysicalQpIdx_)), - maxMsgCntPerQp_(std::move(other.maxMsgCntPerQp_)), - maxMsgSize_(std::move(other.maxMsgSize_)), - nextPhysicalWrId_(std::move(other.nextPhysicalWrId_)), - loadBalancingScheme_(std::move(other.loadBalancingScheme_)), - pendingSendNotifyVirtualWrQue_( - std::move(other.pendingSendNotifyVirtualWrQue_)), - notifyQp_(std::move(other.notifyQp_)), - dqplbSeqTracker(std::move(other.dqplbSeqTracker)), - dqplbReceiverInitialized_(std::move(other.dqplbReceiverInitialized_)) { - // Update coordinator pointer mapping for this virtual QP after move - auto coordinator = Coordinator::getCoordinator(); - CHECK(coordinator) - << "Coordinator should not be nullptr during IbvVirtualQp move construction!"; - coordinator->updateVirtualQpPointer(virtualQpNum_, this); -} - -IbvVirtualQp& IbvVirtualQp::operator=(IbvVirtualQp&& other) noexcept { - if (this != &other) { - physicalQps_ = std::move(other.physicalQps_); - notifyQp_ = std::move(other.notifyQp_); - nextSendPhysicalQpIdx_ = std::move(other.nextSendPhysicalQpIdx_); - nextRecvPhysicalQpIdx_ = std::move(other.nextRecvPhysicalQpIdx_); - qpNumToIdx_ = std::move(other.qpNumToIdx_); - maxMsgCntPerQp_ = std::move(other.maxMsgCntPerQp_); - maxMsgSize_ = std::move(other.maxMsgSize_); - loadBalancingScheme_ = std::move(other.loadBalancingScheme_); - pendingSendVirtualWrQue_ = std::move(other.pendingSendVirtualWrQue_); - pendingRecvVirtualWrQue_ = std::move(other.pendingRecvVirtualWrQue_); - virtualQpNum_ = std::move(other.virtualQpNum_); - nextPhysicalWrId_ = std::move(other.nextPhysicalWrId_); - pendingSendNotifyVirtualWrQue_ = - std::move(other.pendingSendNotifyVirtualWrQue_); - dqplbSeqTracker = std::move(other.dqplbSeqTracker); - dqplbReceiverInitialized_ = std::move(other.dqplbReceiverInitialized_); - - // Update coordinator pointer mapping for this virtual QP after move - auto coordinator = Coordinator::getCoordinator(); - CHECK(coordinator) - << "Coordinator should not be nullptr during IbvVirtualQp move construction!"; - coordinator->updateVirtualQpPointer(virtualQpNum_, this); - } - return *this; -} - -IbvVirtualQp::~IbvVirtualQp() { - // Always call unregister - the coordinator will check if the pointer matches - // and do nothing if the object was moved - auto coordinator = Coordinator::getCoordinator(); - CHECK(coordinator) - << "Coordinator should not be nullptr during IbvVirtualQp destruction!"; - coordinator->unregisterVirtualQp(virtualQpNum_, this); -} - -folly::Expected IbvVirtualQp::modifyVirtualQp( - ibv_qp_attr* attr, - int attrMask, - const IbvVirtualQpBusinessCard& businessCard) { - // If businessCard is not empty, use it to modify QPs with specific - // dest_qp_num values - if (!businessCard.qpNums_.empty()) { - // Make sure the businessCard has the same number of QPs as physicalQps_ - if (businessCard.qpNums_.size() != physicalQps_.size()) { - return folly::makeUnexpected(Error( - EINVAL, "BusinessCard QP count doesn't match physical QP count")); - } - - // Modify each QP with its corresponding dest_qp_num from the businessCard - for (auto i = 0; i < physicalQps_.size(); i++) { - attr->dest_qp_num = businessCard.qpNums_.at(i); - auto maybeModifyQp = physicalQps_.at(i).modifyQp(attr, attrMask); - if (maybeModifyQp.hasError()) { - return folly::makeUnexpected(maybeModifyQp.error()); - } - } - attr->dest_qp_num = businessCard.notifyQpNum_; - auto maybeModifyQp = notifyQp_.modifyQp(attr, attrMask); - if (maybeModifyQp.hasError()) { - return folly::makeUnexpected(maybeModifyQp.error()); - } - } else { - // If no businessCard provided, modify all QPs with the same attributes - for (auto& qp : physicalQps_) { - auto maybeModifyQp = qp.modifyQp(attr, attrMask); - if (maybeModifyQp.hasError()) { - return folly::makeUnexpected(maybeModifyQp.error()); - } - } - auto maybeModifyQp = notifyQp_.modifyQp(attr, attrMask); - if (maybeModifyQp.hasError()) { - return folly::makeUnexpected(maybeModifyQp.error()); - } - } - return folly::unit; -} - -IbvVirtualQpBusinessCard IbvVirtualQp::getVirtualQpBusinessCard() const { - std::vector qpNums; - qpNums.reserve(physicalQps_.size()); - for (auto& qp : physicalQps_) { - qpNums.push_back(qp.qp()->qp_num); - } - return IbvVirtualQpBusinessCard(std::move(qpNums), notifyQp_.qp()->qp_num); -} - -LoadBalancingScheme IbvVirtualQp::getLoadBalancingScheme() const { - return loadBalancingScheme_; -} - -/*** IbvVirtualQpBusinessCard ***/ - -IbvVirtualQpBusinessCard::IbvVirtualQpBusinessCard( - std::vector qpNums, - uint32_t notifyQpNum) - : qpNums_(std::move(qpNums)), notifyQpNum_(notifyQpNum) {} - -folly::dynamic IbvVirtualQpBusinessCard::toDynamic() const { - folly::dynamic obj = folly::dynamic::object; - folly::dynamic qpNumsArray = folly::dynamic::array; - - // Use fixed-width string formatting to ensure consistent size - // All uint32_t values will be formatted as 10-digit zero-padded strings - for (const auto& qpNum : qpNums_) { - std::string paddedQpNum = fmt::format("{:010d}", qpNum); - qpNumsArray.push_back(paddedQpNum); - } - - obj["qpNums"] = std::move(qpNumsArray); - obj["notifyQpNum"] = fmt::format("{:010d}", notifyQpNum_); - return obj; -} - -folly::Expected -IbvVirtualQpBusinessCard::fromDynamic(const folly::dynamic& obj) { - std::vector qpNums; - - if (obj.count("qpNums") > 0 && obj["qpNums"].isArray()) { - const auto& qpNumsArray = obj["qpNums"]; - qpNums.reserve(qpNumsArray.size()); - - for (const auto& qpNum : qpNumsArray) { - CHECK(qpNum.isString()) << "qp num is not string!"; - try { - uint32_t qpNumValue = - static_cast(std::stoul(qpNum.asString())); - qpNums.push_back(qpNumValue); - } catch (const std::exception& e) { - return folly::makeUnexpected(Error( - EINVAL, - fmt::format( - "Invalid QP number string format: {}. Exception: {}", - qpNum.asString(), - e.what()))); - } - } - } else { - return folly::makeUnexpected( - Error(EINVAL, "Invalid qpNums array received from remote side")); - } - - uint32_t notifyQpNum = 0; // Default value for backwards compatibility - if (obj.count("notifyQpNum") > 0 && obj["notifyQpNum"].isString()) { - try { - notifyQpNum = - static_cast(std::stoul(obj["notifyQpNum"].asString())); - } catch (const std::exception& e) { - return folly::makeUnexpected(Error( - EINVAL, - fmt::format( - "Invalid notifyQpNum string format: {}. Exception: {}", - obj["notifyQpNum"].asString(), - e.what()))); - } - } - - return IbvVirtualQpBusinessCard(std::move(qpNums), notifyQpNum); -} - -std::string IbvVirtualQpBusinessCard::serialize() const { - return folly::toJson(toDynamic()); -} - -folly::Expected -IbvVirtualQpBusinessCard::deserialize(const std::string& jsonStr) { - try { - folly::dynamic obj = folly::parseJson(jsonStr); - return fromDynamic(obj); - } catch (const std::exception& e) { - return folly::makeUnexpected(Error( - EINVAL, - fmt::format( - "Failed to parse JSON in IbvVirtualQpBusinessCard Deserialize. Exception: {}", - e.what()))); - } -} - -/*** IbvVirtualCq ***/ - -IbvVirtualCq::IbvVirtualCq(IbvCq&& physicalCq, int maxCqe) - : physicalCq_(std::move(physicalCq)), maxCqe_(maxCqe) { - virtualCqNum_ = - nextVirtualCqNum_.fetch_add(1); // Assign unique virtual CQ number - - // Register the virtual CQ with Coordinator - auto coordinator = Coordinator::getCoordinator(); - CHECK(coordinator) - << "Coordinator should not be nullptr during IbvVirtualCq construction!"; - coordinator->registerVirtualCq(virtualCqNum_, this); -} - -IbvVirtualCq::IbvVirtualCq(IbvVirtualCq&& other) noexcept { - physicalCq_ = std::move(other.physicalCq_); - pendingSendVirtualWcQue_ = std::move(other.pendingSendVirtualWcQue_); - pendingRecvVirtualWcQue_ = std::move(other.pendingRecvVirtualWcQue_); - maxCqe_ = other.maxCqe_; - virtualWrIdToVirtualWc_ = std::move(other.virtualWrIdToVirtualWc_); - virtualCqNum_ = other.virtualCqNum_; - - // Update coordinator pointer mapping for this virtual CQ after move - auto coordinator = Coordinator::getCoordinator(); - CHECK(coordinator) - << "Coordinator should not be nullptr during IbvVirtualCq move construction!"; - coordinator->updateVirtualCqPointer(virtualCqNum_, this); -} - -IbvVirtualCq& IbvVirtualCq::operator=(IbvVirtualCq&& other) noexcept { - if (this != &other) { - physicalCq_ = std::move(other.physicalCq_); - pendingSendVirtualWcQue_ = std::move(other.pendingSendVirtualWcQue_); - pendingRecvVirtualWcQue_ = std::move(other.pendingRecvVirtualWcQue_); - maxCqe_ = other.maxCqe_; - virtualWrIdToVirtualWc_ = std::move(other.virtualWrIdToVirtualWc_); - virtualCqNum_ = other.virtualCqNum_; - - // Update coordinator pointer mapping for this virtual CQ after move - auto coordinator = Coordinator::getCoordinator(); - CHECK(coordinator) - << "Coordinator should not be nullptr during IbvVirtualCq move construction!"; - coordinator->updateVirtualCqPointer(virtualCqNum_, this); - } - return *this; -} - -IbvCq& IbvVirtualCq::getPhysicalCqRef() { - return physicalCq_; -} - -uint32_t IbvVirtualCq::getVirtualCqNum() const { - return virtualCqNum_; -} - -void IbvVirtualCq::enqueSendCq(VirtualWc virtualWc) { - pendingSendVirtualWcQue_.push_back(std::move(virtualWc)); -} - -void IbvVirtualCq::enqueRecvCq(VirtualWc virtualWc) { - pendingRecvVirtualWcQue_.push_back(std::move(virtualWc)); -} - -IbvVirtualCq::~IbvVirtualCq() { - // Always call unregister - the coordinator will check if the pointer matches - // and do nothing if the object was moved - auto coordinator = Coordinator::getCoordinator(); - CHECK(coordinator) - << "Coordinator should not be nullptr during IbvVirtualCq destruction!"; - coordinator->unregisterVirtualCq(virtualCqNum_, this); -} - -/*** Coordinator ***/ - -std::shared_ptr Coordinator::getCoordinator() { - return coordinatorSingleton.try_get(); -} - -// Register APIs for mapping management -void Coordinator::registerVirtualQp( - uint32_t virtualQpNum, - IbvVirtualQp* virtualQp) { - virtualQpNumToVirtualQp_[virtualQpNum] = virtualQp; -} - -void Coordinator::registerVirtualCq( - uint32_t virtualCqNum, - IbvVirtualCq* virtualCq) { - virtualCqNumToVirtualCq_[virtualCqNum] = virtualCq; -} - -void Coordinator::registerPhysicalQpToVirtualQp( - int physicalQpNum, - uint32_t virtualQpNum) { - physicalQpNumToVirtualQpNum_[physicalQpNum] = virtualQpNum; -} - -void Coordinator::registerVirtualQpToVirtualSendCq( - uint32_t virtualQpNum, - uint32_t virtualSendCqNum) { - virtualQpNumToVirtualSendCqNum_[virtualQpNum] = virtualSendCqNum; -} - -void Coordinator::registerVirtualQpToVirtualRecvCq( - uint32_t virtualQpNum, - uint32_t virtualRecvCqNum) { - virtualQpNumToVirtualRecvCqNum_[virtualQpNum] = virtualRecvCqNum; -} - -void Coordinator::registerVirtualQpWithVirtualCqMappings( - IbvVirtualQp* virtualQp, - uint32_t virtualSendCqNum, - uint32_t virtualRecvCqNum) { - // Extract virtual QP number from the virtual QP object - uint32_t virtualQpNum = virtualQp->getVirtualQpNum(); - - // Register the virtual QP - registerVirtualQp(virtualQpNum, virtualQp); - - // Register all physical QP to virtual QP mappings - for (const auto& qp : virtualQp->getQpsRef()) { - registerPhysicalQpToVirtualQp(qp.qp()->qp_num, virtualQpNum); - } - // Register notify QP - registerPhysicalQpToVirtualQp( - virtualQp->getNotifyQpRef().qp()->qp_num, virtualQpNum); - - // Register virtual QP to virtual CQ relationships - registerVirtualQpToVirtualSendCq(virtualQpNum, virtualSendCqNum); - registerVirtualQpToVirtualRecvCq(virtualQpNum, virtualRecvCqNum); -} - -// Access APIs for testing and internal use -const std::unordered_map& -Coordinator::getVirtualQpMap() const { - return virtualQpNumToVirtualQp_; -} - -const std::unordered_map& -Coordinator::getVirtualCqMap() const { - return virtualCqNumToVirtualCq_; -} - -const std::unordered_map& -Coordinator::getPhysicalQpToVirtualQpMap() const { - return physicalQpNumToVirtualQpNum_; -} - -const std::unordered_map& -Coordinator::getVirtualQpToVirtualSendCqMap() const { - return virtualQpNumToVirtualSendCqNum_; -} - -const std::unordered_map& -Coordinator::getVirtualQpToVirtualRecvCqMap() const { - return virtualQpNumToVirtualRecvCqNum_; -} - -// Update API for move operations - only need to update pointer maps -void Coordinator::updateVirtualQpPointer( - uint32_t virtualQpNum, - IbvVirtualQp* newPtr) { - virtualQpNumToVirtualQp_[virtualQpNum] = newPtr; -} - -void Coordinator::updateVirtualCqPointer( - uint32_t virtualCqNum, - IbvVirtualCq* newPtr) { - virtualCqNumToVirtualCq_[virtualCqNum] = newPtr; -} - -void Coordinator::unregisterVirtualQp( - uint32_t virtualQpNum, - IbvVirtualQp* ptr) { - // Only unregister if the pointer in the map matches the object being - // destroyed. This handles the case where the object was moved and the map was - // already updated with the new pointer. - auto it = virtualQpNumToVirtualQp_.find(virtualQpNum); - if (it == virtualQpNumToVirtualQp_.end() || it->second != ptr) { - // Object was moved, map already updated, nothing to do - return; - } - - // Remove entries from all maps related to this virtual QP - virtualQpNumToVirtualQp_.erase(virtualQpNum); - virtualQpNumToVirtualSendCqNum_.erase(virtualQpNum); - virtualQpNumToVirtualRecvCqNum_.erase(virtualQpNum); - - // Remove all physical QP to virtual QP mappings that point to this virtual QP - for (auto it = physicalQpNumToVirtualQpNum_.begin(); - it != physicalQpNumToVirtualQpNum_.end();) { - if (it->second == virtualQpNum) { - it = physicalQpNumToVirtualQpNum_.erase(it); - } else { - ++it; - } - } -} - -void Coordinator::unregisterVirtualCq( - uint32_t virtualCqNum, - IbvVirtualCq* ptr) { - // Only unregister if the pointer in the map matches the object being - // destroyed. This handles the case where the object was moved and the map was - // already updated with the new pointer. - auto it = virtualCqNumToVirtualCq_.find(virtualCqNum); - if (it == virtualCqNumToVirtualCq_.end() || it->second != ptr) { - // Object was moved, map already updated, nothing to do - return; - } - - // Remove the virtual CQ from the pointer map - virtualCqNumToVirtualCq_.erase(virtualCqNum); - - // Remove all virtual QP to virtual send CQ mappings that point to this - // virtual CQ - for (auto it = virtualQpNumToVirtualSendCqNum_.begin(); - it != virtualQpNumToVirtualSendCqNum_.end();) { - if (it->second == virtualCqNum) { - it = virtualQpNumToVirtualSendCqNum_.erase(it); - } else { - ++it; - } - } - - // Remove all virtual QP to virtual recv CQ mappings that point to this - // virtual CQ - for (auto it = virtualQpNumToVirtualRecvCqNum_.begin(); - it != virtualQpNumToVirtualRecvCqNum_.end();) { - if (it->second == virtualCqNum) { - it = virtualQpNumToVirtualRecvCqNum_.erase(it); - } else { - ++it; - } - } -} - -/*** RoceHCA ***/ - -RoceHca::RoceHca(std::string hcaStr, int defaultPort) { - std::string s = std::move(hcaStr); - std::string delim = ":"; - - std::vector hcaStrPair; - folly::split(':', s, hcaStrPair); - if (hcaStrPair.size() == 1) { - this->name = s; - this->port = defaultPort; - } else if (hcaStrPair.size() == 2) { - this->name = hcaStrPair.at(0); - this->port = std::stoi(hcaStrPair.at(1)); - } -} - folly::Expected Mlx5dv::initObj( mlx5dv_obj* obj, uint64_t obj_type) { diff --git a/comms/ctran/ibverbx/Ibverbx.h b/comms/ctran/ibverbx/Ibverbx.h index a5f0a84e..da48957b 100644 --- a/comms/ctran/ibverbx/Ibverbx.h +++ b/comms/ctran/ibverbx/Ibverbx.h @@ -6,9 +6,9 @@ #include #include #include -#include -#include +#include "comms/ctran/ibverbx/IbvCommon.h" +#include "comms/ctran/ibverbx/IbvDevice.h" // IWYU pragma: keep #include "comms/ctran/ibverbx/Ibvcore.h" namespace ibverbx { @@ -17,241 +17,6 @@ namespace ibverbx { class IbvVirtualQp; class Coordinator; -// Default HCA prefix -constexpr std::string_view kDefaultHcaPrefix = ""; -// Default HCA list -const std::vector kDefaultHcaList{}; -// Default port -constexpr int kIbAnyPort = -1; -constexpr int kIbMaxMsgCntPerQp = 100; -constexpr int kIbMaxMsgSizeByte = 100; -constexpr int kIbMaxCqe_ = 100; -constexpr int kNotifyBit = 31; -constexpr uint32_t kSeqNumMask = 0xFFFFFF; // 24 bits - -// Command types for coordinator routing and operations -enum class RequestType { SEND = 0, RECV = 1, SEND_NOTIFY = 2 }; -enum class LoadBalancingScheme { SPRAY = 0, DQPLB = 1 }; - -struct IbvSymbols { - int (*ibv_internal_fork_init)(void) = nullptr; - struct ibv_device** (*ibv_internal_get_device_list)(int* num_devices) = - nullptr; - void (*ibv_internal_free_device_list)(struct ibv_device** list) = nullptr; - const char* (*ibv_internal_get_device_name)(struct ibv_device* device) = - nullptr; - struct ibv_context* (*ibv_internal_open_device)(struct ibv_device* device) = - nullptr; - int (*ibv_internal_close_device)(struct ibv_context* context) = nullptr; - int (*ibv_internal_get_async_event)( - struct ibv_context* context, - struct ibv_async_event* event) = nullptr; - void (*ibv_internal_ack_async_event)(struct ibv_async_event* event) = nullptr; - int (*ibv_internal_query_device)( - struct ibv_context* context, - struct ibv_device_attr* device_attr) = nullptr; - int (*ibv_internal_query_port)( - struct ibv_context* context, - uint8_t port_num, - struct ibv_port_attr* port_attr) = nullptr; - int (*ibv_internal_query_gid)( - struct ibv_context* context, - uint8_t port_num, - int index, - union ibv_gid* gid) = nullptr; - int (*ibv_internal_query_qp)( - struct ibv_qp* qp, - struct ibv_qp_attr* attr, - int attr_mask, - struct ibv_qp_init_attr* init_attr) = nullptr; - struct ibv_pd* (*ibv_internal_alloc_pd)(struct ibv_context* context) = - nullptr; - struct ibv_pd* (*ibv_internal_alloc_parent_domain)( - struct ibv_context* context, - struct ibv_parent_domain_init_attr* attr) = nullptr; - int (*ibv_internal_dealloc_pd)(struct ibv_pd* pd) = nullptr; - struct ibv_mr* (*ibv_internal_reg_mr)( - struct ibv_pd* pd, - void* addr, - size_t length, - int access) = nullptr; - struct ibv_mr* (*ibv_internal_reg_mr_iova2)( - struct ibv_pd* pd, - void* addr, - size_t length, - uint64_t iova, - unsigned int access) = nullptr; - struct ibv_mr* (*ibv_internal_reg_dmabuf_mr)( - struct ibv_pd* pd, - uint64_t offset, - size_t length, - uint64_t iova, - int fd, - int access) = nullptr; - int (*ibv_internal_dereg_mr)(struct ibv_mr* mr) = nullptr; - struct ibv_cq* (*ibv_internal_create_cq)( - struct ibv_context* context, - int cqe, - void* cq_context, - struct ibv_comp_channel* channel, - int comp_vector) = nullptr; - struct ibv_cq_ex* (*ibv_internal_create_cq_ex)( - struct ibv_context* context, - struct ibv_cq_init_attr_ex* attr) = nullptr; - int (*ibv_internal_destroy_cq)(struct ibv_cq* cq) = nullptr; - struct ibv_comp_channel* (*ibv_internal_create_comp_channel)( - struct ibv_context* context) = nullptr; - int (*ibv_internal_destroy_comp_channel)(struct ibv_comp_channel* channel) = - nullptr; - int (*ibv_internal_req_notify_cq)(struct ibv_cq* cq, int solicited_only) = - nullptr; - int (*ibv_internal_get_cq_event)( - struct ibv_comp_channel* channel, - struct ibv_cq** cq, - void** cq_context) = nullptr; - void (*ibv_internal_ack_cq_events)(struct ibv_cq* cq, unsigned int nevents) = - nullptr; - struct ibv_qp* (*ibv_internal_create_qp)( - struct ibv_pd* pd, - struct ibv_qp_init_attr* qp_init_attr) = nullptr; - int (*ibv_internal_modify_qp)( - struct ibv_qp* qp, - struct ibv_qp_attr* attr, - int attr_mask) = nullptr; - int (*ibv_internal_destroy_qp)(struct ibv_qp* qp) = nullptr; - const char* (*ibv_internal_event_type_str)(enum ibv_event_type event) = - nullptr; - int (*ibv_internal_query_ece)(struct ibv_qp* qp, struct ibv_ece* ece) = - nullptr; - int (*ibv_internal_set_ece)(struct ibv_qp* qp, struct ibv_ece* ece) = nullptr; - enum ibv_fork_status (*ibv_internal_is_fork_initialized)() = nullptr; - - /* mlx5dv functions */ - int (*mlx5dv_internal_init_obj)(struct mlx5dv_obj* obj, uint64_t obj_type) = - nullptr; - bool (*mlx5dv_internal_is_supported)(struct ibv_device* device) = nullptr; - int (*mlx5dv_internal_get_data_direct_sysfs_path)( - struct ibv_context* context, - char* buf, - size_t buf_len) = nullptr; - /* DMA-BUF support */ - struct ibv_mr* (*mlx5dv_internal_reg_dmabuf_mr)( - struct ibv_pd* pd, - uint64_t offset, - size_t length, - uint64_t iova, - int fd, - int access, - int mlx5_access) = nullptr; -}; - -int buildIbvSymbols(IbvSymbols& ibvSymbols); - -struct Error { - Error(); - explicit Error(int errNum); - Error(int errNum, std::string errStr); - - const int errNum{0}; - const std::string errStr; -}; - -struct VirtualWc { - VirtualWc() = default; - ~VirtualWc() = default; - - struct ibv_wc wc{}; - int expectedMsgCnt{0}; - int remainingMsgCnt{0}; - bool sendExtraNotifyImm{ - false}; // Whether to expect an extra notify IMM - // message to be sent for the current virtualWc -}; - -struct VirtualSendWr { - inline VirtualSendWr( - const ibv_send_wr& wr, - int expectedMsgCnt, - int remainingMsgCnt, - bool sendExtraNotifyImm); - VirtualSendWr() = default; - ~VirtualSendWr() = default; - - ibv_send_wr wr{}; // Copy of the work request being posted by the user - std::vector sgList; // Copy of the scatter-gather list - int expectedMsgCnt{0}; // Expected message count resulting from splitting a - // large user message into multiple parts - int remainingMsgCnt{0}; // Number of message segments left to transmit after - // splitting a large user messaget - int offset{ - 0}; // Address offset to be used for the next message send operation - bool sendExtraNotifyImm{false}; // Whether to send an extra notify IMM message - // for the current VirtualSendWr -}; - -struct VirtualRecvWr { - inline VirtualRecvWr( - const ibv_recv_wr& wr, - int expectedMsgCnt, - int remainingMsgCnt); - VirtualRecvWr() = default; - ~VirtualRecvWr() = default; - - ibv_recv_wr wr{}; // Copy of the work request being posted by the user - std::vector sgList; // Copy of the scatter-gather list - int expectedMsgCnt{0}; // Expected message count resulting from splitting a - // large user message into multiple parts - int remainingMsgCnt{0}; // Number of message segments left to transmit after - // splitting a large user messaget - int offset{ - 0}; // Address offset to be used for the next message send operation -}; - -struct VirtualQpRequest { - RequestType type{RequestType::SEND}; - uint64_t wrId{0}; - uint32_t physicalQpNum{0}; - uint32_t immData{0}; -}; - -struct VirtualQpResponse { - uint64_t virtualWrId{0}; - bool useDqplb{false}; - int notifyCount{0}; -}; - -struct VirtualCqRequest { - RequestType type{RequestType::SEND}; - int virtualQpNum{-1}; - int expectedMsgCnt{-1}; - ibv_send_wr* sendWr{nullptr}; - ibv_recv_wr* recvWr{nullptr}; - bool sendExtraNotifyImm{false}; -}; - -class DqplbSeqTracker { - public: - DqplbSeqTracker() = default; - ~DqplbSeqTracker() = default; - - // Explicitly default move constructor and move assignment operator - DqplbSeqTracker(DqplbSeqTracker&&) = default; - DqplbSeqTracker& operator=(DqplbSeqTracker&&) = default; - - // This helper function calculates sender IMM message in DQPLB mode. - inline uint32_t getSendImm(int remainingMsgCnt); - // This helper function processes received IMM message and update - // receivedSeqNums_ map and receiveNext_ field. - inline int processReceivedImm(uint32_t receivedImm); - - private: - int sendNext_{0}; - int receiveNext_{0}; - std::unordered_map receivedSeqNums_; -}; - -std::ostream& operator<<(std::ostream&, Error const&); - /*** ibverbx APIs ***/ folly::Expected ibvInit(); @@ -263,552 +28,6 @@ ibvGetCqEvent(ibv_comp_channel* channel, ibv_cq** cq, void** cq_context); // Acknowledge completion events void ibvAckCqEvents(ibv_cq* cq, unsigned int nevents); -// IbvMr: Memory Region -class IbvMr { - public: - ~IbvMr(); - - // disable copy constructor - IbvMr(const IbvMr&) = delete; - IbvMr& operator=(const IbvMr&) = delete; - - // move constructor - IbvMr(IbvMr&& other) noexcept; - IbvMr& operator=(IbvMr&& other) noexcept; - - ibv_mr* mr() const; - - private: - friend class IbvPd; - - explicit IbvMr(ibv_mr* mr); - - ibv_mr* mr_{nullptr}; -}; - -// Ibv CompletionQueue(CQ) -class IbvCq { - public: - IbvCq() = default; - ~IbvCq(); - - // disable copy constructor - IbvCq(const IbvCq&) = delete; - IbvCq& operator=(const IbvCq&) = delete; - - // move constructor - IbvCq(IbvCq&& other) noexcept; - IbvCq& operator=(IbvCq&& other) noexcept; - - ibv_cq* cq() const; - inline folly::Expected, Error> pollCq(int numEntries); - - // Request notification when the next completion is added to this CQ - folly::Expected reqNotifyCq(int solicited_only) const; - - private: - friend class IbvDevice; - - explicit IbvCq(ibv_cq* cq); - - ibv_cq* cq_{nullptr}; -}; - -// Ibv Virtual Completion Queue (CQ): Provides a virtual CQ abstraction for the -// user. When the user calls IbvVirtualQp::postSend() or -// IbvVirtualQp::postRecv(), they can track the completion of messages posted on -// the Virtual QP through this virtual CQ. -class IbvVirtualCq { - public: - IbvVirtualCq(IbvCq&& cq, int maxCqe); - ~IbvVirtualCq(); - - // disable copy constructor - IbvVirtualCq(const IbvVirtualCq&) = delete; - IbvVirtualCq& operator=(const IbvVirtualCq&) = delete; - - // move constructor - IbvVirtualCq(IbvVirtualCq&& other) noexcept; - IbvVirtualCq& operator=(IbvVirtualCq&& other) noexcept; - - inline folly::Expected, Error> pollCq(int numEntries); - - IbvCq& getPhysicalCqRef(); - uint32_t getVirtualCqNum() const; - - void enqueSendCq(VirtualWc virtualWc); - void enqueRecvCq(VirtualWc virtualWc); - - inline void processRequest(VirtualCqRequest&& request); - - private: - friend class IbvPd; - friend class IbvVirtualQp; - - inline static std::atomic nextVirtualCqNum_{ - 0}; // Static counter for assigning unique virtual CQ numbers - uint32_t virtualCqNum_{ - 0}; // The unique virtual CQ number assigned to instance of IbvVirtualCq - - IbvCq physicalCq_; - int maxCqe_{0}; - std::deque pendingSendVirtualWcQue_; - std::deque pendingRecvVirtualWcQue_; - inline void updateVirtualWcFromPhysicalWc( - const ibv_wc& physicalWc, - VirtualWc* virtualWc); - std::unordered_map virtualWrIdToVirtualWc_; - - // Helper function for IbvVirtualCq::pollCq. - // Continuously polls the underlying physical Completion Queue (CQ) in a loop, - // retrieving all available Completion Queue Entries (CQEs) until none remain. - // For each physical CQE polled, the corresponding virtual CQE entries in the - // virtual CQ are also updated. This function ensures that all ready physical - // CQEs are polled, processed, and reflected in the virtual CQ state. - inline folly::Expected loopPollPhysicalCqUntilEmpty(); - - // Helper function for IbvVirtualCq::pollCq. - // Continuously polls the underlying virtual Completion Queues (CQs) in a - // loop. The function collects up to numEntries virtual Completion Queue - // Entries (CQEs), or stops early if there are no more virtual CQEs available - // to poll. Returns a vector containing the polled virtual CQEs. - inline std::vector loopPollVirtualCqUntil(int numEntries); -}; - -// Ibv Queue Pair -class IbvQp { - public: - ~IbvQp(); - - // disable copy constructor - IbvQp(const IbvQp&) = delete; - IbvQp& operator=(const IbvQp&) = delete; - - // move constructor - IbvQp(IbvQp&& other) noexcept; - IbvQp& operator=(IbvQp&& other) noexcept; - - ibv_qp* qp() const; - - folly::Expected modifyQp(ibv_qp_attr* attr, int attrMask); - folly::Expected, Error> queryQp( - int attrMask) const; - - inline uint32_t getQpNum() const; - inline folly::Expected postRecv( - ibv_recv_wr* recvWr, - ibv_recv_wr* recvWrBad); - inline folly::Expected postSend( - ibv_send_wr* sendWr, - ibv_send_wr* sendWrBad); - - void enquePhysicalSendWrStatus(int physicalWrId, int virtualWrId); - void enquePhysicalRecvWrStatus(int physicalWrId, int virtualWrId); - void dequePhysicalSendWrStatus(); - void dequePhysicalRecvWrStatus(); - bool isSendQueueAvailable(int maxMsgCntPerQp) const; - bool isRecvQueueAvailable(int maxMsgCntPerQp) const; - - private: - friend class IbvPd; - friend class IbvVirtualQp; - friend class IbvVirtualCq; - - struct PhysicalSendWrStatus { - PhysicalSendWrStatus(uint64_t physicalWrId, uint64_t virtualWrId) - : physicalWrId(physicalWrId), virtualWrId(virtualWrId) {} - uint64_t physicalWrId{0}; - uint64_t virtualWrId{0}; - }; - struct PhysicalRecvWrStatus { - PhysicalRecvWrStatus(uint64_t physicalWrId, uint64_t virtualWrId) - : physicalWrId(physicalWrId), virtualWrId(virtualWrId) {} - uint64_t physicalWrId{0}; - uint64_t virtualWrId{0}; - }; - explicit IbvQp(ibv_qp* qp); - - ibv_qp* qp_{nullptr}; - std::deque physicalSendWrStatus_; - std::deque physicalRecvWrStatus_; -}; - -// IbvVirtualQpBusinessCard -struct IbvVirtualQpBusinessCard { - explicit IbvVirtualQpBusinessCard( - std::vector qpNums, - uint32_t notifyQpNum = 0); - IbvVirtualQpBusinessCard() = default; - ~IbvVirtualQpBusinessCard() = default; - - // Default copy constructor and assignment operator - IbvVirtualQpBusinessCard(const IbvVirtualQpBusinessCard& other) = default; - IbvVirtualQpBusinessCard& operator=(const IbvVirtualQpBusinessCard& other) = - default; - - // Default move constructor and assignment operator - IbvVirtualQpBusinessCard(IbvVirtualQpBusinessCard&& other) = default; - IbvVirtualQpBusinessCard& operator=(IbvVirtualQpBusinessCard&& other) = - default; - - // Convert to/from folly::dynamic for serialization - folly::dynamic toDynamic() const; - static folly::Expected fromDynamic( - const folly::dynamic& obj); - - // JSON serialization methods - std::string serialize() const; - static folly::Expected deserialize( - const std::string& jsonStr); - - // The qpNums_ vector is ordered: the ith QP in qpNums_ will be - // connected to the ith QP in the remote side's qpNums_ vector. - std::vector qpNums_; - uint32_t notifyQpNum_{0}; -}; - -// Ibv Virtual Queue Pair -class IbvVirtualQp { - public: - ~IbvVirtualQp(); - - // disable copy constructor - IbvVirtualQp(const IbvVirtualQp&) = delete; - IbvVirtualQp& operator=(const IbvVirtualQp&) = delete; - - // move constructor - IbvVirtualQp(IbvVirtualQp&& other) noexcept; - IbvVirtualQp& operator=(IbvVirtualQp&& other) noexcept; - - size_t getTotalQps() const; - const std::vector& getQpsRef() const; - std::vector& getQpsRef(); - const IbvQp& getNotifyQpRef() const; - uint32_t getVirtualQpNum() const; - // If businessCard is not provided, all physical QPs will be updated with the - // universal attributes specified in attr. This is typically used for changing - // the state to INIT or RTS. - // If businessCard is provided, attr.qp_num for each physical QP will be set - // individually to the corresponding qpNum stored in qpNums_ within - // businessCard. This is typically used for changing the state to RTR. - folly::Expected modifyVirtualQp( - ibv_qp_attr* attr, - int attrMask, - const IbvVirtualQpBusinessCard& businessCard = - IbvVirtualQpBusinessCard()); - IbvVirtualQpBusinessCard getVirtualQpBusinessCard() const; - LoadBalancingScheme getLoadBalancingScheme() const; - - inline folly::Expected postSend( - ibv_send_wr* sendWr, - ibv_send_wr* sendWrBad); - - inline folly::Expected postRecv( - ibv_recv_wr* ibvRecvWr, - ibv_recv_wr* badIbvRecvWr); - - inline int findAvailableSendQp(); - inline int findAvailableRecvQp(); - - inline folly::Expected processRequest( - VirtualQpRequest&& request); - - private: -#ifdef IBVERBX_TEST_FRIENDS - IBVERBX_TEST_FRIENDS -#endif - - // updatePhysicalSendWrFromVirtualSendWr is a helper function to update - // physical send work request (ibv_send_wr) from virtual send work request - inline void updatePhysicalSendWrFromVirtualSendWr( - VirtualSendWr& virtualSendWr, - ibv_send_wr* sendWr, - ibv_sge* sendSg); - - friend class IbvPd; - friend class IbvVirtualCq; - - std::deque pendingSendVirtualWrQue_; - std::deque pendingRecvVirtualWrQue_; - - inline static std::atomic nextVirtualQpNum_{ - 0}; // Static counter for assigning unique virtual QP numbers - uint32_t virtualQpNum_{0}; // The unique virtual QP number assigned to - // instance of IbvVirtualQp. - - std::vector physicalQps_; - std::unordered_map qpNumToIdx_; - - int nextSendPhysicalQpIdx_{0}; - int nextRecvPhysicalQpIdx_{0}; - - int maxMsgCntPerQp_{ - -1}; // Maximum number of messages that can be sent on each physical QP. A - // value of -1 indicates there is no limit. - int maxMsgSize_{0}; - - uint64_t nextPhysicalWrId_{0}; // ID of the next physical work request to - // be posted on the physical QP - - LoadBalancingScheme loadBalancingScheme_{ - LoadBalancingScheme::SPRAY}; // Load balancing scheme for this virtual QP - - // Spray mode specific fields - std::deque pendingSendNotifyVirtualWrQue_; - IbvQp notifyQp_; - - // DQPLB mode specific fields and functions - DqplbSeqTracker dqplbSeqTracker; - bool dqplbReceiverInitialized_{ - false}; // flag to indicate if dqplb receiver is initialized - inline folly::Expected initializeDqplbReceiver(); - - IbvVirtualQp( - std::vector&& qps, - IbvQp&& notifyQp, - IbvVirtualCq* sendCq, - IbvVirtualCq* recvCq, - int maxMsgCntPerQp = kIbMaxMsgCntPerQp, - int maxMsgSize = kIbMaxMsgSizeByte, - LoadBalancingScheme loadBalancingScheme = LoadBalancingScheme::SPRAY); - - // mapPendingSendQueToPhysicalQp is a helper function to iterate through - // virtualSendWr in the pendingSendVirtualWrQue_, construct physical wrs and - // call postSend on physical QP. If qpIdx is provided, this function will - // postSend physicalWr on qpIdx. If qpIdx is not provided, then the function - // will find an available Qp to postSend the physical work request on. - inline folly::Expected mapPendingSendQueToPhysicalQp( - int qpIdx = -1); - - // postSendNotifyImm is a helper function to send IMM notification message - // after all previous messages are sent in a large message - inline folly::Expected postSendNotifyImm(); - inline folly::Expected mapPendingRecvQueToPhysicalQp( - int qpIdx = -1); - inline folly::Expected postRecvNotifyImm(int qpIdx = -1); -}; - -// Coordinator class responsible for routing commands and responses between -// IbvVirtualQp and IbvVirtualCq. Maintains mappings from physical QP numbers to -// IbvVirtualQp pointers, and from virtual CQ numbers to IbvVirtualCq pointers. -// Acts as a router to forward requests between these two classes. -// -// NOTE: The Coordinator APIs are NOT thread-safe. Users must ensure proper -// synchronization when accessing Coordinator methods from multiple threads. -// Thread-safe support can be added in the future if needed. -class Coordinator { - public: - Coordinator() = default; - ~Coordinator() = default; - - // Disable copy constructor and assignment operator - Coordinator(const Coordinator&) = delete; - Coordinator& operator=(const Coordinator&) = delete; - - // Allow default move constructor and assignment operator - Coordinator(Coordinator&&) = default; - Coordinator& operator=(Coordinator&&) = default; - - inline void submitRequestToVirtualCq(VirtualCqRequest&& request); - inline folly::Expected submitRequestToVirtualQp( - VirtualQpRequest&& request); - - // Register APIs for mapping management - void registerVirtualQp(uint32_t virtualQpNum, IbvVirtualQp* virtualQp); - void registerVirtualCq(uint32_t virtualCqNum, IbvVirtualCq* virtualCq); - void registerPhysicalQpToVirtualQp(int physicalQpNum, uint32_t virtualQpNum); - void registerVirtualQpToVirtualSendCq( - uint32_t virtualQpNum, - uint32_t virtualSendCqNum); - void registerVirtualQpToVirtualRecvCq( - uint32_t virtualQpNum, - uint32_t virtualRecvCqNum); - - // Consolidated registration API for IbvVirtualQp - registers the virtual QP - // along with all its physical QPs and CQ relationships in one call - void registerVirtualQpWithVirtualCqMappings( - IbvVirtualQp* virtualQp, - uint32_t virtualSendCqNum, - uint32_t virtualRecvCqNum); - - // Getter APIs for accessing mappings - inline IbvVirtualCq* getVirtualSendCq(uint32_t virtualQpNum) const; - inline IbvVirtualCq* getVirtualRecvCq(uint32_t virtualQpNum) const; - inline IbvVirtualQp* getVirtualQpByPhysicalQpNum(int physicalQpNum) const; - inline IbvVirtualQp* getVirtualQpById(uint32_t virtualQpNum) const; - inline IbvVirtualCq* getVirtualCqById(uint32_t virtualCqNum) const; - - // Access APIs for testing and internal use - const std::unordered_map& getVirtualQpMap() const; - const std::unordered_map& getVirtualCqMap() const; - const std::unordered_map& getPhysicalQpToVirtualQpMap() const; - const std::unordered_map& getVirtualQpToVirtualSendCqMap() - const; - const std::unordered_map& getVirtualQpToVirtualRecvCqMap() - const; - - // Update API for move operations - only need to update pointer maps - void updateVirtualQpPointer(uint32_t virtualQpNum, IbvVirtualQp* newPtr); - void updateVirtualCqPointer(uint32_t virtualCqNum, IbvVirtualCq* newPtr); - - // Unregister API for cleanup during destruction - void unregisterVirtualQp(uint32_t virtualQpNum, IbvVirtualQp* ptr); - void unregisterVirtualCq(uint32_t virtualCqNum, IbvVirtualCq* ptr); - - static std::shared_ptr getCoordinator(); - - private: - // Map 1: Virtual QP Num -> Virtual QP pointer - std::unordered_map virtualQpNumToVirtualQp_; - - // Map 2: Virtual CQ Num -> Virtual CQ pointer - std::unordered_map virtualCqNumToVirtualCq_; - - // Map 3: Virtual QP Num -> Virtual Send CQ Num (relationship) - std::unordered_map virtualQpNumToVirtualSendCqNum_; - - // Map 4: Virtual QP Num -> Virtual Recv CQ Num (relationship) - std::unordered_map virtualQpNumToVirtualRecvCqNum_; - - // Map 5: Physical QP number -> Virtual QP Num (for routing) - std::unordered_map physicalQpNumToVirtualQpNum_; -}; - -// IbvPd: Protection Domain -class IbvPd { - public: - ~IbvPd(); - - // disable copy constructor - IbvPd(const IbvPd&) = delete; - IbvPd& operator=(const IbvPd&) = delete; - - // move constructor - IbvPd(IbvPd&& other) noexcept; - IbvPd& operator=(IbvPd&& other) noexcept; - - ibv_pd* pd() const; - bool useDataDirect() const; - - folly::Expected - regMr(void* addr, size_t length, ibv_access_flags access) const; - - folly::Expected regDmabufMr( - uint64_t offset, - size_t length, - uint64_t iova, - int fd, - ibv_access_flags access) const; - - folly::Expected createQp(ibv_qp_init_attr* initAttr) const; - - // The send_cq and recv_cq fields in initAttr are ignored. - // Instead, initAttr.send_cq and initAttr.recv_cq will be set to the physical - // CQs contained within sendCq and recvCq, respectively. - folly::Expected createVirtualQp( - int totalQps, - ibv_qp_init_attr* initAttr, - IbvVirtualCq* sendCq, - IbvVirtualCq* recvCq, - int maxMsgCntPerQp = kIbMaxMsgCntPerQp, - int maxMsgSize = kIbMaxMsgSizeByte, - LoadBalancingScheme loadBalancingScheme = - LoadBalancingScheme::SPRAY) const; - - private: - friend class IbvDevice; - - IbvPd(ibv_pd* pd, bool dataDirect = false); - - ibv_pd* pd_{nullptr}; - bool dataDirect_{false}; // Relevant only to mlx5 -}; - -// IbvDevice -class IbvDevice { - public: - static folly::Expected, Error> ibvGetDeviceList( - const std::vector& hcaList = kDefaultHcaList, - const std::string& hcaPrefix = std::string(kDefaultHcaPrefix), - int defaultPort = kIbAnyPort); - IbvDevice(ibv_device* ibvDevice, int port); - ~IbvDevice(); - - // disable copy constructor - IbvDevice(const IbvDevice&) = delete; - IbvDevice& operator=(const IbvDevice&) = delete; - - // move constructor - IbvDevice(IbvDevice&& other) noexcept; - IbvDevice& operator=(IbvDevice&& other) noexcept; - - ibv_device* device() const; - ibv_context* context() const; - int port() const; - - folly::Expected allocPd(); - folly::Expected allocParentDomain( - ibv_parent_domain_init_attr* attr); - folly::Expected queryDevice() const; - folly::Expected queryPort(uint8_t portNum) const; - folly::Expected queryGid(uint8_t portNum, int gidIndex) const; - - folly::Expected createCq( - int cqe, - void* cq_context, - ibv_comp_channel* channel, - int comp_vector) const; - - // create Cq with attributes - folly::Expected createCq(ibv_cq_init_attr_ex* attr) const; - - // Create a completion channel for event-driven completion handling - folly::Expected createCompChannel() const; - - // Destroy a completion channel - folly::Expected destroyCompChannel( - ibv_comp_channel* channel) const; - - // When creating an IbvVirtualCq for an IbvVirtualQp, ensure that cqe >= - // (number of QPs * capacity per QP). If send queue and recv queue intend to - // share the same cqe, then ensure cqe >= (2 * number of QPs * capacity per - // QP). Failing to meet this condition may result in lost CQEs. TODO: Enforce - // this requirement in the low-level API. If a higher-level API is introduced - // in the future, ensure this guarantee is handled within Ibverbx when - // creating a IbvVirtualCq for the user. - folly::Expected createVirtualCq( - int cqe, - void* cq_context, - ibv_comp_channel* channel, - int comp_vector); - - folly::Expected isPortActive( - uint8_t portNum, - std::unordered_set linkLayers) const; - folly::Expected findActivePort( - std::unordered_set const& linkLayers) const; - - private: - ibv_device* device_{nullptr}; - ibv_context* context_{nullptr}; - int port_{-1}; - bool dataDirect_{false}; // Relevant only to mlx5 - - static std::vector ibvFilterDeviceList( - int numDevs, - ibv_device** devs, - const std::vector& hcaList = kDefaultHcaList, - const std::string& hcaPrefix = std::string(kDefaultHcaPrefix), - int defaultPort = kIbAnyPort); -}; - -class RoceHca { - public: - RoceHca(std::string hcaStr, int defaultPort); - std::string name; - int port{-1}; -}; - class Mlx5dv { public: static folly::Expected initObj( @@ -816,906 +35,4 @@ class Mlx5dv { uint64_t obj_type); }; -// -// Inline function definitions -// - -// IbvQp inline functions -inline uint32_t IbvQp::getQpNum() const { - XCHECK_NE(qp_, nullptr); - return qp_->qp_num; -} - -inline folly::Expected IbvQp::postRecv( - ibv_recv_wr* recvWr, - ibv_recv_wr* recvWrBad) { - int rc = qp_->context->ops.post_recv(qp_, recvWr, &recvWrBad); - if (rc != 0) { - return folly::makeUnexpected(Error(rc)); - } - return folly::unit; -} - -inline folly::Expected IbvQp::postSend( - ibv_send_wr* sendWr, - ibv_send_wr* sendWrBad) { - int rc = qp_->context->ops.post_send(qp_, sendWr, &sendWrBad); - if (rc != 0) { - return folly::makeUnexpected(Error(rc)); - } - return folly::unit; -} - -// IbvCq inline functions -inline folly::Expected, Error> IbvCq::pollCq( - int numEntries) { - std::vector wcs(numEntries); - int numPolled = cq_->context->ops.poll_cq(cq_, numEntries, wcs.data()); - if (numPolled < 0) { - wcs.clear(); - return folly::makeUnexpected( - Error(EINVAL, fmt::format("Call to pollCq() returned {}", numPolled))); - } else { - wcs.resize(numPolled); - } - return wcs; -} - -// IbvVirtualCq inline functions -inline folly::Expected, Error> IbvVirtualCq::pollCq( - int numEntries) { - auto maybeLoopPollPhysicalCq = loopPollPhysicalCqUntilEmpty(); - if (maybeLoopPollPhysicalCq.hasError()) { - return folly::makeUnexpected(maybeLoopPollPhysicalCq.error()); - } - - return loopPollVirtualCqUntil(numEntries); -} - -inline folly::Expected -IbvVirtualCq::loopPollPhysicalCqUntilEmpty() { - // Poll from physical CQ one by one and process immediately - while (true) { - // Poll one completion at a time - auto maybePhysicalWcsVector = physicalCq_.pollCq(1); - if (maybePhysicalWcsVector.hasError()) { - return folly::makeUnexpected(maybePhysicalWcsVector.error()); - } - - // If no completions available, break the loop - if (maybePhysicalWcsVector->empty()) { - break; - } - - // Process the single completion immediately - const auto& physicalWc = maybePhysicalWcsVector->front(); - - if (physicalWc.opcode == IBV_WC_RECV || - physicalWc.opcode == IBV_WC_RECV_RDMA_WITH_IMM) { - VirtualQpRequest request = { - .type = RequestType::RECV, - .wrId = physicalWc.wr_id, - .physicalQpNum = physicalWc.qp_num}; - if (physicalWc.opcode == IBV_WC_RECV_RDMA_WITH_IMM) { - request.immData = physicalWc.imm_data; - } - auto coordinator = Coordinator::getCoordinator(); - CHECK(coordinator) << "Coordinator should not be nullptr during pollCq!"; - auto response = coordinator->submitRequestToVirtualQp(std::move(request)); - if (response.hasError()) { - return folly::makeUnexpected(response.error()); - } - - if (response->useDqplb) { - int processedCount = 0; - for (int i = 0; i < pendingRecvVirtualWcQue_.size() && - processedCount < response->notifyCount; - i++) { - if (pendingRecvVirtualWcQue_.at(i).remainingMsgCnt != 0) { - pendingRecvVirtualWcQue_.at(i).remainingMsgCnt = 0; - processedCount++; - } - } - } else { - auto virtualWc = virtualWrIdToVirtualWc_.at(response->virtualWrId); - virtualWc->remainingMsgCnt--; - updateVirtualWcFromPhysicalWc(physicalWc, virtualWc); - } - } else { - // Except for the above two conditions, all other conditions indicate a - // send message, and we should poll from send queue - VirtualQpRequest request = { - .type = RequestType::SEND, - .wrId = physicalWc.wr_id, - .physicalQpNum = physicalWc.qp_num}; - auto coordinator = Coordinator::getCoordinator(); - CHECK(coordinator) << "Coordinator should not be nullptr during pollCq!"; - auto response = coordinator->submitRequestToVirtualQp(std::move(request)); - if (response.hasError()) { - return folly::makeUnexpected(response.error()); - } - - auto virtualWc = virtualWrIdToVirtualWc_.at(response->virtualWrId); - virtualWc->remainingMsgCnt--; - updateVirtualWcFromPhysicalWc(physicalWc, virtualWc); - if (virtualWc->remainingMsgCnt == 1 && virtualWc->sendExtraNotifyImm) { - VirtualQpRequest request = { - .type = RequestType::SEND_NOTIFY, - .wrId = response->virtualWrId, - .physicalQpNum = physicalWc.qp_num}; - - auto coordinator = Coordinator::getCoordinator(); - CHECK(coordinator) - << "Coordinator should not be nullptr during pollCq!"; - auto response = - coordinator->submitRequestToVirtualQp(std::move(request)); - if (response.hasError()) { - return folly::makeUnexpected(response.error()); - } - } - } - } - - return folly::unit; -} - -inline std::vector IbvVirtualCq::loopPollVirtualCqUntil( - int numEntries) { - std::vector wcs; - wcs.reserve(numEntries); - bool virtualSendCqPollComplete = false; - bool virtualRecvCqPollComplete = false; - while (wcs.size() < static_cast(numEntries) && - (!virtualSendCqPollComplete || !virtualRecvCqPollComplete)) { - if (!virtualSendCqPollComplete) { - if (pendingSendVirtualWcQue_.empty() || - pendingSendVirtualWcQue_.front().remainingMsgCnt > 0) { - virtualSendCqPollComplete = true; - } else { - auto vSendCqHead = pendingSendVirtualWcQue_.front(); - virtualWrIdToVirtualWc_.erase(vSendCqHead.wc.wr_id); - wcs.push_back(std::move(vSendCqHead.wc)); - pendingSendVirtualWcQue_.pop_front(); - } - } - - if (!virtualRecvCqPollComplete) { - if (pendingRecvVirtualWcQue_.empty() || - pendingRecvVirtualWcQue_.front().remainingMsgCnt > 0) { - virtualRecvCqPollComplete = true; - } else { - auto vRecvCqHead = pendingRecvVirtualWcQue_.front(); - virtualWrIdToVirtualWc_.erase(vRecvCqHead.wc.wr_id); - wcs.push_back(std::move(vRecvCqHead.wc)); - pendingRecvVirtualWcQue_.pop_front(); - } - } - } - - return wcs; -} - -inline void IbvVirtualCq::updateVirtualWcFromPhysicalWc( - const ibv_wc& physicalWc, - VirtualWc* virtualWc) { - // Updates the vWc status field based on the statuses of all pWc instances. - // If all physicalWc statuses indicate success, returns success. - // If any of the physicalWc statuses indicate an error, return the first - // encountered error code. - // Additionally, log all error statuses for debug purposes. - if (physicalWc.status != IBV_WC_SUCCESS) { - if (virtualWc->wc.status == IBV_WC_SUCCESS) { - virtualWc->wc.status = physicalWc.status; - } - - // Log the error - XLOGF( - ERR, - "Physical WC error: status={}, vendor_err={}, qp_num={}, wr_id={}", - physicalWc.status, - physicalWc.vendor_err, - physicalWc.qp_num, - physicalWc.wr_id); - } - - // Update the OP code in virtualWc. Note that for the same user message, the - // opcode must remain consistent, because all sub-messages within that user - // message will be postSend using the same opcode. - virtualWc->wc.opcode = physicalWc.opcode; - - // Update the vendor error in virtualWc. For now, assume that all pWc - // instances will report the same vendor_error across all sub-messages - // within a single user message. - virtualWc->wc.vendor_err = physicalWc.vendor_err; - - virtualWc->wc.src_qp = physicalWc.src_qp; - virtualWc->wc.byte_len += physicalWc.byte_len; - virtualWc->wc.imm_data = physicalWc.imm_data; - virtualWc->wc.wc_flags = physicalWc.wc_flags; - virtualWc->wc.pkey_index = physicalWc.pkey_index; - virtualWc->wc.slid = physicalWc.slid; - virtualWc->wc.sl = physicalWc.sl; - virtualWc->wc.dlid_path_bits = physicalWc.dlid_path_bits; -} - -inline void IbvVirtualCq::processRequest(VirtualCqRequest&& request) { - VirtualWc* virtualWcPtr = nullptr; - uint64_t wrId; - if (request.type == RequestType::SEND) { - wrId = request.sendWr->wr_id; - if (request.sendWr->send_flags & IBV_SEND_SIGNALED || - request.sendWr->opcode == IBV_WR_RDMA_WRITE_WITH_IMM) { - VirtualWc virtualWc{}; - virtualWc.wc.wr_id = request.sendWr->wr_id; - virtualWc.wc.qp_num = request.virtualQpNum; - virtualWc.wc.status = IBV_WC_SUCCESS; - virtualWc.wc.byte_len = 0; - virtualWc.expectedMsgCnt = request.expectedMsgCnt; - virtualWc.remainingMsgCnt = request.expectedMsgCnt; - virtualWc.sendExtraNotifyImm = request.sendExtraNotifyImm; - pendingSendVirtualWcQue_.push_back(std::move(virtualWc)); - virtualWcPtr = &pendingSendVirtualWcQue_.back(); - } - } else { - wrId = request.recvWr->wr_id; - VirtualWc virtualWc{}; - virtualWc.wc.wr_id = request.recvWr->wr_id; - virtualWc.wc.qp_num = request.virtualQpNum; - virtualWc.wc.status = IBV_WC_SUCCESS; - virtualWc.wc.byte_len = 0; - virtualWc.expectedMsgCnt = request.expectedMsgCnt; - virtualWc.remainingMsgCnt = request.expectedMsgCnt; - pendingRecvVirtualWcQue_.push_back(std::move(virtualWc)); - virtualWcPtr = &pendingRecvVirtualWcQue_.back(); - } - virtualWrIdToVirtualWc_[wrId] = virtualWcPtr; -} - -// IbvVirtualQp inline functions -inline folly::Expected -IbvVirtualQp::mapPendingSendQueToPhysicalQp(int qpIdx) { - while (!pendingSendVirtualWrQue_.empty()) { - // Get the front of vSendQ_ and obtain the send information - VirtualSendWr& virtualSendWr = pendingSendVirtualWrQue_.front(); - - // For Send opcodes related to RDMA_WRITE operations, use user selected load - // balancing scheme specified in loadBalancingScheme_. For all other - // opcodes, default to using physical QP 0. - auto availableQpIdx = -1; - if (virtualSendWr.wr.opcode == IBV_WR_RDMA_WRITE || - virtualSendWr.wr.opcode == IBV_WR_RDMA_WRITE_WITH_IMM || - virtualSendWr.wr.opcode == IBV_WR_RDMA_READ) { - // Find an available Qp to send - availableQpIdx = qpIdx == -1 ? findAvailableSendQp() : qpIdx; - qpIdx = -1; // If qpIdx is provided, it indicates that one slot has been - // freed for the corresponding qpIdx. After using this slot, - // reset qpIdx to -1. - } else if ( - physicalQps_.at(0).physicalSendWrStatus_.size() < maxMsgCntPerQp_) { - availableQpIdx = 0; - } - if (availableQpIdx == -1) { - break; - } - - // Update the physical send work request with virtual one - ibv_send_wr sendWr_{}; - ibv_sge sendSg_{}; - updatePhysicalSendWrFromVirtualSendWr(virtualSendWr, &sendWr_, &sendSg_); - - // Call ibv_post_send to send the message - ibv_send_wr badSendWr_{}; - auto maybeSend = - physicalQps_.at(availableQpIdx).postSend(&sendWr_, &badSendWr_); - if (maybeSend.hasError()) { - return folly::makeUnexpected(maybeSend.error()); - } - - // Enqueue the send information to physicalQps_ - physicalQps_.at(availableQpIdx) - .physicalSendWrStatus_.emplace_back( - sendWr_.wr_id, virtualSendWr.wr.wr_id); - - // Decide if need to deque the front of vSendQ_ - virtualSendWr.offset += sendWr_.sg_list->length; - virtualSendWr.remainingMsgCnt--; - if (virtualSendWr.remainingMsgCnt == 0) { - pendingSendVirtualWrQue_.pop_front(); - } else if ( - virtualSendWr.remainingMsgCnt == 1 && - virtualSendWr.sendExtraNotifyImm) { - // Move front entry from pendingSendVirtualWrQue_ to - // pendingSendNotifyVirtualWrQue_ - pendingSendNotifyVirtualWrQue_.push_back( - std::move(pendingSendVirtualWrQue_.front())); - pendingSendVirtualWrQue_.pop_front(); - } - } - return folly::unit; -} - -inline int IbvVirtualQp::findAvailableSendQp() { - // maxMsgCntPerQp_ with a value of -1 indicates there is no limit. - if (maxMsgCntPerQp_ == -1) { - auto availableQpIdx = nextSendPhysicalQpIdx_; - nextSendPhysicalQpIdx_ = (nextSendPhysicalQpIdx_ + 1) % physicalQps_.size(); - return availableQpIdx; - } - - for (int i = 0; i < physicalQps_.size(); i++) { - if (physicalQps_.at(nextSendPhysicalQpIdx_).physicalSendWrStatus_.size() < - maxMsgCntPerQp_) { - auto availableQpIdx = nextSendPhysicalQpIdx_; - nextSendPhysicalQpIdx_ = - (nextSendPhysicalQpIdx_ + 1) % physicalQps_.size(); - return availableQpIdx; - } - nextSendPhysicalQpIdx_ = (nextSendPhysicalQpIdx_ + 1) % physicalQps_.size(); - } - return -1; -} - -inline folly::Expected IbvVirtualQp::postSendNotifyImm() { - auto virtualSendWr = pendingSendNotifyVirtualWrQue_.front(); - ibv_send_wr sendWr_{}; - ibv_send_wr badSendWr_{}; - ibv_sge sendSg_{}; - sendWr_.next = nullptr; - sendWr_.sg_list = &sendSg_; - sendWr_.num_sge = 0; - sendWr_.opcode = IBV_WR_RDMA_WRITE_WITH_IMM; - sendWr_.send_flags = IBV_SEND_SIGNALED; - sendWr_.wr.rdma.remote_addr = virtualSendWr.wr.wr.rdma.remote_addr; - sendWr_.wr.rdma.rkey = virtualSendWr.wr.wr.rdma.rkey; - sendWr_.imm_data = virtualSendWr.wr.imm_data; - sendWr_.wr_id = nextPhysicalWrId_++; - auto maybeSend = notifyQp_.postSend(&sendWr_, &badSendWr_); - if (maybeSend.hasError()) { - return folly::makeUnexpected(maybeSend.error()); - } - notifyQp_.physicalSendWrStatus_.emplace_back( - sendWr_.wr_id, virtualSendWr.wr.wr_id); - virtualSendWr.remainingMsgCnt = 0; - pendingSendNotifyVirtualWrQue_.pop_front(); - return folly::unit; -} - -inline void IbvVirtualQp::updatePhysicalSendWrFromVirtualSendWr( - VirtualSendWr& virtualSendWr, - ibv_send_wr* sendWr, - ibv_sge* sendSg) { - sendWr->wr_id = nextPhysicalWrId_++; - - auto lenToSend = std::min( - int(virtualSendWr.wr.sg_list->length - virtualSendWr.offset), - maxMsgSize_); - sendSg->addr = virtualSendWr.wr.sg_list->addr + virtualSendWr.offset; - sendSg->length = lenToSend; - sendSg->lkey = virtualSendWr.wr.sg_list->lkey; - sendWr->next = nullptr; - sendWr->sg_list = sendSg; - sendWr->num_sge = 1; - - // Set the opcode to the same as virtual wr, except for RDMA_WRITE_WITH_IMM, - // we'll handle the notification message separately - switch (virtualSendWr.wr.opcode) { - case IBV_WR_RDMA_WRITE: - case IBV_WR_RDMA_READ: - sendWr->opcode = virtualSendWr.wr.opcode; - sendWr->send_flags = virtualSendWr.wr.send_flags; - sendWr->wr.rdma.remote_addr = - virtualSendWr.wr.wr.rdma.remote_addr + virtualSendWr.offset; - sendWr->wr.rdma.rkey = virtualSendWr.wr.wr.rdma.rkey; - break; - case IBV_WR_RDMA_WRITE_WITH_IMM: - sendWr->opcode = (loadBalancingScheme_ == LoadBalancingScheme::SPRAY) - ? IBV_WR_RDMA_WRITE - : IBV_WR_RDMA_WRITE_WITH_IMM; - sendWr->send_flags = IBV_SEND_SIGNALED; - sendWr->wr.rdma.remote_addr = - virtualSendWr.wr.wr.rdma.remote_addr + virtualSendWr.offset; - sendWr->wr.rdma.rkey = virtualSendWr.wr.wr.rdma.rkey; - break; - case IBV_WR_SEND: - sendWr->opcode = virtualSendWr.wr.opcode; - sendWr->send_flags = virtualSendWr.wr.send_flags; - break; - - default: - break; - } - - if (sendWr->opcode == IBV_WR_RDMA_WRITE_WITH_IMM && - loadBalancingScheme_ == LoadBalancingScheme::DQPLB) { - sendWr->imm_data = - dqplbSeqTracker.getSendImm(virtualSendWr.remainingMsgCnt); - } -} - -inline folly::Expected IbvVirtualQp::postSend( - ibv_send_wr* sendWr, - ibv_send_wr* sendWrBad) { - // Report error if num_sge is more than 1 - if (sendWr->num_sge > 1) { - return folly::makeUnexpected(Error( - EINVAL, "In IbvVirtualQp::postSend, num_sge > 1 is not supported")); - } - - // Report error if opcode is not supported by Ibverbx virtualQp - switch (sendWr->opcode) { - case IBV_WR_SEND_WITH_IMM: - case IBV_WR_ATOMIC_CMP_AND_SWP: - case IBV_WR_ATOMIC_FETCH_AND_ADD: - return folly::makeUnexpected(Error( - EINVAL, - "In IbvVirtualQp::postSend, opcode IBV_WR_SEND_WITH_IMM, IBV_WR_ATOMIC_CMP_AND_SWP, IBV_WR_ATOMIC_FETCH_AND_ADD are not supported")); - - default: - break; - } - - // Calculate the chunk number for the current message and update sendWqe - bool sendExtraNotifyImm = - (sendWr->opcode == IBV_WR_RDMA_WRITE_WITH_IMM && - loadBalancingScheme_ == LoadBalancingScheme::SPRAY); - int expectedMsgCnt = - (sendWr->sg_list->length + maxMsgSize_ - 1) / maxMsgSize_; - if (sendExtraNotifyImm) { - expectedMsgCnt += 1; // After post send all data messages, will post send - // 1 more notification message on QP 0 - } - - // Submit request to virtualCq to enqueue VirtualWc - VirtualCqRequest request = { - .type = RequestType::SEND, - .virtualQpNum = (int)virtualQpNum_, - .expectedMsgCnt = expectedMsgCnt, - .sendWr = sendWr, - .sendExtraNotifyImm = sendExtraNotifyImm}; - auto coordinator = Coordinator::getCoordinator(); - CHECK(coordinator) << "Coordinator should not be nullptr during postSend!"; - coordinator->submitRequestToVirtualCq(std::move(request)); - - // Set up the send work request with the completion queue entry and enqueue - // Note: virtualWcPtr can be nullptr - this is intentional and supported - // The VirtualSendWr constructor will handle deep copying of sendWr and - // sg_list - pendingSendVirtualWrQue_.emplace_back( - *sendWr, expectedMsgCnt, expectedMsgCnt, sendExtraNotifyImm); - - // Map large messages from vSendQ_ to pQps_ - if (mapPendingSendQueToPhysicalQp().hasError()) { - *sendWrBad = *sendWr; - return folly::makeUnexpected(Error(errno)); - } - - return folly::unit; -} - -inline folly::Expected IbvVirtualQp::processRequest( - VirtualQpRequest&& request) { - VirtualQpResponse response; - // If request.physicalQpNum differs from notifyQpNum, locate the corresponding - // physical qpIdx to process this request. - auto qpIdx = request.physicalQpNum == notifyQp_.getQpNum() - ? -1 - : qpNumToIdx_.at(request.physicalQpNum); - // If qpIdx is -1, physicalQp is notifyQp; otherwise, physicalQp is the qpIdx - // entry of physicalQps_ - auto& physicalQp = qpIdx == -1 ? notifyQp_ : physicalQps_.at(qpIdx); - - if (request.type == RequestType::RECV) { - if (physicalQp.physicalRecvWrStatus_.empty()) { - return folly::makeUnexpected(Error( - EINVAL, - fmt::format( - "In pollCq, after calling submit command to IbvVirtualQp, \ - physicalRecvWrStatus_ at physicalQp {} is empty!", - request.physicalQpNum))); - } - - auto& physicalRecvWrStatus = physicalQp.physicalRecvWrStatus_.front(); - - if (physicalRecvWrStatus.physicalWrId != request.wrId) { - return folly::makeUnexpected(Error( - EINVAL, - fmt::format( - "In pollCq, after calling submit command to IbvVirtualQp, \ - physicalRecvWrStatus.physicalWrId({}) != request.wrId({})", - physicalRecvWrStatus.physicalWrId, - request.wrId))); - } - - response.virtualWrId = physicalRecvWrStatus.virtualWrId; - physicalQp.physicalRecvWrStatus_.pop_front(); - if (loadBalancingScheme_ == LoadBalancingScheme::DQPLB) { - if (postRecvNotifyImm(qpIdx).hasError()) { - return folly::makeUnexpected( - Error(errno, fmt::format("postRecvNotifyImm() failed!"))); - } - response.notifyCount = - dqplbSeqTracker.processReceivedImm(request.immData); - response.useDqplb = true; - } else if (qpIdx != -1) { - if (mapPendingRecvQueToPhysicalQp(qpIdx).hasError()) { - return folly::makeUnexpected(Error( - errno, - fmt::format("mapPendingRecvQueToPhysicalQp({}) failed!", qpIdx))); - } - } - } else if (request.type == RequestType::SEND) { - if (physicalQp.physicalSendWrStatus_.empty()) { - return folly::makeUnexpected(Error( - EINVAL, - fmt::format( - "In pollCq, after calling submit command to IbvVirtualQp, \ - physicalSendWrStatus_ at physicalQp {} is empty!", - request.physicalQpNum))); - } - - auto physicalSendWrStatus = physicalQp.physicalSendWrStatus_.front(); - - if (physicalSendWrStatus.physicalWrId != request.wrId) { - return folly::makeUnexpected(Error( - EINVAL, - fmt::format( - "In pollCq, after calling submit command to IbvVirtualQp, \ - physicalSendWrStatus.physicalWrId({}) != request.wrId({})", - physicalSendWrStatus.physicalWrId, - request.wrId))); - } - - response.virtualWrId = physicalSendWrStatus.virtualWrId; - physicalQp.physicalSendWrStatus_.pop_front(); - if (qpIdx != -1) { - if (mapPendingSendQueToPhysicalQp(qpIdx).hasError()) { - return folly::makeUnexpected(Error( - errno, - fmt::format("mapPendingSendQueToPhysicalQp({}) failed!", qpIdx))); - } - } - } else if (request.type == RequestType::SEND_NOTIFY) { - if (pendingSendNotifyVirtualWrQue_.empty()) { - return folly::makeUnexpected(Error( - EINVAL, - fmt::format( - "Tried to post send notify IMM message for wrId {} when pendingSendNotifyVirtualWrQue_ is empty", - request.wrId))); - } - - if (pendingSendNotifyVirtualWrQue_.front().wr.wr_id == request.wrId) { - if (postSendNotifyImm().hasError()) { - return folly::makeUnexpected( - Error(errno, fmt::format("postSendNotifyImm() failed!"))); - } - } - } - return response; -} - -// Currently, this function is only invoked to receive messages with opcode -// IBV_WR_SEND. Therefore, we restrict its usage to physical QP 0. -// Note: If Dynamic QP Load Balancing (DQPLB) or other load balancing techniques -// are required in the future, this function can be updated to support more -// advanced usage. -inline int IbvVirtualQp::findAvailableRecvQp() { - // maxMsgCntPerQp_ with a value of -1 indicates there is no limit. - auto availableQpIdx = -1; - if (maxMsgCntPerQp_ == -1 || - physicalQps_.at(0).physicalRecvWrStatus_.size() < maxMsgCntPerQp_) { - availableQpIdx = 0; - } - - return availableQpIdx; -} - -inline folly::Expected IbvVirtualQp::postRecvNotifyImm( - int qpIdx) { - auto& qp = qpIdx == -1 ? notifyQp_ : physicalQps_.at(qpIdx); - auto virtualRecvWrId = loadBalancingScheme_ == LoadBalancingScheme::SPRAY - ? pendingRecvVirtualWrQue_.front().wr.wr_id - : -1; - ibv_recv_wr recvWr_{}; - ibv_recv_wr badRecvWr_{}; - ibv_sge recvSg_{}; - recvWr_.next = nullptr; - recvWr_.sg_list = &recvSg_; - recvWr_.num_sge = 0; - recvWr_.wr_id = nextPhysicalWrId_++; - auto maybeRecv = qp.postRecv(&recvWr_, &badRecvWr_); - if (maybeRecv.hasError()) { - return folly::makeUnexpected(maybeRecv.error()); - } - qp.physicalRecvWrStatus_.emplace_back(recvWr_.wr_id, virtualRecvWrId); - - if (loadBalancingScheme_ == LoadBalancingScheme::SPRAY) { - pendingRecvVirtualWrQue_.pop_front(); - } - return folly::unit; -} - -inline folly::Expected -IbvVirtualQp::initializeDqplbReceiver() { - ibv_recv_wr recvWr_{}; - ibv_recv_wr badRecvWr_{}; - ibv_sge recvSg_{}; - recvWr_.next = nullptr; - recvWr_.sg_list = &recvSg_; - recvWr_.num_sge = 0; - for (int i = 0; i < maxMsgCntPerQp_; i++) { - for (int j = 0; j < physicalQps_.size(); j++) { - recvWr_.wr_id = nextPhysicalWrId_++; - auto maybeRecv = physicalQps_.at(j).postRecv(&recvWr_, &badRecvWr_); - if (maybeRecv.hasError()) { - return folly::makeUnexpected(maybeRecv.error()); - } - physicalQps_.at(j).physicalRecvWrStatus_.emplace_back(recvWr_.wr_id, -1); - } - } - - dqplbReceiverInitialized_ = true; - return folly::unit; -} - -inline folly::Expected -IbvVirtualQp::mapPendingRecvQueToPhysicalQp(int qpIdx) { - while (!pendingRecvVirtualWrQue_.empty()) { - VirtualRecvWr& virtualRecvWr = pendingRecvVirtualWrQue_.front(); - - if (virtualRecvWr.wr.num_sge == 0) { - auto maybeRecvNotifyImm = postRecvNotifyImm(); - if (maybeRecvNotifyImm.hasError()) { - return folly::makeUnexpected(maybeRecvNotifyImm.error()); - } - continue; - } - - // If num_sge is > 0, then the receive work request is used to receive - // messages with opcode IBV_WR_SEND. In this scenario, we restrict usage to - // physical QP 0 only. The reason behind is that, IBV_WR_SEND requires a - // strict one-to-one correspondence between send and receive WRs. If Dynamic - // QP Load Balancing (DQPLB) is applied, send and receive WRs may be posted - // to different physical QPs within the QP list. This mismatch can result in - // data being delivered to the wrong address, causing data integrity issues. - auto availableQpIdx = qpIdx != 0 ? findAvailableRecvQp() : qpIdx; - qpIdx = -1; // If qpIdx is provided, it indicates that one slot has been - // freed for the corresponding qpIdx. After using this slot, - // reset qpIdx to -1. - if (availableQpIdx == -1) { - break; - } - - // Get the front of vRecvQ_ and obtain the receive information - ibv_recv_wr recvWr_{}; - ibv_recv_wr badRecvWr_{}; - ibv_sge recvSg_{}; - int lenToRecv = 0; - if (virtualRecvWr.wr.num_sge == 1) { - lenToRecv = std::min( - int(virtualRecvWr.wr.sg_list->length - virtualRecvWr.offset), - maxMsgSize_); - recvSg_.addr = virtualRecvWr.wr.sg_list->addr + virtualRecvWr.offset; - recvSg_.length = lenToRecv; - recvSg_.lkey = virtualRecvWr.wr.sg_list->lkey; - - recvWr_.sg_list = &recvSg_; - recvWr_.num_sge = 1; - } else { - recvWr_.sg_list = nullptr; - recvWr_.num_sge = 0; - } - recvWr_.wr_id = nextPhysicalWrId_++; - recvWr_.next = nullptr; - - // Call ibv_post_recv to receive the message - auto maybeRecv = - physicalQps_.at(availableQpIdx).postRecv(&recvWr_, &badRecvWr_); - if (maybeRecv.hasError()) { - return folly::makeUnexpected(maybeRecv.error()); - } - - // Enqueue the receive information to physicalQps_ - physicalQps_.at(availableQpIdx) - .physicalRecvWrStatus_.emplace_back( - recvWr_.wr_id, virtualRecvWr.wr.wr_id); - - // Decide if need to deque the front of vRecvQ_ - if (virtualRecvWr.wr.num_sge == 1) { - virtualRecvWr.offset += lenToRecv; - } - virtualRecvWr.remainingMsgCnt--; - if (virtualRecvWr.remainingMsgCnt == 0) { - pendingRecvVirtualWrQue_.pop_front(); - } - } - return folly::unit; -} - -inline folly::Expected IbvVirtualQp::postRecv( - ibv_recv_wr* recvWr, - ibv_recv_wr* recvWrBad) { - // Report error if num_sge is more than 1 - if (recvWr->num_sge > 1) { - return folly::makeUnexpected(Error(EINVAL)); - } - - int expectedMsgCnt = 1; - - if (recvWr->num_sge == 0) { // recvWr->num_sge == 0 mean it's receiving a - // IMM notification message - expectedMsgCnt = 1; - } else if (recvWr->num_sge == 1) { // Calculate the chunk number for the - // current message and update recvWqe if - // num_sge is 1 - expectedMsgCnt = (recvWr->sg_list->length + maxMsgSize_ - 1) / maxMsgSize_; - } - - // Submit request to virtualCq to enqueue VirtualWc - VirtualCqRequest request = { - .type = RequestType::RECV, - .virtualQpNum = (int)virtualQpNum_, - .expectedMsgCnt = expectedMsgCnt, - .recvWr = recvWr}; - auto coordinator = Coordinator::getCoordinator(); - CHECK(coordinator) << "Coordinator should not be nullptr during postRecv!"; - coordinator->submitRequestToVirtualCq(std::move(request)); - - // Set up the recv work request with the completion queue entry and enqueue - pendingRecvVirtualWrQue_.emplace_back( - *recvWr, expectedMsgCnt, expectedMsgCnt); - - if (loadBalancingScheme_ != LoadBalancingScheme::DQPLB) { - if (mapPendingRecvQueToPhysicalQp().hasError()) { - // For non-DQPLB modes: map messages from pendingRecvVirtualWrQue_ to - // physicalQps_. In DQPLB mode, this mapping is unnecessary because all - // receive notify IMM operations are pre-posted to the QPs before postRecv - // is called. - *recvWrBad = *recvWr; - return folly::makeUnexpected(Error(errno)); - } - } else if (dqplbReceiverInitialized_ == false) { - if (initializeDqplbReceiver().hasError()) { - *recvWrBad = *recvWr; - return folly::makeUnexpected(Error(errno)); - } - } - - return folly::unit; -} - -// Coordinator inline functions -inline IbvVirtualCq* Coordinator::getVirtualSendCq( - uint32_t virtualQpNum) const { - auto it = virtualQpNumToVirtualSendCqNum_.find(virtualQpNum); - if (it == virtualQpNumToVirtualSendCqNum_.end()) { - return nullptr; - } - return getVirtualCqById(it->second); -} - -inline IbvVirtualCq* Coordinator::getVirtualRecvCq( - uint32_t virtualQpNum) const { - auto it = virtualQpNumToVirtualRecvCqNum_.find(virtualQpNum); - if (it == virtualQpNumToVirtualRecvCqNum_.end()) { - return nullptr; - } - return getVirtualCqById(it->second); -} - -inline IbvVirtualQp* Coordinator::getVirtualQpByPhysicalQpNum( - int physicalQpNum) const { - auto it = physicalQpNumToVirtualQpNum_.find(physicalQpNum); - if (it == physicalQpNumToVirtualQpNum_.end()) { - return nullptr; - } - return getVirtualQpById(it->second); -} - -inline IbvVirtualQp* Coordinator::getVirtualQpById( - uint32_t virtualQpNum) const { - auto it = virtualQpNumToVirtualQp_.find(virtualQpNum); - if (it == virtualQpNumToVirtualQp_.end()) { - return nullptr; - } - return it->second; -} - -inline IbvVirtualCq* Coordinator::getVirtualCqById( - uint32_t virtualCqNum) const { - auto it = virtualCqNumToVirtualCq_.find(virtualCqNum); - if (it == virtualCqNumToVirtualCq_.end()) { - return nullptr; - } - return it->second; -} - -inline folly::Expected -Coordinator::submitRequestToVirtualQp(VirtualQpRequest&& request) { - auto virtualQp = getVirtualQpByPhysicalQpNum(request.physicalQpNum); - return virtualQp->processRequest(std::move(request)); -} - -inline void Coordinator::submitRequestToVirtualCq(VirtualCqRequest&& request) { - if (request.type == RequestType::SEND) { - auto virtualCq = getVirtualSendCq(request.virtualQpNum); - virtualCq->processRequest(std::move(request)); - } else { - auto virtualCq = getVirtualRecvCq(request.virtualQpNum); - virtualCq->processRequest(std::move(request)); - } -} - -// VirtualSendWr inline constructor -inline VirtualSendWr::VirtualSendWr( - const ibv_send_wr& wr, - int expectedMsgCnt, - int remainingMsgCnt, - bool sendExtraNotifyImm) - : expectedMsgCnt(expectedMsgCnt), - remainingMsgCnt(remainingMsgCnt), - sendExtraNotifyImm(sendExtraNotifyImm) { - // Make an explicit copy of the ibv_send_wr structure - this->wr = wr; - - // Deep copy the scatter-gather list - if (wr.sg_list != nullptr && wr.num_sge > 0) { - sgList.resize(wr.num_sge); - std::copy(wr.sg_list, wr.sg_list + wr.num_sge, sgList.begin()); - // Update the copied work request to point to our own scatter-gather list - this->wr.sg_list = sgList.data(); - } else { - // Handle case where there's no scatter-gather list - this->wr.sg_list = nullptr; - this->wr.num_sge = 0; - } -} - -// VirtualRecvWr inline constructor -inline VirtualRecvWr::VirtualRecvWr( - const ibv_recv_wr& wr, - int expectedMsgCnt, - int remainingMsgCnt) - : expectedMsgCnt(expectedMsgCnt), remainingMsgCnt(remainingMsgCnt) { - // Make an explicit copy of the ibv_recv_wr structure - this->wr = wr; - - // Deep copy the scatter-gather list - if (wr.sg_list != nullptr && wr.num_sge > 0) { - sgList.resize(wr.num_sge); - std::copy(wr.sg_list, wr.sg_list + wr.num_sge, sgList.begin()); - // Update the copied work request to point to our own scatter-gather list - this->wr.sg_list = sgList.data(); - } else { - // Handle case where there's no scatter-gather list - this->wr.sg_list = nullptr; - this->wr.num_sge = 0; - } -} - -// DqplbSeqTracker inline functions -inline uint32_t DqplbSeqTracker::getSendImm(int remainingMsgCnt) { - uint32_t immData = sendNext_; - sendNext_ = (sendNext_ + 1) % kSeqNumMask; - if (remainingMsgCnt == 1) { - immData |= (1 << kNotifyBit); - } - return immData; -} - -inline int DqplbSeqTracker::processReceivedImm(uint32_t immData) { - int notifyCount = 0; - receivedSeqNums_[immData & kSeqNumMask] = immData & (1U << kNotifyBit); - auto it = receivedSeqNums_.find(receiveNext_); - - while (it != receivedSeqNums_.end()) { - if (it->second) { - notifyCount++; - } - receivedSeqNums_.erase(it); - receiveNext_ = (receiveNext_ + 1) % kSeqNumMask; - it = receivedSeqNums_.find(receiveNext_); - } - return notifyCount; -} - } // namespace ibverbx diff --git a/comms/ctran/ibverbx/IbverbxSymbols.cc b/comms/ctran/ibverbx/IbverbxSymbols.cc new file mode 100644 index 00000000..2ab86169 --- /dev/null +++ b/comms/ctran/ibverbx/IbverbxSymbols.cc @@ -0,0 +1,479 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. + +#include "comms/ctran/ibverbx/IbverbxSymbols.h" + +#include +#include +#include +#include + +namespace ibverbx { + +IbvSymbols ibvSymbols; + +#define IBVERBS_VERSION "IBVERBS_1.1" + +#define MLX5DV_VERSION "MLX5_1.8" + +#ifdef IBVERBX_BUILD_RDMA_CORE +// Wrapper functions to handle type conversions between custom and real types +struct ibv_device** linked_get_device_list(int* num_devices) { + return reinterpret_cast( + ibv_get_device_list(num_devices)); +} + +void linked_free_device_list(struct ibv_device** list) { + ibv_free_device_list(reinterpret_cast<::ibv_device**>(list)); +} + +const char* linked_get_device_name(struct ibv_device* device) { + return ibv_get_device_name(reinterpret_cast<::ibv_device*>(device)); +} + +struct ibv_context* linked_open_device(struct ibv_device* device) { + return reinterpret_cast( + ibv_open_device(reinterpret_cast<::ibv_device*>(device))); +} + +int linked_close_device(struct ibv_context* context) { + return ibv_close_device(reinterpret_cast<::ibv_context*>(context)); +} + +int linked_query_device( + struct ibv_context* context, + struct ibv_device_attr* device_attr) { + return ibv_query_device( + reinterpret_cast<::ibv_context*>(context), + reinterpret_cast<::ibv_device_attr*>(device_attr)); +} + +int linked_query_port( + struct ibv_context* context, + uint8_t port_num, + struct ibv_port_attr* port_attr) { + return ibv_query_port( + reinterpret_cast<::ibv_context*>(context), + port_num, + reinterpret_cast<::ibv_port_attr*>(port_attr)); +} + +int linked_query_gid( + struct ibv_context* context, + uint8_t port_num, + int index, + union ibv_gid* gid) { + return ibv_query_gid( + reinterpret_cast<::ibv_context*>(context), + port_num, + index, + reinterpret_cast<::ibv_gid*>(gid)); +} + +struct ibv_pd* linked_alloc_pd(struct ibv_context* context) { + return reinterpret_cast( + ibv_alloc_pd(reinterpret_cast<::ibv_context*>(context))); +} + +struct ibv_pd* linked_alloc_parent_domain( + struct ibv_context* context, + struct ibv_parent_domain_init_attr* attr) { + return reinterpret_cast(ibv_alloc_parent_domain( + reinterpret_cast<::ibv_context*>(context), + reinterpret_cast<::ibv_parent_domain_init_attr*>(attr))); +} + +int linked_dealloc_pd(struct ibv_pd* pd) { + return ibv_dealloc_pd(reinterpret_cast<::ibv_pd*>(pd)); +} + +struct ibv_mr* +linked_reg_mr(struct ibv_pd* pd, void* addr, size_t length, int access) { + return reinterpret_cast( + ibv_reg_mr(reinterpret_cast<::ibv_pd*>(pd), addr, length, access)); +} + +int linked_dereg_mr(struct ibv_mr* mr) { + return ibv_dereg_mr(reinterpret_cast<::ibv_mr*>(mr)); +} + +struct ibv_cq* linked_create_cq( + struct ibv_context* context, + int cqe, + void* cq_context, + struct ibv_comp_channel* channel, + int comp_vector) { + return reinterpret_cast(ibv_create_cq( + reinterpret_cast<::ibv_context*>(context), + cqe, + cq_context, + reinterpret_cast<::ibv_comp_channel*>(channel), + comp_vector)); +} + +struct ibv_cq_ex* linked_create_cq_ex( + struct ibv_context* context, + struct ibv_cq_init_attr_ex* attr) { + return reinterpret_cast(ibv_create_cq_ex( + reinterpret_cast<::ibv_context*>(context), + reinterpret_cast<::ibv_cq_init_attr_ex*>(attr))); +} + +int linked_destroy_cq(struct ibv_cq* cq) { + return ibv_destroy_cq(reinterpret_cast<::ibv_cq*>(cq)); +} + +struct ibv_qp* linked_create_qp( + struct ibv_pd* pd, + struct ibv_qp_init_attr* qp_init_attr) { + return reinterpret_cast(ibv_create_qp( + reinterpret_cast<::ibv_pd*>(pd), + reinterpret_cast<::ibv_qp_init_attr*>(qp_init_attr))); +} + +int linked_modify_qp( + struct ibv_qp* qp, + struct ibv_qp_attr* attr, + int attr_mask) { + return ibv_modify_qp( + reinterpret_cast<::ibv_qp*>(qp), + reinterpret_cast<::ibv_qp_attr*>(attr), + attr_mask); +} + +int linked_destroy_qp(struct ibv_qp* qp) { + return ibv_destroy_qp(reinterpret_cast<::ibv_qp*>(qp)); +} + +const char* linked_event_type_str(enum ibv_event_type event) { + return ibv_event_type_str(static_cast<::ibv_event_type>(event)); +} + +int linked_get_async_event( + struct ibv_context* context, + struct ibv_async_event* event) { + return ibv_get_async_event( + reinterpret_cast<::ibv_context*>(context), + reinterpret_cast<::ibv_async_event*>(event)); +} + +void linked_ack_async_event(struct ibv_async_event* event) { + ibv_ack_async_event(reinterpret_cast<::ibv_async_event*>(event)); +} + +int linked_query_qp( + struct ibv_qp* qp, + struct ibv_qp_attr* attr, + int attr_mask, + struct ibv_qp_init_attr* init_attr) { + return ibv_query_qp( + reinterpret_cast<::ibv_qp*>(qp), + reinterpret_cast<::ibv_qp_attr*>(attr), + attr_mask, + reinterpret_cast<::ibv_qp_init_attr*>(init_attr)); +} + +struct ibv_mr* linked_reg_mr_iova2( + struct ibv_pd* pd, + void* addr, + size_t length, + uint64_t iova, + unsigned int access) { + return reinterpret_cast(ibv_reg_mr_iova2( + reinterpret_cast<::ibv_pd*>(pd), addr, length, iova, access)); +} + +struct ibv_mr* linked_reg_dmabuf_mr( + struct ibv_pd* pd, + uint64_t offset, + size_t length, + uint64_t iova, + int fd, + int access) { + return reinterpret_cast(ibv_reg_dmabuf_mr( + reinterpret_cast<::ibv_pd*>(pd), offset, length, iova, fd, access)); +} + +int linked_query_ece(struct ibv_qp* qp, struct ibv_ece* ece) { + return ibv_query_ece( + reinterpret_cast<::ibv_qp*>(qp), reinterpret_cast<::ibv_ece*>(ece)); +} + +int linked_set_ece(struct ibv_qp* qp, struct ibv_ece* ece) { + return ibv_set_ece( + reinterpret_cast<::ibv_qp*>(qp), reinterpret_cast<::ibv_ece*>(ece)); +} + +enum ibv_fork_status linked_is_fork_initialized() { + return static_cast(ibv_is_fork_initialized()); +} + +struct ibv_comp_channel* linked_create_comp_channel( + struct ibv_context* context) { + return reinterpret_cast( + ibv_create_comp_channel(reinterpret_cast<::ibv_context*>(context))); +} + +int linked_destroy_comp_channel(struct ibv_comp_channel* channel) { + return ibv_destroy_comp_channel( + reinterpret_cast<::ibv_comp_channel*>(channel)); +} + +int linked_req_notify_cq(struct ibv_cq* cq, int solicited_only) { + return ibv_req_notify_cq(reinterpret_cast<::ibv_cq*>(cq), solicited_only); +} + +int linked_get_cq_event( + struct ibv_comp_channel* channel, + struct ibv_cq** cq, + void** cq_context) { + return ibv_get_cq_event( + reinterpret_cast<::ibv_comp_channel*>(channel), + reinterpret_cast<::ibv_cq**>(cq), + cq_context); +} + +void linked_ack_cq_events(struct ibv_cq* cq, unsigned int nevents) { + ibv_ack_cq_events(reinterpret_cast<::ibv_cq*>(cq), nevents); +} + +bool linked_mlx5dv_is_supported(struct ibv_device* device) { + return mlx5dv_is_supported(reinterpret_cast<::ibv_device*>(device)); +} + +int linked_mlx5dv_init_obj(mlx5dv_obj* obj, uint64_t obj_type) { + return mlx5dv_init_obj(reinterpret_cast<::mlx5dv_obj*>(obj), obj_type); +} + +int linked_mlx5dv_get_data_direct_sysfs_path( + struct ibv_context* context, + char* buf, + size_t buf_len) { + return mlx5dv_get_data_direct_sysfs_path( + reinterpret_cast<::ibv_context*>(context), buf, buf_len); +} + +struct ibv_mr* linked_mlx5dv_reg_dmabuf_mr( + struct ibv_pd* pd, + uint64_t offset, + size_t length, + uint64_t iova, + int fd, + int access, + int mlx5_access) { + return reinterpret_cast(mlx5dv_reg_dmabuf_mr( + reinterpret_cast<::ibv_pd*>(pd), + offset, + length, + iova, + fd, + access, + mlx5_access)); +} +#endif + +int buildIbvSymbols(IbvSymbols& symbols, const std::string& ibv_path) { +#ifdef IBVERBX_BUILD_RDMA_CORE + // Direct linking mode - use wrapper functions to handle type conversions + symbols.ibv_internal_get_device_list = &linked_get_device_list; + symbols.ibv_internal_free_device_list = &linked_free_device_list; + symbols.ibv_internal_get_device_name = &linked_get_device_name; + symbols.ibv_internal_open_device = &linked_open_device; + symbols.ibv_internal_close_device = &linked_close_device; + symbols.ibv_internal_get_async_event = &linked_get_async_event; + symbols.ibv_internal_ack_async_event = &linked_ack_async_event; + symbols.ibv_internal_query_device = &linked_query_device; + symbols.ibv_internal_query_port = &linked_query_port; + symbols.ibv_internal_query_gid = &linked_query_gid; + symbols.ibv_internal_query_qp = &linked_query_qp; + symbols.ibv_internal_alloc_pd = &linked_alloc_pd; + symbols.ibv_internal_alloc_parent_domain = &linked_alloc_parent_domain; + symbols.ibv_internal_dealloc_pd = &linked_dealloc_pd; + symbols.ibv_internal_reg_mr = &linked_reg_mr; + + symbols.ibv_internal_reg_mr_iova2 = &linked_reg_mr_iova2; + symbols.ibv_internal_reg_dmabuf_mr = &linked_reg_dmabuf_mr; + symbols.ibv_internal_query_ece = &linked_query_ece; + symbols.ibv_internal_set_ece = &linked_set_ece; + symbols.ibv_internal_is_fork_initialized = &linked_is_fork_initialized; + + symbols.ibv_internal_dereg_mr = &linked_dereg_mr; + symbols.ibv_internal_create_cq = &linked_create_cq; + symbols.ibv_internal_create_cq_ex = &linked_create_cq_ex; + symbols.ibv_internal_destroy_cq = &linked_destroy_cq; + symbols.ibv_internal_create_comp_channel = &linked_create_comp_channel; + symbols.ibv_internal_destroy_comp_channel = &linked_destroy_comp_channel; + symbols.ibv_internal_get_cq_event = &linked_get_cq_event; + symbols.ibv_internal_ack_cq_events = &linked_ack_cq_events; + symbols.ibv_internal_create_qp = &linked_create_qp; + symbols.ibv_internal_modify_qp = &linked_modify_qp; + symbols.ibv_internal_destroy_qp = &linked_destroy_qp; + symbols.ibv_internal_fork_init = &ibv_fork_init; + symbols.ibv_internal_event_type_str = &linked_event_type_str; + + // mlx5dv symbols + symbols.mlx5dv_internal_is_supported = &linked_mlx5dv_is_supported; + symbols.mlx5dv_internal_init_obj = &linked_mlx5dv_init_obj; + symbols.mlx5dv_internal_get_data_direct_sysfs_path = + &linked_mlx5dv_get_data_direct_sysfs_path; + symbols.mlx5dv_internal_reg_dmabuf_mr = &linked_mlx5dv_reg_dmabuf_mr; + return 0; +#else + // Dynamic loading mode - use dlopen/dlsym + static void* ibvhandle = nullptr; + static void* mlx5dvhandle = nullptr; + void* tmp; + void** cast; + + // Use folly::ScopedGuard to ensure resources are cleaned up upon failure + auto guard = folly::makeGuard([&]() { + if (ibvhandle != nullptr) { + dlclose(ibvhandle); + } + if (mlx5dvhandle != nullptr) { + dlclose(mlx5dvhandle); + } + symbols = {}; // Reset all function pointers to nullptr + }); + + if (!ibv_path.empty()) { + ibvhandle = dlopen(ibv_path.c_str(), RTLD_NOW); + } + if (!ibvhandle) { + ibvhandle = dlopen("libibverbs.so.1", RTLD_NOW); + if (!ibvhandle) { + XLOG(ERR) << "Failed to open libibverbs.so.1"; + return 1; + } + } + + // Load mlx5dv symbols if available, do not abort if failed + mlx5dvhandle = dlopen("libmlx5.so", RTLD_NOW); + if (!mlx5dvhandle) { + mlx5dvhandle = dlopen("libmlx5.so.1", RTLD_NOW); + if (!mlx5dvhandle) { + XLOG(WARN) + << "Failed to open libmlx5.so[.1]. Advance features like CX-8 Direct-NIC will be disabled."; + } + } + +#define LOAD_SYM(handle, symbol, funcptr, version) \ + { \ + cast = (void**)&funcptr; \ + tmp = dlvsym(handle, symbol, version); \ + if (tmp == nullptr) { \ + XLOG(ERR) << fmt::format( \ + "dlvsym failed on {} - {} version {}", symbol, dlerror(), version); \ + return 1; \ + } \ + *cast = tmp; \ + } + +#define LOAD_SYM_WARN_ONLY(handle, symbol, funcptr, version) \ + { \ + cast = (void**)&funcptr; \ + tmp = dlvsym(handle, symbol, version); \ + if (tmp == nullptr) { \ + XLOG(WARN) << fmt::format( \ + "dlvsym failed on {} - {} version {}, set null", \ + symbol, \ + dlerror(), \ + version); \ + } \ + *cast = tmp; \ + } + +#define LOAD_IBVERBS_SYM(symbol, funcptr) \ + LOAD_SYM(ibvhandle, symbol, funcptr, IBVERBS_VERSION) + +#define LOAD_IBVERBS_SYM_VERSION(symbol, funcptr, version) \ + LOAD_SYM_WARN_ONLY(ibvhandle, symbol, funcptr, version) + +#define LOAD_IBVERBS_SYM_WARN_ONLY(symbol, funcptr) \ + LOAD_SYM_WARN_ONLY(ibvhandle, symbol, funcptr, IBVERBS_VERSION) + +// mlx5 +#define LOAD_MLX5DV_SYM(symbol, funcptr) \ + if (mlx5dvhandle != nullptr) { \ + LOAD_SYM_WARN_ONLY(mlx5dvhandle, symbol, funcptr, MLX5DV_VERSION) \ + } + +#define LOAD_MLX5DV_SYM_VERSION(symbol, funcptr, version) \ + if (mlx5dvhandle != nullptr) { \ + LOAD_SYM_WARN_ONLY(mlx5dvhandle, symbol, funcptr, version) \ + } + + LOAD_IBVERBS_SYM("ibv_get_device_list", symbols.ibv_internal_get_device_list); + LOAD_IBVERBS_SYM( + "ibv_free_device_list", symbols.ibv_internal_free_device_list); + LOAD_IBVERBS_SYM("ibv_get_device_name", symbols.ibv_internal_get_device_name); + LOAD_IBVERBS_SYM("ibv_open_device", symbols.ibv_internal_open_device); + LOAD_IBVERBS_SYM("ibv_close_device", symbols.ibv_internal_close_device); + LOAD_IBVERBS_SYM("ibv_get_async_event", symbols.ibv_internal_get_async_event); + LOAD_IBVERBS_SYM("ibv_ack_async_event", symbols.ibv_internal_ack_async_event); + LOAD_IBVERBS_SYM("ibv_query_device", symbols.ibv_internal_query_device); + LOAD_IBVERBS_SYM("ibv_query_port", symbols.ibv_internal_query_port); + LOAD_IBVERBS_SYM("ibv_query_gid", symbols.ibv_internal_query_gid); + LOAD_IBVERBS_SYM("ibv_query_qp", symbols.ibv_internal_query_qp); + LOAD_IBVERBS_SYM("ibv_alloc_pd", symbols.ibv_internal_alloc_pd); + LOAD_IBVERBS_SYM_WARN_ONLY( + "ibv_alloc_parent_domain", symbols.ibv_internal_alloc_parent_domain); + LOAD_IBVERBS_SYM("ibv_dealloc_pd", symbols.ibv_internal_dealloc_pd); + LOAD_IBVERBS_SYM("ibv_reg_mr", symbols.ibv_internal_reg_mr); + // Cherry-pick the ibv_reg_mr_iova2 API from IBVERBS 1.8 + LOAD_IBVERBS_SYM_VERSION( + "ibv_reg_mr_iova2", symbols.ibv_internal_reg_mr_iova2, "IBVERBS_1.8"); + // Cherry-pick the ibv_reg_dmabuf_mr API from IBVERBS 1.12 + LOAD_IBVERBS_SYM_VERSION( + "ibv_reg_dmabuf_mr", symbols.ibv_internal_reg_dmabuf_mr, "IBVERBS_1.12"); + LOAD_IBVERBS_SYM("ibv_dereg_mr", symbols.ibv_internal_dereg_mr); + LOAD_IBVERBS_SYM("ibv_create_cq", symbols.ibv_internal_create_cq); + LOAD_IBVERBS_SYM("ibv_destroy_cq", symbols.ibv_internal_destroy_cq); + LOAD_IBVERBS_SYM("ibv_create_qp", symbols.ibv_internal_create_qp); + LOAD_IBVERBS_SYM("ibv_modify_qp", symbols.ibv_internal_modify_qp); + LOAD_IBVERBS_SYM("ibv_destroy_qp", symbols.ibv_internal_destroy_qp); + LOAD_IBVERBS_SYM("ibv_fork_init", symbols.ibv_internal_fork_init); + LOAD_IBVERBS_SYM("ibv_event_type_str", symbols.ibv_internal_event_type_str); + + LOAD_IBVERBS_SYM_VERSION( + "ibv_create_comp_channel", + symbols.ibv_internal_create_comp_channel, + "IBVERBS_1.0"); + LOAD_IBVERBS_SYM_VERSION( + "ibv_destroy_comp_channel", + symbols.ibv_internal_destroy_comp_channel, + "IBVERBS_1.0"); + LOAD_IBVERBS_SYM_VERSION( + "ibv_get_cq_event", symbols.ibv_internal_get_cq_event, "IBVERBS_1.0"); + LOAD_IBVERBS_SYM_VERSION( + "ibv_ack_cq_events", symbols.ibv_internal_ack_cq_events, "IBVERBS_1.0"); + LOAD_IBVERBS_SYM_VERSION( + "ibv_query_ece", symbols.ibv_internal_query_ece, "IBVERBS_1.10"); + LOAD_IBVERBS_SYM_VERSION( + "ibv_set_ece", symbols.ibv_internal_set_ece, "IBVERBS_1.10"); + LOAD_IBVERBS_SYM_VERSION( + "ibv_is_fork_initialized", + symbols.ibv_internal_is_fork_initialized, + "IBVERBS_1.13"); + + LOAD_MLX5DV_SYM("mlx5dv_is_supported", symbols.mlx5dv_internal_is_supported); + // Cherry-pick the mlx5dv_get_data_direct_sysfs_path API from MLX5 1.2 + LOAD_MLX5DV_SYM_VERSION( + "mlx5dv_init_obj", symbols.mlx5dv_internal_init_obj, "MLX5_1.2"); + // Cherry-pick the mlx5dv_get_data_direct_sysfs_path API from MLX5 1.25 + LOAD_MLX5DV_SYM_VERSION( + "mlx5dv_get_data_direct_sysfs_path", + symbols.mlx5dv_internal_get_data_direct_sysfs_path, + "MLX5_1.25"); + // Cherry-pick the ibv_reg_dmabuf_mr API from MLX5 1.25 + LOAD_MLX5DV_SYM_VERSION( + "mlx5dv_reg_dmabuf_mr", + symbols.mlx5dv_internal_reg_dmabuf_mr, + "MLX5_1.25"); + + // all symbols were loaded successfully, dismiss guard + guard.dismiss(); + return 0; +#endif +} + +} // namespace ibverbx diff --git a/comms/ctran/ibverbx/IbverbxSymbols.h b/comms/ctran/ibverbx/IbverbxSymbols.h new file mode 100644 index 00000000..557319da --- /dev/null +++ b/comms/ctran/ibverbx/IbverbxSymbols.h @@ -0,0 +1,126 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include "comms/ctran/ibverbx/Ibvcore.h" + +#ifdef IBVERBX_BUILD_RDMA_CORE +#include +#include +#endif + +namespace ibverbx { + +struct IbvSymbols { + int (*ibv_internal_fork_init)(void) = nullptr; + struct ibv_device** (*ibv_internal_get_device_list)(int* num_devices) = + nullptr; + void (*ibv_internal_free_device_list)(struct ibv_device** list) = nullptr; + const char* (*ibv_internal_get_device_name)(struct ibv_device* device) = + nullptr; + struct ibv_context* (*ibv_internal_open_device)(struct ibv_device* device) = + nullptr; + int (*ibv_internal_close_device)(struct ibv_context* context) = nullptr; + int (*ibv_internal_get_async_event)( + struct ibv_context* context, + struct ibv_async_event* event) = nullptr; + void (*ibv_internal_ack_async_event)(struct ibv_async_event* event) = nullptr; + int (*ibv_internal_query_device)( + struct ibv_context* context, + struct ibv_device_attr* device_attr) = nullptr; + int (*ibv_internal_query_port)( + struct ibv_context* context, + uint8_t port_num, + struct ibv_port_attr* port_attr) = nullptr; + int (*ibv_internal_query_gid)( + struct ibv_context* context, + uint8_t port_num, + int index, + union ibv_gid* gid) = nullptr; + int (*ibv_internal_query_qp)( + struct ibv_qp* qp, + struct ibv_qp_attr* attr, + int attr_mask, + struct ibv_qp_init_attr* init_attr) = nullptr; + struct ibv_pd* (*ibv_internal_alloc_pd)(struct ibv_context* context) = + nullptr; + struct ibv_pd* (*ibv_internal_alloc_parent_domain)( + struct ibv_context* context, + struct ibv_parent_domain_init_attr* attr) = nullptr; + int (*ibv_internal_dealloc_pd)(struct ibv_pd* pd) = nullptr; + struct ibv_mr* (*ibv_internal_reg_mr)( + struct ibv_pd* pd, + void* addr, + size_t length, + int access) = nullptr; + struct ibv_mr* (*ibv_internal_reg_mr_iova2)( + struct ibv_pd* pd, + void* addr, + size_t length, + uint64_t iova, + unsigned int access) = nullptr; + struct ibv_mr* (*ibv_internal_reg_dmabuf_mr)( + struct ibv_pd* pd, + uint64_t offset, + size_t length, + uint64_t iova, + int fd, + int access) = nullptr; + int (*ibv_internal_dereg_mr)(struct ibv_mr* mr) = nullptr; + struct ibv_cq* (*ibv_internal_create_cq)( + struct ibv_context* context, + int cqe, + void* cq_context, + struct ibv_comp_channel* channel, + int comp_vector) = nullptr; + struct ibv_cq_ex* (*ibv_internal_create_cq_ex)( + struct ibv_context* context, + struct ibv_cq_init_attr_ex* attr) = nullptr; + int (*ibv_internal_destroy_cq)(struct ibv_cq* cq) = nullptr; + struct ibv_comp_channel* (*ibv_internal_create_comp_channel)( + struct ibv_context* context) = nullptr; + int (*ibv_internal_destroy_comp_channel)(struct ibv_comp_channel* channel) = + nullptr; + int (*ibv_internal_get_cq_event)( + struct ibv_comp_channel* channel, + struct ibv_cq** cq, + void** cq_context) = nullptr; + void (*ibv_internal_ack_cq_events)(struct ibv_cq* cq, unsigned int nevents) = + nullptr; + struct ibv_qp* (*ibv_internal_create_qp)( + struct ibv_pd* pd, + struct ibv_qp_init_attr* qp_init_attr) = nullptr; + int (*ibv_internal_modify_qp)( + struct ibv_qp* qp, + struct ibv_qp_attr* attr, + int attr_mask) = nullptr; + int (*ibv_internal_destroy_qp)(struct ibv_qp* qp) = nullptr; + const char* (*ibv_internal_event_type_str)(enum ibv_event_type event) = + nullptr; + int (*ibv_internal_query_ece)(struct ibv_qp* qp, struct ibv_ece* ece) = + nullptr; + int (*ibv_internal_set_ece)(struct ibv_qp* qp, struct ibv_ece* ece) = nullptr; + enum ibv_fork_status (*ibv_internal_is_fork_initialized)() = nullptr; + + /* mlx5dv functions */ + int (*mlx5dv_internal_init_obj)(struct mlx5dv_obj* obj, uint64_t obj_type) = + nullptr; + bool (*mlx5dv_internal_is_supported)(struct ibv_device* device) = nullptr; + int (*mlx5dv_internal_get_data_direct_sysfs_path)( + struct ibv_context* context, + char* buf, + size_t buf_len) = nullptr; + /* DMA-BUF support */ + struct ibv_mr* (*mlx5dv_internal_reg_dmabuf_mr)( + struct ibv_pd* pd, + uint64_t offset, + size_t length, + uint64_t iova, + int fd, + int access, + int mlx5_access) = nullptr; +}; + +int buildIbvSymbols(IbvSymbols& ibvSymbols, const std::string& ibv_path = ""); + +} // namespace ibverbx