-
Notifications
You must be signed in to change notification settings - Fork 4.5k
/
Copy pathp2p.py
805 lines (688 loc) · 29.9 KB
/
p2p.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import io
import pickle
import re
from collections import namedtuple
from typing import Any, Callable, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
from packaging.version import Version
from torch.distributed import ProcessGroup
from torch.distributed import distributed_c10d as c10d
from torch.utils._pytree import tree_flatten, tree_unflatten
from .stage_manager import PipelineStageManager
def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -> Any:
"""transform tensor to object with unpickle.
Info of the device in bytes stream will be modified into current device before unpickling
Args:
tensor (:class:`torch.tensor`): tensor to be unpickled
tensor_size (:class:`torch.Size`): Size of the real info in bytes
Returns:
Any: object after unpickled
"""
buf = tensor.numpy().tobytes()[:tensor_size]
if b"cuda" in buf:
buf_array = bytearray(buf)
device_index = torch.cuda.current_device()
# There might be more than one output tensors during forward
for cuda_str in re.finditer(b"cuda", buf_array):
pos = cuda_str.start()
buf_array[pos + 5] = 48 + device_index
buf = bytes(buf_array)
io_bytes = io.BytesIO(buf)
byte_pickler = pickle.Unpickler(io_bytes)
unpickle = byte_pickler.load()
return unpickle
def check_for_nccl_backend(group):
pg = group or c10d._get_default_group()
# Gate PG wrapper check on Gloo availability.
if c10d._GLOO_AVAILABLE:
# It is not expected for PG to be wrapped many times, but support it just
# in case
while isinstance(pg, c10d._ProcessGroupWrapper):
pg = pg.wrapped_pg
return c10d.is_nccl_available() and pg.name() == c10d.Backend.NCCL
# NOTE: FIXME: NPU DOES NOT support isend nor irecv, so broadcast is kept for future use
def _broadcast_object_list(
object_list: List[Any], src: int, group: ProcessGroup, device: Optional[Union[torch.device, str, int]] = None
):
"""This is a modified version of the broadcast_object_list in torch.distribution
The only difference is that object will be move to correct device after unpickled.
If local_rank = src, then object list will be sent to rank src. Otherwise, object list will
be updated with data sent from rank src.
Args:
object_list (List[Any]): list of object to broadcast
src (int): source rank to broadcast
dst (int): dst rank to broadcast
device (:class:`torch.device`): device to do broadcast. current device in default
"""
if c10d._rank_not_in_group(group):
c10d._warn_not_in_group("broadcast_object_list")
return
is_nccl_backend = _check_for_nccl_backend(group)
current_device = None
if device is not None:
if is_nccl_backend and device.type != "cuda":
raise ValueError("device type must be cuda for nccl backend")
current_device = device
else:
current_device = torch.device("cpu")
if is_nccl_backend:
current_device = torch.device("cuda", torch.cuda.current_device())
my_rank = dist.get_rank()
# Serialize object_list elements to tensors on src rank.
if my_rank == src:
if Version(torch.__version__) >= Version("2.3.0"):
tensor_list, size_list = zip(
*[c10d._object_to_tensor(obj, device=current_device, group=group) for obj in object_list]
)
elif Version(torch.__version__) >= Version("1.13.0"):
tensor_list, size_list = zip(*[c10d._object_to_tensor(obj, device=current_device) for obj in object_list])
else:
tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list])
object_sizes_tensor = torch.cat(size_list)
else:
object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long)
if is_nccl_backend:
object_sizes_tensor = object_sizes_tensor.to(current_device)
# Broadcast object sizes
c10d.broadcast(object_sizes_tensor, src=src, group=group, async_op=False)
# Concatenate and broadcast serialized object tensors
if my_rank == src:
object_tensor = torch.cat(tensor_list)
else:
object_tensor = torch.empty( # type: ignore[call-overload]
torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type]
dtype=torch.uint8,
)
if is_nccl_backend:
object_tensor = object_tensor.to(current_device)
c10d.broadcast(object_tensor, src=src, group=group, async_op=False)
# Deserialize objects using their stored sizes.
offset = 0
if my_rank != src:
for i, obj_size in enumerate(object_sizes_tensor):
obj_view = object_tensor[offset : offset + obj_size]
obj_view = obj_view.type(torch.uint8)
if obj_view.device != torch.device("cpu"):
obj_view = obj_view.cpu()
offset += obj_size
# unpickle
unpickle_object = _cuda_safe_tensor_to_object(obj_view, obj_size)
# unconsistence in device
if (
isinstance(unpickle_object, torch.Tensor)
and unpickle_object.device.index != torch.cuda.current_device()
):
unpickle_object = unpickle_object.cuda()
object_list[i] = unpickle_object
def _check_for_nccl_backend(group):
pg = group or c10d._get_default_group()
# Gate PG wrapper check on Gloo availability.
if c10d._GLOO_AVAILABLE:
# It is not expected for PG to be wrapped many times, but support it just in case
while isinstance(pg, c10d._ProcessGroupWrapper):
pg = pg.wrapped_pg
return c10d.is_nccl_available() and pg.name() == c10d.Backend.NCCL
def _check_for_hccl_backend(group):
pg = group or c10d._get_default_group()
# Gate PG wrapper check on Gloo availability.
if c10d._GLOO_AVAILABLE:
# It is not expected for PG to be wrapped many times, but support it just in case
while isinstance(pg, c10d._ProcessGroupWrapper):
pg = pg.wrapped_pg
return torch.distributed.is_hccl_available() and pg.name() == c10d.Backend.HCCL
def _check_device(group):
is_nccl_backend = _check_for_nccl_backend(group)
is_hccl_backend = _check_for_hccl_backend(group)
current_device = torch.device("cpu")
if is_nccl_backend:
current_device = torch.device("cuda", torch.cuda.current_device())
elif is_hccl_backend:
current_device = torch.device("npu", torch.cuda.current_device())
return current_device, is_nccl_backend or is_hccl_backend
TensorMetadata = namedtuple("TensorMetadata", ["shape", "dtype", "requires_grad"])
P2PMetadata = namedtuple("P2PMetadata", ["tree_spec", "tensor_metadata", "non_tensor_obj_idx", "non_tensor_objs"])
def create_send_metadata(
object: Any, strict: bool = True, return_tensor: bool = False
) -> Union[P2PMetadata, Tuple[P2PMetadata, List[torch.Tensor]]]:
"""
Args:
object (Any): object needed to be sent
strict (bool, optional): whether to check if the object is supported for fast send
return_tensor (bool, optional): whether to return tensor objects
"""
objs, tree_spec = tree_flatten(object)
tensor_metadata, tensor_objs = [], []
non_tensor_obj_idx, non_tensor_objs = [], []
for idx, obj in enumerate(objs):
if isinstance(obj, torch.Tensor):
tensor_objs.append(obj)
tensor_metadata.append(TensorMetadata(obj.shape, obj.dtype, obj.requires_grad))
else:
non_tensor_obj_idx.append(idx)
non_tensor_objs.append(obj)
assert not strict or len(non_tensor_objs) == 0, "Only support tensor for fast send"
metadata = P2PMetadata(tree_spec, tensor_metadata, non_tensor_obj_idx, non_tensor_objs)
return metadata if not return_tensor else (metadata, tensor_objs)
def _filling_ops_queue(
obj: Union[torch.Tensor, List[torch.Tensor]],
comm_op: Callable,
comm_rank: int,
ops_queue: List,
group: ProcessGroup,
):
if isinstance(obj, torch.Tensor):
obj = obj.contiguous()
op_to_add = dist.P2POp(comm_op, obj, comm_rank, group)
ops_queue.append(op_to_add)
else:
for tensor_to_comm in obj:
assert isinstance(tensor_to_comm, torch.Tensor)
_filling_ops_queue(tensor_to_comm, comm_op, comm_rank, ops_queue, group)
def _create_recv_buffer(tensor_metadata: List[TensorMetadata], current_device) -> List[torch.Tensor]:
buffer_recv = []
for metadata in tensor_metadata:
tensor_recv = torch.empty(
metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype
)
buffer_recv.append(tensor_recv)
return buffer_recv
def _batch_send_recv_tensor(
send_tensor_list: Optional[List[torch.Tensor]],
recv_tensor_metadata: Optional[List[TensorMetadata]],
send_dst: Optional[int],
recv_src: Optional[int],
send_group: Optional[ProcessGroup],
recv_group: Optional[ProcessGroup],
current_device: Any,
overlap_p2p: bool = True,
send_first: bool = True,
) -> Optional[Union[torch.Tensor, List[torch.Tensor]]]:
buffer_recv = None
if recv_tensor_metadata is not None:
buffer_recv = _create_recv_buffer(recv_tensor_metadata, current_device)
ops = []
is_send = send_dst is not None and send_tensor_list is not None
is_recv = recv_src is not None and buffer_recv is not None
if send_first:
if is_send:
assert send_group is not None
_filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group)
if is_recv:
assert recv_group is not None
_filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group)
else:
if is_recv:
assert recv_group is not None
_filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group)
if is_send:
assert send_group is not None
_filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group)
if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops)
if not overlap_p2p:
for req in reqs:
req.wait()
return buffer_recv, []
else:
return buffer_recv, reqs
return None, []
def _send_recv_serialization_object(
object: Optional[P2PMetadata],
send_dst: Optional[int],
recv_src: Optional[int],
send_group: Optional[ProcessGroup],
recv_group: Optional[ProcessGroup],
current_device: Any,
is_nccl_backend: bool,
send_first: bool = True,
) -> Optional[P2PMetadata]:
ops = []
send_object_tensor = None
send_object_size_tensor = None
if object is not None and send_dst is not None:
if Version(torch.__version__) >= Version("2.3.0"):
send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(
object, device=current_device, group=send_group
)
elif Version(torch.__version__) >= Version("1.13.0"):
send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object, device=current_device)
else:
send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object)
if is_nccl_backend:
send_object_size_tensor = send_object_size_tensor.to(current_device)
send_object_tensor = send_object_tensor.to(current_device)
recv_object_size_tensor = None
if recv_src is not None:
recv_object_size_tensor = torch.empty(1, dtype=torch.long)
if is_nccl_backend:
recv_object_size_tensor = recv_object_size_tensor.to(current_device)
if send_first:
if send_object_size_tensor is not None:
_filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group)
if recv_src is not None:
_filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group)
else:
if recv_src is not None:
_filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group)
if send_object_size_tensor is not None:
_filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group)
if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait() # This blocks the compute stream in torch
ops = []
is_send = send_dst is not None and send_object_tensor is not None
is_recv = recv_src is not None and recv_object_size_tensor is not None
recv_object_tensor = None
if is_recv:
recv_object_tensor = torch.empty(recv_object_size_tensor.item(), dtype=torch.uint8)
if is_nccl_backend:
recv_object_tensor = recv_object_tensor.to(current_device)
if send_first:
if is_send:
_filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group)
if is_recv:
_filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group)
else:
if is_recv:
_filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group)
if is_send:
_filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group)
if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
if recv_object_tensor is not None and recv_object_size_tensor is not None:
recv_object_tensor = recv_object_tensor.type(torch.uint8)
if recv_object_tensor.device != torch.device("cpu"):
recv_object_tensor = recv_object_tensor.cpu()
unpickle_object = _cuda_safe_tensor_to_object(recv_object_tensor, recv_object_size_tensor.item())
if isinstance(unpickle_object, torch.Tensor) and unpickle_object.device.index != torch.cuda.current_device():
unpickle_object = unpickle_object.cuda()
return unpickle_object
def _communicate(
object: Any,
send_dst: Optional[int],
recv_src: Optional[int],
overlap_p2p: bool,
send_group: Optional[ProcessGroup] = None,
recv_group: Optional[ProcessGroup] = None,
send_metadata: bool = True,
metadata_recv: Optional[P2PMetadata] = None,
send_first: Optional[bool] = None,
) -> Any:
"""
Send and receive object from send_dst and recv_src respectively
Args:
object (Any): object needed to be sent
send_dst (int): rank of the destination
recv_src (int): rank of the source
overlap_p2p (bool): whether to overlap p2p communication with computation
send_group (ProcessGroup, optional): process group of sender
recv_group (ProcessGroup, optional): process group of receiver
send_metadata (bool, optional): whether to send metadata
metadata_recv (P2PMetadata, optional): metadata of the object to be received
"""
assert send_dst is not None or recv_src is not None, "send_dst and recv_src cannot be both None"
assert send_dst is None or send_group is not None, "send_group must be specified when send_dst is not None"
assert recv_src is None or recv_group is not None, "recv_group must be specified when recv_src is not None"
assert (
metadata_recv is None or len(metadata_recv.non_tensor_obj_idx) == 0
), "metadata_recv should not contain non-tensor objects"
metadata_send, tensor_objs = None, None
if object is not None:
# NOTE: if object contains non-tensor objects, we have to send metadata
metadata_send, tensor_objs = create_send_metadata(object, strict=False, return_tensor=True)
send_metadata = send_metadata or len(metadata_send.non_tensor_obj_idx) > 0
else:
send_metadata = False
assert not c10d._rank_not_in_group(send_group) and not c10d._rank_not_in_group(recv_group)
current_send_device, is_send_nccl_backend = _check_device(send_group)
current_recv_device, is_recv_nccl_backend = _check_device(recv_group)
is_nccl_backend = is_send_nccl_backend and is_recv_nccl_backend
assert current_send_device == current_recv_device
current_device = current_send_device
if (send_dst is not None and send_metadata) or (recv_src is not None and metadata_recv is None):
# Send and receive metadata
_metadata_recv = _send_recv_serialization_object(
object=metadata_send,
send_dst=send_dst if send_metadata else None,
recv_src=recv_src if metadata_recv is None else None,
send_group=send_group if send_metadata else None,
recv_group=recv_group if metadata_recv is None else None,
current_device=current_device,
is_nccl_backend=is_nccl_backend,
send_first=send_first if send_first != None else True,
)
assert (
metadata_recv is None or _metadata_recv is None
), "You shouldn't receive metadata when using the cached metadata"
metadata_recv = _metadata_recv if metadata_recv is None else metadata_recv
# Send and receive data
recv_tensor_metadata = None if metadata_recv is None else metadata_recv.tensor_metadata
recv_tensor_objs, wait_handles = _batch_send_recv_tensor(
tensor_objs,
recv_tensor_metadata,
send_dst,
recv_src,
send_group,
recv_group,
current_device,
overlap_p2p=overlap_p2p,
send_first=send_first if send_first != None else True,
)
if metadata_recv is not None:
assert isinstance(metadata_recv, P2PMetadata)
tree_spec = metadata_recv.tree_spec
non_tensor_obj_idx = metadata_recv.non_tensor_obj_idx
non_tensor_objs = metadata_recv.non_tensor_objs
if recv_tensor_objs is None:
recv_tensor_objs = []
for idx in non_tensor_obj_idx:
recv_tensor_objs.insert(idx, non_tensor_objs.pop(0))
recv_object = tree_unflatten(recv_tensor_objs, tree_spec)
return recv_object, wait_handles
return None, wait_handles
def _p2p_comm(
tensor_send_next: torch.Tensor,
recv_prev: bool,
peer: int,
group: ProcessGroup,
comm_dtype: torch.dtype = torch.float16,
):
"""
Send and recv tensor using P2P communication, used when pipeline size is 2 to solve the race communication.
Args:
tensor_send_next (torch.Tensor): tensor to be sent to next stage
recv_prev (bool): whether to receive tensor from previous stage
peer (int): rank of the peer
group (ProcessGroup): process group
comm_dtype (torch.dtype): dtype of the tensor to be sent
Returns:
torch.Tensor: tensor received from previous stage
"""
# send and recv shape
send_next_shape = None
recv_prev_shape = None
if tensor_send_next is not None:
send_next_shape = torch.tensor(tensor_send_next.size(), device=torch.cuda.current_device(), dtype=torch.int64)
if recv_prev:
recv_prev_shape = torch.empty((3), device=torch.cuda.current_device(), dtype=torch.int64)
ops = []
if send_next_shape is not None:
send_next_op = dist.P2POp(dist.isend, send_next_shape, peer=peer, group=group)
ops.append(send_next_op)
if recv_prev_shape is not None:
recv_prev_op = dist.P2POp(
dist.irecv,
recv_prev_shape,
peer=peer,
group=group,
)
ops.append(recv_prev_op)
if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
if recv_prev_shape is not None:
recv_prev_shape = recv_prev_shape.tolist()
# send and recv data
tensor_recv_prev = None
if recv_prev:
tensor_recv_prev = torch.empty(recv_prev_shape, device=torch.cuda.current_device(), dtype=comm_dtype)
ops = []
if tensor_send_next is not None:
send_next_op = dist.P2POp(
dist.isend,
tensor_send_next,
peer=peer,
group=group,
)
ops.append(send_next_op)
if tensor_recv_prev is not None:
recv_prev_op = dist.P2POp(
dist.irecv,
tensor_recv_prev,
peer=peer,
group=group,
)
ops.append(recv_prev_op)
if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
return tensor_recv_prev
class PipelineP2PCommunication:
def __init__(self, stage_manager: PipelineStageManager, overlap_p2p: bool = True) -> None:
self.stage_manager = stage_manager
self.overlap_p2p = overlap_p2p
def recv_forward(
self, prev_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None
) -> Tuple[Any, List]:
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
Args:
prev_rank (int, optional): The rank of the source of the tensor.
Returns:
Any: The input tensor or input tensor list.
List: List of handles for the communication requests, if overlap is enabled.
"""
if prev_rank is None:
prev_rank = self.stage_manager.get_prev_rank()
input_tensor, wait_handles = _communicate(
object=None,
recv_src=prev_rank,
send_dst=None,
recv_group=self.stage_manager.get_p2p_process_group(),
metadata_recv=metadata_recv,
overlap_p2p=self.overlap_p2p,
)
return input_tensor, wait_handles
def recv_backward(
self, next_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None
) -> Tuple[Any, List]:
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
Args:
next_rank (int, optional): The rank of the source of the tensor.
Returns:
Any: The input tensor or input tensor list.
List: List of handles for the communication requests, if overlap is enabled.
"""
if next_rank is None:
next_rank = self.stage_manager.get_next_rank()
output_tensor_grad, wait_handles = _communicate(
object=None,
recv_src=next_rank,
send_dst=None,
recv_group=self.stage_manager.get_p2p_process_group(),
metadata_recv=metadata_recv,
overlap_p2p=self.overlap_p2p,
)
return output_tensor_grad, wait_handles
def send_forward(self, output_object: Any, next_rank: Optional[int] = None, send_metadata: bool = True) -> List:
"""Sends the input tensor to the next stage in pipeline.
Args:
output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor.
Returns:
List: List of handles for the communication requests, if overlap is enabled.
"""
if next_rank is None:
next_rank = self.stage_manager.get_next_rank()
_, handles = _communicate(
output_object,
recv_src=None,
send_dst=next_rank,
send_group=self.stage_manager.get_p2p_process_group(),
send_metadata=send_metadata,
overlap_p2p=self.overlap_p2p,
)
return handles
def send_backward(self, input_object: Any, prev_rank: Optional[int] = None, send_metadata: bool = True) -> List:
"""Sends the gradient tensor to the previous stage in pipeline.
Args:
input_object (Any): Object to be sent.
prev_rank (int, optional): The rank of the recipient of the tensor
Returns:
List: List of handles for the communication requests, if overlap is enabled.
"""
if prev_rank is None:
prev_rank = self.stage_manager.get_prev_rank()
_, handles = _communicate(
input_object,
recv_src=None,
send_dst=prev_rank,
send_group=self.stage_manager.get_p2p_process_group(),
send_metadata=send_metadata,
overlap_p2p=self.overlap_p2p,
)
return handles
def send_forward_recv_forward(
self,
output_object: Any,
is_send: bool,
is_recv: bool,
send_first: bool,
send_metadata: bool = True,
metadata_recv: Optional[P2PMetadata] = None,
) -> Tuple[Any, List]:
"""Sends the input tensor to the next pipeline stage and copy the output tensor from the next pipeline stage
Args:
output_object (Any): Object to be sent.
is_send (bool): Whether to send the input tensor to the next pipeline stage.
is_recv (bool): Whether to copy the output tensor from the next pipeline stage.
send_first (bool): Whether to send before receive.
send_metadata (bool, optional): Whether to send metadata.
metadata_recv (P2PMetadata, optional): The cached metadata(size, type) of the object to be received.
Returns:
Any: The input tensor or input tensor list.
List: List of handles for the communication requests, if overlap is enabled.
"""
next_rank = self.stage_manager.get_next_rank() if is_send else None
prev_rank = self.stage_manager.get_prev_rank() if is_recv else None
group = self.stage_manager.get_p2p_process_group()
return _communicate(
output_object,
send_dst=next_rank,
recv_src=prev_rank,
send_group=group if is_send else None,
recv_group=group if is_recv else None,
send_metadata=send_metadata if is_send else False,
metadata_recv=metadata_recv if is_recv else None,
send_first=send_first,
overlap_p2p=self.overlap_p2p,
)
def send_backward_recv_backward(
self,
input_object: Any,
is_send: bool,
is_recv: bool,
send_first: bool,
send_metadata: bool = True,
metadata_recv: Optional[P2PMetadata] = None,
) -> Tuple[Any, List]:
"""Sends the gradient tensor to the previous pipeline stage and copy the gradient tensor from the previous pipeline stage
Args:
input_object (Any): Object to be sent.
is_send (bool): Whether to send the gradient tensor to the previous pipeline stage.
is_recv (bool): Whether to copy the gradient tensor from the previous pipeline stage.
send_first (bool): Whether to send before receive.
send_metadata (bool, optional): Whether to send metadata.
metadata_recv (P2PMetadata, optional): The cached metadata(size, type) of the object to be received.
Returns:
Any: The input tensor or input tensor list.
List: List of handles for the communication requests, if overlap is enabled.
"""
prev_rank = self.stage_manager.get_prev_rank() if is_send else None
next_rank = self.stage_manager.get_next_rank() if is_recv else None
group = self.stage_manager.get_p2p_process_group()
return _communicate(
input_object,
send_dst=prev_rank,
recv_src=next_rank,
send_group=group if is_send else None,
recv_group=group if is_recv else None,
send_metadata=send_metadata if is_send else False,
metadata_recv=metadata_recv if is_recv else None,
send_first=send_first,
overlap_p2p=self.overlap_p2p,
)
def send_forward_recv_backward(
self,
input_object: Any,
send_metadata: bool = True,
metadata_recv: Optional[P2PMetadata] = None,
send_first: Optional[bool] = None,
) -> Tuple[Any, List]:
"""Sends the gradient tensor to and copy the gradient tensor from the next pipeline stage
Args:
input_object (Any): Object to be sent.
Returns:
Any: The input tensor or input tensor list.
List: List of handles for the communication requests, if overlap is enabled.
"""
next_rank = self.stage_manager.get_next_rank()
group = self.stage_manager.get_p2p_process_group()
return _communicate(
input_object,
next_rank,
next_rank,
send_group=group,
recv_group=group,
send_metadata=send_metadata,
metadata_recv=metadata_recv,
send_first=send_first,
overlap_p2p=False,
)
def send_backward_recv_forward(
self,
input_object: Any,
send_metadata: bool = True,
metadata_recv: Optional[P2PMetadata] = None,
send_first: Optional[bool] = None,
) -> Tuple[Any, List]:
"""Sends the gradient tensor to and copy the gradient tensor from the previous stage in pipeline
Args:
input_object (Any): Object to be sent.
Returns:
Any: The input tensor or input tensor list.
List: List of handles for the communication requests, if overlap is enabled.
"""
prev_rank = self.stage_manager.get_prev_rank()
group = self.stage_manager.get_p2p_process_group()
return _communicate(
input_object,
prev_rank,
prev_rank,
send_group=group,
recv_group=group,
send_metadata=send_metadata,
metadata_recv=metadata_recv,
send_first=send_first,
overlap_p2p=False,
)
def p2p_communicate(
self,
output_object: Any,
recv_pre: bool,
next_rank: Optional[int] = None,
comm_dtype: torch.dtype = torch.float16,
) -> Any:
"""
Sends the input tensor to the next stage in pipeline, using `P2Pop` in torch.
Args:
output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor.
"""
if next_rank is None:
next_rank = self.stage_manager.get_next_rank()
recv_tensor = _p2p_comm(
output_object,
recv_pre,
next_rank,
self.stage_manager.get_p2p_process_group(),
comm_dtype,
)
return recv_tensor