Skip to content

Commit 29b5a46

Browse files
committed
fix wait comm op for both collective and p2p
1 parent ccd00dc commit 29b5a46

File tree

4 files changed

+22
-32
lines changed

4 files changed

+22
-32
lines changed

et_replay/comm/backend/base_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def __init__(self) -> None:
6161
self.global_rank = -1
6262
self.backendFuncs = {}
6363
self.collective = ""
64-
self.collectiveId = 0
64+
self.wait_obj_key = (0, 0, False) # (pg_id, req_id, is_p2p)
6565
self.pt2pt = ""
6666
self.src_rank = -1
6767
self.dst_rank = -1
@@ -102,7 +102,7 @@ def __init__(self) -> None:
102102
self.dataSize = 0
103103
self.numElements = 0
104104
self.waitObj = []
105-
self.waitObjIds = {} # mapping of reqID to future of async collectives
105+
self.waitObjIds = {} # mapping of (pg_id, req_id, is_p2p) to future of async collectives
106106

107107
self.ipTensor_split_pair = []
108108
self.opTensor_split_pair = []

et_replay/comm/backend/pytorch_dist_backend.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -624,28 +624,14 @@ def complete_accel_ops(self, collectiveArgs, devSync=True):
624624
if devSync:
625625
self.device_sync(collectiveArgs)
626626

627-
# retFlag not used
628-
def complete_single_op(self, collectiveArgs, retFlag=False):
629-
"""only wait on the first op in the queue"""
630-
if len(collectiveArgs.waitObj) > 0:
631-
waitReq = collectiveArgs.waitObj.pop(0)
632-
if waitReq is not None:
633-
waitReq.wait()
634-
635-
# to ensure GPU collective is completed
636-
self.device_sync(collectiveArgs)
637-
638627
def wait(self, collectiveArgs, retFlag=False):
639-
# for backwards compatibility, use old wait functionality.
640-
if len(collectiveArgs.waitObjIds) == 0:
641-
self.complete_single_op(collectiveArgs)
642-
return
643-
644-
"""wait on op with the matching reqID"""
645-
if collectiveArgs.collectiveId in collectiveArgs.waitObjIds:
646-
waitObj = collectiveArgs.waitObjIds[collectiveArgs.collectiveId]
647-
if waitObj is not None:
648-
waitObj.wait()
628+
# wait on op with the matching (pg_id, req_id, is_p2p)
629+
if collectiveArgs.wait_obj_key in collectiveArgs.waitObjIds:
630+
work = collectiveArgs.waitObjIds.pop(collectiveArgs.wait_obj_key)
631+
for i,w in enumerate(collectiveArgs.waitObj):
632+
if w is work:
633+
collectiveArgs.waitObj.pop(i)
634+
work.wait()
649635

650636
def barrier(self, collectiveArgs, name="dummy", retFlag=False):
651637
my_dev = self.get_device()

et_replay/comm/commsTraceParser.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,12 @@ def _parse_comms_op_node( # noqa: C901
122122
comm_nodes = (
123123
node for node in in_trace.nodes.values() if node.name == "record_param_comms"
124124
)
125+
is_seq = lambda x: isinstance(x, list) and len(x) == 2 and isinstance(x[0], int) and isinstance(x[1], bool)
125126
for node in comm_nodes:
126127
# according to macro RECORD_PARAM_COMMS and RECORD_PARAM_COMMS_DATA in torch/csrc/distributed/c10d/ParamCommsUtils.hpp
127-
# ["wait", "barrier", "init"] record 1st element as seq, others record starting from input tensor
128-
index_base = 0 if isinstance(node.inputs[0], int) else 1
128+
# ["wait", "barrier", "init"] record 1st element as seq, whose 1st element is sequence number as int, 2nd element is isP2P as bool
129+
# others record starting from input tensor
130+
index_base = 0 if is_seq(node.inputs[0]) else 1
129131
req_id = node.inputs[index_base]
130132
recorded_rank = node.inputs[index_base + 2]
131133

et_replay/tools/comm_replay.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def __init__(self):
114114
self.shrink = False
115115
self.max_msg_cnt = 0 # 0 means no limit
116116
self.num_msg = 0
117-
self.is_blocking = True
117+
self.is_blocking = False
118118
self.do_warm_up = False
119119
self.reuse_tensors = False
120120

@@ -802,9 +802,11 @@ def runComms(
802802
description=f"# PARAM replay {self.replayIter}:" + curBlockStack,
803803
):
804804
if collName in self.backendFuncs.collectiveFunc.keys():
805-
# record collectiveID for wait ops
806-
if curComm.req is not None:
807-
self.collectiveArgs.collectiveId = curComm.req
805+
# record wait_obj_key for wait ops
806+
if curComm.req is not None and curComm.pgId is not None:
807+
self.collectiveArgs.wait_obj_key = (curComm.pgId, curComm.req[0], curComm.req[1])
808+
else:
809+
self.collectiveArgs.wait_obj_key = None
808810

809811
# handle point-to-point separately
810812
if collName in supportedP2pOps:
@@ -832,10 +834,10 @@ def runComms(
832834
if self.is_blocking:
833835
self.backendFuncs.complete_accel_ops(self.collectiveArgs)
834836

835-
# if nonblocking, then store the pair {reqID, future} so that we can wait on it later
837+
# if nonblocking, then store the pair {(pg_id, reqID, isP2P), future} so that we can wait on it later
836838
# check if req id is recorded in trace for backwards compatibility
837-
if curComm.req is not None and not self.is_blocking and collName != "wait":
838-
self.collectiveArgs.waitObjIds[curComm.req] = retObj
839+
if not self.is_blocking and collName != "wait" and self.collectiveArgs.wait_obj_key is not None:
840+
self.collectiveArgs.waitObjIds[self.collectiveArgs.wait_obj_key] = retObj
839841

840842
# For non-blocking, latency and global_latency are the same
841843
global_latency = latency = collTimer.getTimeUS()

0 commit comments

Comments
 (0)