From 1df4f21321b2234b29e2662b9a27f548ff71c037 Mon Sep 17 00:00:00 2001 From: tharittk Date: Sun, 17 Aug 2025 04:07:34 -0500 Subject: [PATCH 01/17] Buffered Send/Recv --- pylops_mpi/DistributedArray.py | 61 +++++++++++++++++++++++----------- 1 file changed, 42 insertions(+), 19 deletions(-) diff --git a/pylops_mpi/DistributedArray.py b/pylops_mpi/DistributedArray.py index 979882c0..18470592 100644 --- a/pylops_mpi/DistributedArray.py +++ b/pylops_mpi/DistributedArray.py @@ -3,6 +3,7 @@ from typing import Any, List, Optional, Tuple, Union, NewType import numpy as np +import os from mpi4py import MPI from pylops.utils import DTypeLike, NDArray from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils @@ -21,6 +22,10 @@ NcclCommunicatorType = NewType("NcclCommunicator", NcclCommunicator) +if int(os.environ.get("PYLOPS_MPI_CUDA_AWARE", 0)): + is_cuda_aware_mpi = True +else: + is_cuda_aware_mpi = False class Partition(Enum): r"""Enum class @@ -529,34 +534,52 @@ def _allgather_subcomm(self, send_buf, recv_buf=None): return self.sub_comm.allgather(send_buf) self.sub_comm.Allgather(send_buf, recv_buf) - def _send(self, send_buf, dest, count=None, tag=None): - """ Send operation + def _send(self, send_buf, dest, count=None, tag=0): + """Send operation """ if deps.nccl_enabled and self.base_comm_nccl: if count is None: - # assuming sending the whole array count = send_buf.size nccl_send(self.base_comm_nccl, send_buf, dest, count) else: - self.base_comm.send(send_buf, dest, tag) - - def _recv(self, recv_buf=None, source=0, count=None, tag=None): - """ Receive operation - """ - # NCCL must be called with recv_buf. Size cannot be inferred from - # other arguments and thus cannot be dynamically allocated - if deps.nccl_enabled and self.base_comm_nccl and recv_buf is not None: - if recv_buf is not None: + if is_cuda_aware_mpi or self.engine == "numpy": + # Determine MPI type based on array dtype + mpi_type = MPI._typedict[send_buf.dtype.char] if count is None: - # assuming data will take a space of the whole buffer - count = recv_buf.size - nccl_recv(self.base_comm_nccl, recv_buf, source, count) - return recv_buf + count = send_buf.size + self.base_comm.Send([send_buf, count, mpi_type], dest=dest, tag=tag) else: - raise ValueError("Using recv with NCCL must also supply receiver buffer ") + # Uses CuPy without CUDA-aware MPI + self.base_comm.send(send_buf, dest, tag) + + + def _recv(self, recv_buf=None, source=0, count=None, tag=0): + """Receive operation + """ + if deps.nccl_enabled and self.base_comm_nccl: + if recv_buf is None: + raise ValueError("recv_buf must be supplied when using NCCL") + if count is None: + count = recv_buf.size + nccl_recv(self.base_comm_nccl, recv_buf, source, count) + return recv_buf else: - # MPI allows a receiver buffer to be optional and receives as a Python Object - return self.base_comm.recv(source=source, tag=tag) + # NumPy + MPI will benefit from buffered communication regardless of MPI installation + if is_cuda_aware_mpi or self.engine == "numpy": + ncp = get_module(self.engine) + if recv_buf is None: + if count is None: + raise ValueError("Must provide either recv_buf or count for MPI receive") + # Default to int32 works currently because add_ghost_cells() is called + # with recv_buf and is not affected by this branch. The int32 is for when + # dimension or shape-related integers are send/recv + recv_buf = ncp.zeros(count, dtype=ncp.int32) + mpi_type = MPI._typedict[recv_buf.dtype.char] + self.base_comm.Recv([recv_buf, recv_buf.size, mpi_type], source=source, tag=tag) + else: + # Uses CuPy without CUDA-aware MPI + recv_buf = self.base_comm.recv(source=source, tag=tag) + return recv_buf def _nccl_local_shapes(self, masked: bool): """Get the the list of shapes of every GPU in the communicator From 647ce658a149a48f12d8c7f938056a05cba6414e Mon Sep 17 00:00:00 2001 From: tharittk Date: Sun, 17 Aug 2025 05:50:26 -0500 Subject: [PATCH 02/17] Buffered Allreduce --- pylops_mpi/DistributedArray.py | 53 +++++++++++++++++++++++----------- 1 file changed, 36 insertions(+), 17 deletions(-) diff --git a/pylops_mpi/DistributedArray.py b/pylops_mpi/DistributedArray.py index 18470592..dd9fd508 100644 --- a/pylops_mpi/DistributedArray.py +++ b/pylops_mpi/DistributedArray.py @@ -483,11 +483,19 @@ def _allreduce(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM): if deps.nccl_enabled and getattr(self, "base_comm_nccl"): return nccl_allreduce(self.base_comm_nccl, send_buf, recv_buf, op) else: - if recv_buf is None: - return self.base_comm.allreduce(send_buf, op) - # For MIN and MAX which require recv_buf - self.base_comm.Allreduce(send_buf, recv_buf, op) - return recv_buf + if is_cuda_aware_mpi or self.engine == "numpy": + ncp = get_module(self.engine) + # mpi_type = MPI._typedict[send_buf.dtype.char] + recv_buf = ncp.zeros(send_buf.size, dtype=send_buf.dtype) + self.base_comm.Allreduce(send_buf, recv_buf, op) + return recv_buf + else: + # CuPy with non-CUDA-aware MPI + if recv_buf is None: + return self.base_comm.allreduce(send_buf, op) + # For MIN and MAX which require recv_buf + self.base_comm.Allreduce(send_buf, recv_buf, op) + return recv_buf def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM): """Allreduce operation with subcommunicator @@ -495,11 +503,19 @@ def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM): if deps.nccl_enabled and getattr(self, "base_comm_nccl"): return nccl_allreduce(self.sub_comm, send_buf, recv_buf, op) else: - if recv_buf is None: - return self.sub_comm.allreduce(send_buf, op) - # For MIN and MAX which require recv_buf - self.sub_comm.Allreduce(send_buf, recv_buf, op) - return recv_buf + if is_cuda_aware_mpi or self.engine == "numpy": + ncp = get_module(self.engine) + # mpi_type = MPI._typedict[send_buf.dtype.char] + recv_buf = ncp.zeros(send_buf.size, dtype=send_buf.dtype) + self.sub_comm.Allreduce(send_buf, recv_buf, op) + return recv_buf + else: + # CuPy with non-CUDA-aware MPI + if recv_buf is None: + return self.sub_comm.allreduce(send_buf, op) + # For MIN and MAX which require recv_buf + self.sub_comm.Allreduce(send_buf, recv_buf, op) + return recv_buf def _allgather(self, send_buf, recv_buf=None): """Allgather operation @@ -717,26 +733,29 @@ def _compute_vector_norm(self, local_array: NDArray, recv_buf = self._allreduce_subcomm(ncp.count_nonzero(local_array, axis=axis).astype(ncp.float64)) elif ord == ncp.inf: # Calculate max followed by max reduction - # TODO (tharitt): currently CuPy + MPI does not work well with buffered communication, particularly + # CuPy + non-CUDA-aware MPI does not work well with buffered communication, particularly # with MAX, MIN operator. Here we copy the array back to CPU, transfer, and copy them back to GPUs send_buf = ncp.max(ncp.abs(local_array), axis=axis).astype(ncp.float64) - if self.engine == "cupy" and self.base_comm_nccl is None: + if self.engine == "cupy" and self.base_comm_nccl is None and not is_cuda_aware_mpi: + # CuPy + non-CUDA-aware MPI: This will call non-buffered communication + # which return a list of object - must be copied back to a GPU memory. recv_buf = self._allreduce_subcomm(send_buf.get(), recv_buf.get(), op=MPI.MAX) recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis)) else: recv_buf = self._allreduce_subcomm(send_buf, recv_buf, op=MPI.MAX) - recv_buf = ncp.squeeze(recv_buf, axis=axis) + if self.base_comm_nccl: + recv_buf = ncp.squeeze(recv_buf, axis=axis) elif ord == -ncp.inf: # Calculate min followed by min reduction - # TODO (tharitt): see the comment above in infinity norm + # See the comment above in +infinity norm send_buf = ncp.min(ncp.abs(local_array), axis=axis).astype(ncp.float64) - if self.engine == "cupy" and self.base_comm_nccl is None: + if self.engine == "cupy" and self.base_comm_nccl is None and not is_cuda_aware_mpi: recv_buf = self._allreduce_subcomm(send_buf.get(), recv_buf.get(), op=MPI.MIN) recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis)) else: recv_buf = self._allreduce_subcomm(send_buf, recv_buf, op=MPI.MIN) - recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis)) - + if self.base_comm_nccl: + recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis)) else: recv_buf = self._allreduce_subcomm(ncp.sum(ncp.abs(ncp.float_power(local_array, ord)), axis=axis)) recv_buf = ncp.power(recv_buf, 1.0 / ord) From 31068f9b65cb3429483646ec993ba151b7a6cb91 Mon Sep 17 00:00:00 2001 From: tharittk Date: Thu, 28 Aug 2025 08:21:22 -0500 Subject: [PATCH 03/17] minor clean up --- pylops_mpi/DistributedArray.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pylops_mpi/DistributedArray.py b/pylops_mpi/DistributedArray.py index dd9fd508..9d99fe39 100644 --- a/pylops_mpi/DistributedArray.py +++ b/pylops_mpi/DistributedArray.py @@ -485,7 +485,6 @@ def _allreduce(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM): else: if is_cuda_aware_mpi or self.engine == "numpy": ncp = get_module(self.engine) - # mpi_type = MPI._typedict[send_buf.dtype.char] recv_buf = ncp.zeros(send_buf.size, dtype=send_buf.dtype) self.base_comm.Allreduce(send_buf, recv_buf, op) return recv_buf @@ -505,7 +504,6 @@ def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM): else: if is_cuda_aware_mpi or self.engine == "numpy": ncp = get_module(self.engine) - # mpi_type = MPI._typedict[send_buf.dtype.char] recv_buf = ncp.zeros(send_buf.size, dtype=send_buf.dtype) self.sub_comm.Allreduce(send_buf, recv_buf, op) return recv_buf @@ -743,6 +741,9 @@ def _compute_vector_norm(self, local_array: NDArray, recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis)) else: recv_buf = self._allreduce_subcomm(send_buf, recv_buf, op=MPI.MAX) + # TODO (tharitt): In current implementation, there seems to be a semantic difference between Buffered MPI and NCCL + # the (1, size) is collapsed to (size, ) with buffered MPI while NCCL retains it. + # There may be a way to unify it - may be something to do with how we allocate the recv_buf. if self.base_comm_nccl: recv_buf = ncp.squeeze(recv_buf, axis=axis) elif ord == -ncp.inf: From ca558fd70c568e21d09ff36673435a0de5b85ee2 Mon Sep 17 00:00:00 2001 From: mrava87 Date: Sun, 7 Sep 2025 20:47:35 +0000 Subject: [PATCH 04/17] feat: WIP DistributedMix A new DistributedMix class is create with the aim of simpflify and unify all comm. calls in both DistributedArray and operators (further hiding away all implementation details). --- pylops_mpi/Distributed.py | 45 ++++++++++++++++++++ pylops_mpi/DistributedArray.py | 66 ++++------------------------- pylops_mpi/basicoperators/VStack.py | 16 ++++--- pylops_mpi/utils/deps.py | 4 ++ 4 files changed, 68 insertions(+), 63 deletions(-) create mode 100644 pylops_mpi/Distributed.py diff --git a/pylops_mpi/Distributed.py b/pylops_mpi/Distributed.py new file mode 100644 index 00000000..dccaf6a6 --- /dev/null +++ b/pylops_mpi/Distributed.py @@ -0,0 +1,45 @@ +from typing import Any, NewType + +from mpi4py import MPI +from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils +from pylops_mpi.utils._mpi import mpi_allreduce +from pylops_mpi.utils import deps + +cupy_message = pylops_deps.cupy_import("the DistributedArray module") +nccl_message = deps.nccl_import("the DistributedArray module") + +if nccl_message is None and cupy_message is None: + from pylops_mpi.utils._nccl import ( + nccl_allgather, nccl_allreduce, + nccl_asarray, nccl_bcast, nccl_split, nccl_send, nccl_recv, + _prepare_nccl_allgather_inputs, _unroll_nccl_allgather_recv + ) + + +class DistributedMixIn: + r"""Distributed Mixin class + + This class implements all methods associated with communication primitives + from MPI and NCCL. It is mostly charged to identifying which commuicator + to use and whether the buffered or object MPI primitives should be used + (the former in the case of NumPy arrays or CuPy arrays when a CUDA-Aware + MPI installation is available, the latter with CuPy arrays when a CUDA-Aware + MPI installation is not available). + """ + def _allreduce(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM): + """Allreduce operation + """ + if deps.nccl_enabled and getattr(self, "base_comm_nccl"): + return nccl_allreduce(self.base_comm_nccl, send_buf, recv_buf, op) + else: + return mpi_allreduce(self.base_comm, send_buf, + recv_buf, self.engine, op) + + def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM): + """Allreduce operation with subcommunicator + """ + if deps.nccl_enabled and getattr(self, "base_comm_nccl"): + return nccl_allreduce(self.sub_comm, send_buf, recv_buf, op) + else: + return mpi_allreduce(self.sub_comm, send_buf, + recv_buf, self.engine, op) diff --git a/pylops_mpi/DistributedArray.py b/pylops_mpi/DistributedArray.py index 9d99fe39..6fd3ee95 100644 --- a/pylops_mpi/DistributedArray.py +++ b/pylops_mpi/DistributedArray.py @@ -3,12 +3,13 @@ from typing import Any, List, Optional, Tuple, Union, NewType import numpy as np -import os from mpi4py import MPI +from pylops_mpi.Distributed import DistributedMixIn from pylops.utils import DTypeLike, NDArray from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils from pylops.utils._internal import _value_or_sized_to_tuple from pylops.utils.backend import get_array_module, get_module, get_module_name +from pylops_mpi.utils._mpi import mpi_allreduce, mpi_send from pylops_mpi.utils import deps cupy_message = pylops_deps.cupy_import("the DistributedArray module") @@ -22,10 +23,6 @@ NcclCommunicatorType = NewType("NcclCommunicator", NcclCommunicator) -if int(os.environ.get("PYLOPS_MPI_CUDA_AWARE", 0)): - is_cuda_aware_mpi = True -else: - is_cuda_aware_mpi = False class Partition(Enum): r"""Enum class @@ -104,7 +101,7 @@ def subcomm_split(mask, comm: Optional[Union[MPI.Comm, NcclCommunicatorType]] = return sub_comm -class DistributedArray: +class DistributedArray(DistributedMixIn): r"""Distributed Numpy Arrays Multidimensional NumPy-like distributed arrays. @@ -477,44 +474,6 @@ def _check_mask(self, dist_array): if not np.array_equal(self.mask, dist_array.mask): raise ValueError("Mask of both the arrays must be same") - def _allreduce(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM): - """Allreduce operation - """ - if deps.nccl_enabled and getattr(self, "base_comm_nccl"): - return nccl_allreduce(self.base_comm_nccl, send_buf, recv_buf, op) - else: - if is_cuda_aware_mpi or self.engine == "numpy": - ncp = get_module(self.engine) - recv_buf = ncp.zeros(send_buf.size, dtype=send_buf.dtype) - self.base_comm.Allreduce(send_buf, recv_buf, op) - return recv_buf - else: - # CuPy with non-CUDA-aware MPI - if recv_buf is None: - return self.base_comm.allreduce(send_buf, op) - # For MIN and MAX which require recv_buf - self.base_comm.Allreduce(send_buf, recv_buf, op) - return recv_buf - - def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM): - """Allreduce operation with subcommunicator - """ - if deps.nccl_enabled and getattr(self, "base_comm_nccl"): - return nccl_allreduce(self.sub_comm, send_buf, recv_buf, op) - else: - if is_cuda_aware_mpi or self.engine == "numpy": - ncp = get_module(self.engine) - recv_buf = ncp.zeros(send_buf.size, dtype=send_buf.dtype) - self.sub_comm.Allreduce(send_buf, recv_buf, op) - return recv_buf - else: - # CuPy with non-CUDA-aware MPI - if recv_buf is None: - return self.sub_comm.allreduce(send_buf, op) - # For MIN and MAX which require recv_buf - self.sub_comm.Allreduce(send_buf, recv_buf, op) - return recv_buf - def _allgather(self, send_buf, recv_buf=None): """Allgather operation """ @@ -556,16 +515,9 @@ def _send(self, send_buf, dest, count=None, tag=0): count = send_buf.size nccl_send(self.base_comm_nccl, send_buf, dest, count) else: - if is_cuda_aware_mpi or self.engine == "numpy": - # Determine MPI type based on array dtype - mpi_type = MPI._typedict[send_buf.dtype.char] - if count is None: - count = send_buf.size - self.base_comm.Send([send_buf, count, mpi_type], dest=dest, tag=tag) - else: - # Uses CuPy without CUDA-aware MPI - self.base_comm.send(send_buf, dest, tag) - + mpi_send(self.base_comm, + send_buf, dest, count, tag=tag, + engine=self.engine) def _recv(self, recv_buf=None, source=0, count=None, tag=0): """Receive operation @@ -579,7 +531,7 @@ def _recv(self, recv_buf=None, source=0, count=None, tag=0): return recv_buf else: # NumPy + MPI will benefit from buffered communication regardless of MPI installation - if is_cuda_aware_mpi or self.engine == "numpy": + if deps.cuda_aware_mpi_enabled or self.engine == "numpy": ncp = get_module(self.engine) if recv_buf is None: if count is None: @@ -734,7 +686,7 @@ def _compute_vector_norm(self, local_array: NDArray, # CuPy + non-CUDA-aware MPI does not work well with buffered communication, particularly # with MAX, MIN operator. Here we copy the array back to CPU, transfer, and copy them back to GPUs send_buf = ncp.max(ncp.abs(local_array), axis=axis).astype(ncp.float64) - if self.engine == "cupy" and self.base_comm_nccl is None and not is_cuda_aware_mpi: + if self.engine == "cupy" and self.base_comm_nccl is None and not deps.cuda_aware_mpi_enabled: # CuPy + non-CUDA-aware MPI: This will call non-buffered communication # which return a list of object - must be copied back to a GPU memory. recv_buf = self._allreduce_subcomm(send_buf.get(), recv_buf.get(), op=MPI.MAX) @@ -750,7 +702,7 @@ def _compute_vector_norm(self, local_array: NDArray, # Calculate min followed by min reduction # See the comment above in +infinity norm send_buf = ncp.min(ncp.abs(local_array), axis=axis).astype(ncp.float64) - if self.engine == "cupy" and self.base_comm_nccl is None and not is_cuda_aware_mpi: + if self.engine == "cupy" and self.base_comm_nccl is None and not deps.cuda_aware_mpi_enabled: recv_buf = self._allreduce_subcomm(send_buf.get(), recv_buf.get(), op=MPI.MIN) recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis)) else: diff --git a/pylops_mpi/basicoperators/VStack.py b/pylops_mpi/basicoperators/VStack.py index 58581565..f6d5b198 100644 --- a/pylops_mpi/basicoperators/VStack.py +++ b/pylops_mpi/basicoperators/VStack.py @@ -15,6 +15,7 @@ Partition, StackedDistributedArray ) +from pylops_mpi.Distributed import DistributedMixIn from pylops_mpi.utils.decorators import reshaped from pylops_mpi.utils import deps @@ -25,7 +26,7 @@ from pylops_mpi.utils._nccl import nccl_allreduce -class MPIVStack(MPILinearOperator): +class MPIVStack(DistributedMixIn, MPILinearOperator): r"""MPI VStack Operator Create a vertical stack of a set of linear operators using MPI. Each rank must @@ -141,16 +142,19 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: @reshaped(forward=False, stacking=True) def _rmatvec(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=self.shape[1], base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, partition=Partition.BROADCAST, + # TODO: consider adding base_comm, base_comm_nccl, engine to the + # input parameters of _allreduce instead of relying on self + self.base_comm, self.base_comm_nccl, self.engine = \ + x.base_comm, x.base_comm_nccl, x.engine + y = DistributedArray(global_shape=self.shape[1], base_comm=x.base_comm, + base_comm_nccl=x.base_comm_nccl, + partition=Partition.BROADCAST, engine=x.engine, dtype=self.dtype) y1 = [] for iop, oper in enumerate(self.ops): y1.append(oper.rmatvec(x.local_array[self.nnops[iop]: self.nnops[iop + 1]])) y1 = ncp.sum(ncp.vstack(y1), axis=0) - if deps.nccl_enabled and x.base_comm_nccl: - y[:] = nccl_allreduce(x.base_comm_nccl, y1, op=MPI.SUM) - else: - y[:] = self.base_comm.allreduce(y1, op=MPI.SUM) + y[:] = self._allreduce(y1, op=MPI.SUM) return y diff --git a/pylops_mpi/utils/deps.py b/pylops_mpi/utils/deps.py index 9d983f60..c9dc4aa3 100644 --- a/pylops_mpi/utils/deps.py +++ b/pylops_mpi/utils/deps.py @@ -39,6 +39,10 @@ def nccl_import(message: Optional[str] = None) -> str: return nccl_message +cuda_aware_mpi_enabled: bool = ( + True if int(os.getenv("PYLOPS_MPI_CUDA_AWARE", 1) == 1) else False +) + nccl_enabled: bool = ( True if (nccl_import() is None and int(os.getenv("NCCL_PYLOPS_MPI", 1)) == 1) else False ) From 64854bbe88061f097350d687ae3ad15e30a4e8c9 Mon Sep 17 00:00:00 2001 From: mrava87 Date: Sun, 7 Sep 2025 20:53:18 +0000 Subject: [PATCH 05/17] feat: added _mpi file with actual mpi comm. implementations --- pylops_mpi/utils/_mpi.py | 95 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 pylops_mpi/utils/_mpi.py diff --git a/pylops_mpi/utils/_mpi.py b/pylops_mpi/utils/_mpi.py new file mode 100644 index 00000000..2d08245e --- /dev/null +++ b/pylops_mpi/utils/_mpi.py @@ -0,0 +1,95 @@ +__all__ = [ + # "mpi_allgather", + "mpi_allreduce", + # "mpi_bcast", + # "mpi_asarray", + "mpi_send", + # "mpi_recv", +] + +from typing import Optional + +import numpy as np +from mpi4py import MPI +from pylops.utils.backend import get_module +from pylops_mpi.utils import deps + + +def mpi_allreduce(base_comm: MPI.Comm, + send_buf, recv_buf=None, + engine: Optional[str] = "numpy", + op: MPI.Op = MPI.SUM) -> np.ndarray: + """MPI_Allreduce/allreduce + + Dispatch allreduce routine based on type of input and availability of + CUDA-Aware MPI + + Parameters + ---------- + base_comm : :obj:`MPI.Comm` + Base MPI Communicator. + send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` + The data buffer from the local GPU to be reduced. + recv_buf : :obj:`cupy.ndarray`, optional + The buffer to store the result of the reduction. If None, + a new buffer will be allocated with the appropriate shape. + engine : :obj:`str`, optional + Engine used to store array (``numpy`` or ``cupy``) + op : :obj:mpi4py.MPI.Op, optional + The reduction operation to apply. Defaults to MPI.SUM. + + Returns + ------- + recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` + A buffer containing the result of the reduction, broadcasted + to all GPUs. + + """ + if deps.cuda_aware_mpi_enabled or engine == "numpy": + ncp = get_module(engine) + recv_buf = ncp.zeros(send_buf.size, dtype=send_buf.dtype) + base_comm.Allreduce(send_buf, recv_buf, op) + return recv_buf + else: + # CuPy with non-CUDA-aware MPI + if recv_buf is None: + return base_comm.allreduce(send_buf, op) + # For MIN and MAX which require recv_buf + base_comm.Allreduce(send_buf, recv_buf, op) + return recv_buf + + +def mpi_send(base_comm: MPI.Comm, + send_buf, dest, count, tag=0, + engine: Optional[str] = "numpy", + ) -> None: + """MPI_Send/send + + Dispatch send routine based on type of input and availability of + CUDA-Aware MPI + + Parameters + ---------- + base_comm : :obj:`MPI.Comm` + Base MPI Communicator. + send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` + The array containing data to send. + dest: :obj:`int` + The rank of the destination GPU device. + count : :obj:`int` + Number of elements to send from `send_buf`. + tag : :obj:`int` + Tag of the message to be sent. + engine : :obj:`str`, optional + Engine used to store array (``numpy`` or ``cupy``) + + """ + if deps.cuda_aware_mpi_enabled or engine == "numpy": + # Determine MPI type based on array dtype + mpi_type = MPI._typedict[send_buf.dtype.char] + if count is None: + count = send_buf.size + base_comm.Send([send_buf, count, mpi_type], dest=dest, tag=tag) + else: + # Uses CuPy without CUDA-aware MPI + base_comm.send(send_buf, dest, tag) From 838ed0b98dcbb0afdab5413685bb33608c794ba5 Mon Sep 17 00:00:00 2001 From: mrava87 Date: Sun, 7 Sep 2025 21:10:47 +0000 Subject: [PATCH 06/17] feat: moved _send to Distributed --- pylops_mpi/Distributed.py | 14 +++++++++++++- pylops_mpi/DistributedArray.py | 12 ------------ 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/pylops_mpi/Distributed.py b/pylops_mpi/Distributed.py index dccaf6a6..7384c40f 100644 --- a/pylops_mpi/Distributed.py +++ b/pylops_mpi/Distributed.py @@ -2,7 +2,7 @@ from mpi4py import MPI from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils -from pylops_mpi.utils._mpi import mpi_allreduce +from pylops_mpi.utils._mpi import mpi_allreduce, mpi_send from pylops_mpi.utils import deps cupy_message = pylops_deps.cupy_import("the DistributedArray module") @@ -43,3 +43,15 @@ def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM): else: return mpi_allreduce(self.sub_comm, send_buf, recv_buf, self.engine, op) + + def _send(self, send_buf, dest, count=None, tag=0): + """Send operation + """ + if deps.nccl_enabled and self.base_comm_nccl: + if count is None: + count = send_buf.size + nccl_send(self.base_comm_nccl, send_buf, dest, count) + else: + mpi_send(self.base_comm, + send_buf, dest, count, tag=tag, + engine=self.engine) diff --git a/pylops_mpi/DistributedArray.py b/pylops_mpi/DistributedArray.py index 6fd3ee95..cac36f6a 100644 --- a/pylops_mpi/DistributedArray.py +++ b/pylops_mpi/DistributedArray.py @@ -507,18 +507,6 @@ def _allgather_subcomm(self, send_buf, recv_buf=None): return self.sub_comm.allgather(send_buf) self.sub_comm.Allgather(send_buf, recv_buf) - def _send(self, send_buf, dest, count=None, tag=0): - """Send operation - """ - if deps.nccl_enabled and self.base_comm_nccl: - if count is None: - count = send_buf.size - nccl_send(self.base_comm_nccl, send_buf, dest, count) - else: - mpi_send(self.base_comm, - send_buf, dest, count, tag=tag, - engine=self.engine) - def _recv(self, recv_buf=None, source=0, count=None, tag=0): """Receive operation """ From ab97e3dbfaaca6ff69dfc600a9aecfe7d5f93a3d Mon Sep 17 00:00:00 2001 From: tharittk Date: Fri, 12 Sep 2025 01:47:41 -0500 Subject: [PATCH 07/17] mpi_recv for MixIn --- pylops_mpi/utils/_mpi.py | 42 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/pylops_mpi/utils/_mpi.py b/pylops_mpi/utils/_mpi.py index 2d08245e..c635acc4 100644 --- a/pylops_mpi/utils/_mpi.py +++ b/pylops_mpi/utils/_mpi.py @@ -75,7 +75,7 @@ def mpi_send(base_comm: MPI.Comm, send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` The array containing data to send. dest: :obj:`int` - The rank of the destination GPU device. + The rank of the destination CPU/GPU device. count : :obj:`int` Number of elements to send from `send_buf`. tag : :obj:`int` @@ -93,3 +93,43 @@ def mpi_send(base_comm: MPI.Comm, else: # Uses CuPy without CUDA-aware MPI base_comm.send(send_buf, dest, tag) + +def mpi_recv(base_comm: MPI.Comm, + recv_buf=None, source=0, count=None, tag=0, + engine: Optional[str] = "numpy") -> np.ndarray: + """ MPI_Recv/recv + Dispatch receive routine based on type of input and availability of + CUDA-Aware MPI + + Parameters + ---------- + base_comm : :obj:`MPI.Comm` + Base MPI Communicator. + recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`, optional + The buffered array to receive data. + source : :obj:`int` + The rank of the sending CPU/GPU device. + count : :obj:`int` + Number of elements to receive. + tag : :obj:`int` + Tag of the message to be sent. + engine : :obj:`str`, optional + Engine used to store array (``numpy`` or ``cupy``) + + """ + if deps.cuda_aware_mpi_enabled or engine == "numpy": + ncp = get_module(engine) + if recv_buf is None: + if count is None: + raise ValueError("Must provide either recv_buf or count for MPI receive") + # Default to int32 works currently because add_ghost_cells() is called + # with recv_buf and is not affected by this branch. The int32 is for when + # dimension or shape-related integers are send/recv + recv_buf = ncp.zeros(count, dtype=ncp.int32) + mpi_type = MPI._typedict[recv_buf.dtype.char] + base_comm.Recv([recv_buf, recv_buf.size, mpi_type], source=source, tag=tag) + else: + # Uses CuPy without CUDA-aware MPI + recv_buf = base_comm.recv(source=source, tag=tag) + return recv_buf + From dbe1f30e3ab3d1d17f4ef2d551ed18df78e84a7f Mon Sep 17 00:00:00 2001 From: tharittk Date: Fri, 12 Sep 2025 02:53:27 -0500 Subject: [PATCH 08/17] MixIn for allgather. --- pylops_mpi/Distributed.py | 55 +++++++++++++-- pylops_mpi/DistributedArray.py | 64 +----------------- pylops_mpi/utils/_mpi.py | 109 ++++++++++++++++++++++++++++-- pylops_mpi/utils/_nccl.py | 89 +----------------------- tests_nccl/test_ncclutils_nccl.py | 7 +- 5 files changed, 163 insertions(+), 161 deletions(-) diff --git a/pylops_mpi/Distributed.py b/pylops_mpi/Distributed.py index 7384c40f..cc86e7d4 100644 --- a/pylops_mpi/Distributed.py +++ b/pylops_mpi/Distributed.py @@ -1,8 +1,8 @@ -from typing import Any, NewType +from typing import Any, NewType, Tuple from mpi4py import MPI from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils -from pylops_mpi.utils._mpi import mpi_allreduce, mpi_send +from pylops_mpi.utils._mpi import mpi_allreduce, mpi_allgather, mpi_send, mpi_recv, _prepare_allgather_inputs, _unroll_allgather_recv from pylops_mpi.utils import deps cupy_message = pylops_deps.cupy_import("the DistributedArray module") @@ -11,11 +11,9 @@ if nccl_message is None and cupy_message is None: from pylops_mpi.utils._nccl import ( nccl_allgather, nccl_allreduce, - nccl_asarray, nccl_bcast, nccl_split, nccl_send, nccl_recv, - _prepare_nccl_allgather_inputs, _unroll_nccl_allgather_recv + nccl_asarray, nccl_bcast, nccl_split, nccl_send, nccl_recv ) - class DistributedMixIn: r"""Distributed Mixin class @@ -44,6 +42,36 @@ def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM): return mpi_allreduce(self.sub_comm, send_buf, recv_buf, self.engine, op) + def _allgather(self, send_buf, recv_buf=None): + """Allgather operation + """ + if deps.nccl_enabled and self.base_comm_nccl: + if isinstance(send_buf, (tuple, list, int)): + return nccl_allgather(self.base_comm_nccl, send_buf, recv_buf) + else: + send_shapes = self.base_comm.allgather(send_buf.shape) + (padded_send, padded_recv) = _prepare_allgather_inputs(send_buf, send_shapes, engine="cupy") + raw_recv = nccl_allgather(self.base_comm_nccl, padded_send, recv_buf if recv_buf else padded_recv) + return _unroll_allgather_recv(raw_recv, padded_send.shape, send_shapes) + else: + if isinstance(send_buf, (tuple, list, int)): + return self.base_comm.allgather(send_buf) + return mpi_allgather(self.base_comm, send_buf, recv_buf, self.engine) + + def _allgather_subcomm(self, send_buf, recv_buf=None): + """Allgather operation with subcommunicator + """ + if deps.nccl_enabled and getattr(self, "base_comm_nccl"): + if isinstance(send_buf, (tuple, list, int)): + return nccl_allgather(self.sub_comm, send_buf, recv_buf) + else: + send_shapes = self._allgather_subcomm(send_buf.shape) + (padded_send, padded_recv) = _prepare_allgather_inputs(send_buf, send_shapes, engine="cupy") + raw_recv = nccl_allgather(self.sub_comm, padded_send, recv_buf if recv_buf else padded_recv) + return _unroll_allgather_recv(raw_recv, padded_send.shape, send_shapes) + else: + return mpi_allgather(self.sub_comm, send_buf, recv_buf, self.engine) + def _send(self, send_buf, dest, count=None, tag=0): """Send operation """ @@ -55,3 +83,20 @@ def _send(self, send_buf, dest, count=None, tag=0): mpi_send(self.base_comm, send_buf, dest, count, tag=tag, engine=self.engine) + + def _recv(self, recv_buf=None, source=0, count=None, tag=0): + """Receive operation + """ + if deps.nccl_enabled and self.base_comm_nccl: + if recv_buf is None: + raise ValueError("recv_buf must be supplied when using NCCL") + if count is None: + count = recv_buf.size + nccl_recv(self.base_comm_nccl, recv_buf, source, count) + return recv_buf + else: + return mpi_recv(self.base_comm, + recv_buf, source, count, tag=tag, + engine=self.engine) + + diff --git a/pylops_mpi/DistributedArray.py b/pylops_mpi/DistributedArray.py index cac36f6a..d3cb70f0 100644 --- a/pylops_mpi/DistributedArray.py +++ b/pylops_mpi/DistributedArray.py @@ -9,14 +9,13 @@ from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils from pylops.utils._internal import _value_or_sized_to_tuple from pylops.utils.backend import get_array_module, get_module, get_module_name -from pylops_mpi.utils._mpi import mpi_allreduce, mpi_send from pylops_mpi.utils import deps cupy_message = pylops_deps.cupy_import("the DistributedArray module") nccl_message = deps.nccl_import("the DistributedArray module") if nccl_message is None and cupy_message is None: - from pylops_mpi.utils._nccl import nccl_allgather, nccl_allreduce, nccl_asarray, nccl_bcast, nccl_split, nccl_send, nccl_recv, _prepare_nccl_allgather_inputs, _unroll_nccl_allgather_recv + from pylops_mpi.utils._nccl import nccl_asarray, nccl_bcast, nccl_split from cupy.cuda.nccl import NcclCommunicator else: NcclCommunicator = Any @@ -474,67 +473,6 @@ def _check_mask(self, dist_array): if not np.array_equal(self.mask, dist_array.mask): raise ValueError("Mask of both the arrays must be same") - def _allgather(self, send_buf, recv_buf=None): - """Allgather operation - """ - if deps.nccl_enabled and self.base_comm_nccl: - if isinstance(send_buf, (tuple, list, int)): - return nccl_allgather(self.base_comm_nccl, send_buf, recv_buf) - else: - send_shapes = self.base_comm.allgather(send_buf.shape) - (padded_send, padded_recv) = _prepare_nccl_allgather_inputs(send_buf, send_shapes) - raw_recv = nccl_allgather(self.base_comm_nccl, padded_send, recv_buf if recv_buf else padded_recv) - return _unroll_nccl_allgather_recv(raw_recv, padded_send.shape, send_shapes) - else: - if recv_buf is None: - return self.base_comm.allgather(send_buf) - self.base_comm.Allgather(send_buf, recv_buf) - return recv_buf - - def _allgather_subcomm(self, send_buf, recv_buf=None): - """Allgather operation with subcommunicator - """ - if deps.nccl_enabled and getattr(self, "base_comm_nccl"): - if isinstance(send_buf, (tuple, list, int)): - return nccl_allgather(self.sub_comm, send_buf, recv_buf) - else: - send_shapes = self._allgather_subcomm(send_buf.shape) - (padded_send, padded_recv) = _prepare_nccl_allgather_inputs(send_buf, send_shapes) - raw_recv = nccl_allgather(self.sub_comm, padded_send, recv_buf if recv_buf else padded_recv) - return _unroll_nccl_allgather_recv(raw_recv, padded_send.shape, send_shapes) - else: - if recv_buf is None: - return self.sub_comm.allgather(send_buf) - self.sub_comm.Allgather(send_buf, recv_buf) - - def _recv(self, recv_buf=None, source=0, count=None, tag=0): - """Receive operation - """ - if deps.nccl_enabled and self.base_comm_nccl: - if recv_buf is None: - raise ValueError("recv_buf must be supplied when using NCCL") - if count is None: - count = recv_buf.size - nccl_recv(self.base_comm_nccl, recv_buf, source, count) - return recv_buf - else: - # NumPy + MPI will benefit from buffered communication regardless of MPI installation - if deps.cuda_aware_mpi_enabled or self.engine == "numpy": - ncp = get_module(self.engine) - if recv_buf is None: - if count is None: - raise ValueError("Must provide either recv_buf or count for MPI receive") - # Default to int32 works currently because add_ghost_cells() is called - # with recv_buf and is not affected by this branch. The int32 is for when - # dimension or shape-related integers are send/recv - recv_buf = ncp.zeros(count, dtype=ncp.int32) - mpi_type = MPI._typedict[recv_buf.dtype.char] - self.base_comm.Recv([recv_buf, recv_buf.size, mpi_type], source=source, tag=tag) - else: - # Uses CuPy without CUDA-aware MPI - recv_buf = self.base_comm.recv(source=source, tag=tag) - return recv_buf - def _nccl_local_shapes(self, masked: bool): """Get the the list of shapes of every GPU in the communicator """ diff --git a/pylops_mpi/utils/_mpi.py b/pylops_mpi/utils/_mpi.py index c635acc4..33cfe270 100644 --- a/pylops_mpi/utils/_mpi.py +++ b/pylops_mpi/utils/_mpi.py @@ -1,19 +1,100 @@ __all__ = [ - # "mpi_allgather", + "mpi_allgather", "mpi_allreduce", # "mpi_bcast", # "mpi_asarray", "mpi_send", - # "mpi_recv", + "mpi_recv", + "_prepare_allgather_inputs", + "_unroll_allgather_recv" ] -from typing import Optional +from typing import Optional, Tuple import numpy as np from mpi4py import MPI from pylops.utils.backend import get_module from pylops_mpi.utils import deps +# TODO: return type annotation for both cupy and numpy +def _prepare_allgather_inputs(send_buf, send_buf_shapes, engine): + r""" Prepare send_buf and recv_buf for NCCL allgather (nccl_allgather) + + Buffered Allgather (MPI and NCCL) requires the sending buffer to have the same size for every device. + Therefore, padding is required when the array is not evenly partitioned across + all the ranks. The padding is applied such that the each dimension of the sending buffers + is equal to the max size of that dimension across all ranks. + + Similarly, each receiver buffer (recv_buf) is created with size equal to :math:n_rank \cdot send_buf.size + + Parameters + ---------- + send_buf : :obj: `numpy.ndarray` or `cupy.ndarray` or array-like + The data buffer from the local GPU to be sent for allgather. + send_buf_shapes: :obj:`list` + A list of shapes for each GPU send_buf (used to calculate padding size) + engine : :obj:`str` + Engine used to store array (``numpy`` or ``cupy``) + + Returns + ------- + send_buf: :obj:`cupy.ndarray` + A buffer containing the data and padded elements to be sent by this rank. + recv_buf : :obj:`cupy.ndarray` + An empty, padded buffer to gather data from all GPUs. + """ + ncp = get_module(engine) + sizes_each_dim = list(zip(*send_buf_shapes)) + send_shape = tuple(map(max, sizes_each_dim)) + pad_size = [ + (0, s_shape - l_shape) for s_shape, l_shape in zip(send_shape, send_buf.shape) + ] + + send_buf = ncp.pad( + send_buf, pad_size, mode="constant", constant_values=0 + ) + + ndev = len(send_buf_shapes) + recv_buf = ncp.zeros(ndev * send_buf.size, dtype=send_buf.dtype) + + return send_buf, recv_buf + + +def _unroll_allgather_recv(recv_buf, padded_send_buf_shape, send_buf_shapes) -> list: + r"""Unrolll recv_buf after Buffered Allgather (MPI and NCCL) + + Remove the padded elements in recv_buff, extract an individual array from each device and return them as a list of arrays + Each GPU may send array with a different shape, so the return type has to be a list of array + instead of the concatenated array. + + Parameters + ---------- + recv_buf: :obj:`cupy.ndarray` or array-like + The data buffer returned from nccl_allgather call + padded_send_buf_shape: :obj:`tuple`:int + The size of send_buf after padding used in nccl_allgather + send_buf_shapes: :obj:`list` + A list of original shapes for each GPU send_buf prior to padding + + Returns + ------- + chunks: :obj:`list` + A list of `cupy.ndarray` from each GPU with the padded element removed + """ + ndev = len(send_buf_shapes) + # extract an individual array from each device + chunk_size = np.prod(padded_send_buf_shape) + chunks = [ + recv_buf[i * chunk_size:(i + 1) * chunk_size] for i in range(ndev) + ] + + # Remove padding from each array: the padded value may appear somewhere + # in the middle of the flat array and thus the reshape and slicing for each dimension is required + for i in range(ndev): + slicing = tuple(slice(0, end) for end in send_buf_shapes[i]) + chunks[i] = chunks[i].reshape(padded_send_buf_shape)[slicing] + + return chunks def mpi_allreduce(base_comm: MPI.Comm, send_buf, recv_buf=None, @@ -57,7 +138,27 @@ def mpi_allreduce(base_comm: MPI.Comm, # For MIN and MAX which require recv_buf base_comm.Allreduce(send_buf, recv_buf, op) return recv_buf - + + +def mpi_allgather(base_comm: MPI.Comm, + send_buf, recv_buf=None, + engine: Optional[str] = "numpy", + ) -> np.ndarray: + + if deps.cuda_aware_mpi_enabled or engine == "numpy": + send_shapes = base_comm.allgather(send_buf.shape) + (padded_send, padded_recv) = _prepare_allgather_inputs(send_buf, send_shapes, engine=engine) + recv_buffer_to_use = recv_buf if recv_buf else padded_recv + base_comm.Allgather(padded_send, recv_buffer_to_use) + return _unroll_allgather_recv(recv_buffer_to_use, padded_send.shape, send_shapes) + + else: + # CuPy with non-CUDA-aware MPI + if recv_buf is None: + return base_comm.allgather(send_buf) + base_comm.Allgather(send_buf, recv_buf) + return recv_buf + def mpi_send(base_comm: MPI.Comm, send_buf, dest, count, tag=0, diff --git a/pylops_mpi/utils/_nccl.py b/pylops_mpi/utils/_nccl.py index 19c09922..0eb6cde1 100644 --- a/pylops_mpi/utils/_nccl.py +++ b/pylops_mpi/utils/_nccl.py @@ -1,6 +1,4 @@ __all__ = [ - "_prepare_nccl_allgather_inputs", - "_unroll_nccl_allgather_recv", "_nccl_sync", "initialize_nccl_comm", "nccl_split", @@ -13,12 +11,11 @@ ] from enum import IntEnum -from typing import Tuple from mpi4py import MPI import os -import numpy as np import cupy as cp import cupy.cuda.nccl as nccl +from pylops_mpi.utils._mpi import _prepare_allgather_inputs, _unroll_allgather_recv cupy_to_nccl_dtype = { "float32": nccl.NCCL_FLOAT32, @@ -69,86 +66,6 @@ def _nccl_sync(): return cp.cuda.runtime.deviceSynchronize() - -def _prepare_nccl_allgather_inputs(send_buf, send_buf_shapes) -> Tuple[cp.ndarray, cp.ndarray]: - r""" Prepare send_buf and recv_buf for NCCL allgather (nccl_allgather) - - NCCL's allGather requires the sending buffer to have the same size for every device. - Therefore, padding is required when the array is not evenly partitioned across - all the ranks. The padding is applied such that the each dimension of the sending buffers - is equal to the max size of that dimension across all ranks. - - Similarly, each receiver buffer (recv_buf) is created with size equal to :math:n_rank \cdot send_buf.size - - Parameters - ---------- - send_buf : :obj:`cupy.ndarray` or array-like - The data buffer from the local GPU to be sent for allgather. - send_buf_shapes: :obj:`list` - A list of shapes for each GPU send_buf (used to calculate padding size) - - Returns - ------- - send_buf: :obj:`cupy.ndarray` - A buffer containing the data and padded elements to be sent by this rank. - recv_buf : :obj:`cupy.ndarray` - An empty, padded buffer to gather data from all GPUs. - """ - sizes_each_dim = list(zip(*send_buf_shapes)) - send_shape = tuple(map(max, sizes_each_dim)) - pad_size = [ - (0, s_shape - l_shape) for s_shape, l_shape in zip(send_shape, send_buf.shape) - ] - - send_buf = cp.pad( - send_buf, pad_size, mode="constant", constant_values=0 - ) - - # NCCL recommends to use one MPI Process per GPU and so size of receiving buffer can be inferred - ndev = len(send_buf_shapes) - recv_buf = cp.zeros(ndev * send_buf.size, dtype=send_buf.dtype) - - return send_buf, recv_buf - - -def _unroll_nccl_allgather_recv(recv_buf, padded_send_buf_shape, send_buf_shapes) -> list: - """Unrolll recv_buf after NCCL allgather (nccl_allgather) - - Remove the padded elements in recv_buff, extract an individual array from each device and return them as a list of arrays - Each GPU may send array with a different shape, so the return type has to be a list of array - instead of the concatenated array. - - Parameters - ---------- - recv_buf: :obj:`cupy.ndarray` or array-like - The data buffer returned from nccl_allgather call - padded_send_buf_shape: :obj:`tuple`:int - The size of send_buf after padding used in nccl_allgather - send_buf_shapes: :obj:`list` - A list of original shapes for each GPU send_buf prior to padding - - Returns - ------- - chunks: :obj:`list` - A list of `cupy.ndarray` from each GPU with the padded element removed - """ - - ndev = len(send_buf_shapes) - # extract an individual array from each device - chunk_size = np.prod(padded_send_buf_shape) - chunks = [ - recv_buf[i * chunk_size:(i + 1) * chunk_size] for i in range(ndev) - ] - - # Remove padding from each array: the padded value may appear somewhere - # in the middle of the flat array and thus the reshape and slicing for each dimension is required - for i in range(ndev): - slicing = tuple(slice(0, end) for end in send_buf_shapes[i]) - chunks[i] = chunks[i].reshape(padded_send_buf_shape)[slicing] - - return chunks - - def mpi_op_to_nccl(mpi_op) -> NcclOp: """ Map MPI reduction operation to NCCL equivalent @@ -363,9 +280,9 @@ def nccl_asarray(nccl_comm, local_array, local_shapes, axis) -> cp.ndarray: Global array gathered from all GPUs and concatenated along `axis`. """ - send_buf, recv_buf = _prepare_nccl_allgather_inputs(local_array, local_shapes) + send_buf, recv_buf = _prepare_allgather_inputs(local_array, local_shapes, engine="cupy") nccl_allgather(nccl_comm, send_buf, recv_buf) - chunks = _unroll_nccl_allgather_recv(recv_buf, send_buf.shape, local_shapes) + chunks = _unroll_allgather_recv(recv_buf, send_buf.shape, local_shapes) # combine back to single global array return cp.concatenate(chunks, axis=axis) diff --git a/tests_nccl/test_ncclutils_nccl.py b/tests_nccl/test_ncclutils_nccl.py index 21b28ca3..52502afc 100644 --- a/tests_nccl/test_ncclutils_nccl.py +++ b/tests_nccl/test_ncclutils_nccl.py @@ -8,7 +8,8 @@ from numpy.testing import assert_allclose import pytest -from pylops_mpi.utils._nccl import initialize_nccl_comm, nccl_allgather, _prepare_nccl_allgather_inputs, _unroll_nccl_allgather_recv +from pylops_mpi.utils._nccl import initialize_nccl_comm, nccl_allgather +from pylops_mpi.utils._mpi import _prepare_allgather_inputs, _unroll_allgather_recv np.random.seed(42) @@ -83,9 +84,9 @@ def test_allgather_differentsize_withrecbuf(par): # Gathered array send_shapes = MPI.COMM_WORLD.allgather(local_array.shape) - send_buf, recv_buf = _prepare_nccl_allgather_inputs(local_array, send_shapes) + send_buf, recv_buf = _prepare_allgather_inputs(local_array, send_shapes, engine="cupy") recv_buf = nccl_allgather(nccl_comm, send_buf, recv_buf) - chunks = _unroll_nccl_allgather_recv(recv_buf, send_buf.shape, send_shapes) + chunks = _unroll_allgather_recv(recv_buf, send_buf.shape, send_shapes) gathered_array = cp.concatenate(chunks) # Compare with global array created in rank0 From a08924bf18389a666d4e96b1eef845d76c6e2b46 Mon Sep 17 00:00:00 2001 From: tharittk Date: Fri, 12 Sep 2025 03:07:05 -0500 Subject: [PATCH 09/17] fix flake8 --- pylops_mpi/Distributed.py | 16 +++++-------- pylops_mpi/DistributedArray.py | 6 ++--- pylops_mpi/utils/_mpi.py | 41 +++++++++++++++++----------------- pylops_mpi/utils/_nccl.py | 1 + 4 files changed, 30 insertions(+), 34 deletions(-) diff --git a/pylops_mpi/Distributed.py b/pylops_mpi/Distributed.py index cc86e7d4..8876bfc1 100644 --- a/pylops_mpi/Distributed.py +++ b/pylops_mpi/Distributed.py @@ -1,5 +1,3 @@ -from typing import Any, NewType, Tuple - from mpi4py import MPI from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils from pylops_mpi.utils._mpi import mpi_allreduce, mpi_allgather, mpi_send, mpi_recv, _prepare_allgather_inputs, _unroll_allgather_recv @@ -10,10 +8,10 @@ if nccl_message is None and cupy_message is None: from pylops_mpi.utils._nccl import ( - nccl_allgather, nccl_allreduce, - nccl_asarray, nccl_bcast, nccl_split, nccl_send, nccl_recv + nccl_allgather, nccl_allreduce, nccl_send, nccl_recv ) + class DistributedMixIn: r"""Distributed Mixin class @@ -30,7 +28,7 @@ def _allreduce(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM): if deps.nccl_enabled and getattr(self, "base_comm_nccl"): return nccl_allreduce(self.base_comm_nccl, send_buf, recv_buf, op) else: - return mpi_allreduce(self.base_comm, send_buf, + return mpi_allreduce(self.base_comm, send_buf, recv_buf, self.engine, op) def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM): @@ -39,7 +37,7 @@ def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM): if deps.nccl_enabled and getattr(self, "base_comm_nccl"): return nccl_allreduce(self.sub_comm, send_buf, recv_buf, op) else: - return mpi_allreduce(self.sub_comm, send_buf, + return mpi_allreduce(self.sub_comm, send_buf, recv_buf, self.engine, op) def _allgather(self, send_buf, recv_buf=None): @@ -96,7 +94,5 @@ def _recv(self, recv_buf=None, source=0, count=None, tag=0): return recv_buf else: return mpi_recv(self.base_comm, - recv_buf, source, count, tag=tag, - engine=self.engine) - - + recv_buf, source, count, tag=tag, + engine=self.engine) diff --git a/pylops_mpi/DistributedArray.py b/pylops_mpi/DistributedArray.py index d3cb70f0..faa780ac 100644 --- a/pylops_mpi/DistributedArray.py +++ b/pylops_mpi/DistributedArray.py @@ -15,7 +15,7 @@ nccl_message = deps.nccl_import("the DistributedArray module") if nccl_message is None and cupy_message is None: - from pylops_mpi.utils._nccl import nccl_asarray, nccl_bcast, nccl_split + from pylops_mpi.utils._nccl import nccl_asarray, nccl_bcast, nccl_split from cupy.cuda.nccl import NcclCommunicator else: NcclCommunicator = Any @@ -613,14 +613,14 @@ def _compute_vector_norm(self, local_array: NDArray, # with MAX, MIN operator. Here we copy the array back to CPU, transfer, and copy them back to GPUs send_buf = ncp.max(ncp.abs(local_array), axis=axis).astype(ncp.float64) if self.engine == "cupy" and self.base_comm_nccl is None and not deps.cuda_aware_mpi_enabled: - # CuPy + non-CUDA-aware MPI: This will call non-buffered communication + # CuPy + non-CUDA-aware MPI: This will call non-buffered communication # which return a list of object - must be copied back to a GPU memory. recv_buf = self._allreduce_subcomm(send_buf.get(), recv_buf.get(), op=MPI.MAX) recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis)) else: recv_buf = self._allreduce_subcomm(send_buf, recv_buf, op=MPI.MAX) # TODO (tharitt): In current implementation, there seems to be a semantic difference between Buffered MPI and NCCL - # the (1, size) is collapsed to (size, ) with buffered MPI while NCCL retains it. + # the (1, size) is collapsed to (size, ) with buffered MPI while NCCL retains it. # There may be a way to unify it - may be something to do with how we allocate the recv_buf. if self.base_comm_nccl: recv_buf = ncp.squeeze(recv_buf, axis=axis) diff --git a/pylops_mpi/utils/_mpi.py b/pylops_mpi/utils/_mpi.py index 33cfe270..e3520c94 100644 --- a/pylops_mpi/utils/_mpi.py +++ b/pylops_mpi/utils/_mpi.py @@ -9,13 +9,14 @@ "_unroll_allgather_recv" ] -from typing import Optional, Tuple +from typing import Optional import numpy as np from mpi4py import MPI from pylops.utils.backend import get_module from pylops_mpi.utils import deps + # TODO: return type annotation for both cupy and numpy def _prepare_allgather_inputs(send_buf, send_buf_shapes, engine): r""" Prepare send_buf and recv_buf for NCCL allgather (nccl_allgather) @@ -33,7 +34,7 @@ def _prepare_allgather_inputs(send_buf, send_buf_shapes, engine): The data buffer from the local GPU to be sent for allgather. send_buf_shapes: :obj:`list` A list of shapes for each GPU send_buf (used to calculate padding size) - engine : :obj:`str` + engine : :obj:`str` Engine used to store array (``numpy`` or ``cupy``) Returns @@ -96,20 +97,21 @@ def _unroll_allgather_recv(recv_buf, padded_send_buf_shape, send_buf_shapes) -> return chunks + def mpi_allreduce(base_comm: MPI.Comm, - send_buf, recv_buf=None, + send_buf, recv_buf=None, engine: Optional[str] = "numpy", op: MPI.Op = MPI.SUM) -> np.ndarray: - """MPI_Allreduce/allreduce - - Dispatch allreduce routine based on type of input and availability of + """MPI_Allreduce/allreduce + + Dispatch allreduce routine based on type of input and availability of CUDA-Aware MPI Parameters ---------- base_comm : :obj:`MPI.Comm` Base MPI Communicator. - send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` + send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` The data buffer from the local GPU to be reduced. recv_buf : :obj:`cupy.ndarray`, optional The buffer to store the result of the reduction. If None, @@ -121,10 +123,10 @@ def mpi_allreduce(base_comm: MPI.Comm, Returns ------- - recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` + recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` A buffer containing the result of the reduction, broadcasted to all GPUs. - + """ if deps.cuda_aware_mpi_enabled or engine == "numpy": ncp = get_module(engine) @@ -141,9 +143,8 @@ def mpi_allreduce(base_comm: MPI.Comm, def mpi_allgather(base_comm: MPI.Comm, - send_buf, recv_buf=None, - engine: Optional[str] = "numpy", - ) -> np.ndarray: + send_buf, recv_buf=None, + engine: Optional[str] = "numpy") -> np.ndarray: if deps.cuda_aware_mpi_enabled or engine == "numpy": send_shapes = base_comm.allgather(send_buf.shape) @@ -165,15 +166,15 @@ def mpi_send(base_comm: MPI.Comm, engine: Optional[str] = "numpy", ) -> None: """MPI_Send/send - - Dispatch send routine based on type of input and availability of + + Dispatch send routine based on type of input and availability of CUDA-Aware MPI Parameters ---------- base_comm : :obj:`MPI.Comm` Base MPI Communicator. - send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` + send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` The array containing data to send. dest: :obj:`int` The rank of the destination CPU/GPU device. @@ -183,7 +184,6 @@ def mpi_send(base_comm: MPI.Comm, Tag of the message to be sent. engine : :obj:`str`, optional Engine used to store array (``numpy`` or ``cupy``) - """ if deps.cuda_aware_mpi_enabled or engine == "numpy": # Determine MPI type based on array dtype @@ -195,11 +195,12 @@ def mpi_send(base_comm: MPI.Comm, # Uses CuPy without CUDA-aware MPI base_comm.send(send_buf, dest, tag) + def mpi_recv(base_comm: MPI.Comm, - recv_buf=None, source=0, count=None, tag=0, - engine: Optional[str] = "numpy") -> np.ndarray: + recv_buf=None, source=0, count=None, tag=0, + engine: Optional[str] = "numpy") -> np.ndarray: """ MPI_Recv/recv - Dispatch receive routine based on type of input and availability of + Dispatch receive routine based on type of input and availability of CUDA-Aware MPI Parameters @@ -216,7 +217,6 @@ def mpi_recv(base_comm: MPI.Comm, Tag of the message to be sent. engine : :obj:`str`, optional Engine used to store array (``numpy`` or ``cupy``) - """ if deps.cuda_aware_mpi_enabled or engine == "numpy": ncp = get_module(engine) @@ -233,4 +233,3 @@ def mpi_recv(base_comm: MPI.Comm, # Uses CuPy without CUDA-aware MPI recv_buf = base_comm.recv(source=source, tag=tag) return recv_buf - diff --git a/pylops_mpi/utils/_nccl.py b/pylops_mpi/utils/_nccl.py index 0eb6cde1..5f297531 100644 --- a/pylops_mpi/utils/_nccl.py +++ b/pylops_mpi/utils/_nccl.py @@ -66,6 +66,7 @@ def _nccl_sync(): return cp.cuda.runtime.deviceSynchronize() + def mpi_op_to_nccl(mpi_op) -> NcclOp: """ Map MPI reduction operation to NCCL equivalent From b8bcd295c946967923d823693ec0447f4a8c3ef3 Mon Sep 17 00:00:00 2001 From: mrava87SW Date: Tue, 23 Sep 2025 20:54:55 +0000 Subject: [PATCH 10/17] feat: added _bcast to DistributedMixIn and added comms as input for all methods --- pylops_mpi/Distributed.py | 54 +++++++++++++++++++++++++-------------- 1 file changed, 35 insertions(+), 19 deletions(-) diff --git a/pylops_mpi/Distributed.py b/pylops_mpi/Distributed.py index 8876bfc1..7e940b84 100644 --- a/pylops_mpi/Distributed.py +++ b/pylops_mpi/Distributed.py @@ -1,6 +1,6 @@ from mpi4py import MPI from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils -from pylops_mpi.utils._mpi import mpi_allreduce, mpi_allgather, mpi_send, mpi_recv, _prepare_allgather_inputs, _unroll_allgather_recv +from pylops_mpi.utils._mpi import mpi_allreduce, mpi_allgather, mpi_bcast, mpi_send, mpi_recv, _prepare_allgather_inputs, _unroll_allgather_recv from pylops_mpi.utils import deps cupy_message = pylops_deps.cupy_import("the DistributedArray module") @@ -8,7 +8,7 @@ if nccl_message is None and cupy_message is None: from pylops_mpi.utils._nccl import ( - nccl_allgather, nccl_allreduce, nccl_send, nccl_recv + nccl_allgather, nccl_allreduce, nccl_bcast, nccl_send, nccl_recv ) @@ -22,39 +22,45 @@ class DistributedMixIn: MPI installation is available, the latter with CuPy arrays when a CUDA-Aware MPI installation is not available). """ - def _allreduce(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM): + def _allreduce(self, base_comm, base_comm_nccl, + send_buf, recv_buf=None, op: MPI.Op = MPI.SUM, + engine="numpy"): """Allreduce operation """ - if deps.nccl_enabled and getattr(self, "base_comm_nccl"): - return nccl_allreduce(self.base_comm_nccl, send_buf, recv_buf, op) + if deps.nccl_enabled and base_comm_nccl is not None: + return nccl_allreduce(base_comm_nccl, send_buf, recv_buf, op) else: - return mpi_allreduce(self.base_comm, send_buf, - recv_buf, self.engine, op) + return mpi_allreduce(base_comm, send_buf, + recv_buf, engine, op) - def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM): + def _allreduce_subcomm(self, sub_comm, base_comm_nccl, + send_buf, recv_buf=None, op: MPI.Op = MPI.SUM, + engine="numpy"): """Allreduce operation with subcommunicator """ - if deps.nccl_enabled and getattr(self, "base_comm_nccl"): - return nccl_allreduce(self.sub_comm, send_buf, recv_buf, op) + if deps.nccl_enabled and base_comm_nccl is not None: + return nccl_allreduce(sub_comm, send_buf, recv_buf, op) else: - return mpi_allreduce(self.sub_comm, send_buf, - recv_buf, self.engine, op) + return mpi_allreduce(sub_comm, send_buf, + recv_buf, engine, op) - def _allgather(self, send_buf, recv_buf=None): + def _allgather(self, base_comm, base_comm_nccl, + send_buf, recv_buf=None, + engine="numpy"): """Allgather operation """ - if deps.nccl_enabled and self.base_comm_nccl: + if deps.nccl_enabled and base_comm_nccl is not None: if isinstance(send_buf, (tuple, list, int)): - return nccl_allgather(self.base_comm_nccl, send_buf, recv_buf) + return nccl_allgather(base_comm_nccl, send_buf, recv_buf) else: - send_shapes = self.base_comm.allgather(send_buf.shape) + send_shapes = base_comm.allgather(send_buf.shape) (padded_send, padded_recv) = _prepare_allgather_inputs(send_buf, send_shapes, engine="cupy") - raw_recv = nccl_allgather(self.base_comm_nccl, padded_send, recv_buf if recv_buf else padded_recv) + raw_recv = nccl_allgather(base_comm_nccl, padded_send, recv_buf if recv_buf else padded_recv) return _unroll_allgather_recv(raw_recv, padded_send.shape, send_shapes) else: if isinstance(send_buf, (tuple, list, int)): - return self.base_comm.allgather(send_buf) - return mpi_allgather(self.base_comm, send_buf, recv_buf, self.engine) + return base_comm.allgather(send_buf) + return mpi_allgather(base_comm, send_buf, recv_buf, engine) def _allgather_subcomm(self, send_buf, recv_buf=None): """Allgather operation with subcommunicator @@ -70,6 +76,16 @@ def _allgather_subcomm(self, send_buf, recv_buf=None): else: return mpi_allgather(self.sub_comm, send_buf, recv_buf, self.engine) + def _bcast(self, local_array, index, value): + """BCast operation + """ + if deps.nccl_enabled and getattr(self, "base_comm_nccl"): + nccl_bcast(self.base_comm_nccl, local_array, index, value) + else: + # self.local_array[index] = self.base_comm.bcast(value) + mpi_bcast(self.base_comm, self.rank, self.local_array, index, value, + engine=self.engine) + def _send(self, send_buf, dest, count=None, tag=0): """Send operation """ From f362436909db7c856e6fcde26f538455c7131879 Mon Sep 17 00:00:00 2001 From: mrava87SW Date: Tue, 23 Sep 2025 20:56:18 +0000 Subject: [PATCH 11/17] feat: adapted all comm calls in DistributedArray to new method signatures --- pylops_mpi/DistributedArray.py | 49 +++++++++++++++++++++++----------- 1 file changed, 34 insertions(+), 15 deletions(-) diff --git a/pylops_mpi/DistributedArray.py b/pylops_mpi/DistributedArray.py index faa780ac..75b66c7e 100644 --- a/pylops_mpi/DistributedArray.py +++ b/pylops_mpi/DistributedArray.py @@ -15,7 +15,7 @@ nccl_message = deps.nccl_import("the DistributedArray module") if nccl_message is None and cupy_message is None: - from pylops_mpi.utils._nccl import nccl_asarray, nccl_bcast, nccl_split + from pylops_mpi.utils._nccl import nccl_asarray, nccl_split from cupy.cuda.nccl import NcclCommunicator else: NcclCommunicator = Any @@ -204,10 +204,7 @@ def __setitem__(self, index, value): the specified index positions. """ if self.partition is Partition.BROADCAST: - if deps.nccl_enabled and getattr(self, "base_comm_nccl"): - nccl_bcast(self.base_comm_nccl, self.local_array, index, value) - else: - self.local_array[index] = self.base_comm.bcast(value) + self._bcast(self.local_array, index, value) else: self.local_array[index] = value @@ -343,7 +340,9 @@ def local_shapes(self): if deps.nccl_enabled and getattr(self, "base_comm_nccl"): return self._nccl_local_shapes(False) else: - return self._allgather(self.local_shape) + return self._allgather(self.base_comm, + self.base_comm_nccl, + self.local_shape) @property def sub_comm(self): @@ -383,7 +382,10 @@ def asarray(self, masked: bool = False): if masked: final_array = self._allgather_subcomm(self.local_array) else: - final_array = self._allgather(self.local_array) + final_array = self._allgather(self.base_comm, + self.base_comm_nccl, + self.local_array, + engine=self.engine) return np.concatenate(final_array, axis=self.axis) @classmethod @@ -433,6 +435,7 @@ def to_dist(cls, x: NDArray, else: slices = [slice(None)] * x.ndim local_shapes = np.append([0], dist_array._allgather( + base_comm, base_comm_nccl, dist_array.local_shape[axis])) sum_shapes = np.cumsum(local_shapes) slices[axis] = slice(sum_shapes[dist_array.rank], @@ -480,7 +483,9 @@ def _nccl_local_shapes(self, masked: bool): if masked: all_tuples = self._allgather_subcomm(self.local_shape).get() else: - all_tuples = self._allgather(self.local_shape).get() + all_tuples = self._allgather(self.base_comm, + self.base_comm_nccl, + self.local_shape).get() # NCCL returns the flat array that packs every tuple as 1-dimensional array # unpack each tuple from each rank tuple_len = len(self.local_shape) @@ -578,7 +583,9 @@ def dot(self, dist_array): y = DistributedArray.to_dist(x=dist_array.local_array, base_comm=self.base_comm, base_comm_nccl=self.base_comm_nccl) \ if self.partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST] else dist_array # Flatten the local arrays and calculate dot product - return self._allreduce_subcomm(ncp.dot(x.local_array.flatten(), y.local_array.flatten())) + return self._allreduce_subcomm(self.sub_comm, self.base_comm_nccl, + ncp.dot(x.local_array.flatten(), y.local_array.flatten()), + engine=self.engine) def _compute_vector_norm(self, local_array: NDArray, axis: int, ord: Optional[int] = None): @@ -606,7 +613,9 @@ def _compute_vector_norm(self, local_array: NDArray, raise ValueError(f"norm-{ord} not possible for vectors") elif ord == 0: # Count non-zero then sum reduction - recv_buf = self._allreduce_subcomm(ncp.count_nonzero(local_array, axis=axis).astype(ncp.float64)) + recv_buf = self._allreduce_subcomm(self.sub_comm, self.base_comm_nccl, + ncp.count_nonzero(local_array, axis=axis).astype(ncp.float64), + engine=self.engine) elif ord == ncp.inf: # Calculate max followed by max reduction # CuPy + non-CUDA-aware MPI does not work well with buffered communication, particularly @@ -615,10 +624,14 @@ def _compute_vector_norm(self, local_array: NDArray, if self.engine == "cupy" and self.base_comm_nccl is None and not deps.cuda_aware_mpi_enabled: # CuPy + non-CUDA-aware MPI: This will call non-buffered communication # which return a list of object - must be copied back to a GPU memory. - recv_buf = self._allreduce_subcomm(send_buf.get(), recv_buf.get(), op=MPI.MAX) + recv_buf = self._allreduce_subcomm(self.sub_comm, self.base_comm_nccl, + send_buf.get(), recv_buf.get(), + op=MPI.MAX, engine=self.engine) recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis)) else: - recv_buf = self._allreduce_subcomm(send_buf, recv_buf, op=MPI.MAX) + recv_buf = self._allreduce_subcomm(self.sub_comm, self.base_comm_nccl, + send_buf, recv_buf, op=MPI.MAX, + engine=self.engine) # TODO (tharitt): In current implementation, there seems to be a semantic difference between Buffered MPI and NCCL # the (1, size) is collapsed to (size, ) with buffered MPI while NCCL retains it. # There may be a way to unify it - may be something to do with how we allocate the recv_buf. @@ -629,14 +642,20 @@ def _compute_vector_norm(self, local_array: NDArray, # See the comment above in +infinity norm send_buf = ncp.min(ncp.abs(local_array), axis=axis).astype(ncp.float64) if self.engine == "cupy" and self.base_comm_nccl is None and not deps.cuda_aware_mpi_enabled: - recv_buf = self._allreduce_subcomm(send_buf.get(), recv_buf.get(), op=MPI.MIN) + recv_buf = self._allreduce_subcomm(self.sub_comm, self.base_comm_nccl, + send_buf.get(), recv_buf.get(), + op=MPI.MIN, engine=self.engine) recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis)) else: - recv_buf = self._allreduce_subcomm(send_buf, recv_buf, op=MPI.MIN) + recv_buf = self._allreduce_subcomm(self.sub_comm, self.base_comm_nccl, + send_buf, recv_buf, + op=MPI.MIN, engine=self.engine) if self.base_comm_nccl: recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis)) else: - recv_buf = self._allreduce_subcomm(ncp.sum(ncp.abs(ncp.float_power(local_array, ord)), axis=axis)) + recv_buf = self._allreduce_subcomm(self.sub_comm, self.base_comm_nccl, + ncp.sum(ncp.abs(ncp.float_power(local_array, ord)), axis=axis), + engine=self.engine) recv_buf = ncp.power(recv_buf, 1.0 / ord) return recv_buf From 693f0786dd10db78486136f5942eb7d958bb06e8 Mon Sep 17 00:00:00 2001 From: mrava87SW Date: Tue, 23 Sep 2025 20:56:48 +0000 Subject: [PATCH 12/17] feat: adapted all comm calls in VStack to new method signatures --- pylops_mpi/basicoperators/VStack.py | 21 ++++++--------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/pylops_mpi/basicoperators/VStack.py b/pylops_mpi/basicoperators/VStack.py index f6d5b198..174e9739 100644 --- a/pylops_mpi/basicoperators/VStack.py +++ b/pylops_mpi/basicoperators/VStack.py @@ -6,7 +6,6 @@ from pylops import LinearOperator from pylops.utils import DTypeLike from pylops.utils.backend import get_module -from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils from pylops_mpi import ( MPILinearOperator, @@ -17,13 +16,6 @@ ) from pylops_mpi.Distributed import DistributedMixIn from pylops_mpi.utils.decorators import reshaped -from pylops_mpi.utils import deps - -cupy_message = pylops_deps.cupy_import("the VStack module") -nccl_message = deps.nccl_import("the VStack module") - -if nccl_message is None and cupy_message is None: - from pylops_mpi.utils._nccl import nccl_allreduce class MPIVStack(DistributedMixIn, MPILinearOperator): @@ -142,19 +134,18 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: @reshaped(forward=False, stacking=True) def _rmatvec(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - # TODO: consider adding base_comm, base_comm_nccl, engine to the - # input parameters of _allreduce instead of relying on self - self.base_comm, self.base_comm_nccl, self.engine = \ - x.base_comm, x.base_comm_nccl, x.engine - y = DistributedArray(global_shape=self.shape[1], base_comm=x.base_comm, + y = DistributedArray(global_shape=self.shape[1], + base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, partition=Partition.BROADCAST, - engine=x.engine, dtype=self.dtype) + engine=x.engine, + dtype=self.dtype) y1 = [] for iop, oper in enumerate(self.ops): y1.append(oper.rmatvec(x.local_array[self.nnops[iop]: self.nnops[iop + 1]])) y1 = ncp.sum(ncp.vstack(y1), axis=0) - y[:] = self._allreduce(y1, op=MPI.SUM) + y[:] = self._allreduce(x.base_comm, x.base_comm_nccl, + y1, op=MPI.SUM, engine=x.engine) return y From c852fc41bb05e2d0d961940321a75fd1bdd626f9 Mon Sep 17 00:00:00 2001 From: mrava87SW Date: Tue, 23 Sep 2025 20:57:21 +0000 Subject: [PATCH 13/17] feat: adapted all comm calls in Fredholm1 to new method signatures --- pylops_mpi/signalprocessing/Fredholm1.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pylops_mpi/signalprocessing/Fredholm1.py b/pylops_mpi/signalprocessing/Fredholm1.py index 6ccd9d21..2969e3c9 100644 --- a/pylops_mpi/signalprocessing/Fredholm1.py +++ b/pylops_mpi/signalprocessing/Fredholm1.py @@ -128,7 +128,8 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: for isl in range(self.nsls[self.rank]): y1[isl] = ncp.dot(self.G[isl], x[isl]) # gather results - y[:] = ncp.vstack(y._allgather(y1)).ravel() + y[:] = ncp.vstack(y._allgather(y.base_comm, y.base_comm_nccl, y1, + engine=y.engine)).ravel() return y def _rmatvec(self, x: NDArray) -> NDArray: @@ -165,5 +166,6 @@ def _rmatvec(self, x: NDArray) -> NDArray: y1[isl] = ncp.dot(x[isl].T.conj(), self.G[isl]).T.conj() # gather results - y[:] = ncp.vstack(y._allgather(y1)).ravel() + y[:] = ncp.vstack(y._allgather(y.base_comm, y.base_comm_nccl, y1, + engine=y.engine)).ravel() return y From 78d753847aa878439eaee220763d73609a47616d Mon Sep 17 00:00:00 2001 From: mrava87SW Date: Tue, 23 Sep 2025 20:58:13 +0000 Subject: [PATCH 14/17] feat: moved methods shared by _mpi and _nccl to _common --- pylops_mpi/utils/_common.py | 92 ++++++++++++++++++++++++++++ pylops_mpi/utils/_mpi.py | 118 ++++++++---------------------------- pylops_mpi/utils/_nccl.py | 2 +- 3 files changed, 117 insertions(+), 95 deletions(-) create mode 100644 pylops_mpi/utils/_common.py diff --git a/pylops_mpi/utils/_common.py b/pylops_mpi/utils/_common.py new file mode 100644 index 00000000..ab149b5c --- /dev/null +++ b/pylops_mpi/utils/_common.py @@ -0,0 +1,92 @@ +__all__ = [ + "_prepare_allgather_inputs", + "_unroll_allgather_recv" +] + +from typing import Optional + +import numpy as np +from mpi4py import MPI +from pylops.utils.backend import get_module +from pylops_mpi.utils import deps + + +# TODO: return type annotation for both cupy and numpy +def _prepare_allgather_inputs(send_buf, send_buf_shapes, engine): + r""" Prepare send_buf and recv_buf for NCCL allgather (nccl_allgather) + + Buffered Allgather (MPI and NCCL) requires the sending buffer to have the same size for every device. + Therefore, padding is required when the array is not evenly partitioned across + all the ranks. The padding is applied such that the each dimension of the sending buffers + is equal to the max size of that dimension across all ranks. + + Similarly, each receiver buffer (recv_buf) is created with size equal to :math:n_rank \cdot send_buf.size + + Parameters + ---------- + send_buf : :obj: `numpy.ndarray` or `cupy.ndarray` or array-like + The data buffer from the local GPU to be sent for allgather. + send_buf_shapes: :obj:`list` + A list of shapes for each GPU send_buf (used to calculate padding size) + engine : :obj:`str` + Engine used to store array (``numpy`` or ``cupy``) + + Returns + ------- + send_buf: :obj:`cupy.ndarray` + A buffer containing the data and padded elements to be sent by this rank. + recv_buf : :obj:`cupy.ndarray` + An empty, padded buffer to gather data from all GPUs. + """ + ncp = get_module(engine) + sizes_each_dim = list(zip(*send_buf_shapes)) + send_shape = tuple(map(max, sizes_each_dim)) + pad_size = [ + (0, s_shape - l_shape) for s_shape, l_shape in zip(send_shape, send_buf.shape) + ] + + send_buf = ncp.pad( + send_buf, pad_size, mode="constant", constant_values=0 + ) + + ndev = len(send_buf_shapes) + recv_buf = ncp.zeros(ndev * send_buf.size, dtype=send_buf.dtype) + + return send_buf, recv_buf + + +def _unroll_allgather_recv(recv_buf, padded_send_buf_shape, send_buf_shapes) -> list: + r"""Unrolll recv_buf after Buffered Allgather (MPI and NCCL) + + Remove the padded elements in recv_buff, extract an individual array from each device and return them as a list of arrays + Each GPU may send array with a different shape, so the return type has to be a list of array + instead of the concatenated array. + + Parameters + ---------- + recv_buf: :obj:`cupy.ndarray` or array-like + The data buffer returned from nccl_allgather call + padded_send_buf_shape: :obj:`tuple`:int + The size of send_buf after padding used in nccl_allgather + send_buf_shapes: :obj:`list` + A list of original shapes for each GPU send_buf prior to padding + + Returns + ------- + chunks: :obj:`list` + A list of `cupy.ndarray` from each GPU with the padded element removed + """ + ndev = len(send_buf_shapes) + # extract an individual array from each device + chunk_size = np.prod(padded_send_buf_shape) + chunks = [ + recv_buf[i * chunk_size:(i + 1) * chunk_size] for i in range(ndev) + ] + + # Remove padding from each array: the padded value may appear somewhere + # in the middle of the flat array and thus the reshape and slicing for each dimension is required + for i in range(ndev): + slicing = tuple(slice(0, end) for end in send_buf_shapes[i]) + chunks[i] = chunks[i].reshape(padded_send_buf_shape)[slicing] + + return chunks diff --git a/pylops_mpi/utils/_mpi.py b/pylops_mpi/utils/_mpi.py index e3520c94..89304b8c 100644 --- a/pylops_mpi/utils/_mpi.py +++ b/pylops_mpi/utils/_mpi.py @@ -1,12 +1,10 @@ __all__ = [ "mpi_allgather", "mpi_allreduce", - # "mpi_bcast", + "mpi_bcast", # "mpi_asarray", "mpi_send", "mpi_recv", - "_prepare_allgather_inputs", - "_unroll_allgather_recv" ] from typing import Optional @@ -15,87 +13,26 @@ from mpi4py import MPI from pylops.utils.backend import get_module from pylops_mpi.utils import deps +from pylops_mpi.utils._common import _prepare_allgather_inputs, _unroll_allgather_recv -# TODO: return type annotation for both cupy and numpy -def _prepare_allgather_inputs(send_buf, send_buf_shapes, engine): - r""" Prepare send_buf and recv_buf for NCCL allgather (nccl_allgather) - - Buffered Allgather (MPI and NCCL) requires the sending buffer to have the same size for every device. - Therefore, padding is required when the array is not evenly partitioned across - all the ranks. The padding is applied such that the each dimension of the sending buffers - is equal to the max size of that dimension across all ranks. - - Similarly, each receiver buffer (recv_buf) is created with size equal to :math:n_rank \cdot send_buf.size - - Parameters - ---------- - send_buf : :obj: `numpy.ndarray` or `cupy.ndarray` or array-like - The data buffer from the local GPU to be sent for allgather. - send_buf_shapes: :obj:`list` - A list of shapes for each GPU send_buf (used to calculate padding size) - engine : :obj:`str` - Engine used to store array (``numpy`` or ``cupy``) - - Returns - ------- - send_buf: :obj:`cupy.ndarray` - A buffer containing the data and padded elements to be sent by this rank. - recv_buf : :obj:`cupy.ndarray` - An empty, padded buffer to gather data from all GPUs. - """ - ncp = get_module(engine) - sizes_each_dim = list(zip(*send_buf_shapes)) - send_shape = tuple(map(max, sizes_each_dim)) - pad_size = [ - (0, s_shape - l_shape) for s_shape, l_shape in zip(send_shape, send_buf.shape) - ] - - send_buf = ncp.pad( - send_buf, pad_size, mode="constant", constant_values=0 - ) - - ndev = len(send_buf_shapes) - recv_buf = ncp.zeros(ndev * send_buf.size, dtype=send_buf.dtype) - - return send_buf, recv_buf - - -def _unroll_allgather_recv(recv_buf, padded_send_buf_shape, send_buf_shapes) -> list: - r"""Unrolll recv_buf after Buffered Allgather (MPI and NCCL) - - Remove the padded elements in recv_buff, extract an individual array from each device and return them as a list of arrays - Each GPU may send array with a different shape, so the return type has to be a list of array - instead of the concatenated array. - - Parameters - ---------- - recv_buf: :obj:`cupy.ndarray` or array-like - The data buffer returned from nccl_allgather call - padded_send_buf_shape: :obj:`tuple`:int - The size of send_buf after padding used in nccl_allgather - send_buf_shapes: :obj:`list` - A list of original shapes for each GPU send_buf prior to padding - - Returns - ------- - chunks: :obj:`list` - A list of `cupy.ndarray` from each GPU with the padded element removed - """ - ndev = len(send_buf_shapes) - # extract an individual array from each device - chunk_size = np.prod(padded_send_buf_shape) - chunks = [ - recv_buf[i * chunk_size:(i + 1) * chunk_size] for i in range(ndev) - ] +def mpi_allgather(base_comm: MPI.Comm, + send_buf, recv_buf=None, + engine: Optional[str] = "numpy") -> np.ndarray: - # Remove padding from each array: the padded value may appear somewhere - # in the middle of the flat array and thus the reshape and slicing for each dimension is required - for i in range(ndev): - slicing = tuple(slice(0, end) for end in send_buf_shapes[i]) - chunks[i] = chunks[i].reshape(padded_send_buf_shape)[slicing] + if deps.cuda_aware_mpi_enabled or engine == "numpy": + send_shapes = base_comm.allgather(send_buf.shape) + (padded_send, padded_recv) = _prepare_allgather_inputs(send_buf, send_shapes, engine=engine) + recv_buffer_to_use = recv_buf if recv_buf else padded_recv + base_comm.Allgather(padded_send, recv_buffer_to_use) + return _unroll_allgather_recv(recv_buffer_to_use, padded_send.shape, send_shapes) - return chunks + else: + # CuPy with non-CUDA-aware MPI + if recv_buf is None: + return base_comm.allgather(send_buf) + base_comm.Allgather(send_buf, recv_buf) + return recv_buf def mpi_allreduce(base_comm: MPI.Comm, @@ -142,23 +79,16 @@ def mpi_allreduce(base_comm: MPI.Comm, return recv_buf -def mpi_allgather(base_comm: MPI.Comm, - send_buf, recv_buf=None, - engine: Optional[str] = "numpy") -> np.ndarray: - +def mpi_bcast(base_comm: MPI.Comm, + rank, local_array, index, value, + engine: Optional[str] = "numpy") -> np.ndarray: if deps.cuda_aware_mpi_enabled or engine == "numpy": - send_shapes = base_comm.allgather(send_buf.shape) - (padded_send, padded_recv) = _prepare_allgather_inputs(send_buf, send_shapes, engine=engine) - recv_buffer_to_use = recv_buf if recv_buf else padded_recv - base_comm.Allgather(padded_send, recv_buffer_to_use) - return _unroll_allgather_recv(recv_buffer_to_use, padded_send.shape, send_shapes) - + if rank == 0: + local_array[index] = value + base_comm.Bcast(local_array[index]) else: # CuPy with non-CUDA-aware MPI - if recv_buf is None: - return base_comm.allgather(send_buf) - base_comm.Allgather(send_buf, recv_buf) - return recv_buf + local_array[index] = base_comm.bcast(value) def mpi_send(base_comm: MPI.Comm, diff --git a/pylops_mpi/utils/_nccl.py b/pylops_mpi/utils/_nccl.py index 5f297531..cac5b61c 100644 --- a/pylops_mpi/utils/_nccl.py +++ b/pylops_mpi/utils/_nccl.py @@ -15,7 +15,7 @@ import os import cupy as cp import cupy.cuda.nccl as nccl -from pylops_mpi.utils._mpi import _prepare_allgather_inputs, _unroll_allgather_recv +from pylops_mpi.utils._common import _prepare_allgather_inputs, _unroll_allgather_recv cupy_to_nccl_dtype = { "float32": nccl.NCCL_FLOAT32, From 0138e3aaa9e3374e1701d86d2d6fa2e408823bda Mon Sep 17 00:00:00 2001 From: tharittk Date: Thu, 9 Oct 2025 02:37:09 -0500 Subject: [PATCH 15/17] fix env flag precedence bug --- pylops_mpi/utils/deps.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pylops_mpi/utils/deps.py b/pylops_mpi/utils/deps.py index c9dc4aa3..f0279ceb 100644 --- a/pylops_mpi/utils/deps.py +++ b/pylops_mpi/utils/deps.py @@ -40,7 +40,7 @@ def nccl_import(message: Optional[str] = None) -> str: cuda_aware_mpi_enabled: bool = ( - True if int(os.getenv("PYLOPS_MPI_CUDA_AWARE", 1) == 1) else False + True if int(os.getenv("PYLOPS_MPI_CUDA_AWARE", 1)) == 1 else False ) nccl_enabled: bool = ( From ec883711a7e0369c496ff5c8dcbc8035da5d9bf9 Mon Sep 17 00:00:00 2001 From: tharittk Date: Thu, 9 Oct 2025 02:46:02 -0500 Subject: [PATCH 16/17] fix flake8 --- pylops_mpi/Distributed.py | 10 +++++----- pylops_mpi/DistributedArray.py | 16 ++++++++-------- pylops_mpi/basicoperators/VStack.py | 8 ++++---- pylops_mpi/utils/_common.py | 3 --- 4 files changed, 17 insertions(+), 20 deletions(-) diff --git a/pylops_mpi/Distributed.py b/pylops_mpi/Distributed.py index 7e940b84..7e616a3a 100644 --- a/pylops_mpi/Distributed.py +++ b/pylops_mpi/Distributed.py @@ -22,8 +22,8 @@ class DistributedMixIn: MPI installation is available, the latter with CuPy arrays when a CUDA-Aware MPI installation is not available). """ - def _allreduce(self, base_comm, base_comm_nccl, - send_buf, recv_buf=None, op: MPI.Op = MPI.SUM, + def _allreduce(self, base_comm, base_comm_nccl, + send_buf, recv_buf=None, op: MPI.Op = MPI.SUM, engine="numpy"): """Allreduce operation """ @@ -33,7 +33,7 @@ def _allreduce(self, base_comm, base_comm_nccl, return mpi_allreduce(base_comm, send_buf, recv_buf, engine, op) - def _allreduce_subcomm(self, sub_comm, base_comm_nccl, + def _allreduce_subcomm(self, sub_comm, base_comm_nccl, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM, engine="numpy"): """Allreduce operation with subcommunicator @@ -44,7 +44,7 @@ def _allreduce_subcomm(self, sub_comm, base_comm_nccl, return mpi_allreduce(sub_comm, send_buf, recv_buf, engine, op) - def _allgather(self, base_comm, base_comm_nccl, + def _allgather(self, base_comm, base_comm_nccl, send_buf, recv_buf=None, engine="numpy"): """Allgather operation @@ -85,7 +85,7 @@ def _bcast(self, local_array, index, value): # self.local_array[index] = self.base_comm.bcast(value) mpi_bcast(self.base_comm, self.rank, self.local_array, index, value, engine=self.engine) - + def _send(self, send_buf, dest, count=None, tag=0): """Send operation """ diff --git a/pylops_mpi/DistributedArray.py b/pylops_mpi/DistributedArray.py index 75b66c7e..da7712d7 100644 --- a/pylops_mpi/DistributedArray.py +++ b/pylops_mpi/DistributedArray.py @@ -341,7 +341,7 @@ def local_shapes(self): return self._nccl_local_shapes(False) else: return self._allgather(self.base_comm, - self.base_comm_nccl, + self.base_comm_nccl, self.local_shape) @property @@ -383,7 +383,7 @@ def asarray(self, masked: bool = False): final_array = self._allgather_subcomm(self.local_array) else: final_array = self._allgather(self.base_comm, - self.base_comm_nccl, + self.base_comm_nccl, self.local_array, engine=self.engine) return np.concatenate(final_array, axis=self.axis) @@ -484,7 +484,7 @@ def _nccl_local_shapes(self, masked: bool): all_tuples = self._allgather_subcomm(self.local_shape).get() else: all_tuples = self._allgather(self.base_comm, - self.base_comm_nccl, + self.base_comm_nccl, self.local_shape).get() # NCCL returns the flat array that packs every tuple as 1-dimensional array # unpack each tuple from each rank @@ -625,12 +625,12 @@ def _compute_vector_norm(self, local_array: NDArray, # CuPy + non-CUDA-aware MPI: This will call non-buffered communication # which return a list of object - must be copied back to a GPU memory. recv_buf = self._allreduce_subcomm(self.sub_comm, self.base_comm_nccl, - send_buf.get(), recv_buf.get(), + send_buf.get(), recv_buf.get(), op=MPI.MAX, engine=self.engine) recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis)) else: recv_buf = self._allreduce_subcomm(self.sub_comm, self.base_comm_nccl, - send_buf, recv_buf, op=MPI.MAX, + send_buf, recv_buf, op=MPI.MAX, engine=self.engine) # TODO (tharitt): In current implementation, there seems to be a semantic difference between Buffered MPI and NCCL # the (1, size) is collapsed to (size, ) with buffered MPI while NCCL retains it. @@ -643,18 +643,18 @@ def _compute_vector_norm(self, local_array: NDArray, send_buf = ncp.min(ncp.abs(local_array), axis=axis).astype(ncp.float64) if self.engine == "cupy" and self.base_comm_nccl is None and not deps.cuda_aware_mpi_enabled: recv_buf = self._allreduce_subcomm(self.sub_comm, self.base_comm_nccl, - send_buf.get(), recv_buf.get(), + send_buf.get(), recv_buf.get(), op=MPI.MIN, engine=self.engine) recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis)) else: recv_buf = self._allreduce_subcomm(self.sub_comm, self.base_comm_nccl, - send_buf, recv_buf, + send_buf, recv_buf, op=MPI.MIN, engine=self.engine) if self.base_comm_nccl: recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis)) else: recv_buf = self._allreduce_subcomm(self.sub_comm, self.base_comm_nccl, - ncp.sum(ncp.abs(ncp.float_power(local_array, ord)), axis=axis), + ncp.sum(ncp.abs(ncp.float_power(local_array, ord)), axis=axis), engine=self.engine) recv_buf = ncp.power(recv_buf, 1.0 / ord) return recv_buf diff --git a/pylops_mpi/basicoperators/VStack.py b/pylops_mpi/basicoperators/VStack.py index 174e9739..de66c342 100644 --- a/pylops_mpi/basicoperators/VStack.py +++ b/pylops_mpi/basicoperators/VStack.py @@ -135,8 +135,8 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: def _rmatvec(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) y = DistributedArray(global_shape=self.shape[1], - base_comm=x.base_comm, - base_comm_nccl=x.base_comm_nccl, + base_comm=x.base_comm, + base_comm_nccl=x.base_comm_nccl, partition=Partition.BROADCAST, engine=x.engine, dtype=self.dtype) @@ -144,8 +144,8 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray: for iop, oper in enumerate(self.ops): y1.append(oper.rmatvec(x.local_array[self.nnops[iop]: self.nnops[iop + 1]])) y1 = ncp.sum(ncp.vstack(y1), axis=0) - y[:] = self._allreduce(x.base_comm, x.base_comm_nccl, - y1, op=MPI.SUM, engine=x.engine) + y[:] = self._allreduce(x.base_comm, x.base_comm_nccl, + y1, op=MPI.SUM, engine=x.engine) return y diff --git a/pylops_mpi/utils/_common.py b/pylops_mpi/utils/_common.py index ab149b5c..895265df 100644 --- a/pylops_mpi/utils/_common.py +++ b/pylops_mpi/utils/_common.py @@ -3,12 +3,9 @@ "_unroll_allgather_recv" ] -from typing import Optional import numpy as np -from mpi4py import MPI from pylops.utils.backend import get_module -from pylops_mpi.utils import deps # TODO: return type annotation for both cupy and numpy From 02ba45bd872609490431c689dbf5ce16fe375abb Mon Sep 17 00:00:00 2001 From: mrava87SW Date: Wed, 15 Oct 2025 21:50:19 +0000 Subject: [PATCH 17/17] doc: added details about cuda-aware mpi in doc --- docs/source/gpu.rst | 10 +++++++- docs/source/installation.rst | 44 +++++++++++++++++++++++++++++++----- 2 files changed, 47 insertions(+), 7 deletions(-) diff --git a/docs/source/gpu.rst b/docs/source/gpu.rst index 9a1af651..52839069 100644 --- a/docs/source/gpu.rst +++ b/docs/source/gpu.rst @@ -11,7 +11,7 @@ This library must be installed *before* PyLops-mpi is installed. .. note:: - Set environment variable ``CUPY_PYLOPS=0`` to force PyLops to ignore the ``cupy`` backend. + Set the environment variable ``CUPY_PYLOPS=0`` to force PyLops to ignore the ``cupy`` backend. This can be also used if a previous (or faulty) version of ``cupy`` is installed in your system, otherwise you will get an error when importing PyLops. @@ -22,6 +22,14 @@ can handle both scenarios. Note that, since most operators in PyLops-mpi are thi some of the operators in PyLops that lack a GPU implementation cannot be used also in PyLops-mpi when working with cupy arrays. +.. note:: + + By default when using ``cupy`` arrays, PyLops-MPI will try to use methods in MPI4Py that communicate memory buffers. + However, this requires a CUDA-Aware MPI installation. If your MPI installation is not CUDA-Aware, set the + environment variable ``PYLOPS_MPI_CUDA_AWARE=0`` to force PyLops-MPI to use methods in MPI4Py that communicate + general Python objects (this will incur a loss of performance!). + + Moreover, PyLops-MPI also supports the Nvidia's Collective Communication Library (NCCL) for highly-optimized collective operations, such as AllReduce, AllGather, etc. This allows PyLops-MPI users to leverage the proprietary technology like NVLink that might be available in their infrastructure for fast data communication. diff --git a/docs/source/installation.rst b/docs/source/installation.rst index d0aafe88..e1d7faf3 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -15,7 +15,13 @@ The minimal set of dependencies for the PyLops-MPI project is: * `MPI4py `_ * `PyLops `_ -Additionally, to use the NCCL engine, the following additional +Additionally, to use the CUDA-aware MPI engine, the following additional +dependencies are required: + +* `CuPy `_ +* CUDA-aware MPI + +Similarly, to use the NCCL engine, the following additional dependencies are required: * `CuPy `_ @@ -27,12 +33,18 @@ if this is not possible, some of the dependencies must be installed prior to ins Download and Install MPI ======================== -Visit the official MPI website to download an appropriate MPI implementation for your system. -Follow the installation instructions provided by the MPI vendor. +Visit the official website of your MPI vendor of choice to download an appropriate MPI +implementation for your system: + +* `Open MPI `_ +* `MPICH `_ +* `Intel MPI `_ +* ... -* `Open MPI `_ -* `MPICH `_ -* `Intel MPI `_ +Alternatively, the conda-forge community provides ready-to-use binary packages for four MPI implementations +(see `MPI4Py documentation `_ for more +details). In this case, you can defer the installation to the stage when the conda environment for your project +is created - see below for more details. Verify MPI Installation ======================= @@ -42,6 +54,17 @@ After installing MPI, verify its installation by opening a terminal and running >> mpiexec --version +Install CUDA-Aware MPI (optional) +================================= +To be able to achieve the best performance when using PyLops-MPI with CuPy arrays, a CUDA-Aware version of +MPI must be installed. + +For `Open MPI`, the conda-forge package has built-in CUDA support, as long as a pre-installed CUDA is detected. +Run the following `commands `_ +for diagnostics. + +For the other MPI implementations, refer to their specific documentation. + Install NCCL (optional) ======================= To obtain highly-optimized performance on GPU clusters, PyLops-MPI also supports the Nvidia's collective communication calls @@ -103,6 +126,15 @@ For a ``conda`` environment, run This will create and activate an environment called ``pylops_mpi``, with all required and optional dependencies. +If you want to also install MPI as part of the creation process of the conda environment, +modify the ``environment-dev.yml`` file by adding ``openmpi``\``mpich`\``impi_rt``\``msmpi`` +just above ``mpi4py``. Note that only ``openmpi`` provides a CUDA-Aware MPI installation. + +If you want to leverage CUDA-Aware MPI but prefer to use another MPI installation, you must +either switch to a `Pip`-based installation (see below), or move ``mpi4py`` into the ``pip`` +section of the ``environment-dev.yml`` file and export the variable ``MPICC`` pointing to +the path of your CUDA-Aware MPI installation. + If you want to enable `NCCL `_ in PyLops-MPI, run this instead .. code-block:: bash