Skip to content

Commit ccd00dc

Browse files
committed
remove function extractCommsInfo() for deprecatd basic trace
1 parent 57a97dc commit ccd00dc

File tree

1 file changed

+12
-63
lines changed

1 file changed

+12
-63
lines changed

et_replay/tools/comm_replay.py

Lines changed: 12 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch
1919

2020
from et_replay.comm import comms_utils
21+
from et_replay.comm import commsTraceParser
2122
from et_replay.comm.backend.base_backend import supportedP2pOps
2223
from et_replay.comm.comms_utils import (
2324
bootstrap_info_holder,
@@ -1547,69 +1548,17 @@ def readTrace(self, remotePath: str, rank: int) -> None:
15471548
self.readRawTrace(remotePath=remotePath, rank=rank)
15481549

15491550
# Convert trace to comms trace.
1550-
try:
1551-
from et_replay.comm import commsTraceParser
1552-
except ImportError:
1553-
logger.info("FB internals not present, using base parser.")
1554-
self.comms_trace = extractCommsInfo(self.comms_trace)
1555-
else:
1556-
self.comms_trace = commsTraceParser.parseTrace(
1557-
self.comms_trace,
1558-
self.trace_type,
1559-
(
1560-
self.trace_file
1561-
if not os.path.isdir(self.trace_file)
1562-
else f"{self.trace_file}/rank-{rank}.json"
1563-
),
1564-
rank,
1565-
self.backendFuncs.get_world_size(),
1566-
)
1567-
1568-
1569-
def extractCommsInfo(in_trace: List[Dict]) -> List[commsArgs]:
1570-
"""
1571-
Convert Basic Trace to comms trace format.
1572-
"""
1573-
# print("in extract comms info")
1574-
# exit(1)
1575-
newCommsTrace = []
1576-
for cnt, curComm in enumerate(in_trace):
1577-
newComm = commsArgs()
1578-
newComm.comms = paramToCommName(curComm["comms"].lower())
1579-
logger.info(f"in extract comms info of {newComm.comms}: {curComm}")
1580-
newComm.id = cnt
1581-
if "req" in curComm:
1582-
newComm.req = curComm["req"]
1583-
if "startTime_ns" in curComm:
1584-
newComm.startTimeNs = curComm["startTime_ns"]
1585-
if "markers" in curComm:
1586-
newComm.markerStack = curComm["markers"]
1587-
if "world_size" in curComm:
1588-
newComm.worldSize = curComm["world_size"]
1589-
if "root" in curComm:
1590-
newComm.root = curComm["root"]
1591-
if "pg_id" in curComm:
1592-
newComm.pgId = curComm["pg_id"]
1593-
if "global_ranks" in curComm:
1594-
newComm.groupRanks = curComm["global_ranks"]
1595-
1596-
if newComm.comms not in ("wait", "barrier", "init"):
1597-
newComm.inMsgSize = curComm["in_msg_size"]
1598-
newComm.outMsgSize = curComm["out_msg_size"]
1599-
newComm.dtype = curComm["dtype"]
1600-
1601-
if newComm.comms in ("all_to_allv"):
1602-
newComm.inSplit = curComm["in_split"]
1603-
newComm.outSplit = curComm["out_split"]
1604-
1605-
if newComm.comms in supportedP2pOps:
1606-
newComm.src_rank = curComm["src_rank"]
1607-
newComm.dst_rank = curComm["dst_rank"]
1608-
newComm.batch_p2p = curComm["use_batch"]
1609-
1610-
newCommsTrace.append(newComm)
1611-
1612-
return newCommsTrace
1551+
self.comms_trace = commsTraceParser.parseTrace(
1552+
self.comms_trace,
1553+
self.trace_type,
1554+
(
1555+
self.trace_file
1556+
if not os.path.isdir(self.trace_file)
1557+
else f"{self.trace_file}/rank-{rank}.json"
1558+
),
1559+
rank,
1560+
self.backendFuncs.get_world_size(),
1561+
)
16131562

16141563

16151564
def main() -> None:

0 commit comments

Comments
 (0)