|
18 | 18 | import torch |
19 | 19 |
|
20 | 20 | from et_replay.comm import comms_utils |
| 21 | +from et_replay.comm import commsTraceParser |
21 | 22 | from et_replay.comm.backend.base_backend import supportedP2pOps |
22 | 23 | from et_replay.comm.comms_utils import ( |
23 | 24 | bootstrap_info_holder, |
@@ -1547,69 +1548,17 @@ def readTrace(self, remotePath: str, rank: int) -> None: |
1547 | 1548 | self.readRawTrace(remotePath=remotePath, rank=rank) |
1548 | 1549 |
|
1549 | 1550 | # Convert trace to comms trace. |
1550 | | - try: |
1551 | | - from et_replay.comm import commsTraceParser |
1552 | | - except ImportError: |
1553 | | - logger.info("FB internals not present, using base parser.") |
1554 | | - self.comms_trace = extractCommsInfo(self.comms_trace) |
1555 | | - else: |
1556 | | - self.comms_trace = commsTraceParser.parseTrace( |
1557 | | - self.comms_trace, |
1558 | | - self.trace_type, |
1559 | | - ( |
1560 | | - self.trace_file |
1561 | | - if not os.path.isdir(self.trace_file) |
1562 | | - else f"{self.trace_file}/rank-{rank}.json" |
1563 | | - ), |
1564 | | - rank, |
1565 | | - self.backendFuncs.get_world_size(), |
1566 | | - ) |
1567 | | - |
1568 | | - |
1569 | | -def extractCommsInfo(in_trace: List[Dict]) -> List[commsArgs]: |
1570 | | - """ |
1571 | | - Convert Basic Trace to comms trace format. |
1572 | | - """ |
1573 | | - # print("in extract comms info") |
1574 | | - # exit(1) |
1575 | | - newCommsTrace = [] |
1576 | | - for cnt, curComm in enumerate(in_trace): |
1577 | | - newComm = commsArgs() |
1578 | | - newComm.comms = paramToCommName(curComm["comms"].lower()) |
1579 | | - logger.info(f"in extract comms info of {newComm.comms}: {curComm}") |
1580 | | - newComm.id = cnt |
1581 | | - if "req" in curComm: |
1582 | | - newComm.req = curComm["req"] |
1583 | | - if "startTime_ns" in curComm: |
1584 | | - newComm.startTimeNs = curComm["startTime_ns"] |
1585 | | - if "markers" in curComm: |
1586 | | - newComm.markerStack = curComm["markers"] |
1587 | | - if "world_size" in curComm: |
1588 | | - newComm.worldSize = curComm["world_size"] |
1589 | | - if "root" in curComm: |
1590 | | - newComm.root = curComm["root"] |
1591 | | - if "pg_id" in curComm: |
1592 | | - newComm.pgId = curComm["pg_id"] |
1593 | | - if "global_ranks" in curComm: |
1594 | | - newComm.groupRanks = curComm["global_ranks"] |
1595 | | - |
1596 | | - if newComm.comms not in ("wait", "barrier", "init"): |
1597 | | - newComm.inMsgSize = curComm["in_msg_size"] |
1598 | | - newComm.outMsgSize = curComm["out_msg_size"] |
1599 | | - newComm.dtype = curComm["dtype"] |
1600 | | - |
1601 | | - if newComm.comms in ("all_to_allv"): |
1602 | | - newComm.inSplit = curComm["in_split"] |
1603 | | - newComm.outSplit = curComm["out_split"] |
1604 | | - |
1605 | | - if newComm.comms in supportedP2pOps: |
1606 | | - newComm.src_rank = curComm["src_rank"] |
1607 | | - newComm.dst_rank = curComm["dst_rank"] |
1608 | | - newComm.batch_p2p = curComm["use_batch"] |
1609 | | - |
1610 | | - newCommsTrace.append(newComm) |
1611 | | - |
1612 | | - return newCommsTrace |
| 1551 | + self.comms_trace = commsTraceParser.parseTrace( |
| 1552 | + self.comms_trace, |
| 1553 | + self.trace_type, |
| 1554 | + ( |
| 1555 | + self.trace_file |
| 1556 | + if not os.path.isdir(self.trace_file) |
| 1557 | + else f"{self.trace_file}/rank-{rank}.json" |
| 1558 | + ), |
| 1559 | + rank, |
| 1560 | + self.backendFuncs.get_world_size(), |
| 1561 | + ) |
1613 | 1562 |
|
1614 | 1563 |
|
1615 | 1564 | def main() -> None: |
|
0 commit comments