Skip to content

Commit fa719e8

Browse files
committed
[syncBN]
replacing new_group with torch.distributed.group.WORLD, avoids creating new group in every iteration. This should resolve the issue in Training gets stuck when using SyncBN pytorch#105
1 parent 241dd6c commit fa719e8

File tree

4 files changed

+9
-22
lines changed

4 files changed

+9
-22
lines changed

apex/parallel/__init__.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,5 @@
11
import torch
22

3-
# Backward compatibility hack around
4-
# https://github.com/pytorch/pytorch/pull/14767
5-
if hasattr(torch.distributed, 'get_default_group'):
6-
group_creator = torch.distributed.get_default_group
7-
elif hasattr(torch.distributed, 'new_group'):
8-
group_creator = torch.distributed.new_group
9-
else:
10-
group_creator = torch.distributed.deprecated.new_group
11-
123
if hasattr(torch.distributed, 'ReduceOp'):
134
ReduceOp = torch.distributed.ReduceOp
145
elif hasattr(torch.distributed, 'reduce_op'):

apex/parallel/optimized_sync_batchnorm_kernel.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from torch.autograd.function import Function
33

44
import syncbn
5-
from apex.parallel import group_creator, ReduceOp
5+
from apex.parallel import ReduceOp
66

77
class SyncBatchnormFunction(Function):
88

@@ -16,11 +16,9 @@ def forward(ctx, input, weight, bias, running_mean, running_variance, eps, track
1616
mean, var, var_biased = syncbn.welford_mean_var(input)
1717

1818
if torch.distributed.is_initialized():
19-
if process_group:
20-
world_size = torch.distributed.get_world_size(process_group)
21-
else:
22-
process_group = group_creator()
23-
world_size = torch.distributed.get_world_size()
19+
if not process_group:
20+
process_group = torch.distributed.group.WORLD
21+
world_size = torch.distributed.get_world_size(process_group)
2422
mean_all = torch.empty(world_size, mean.size(0), dtype=mean.dtype, device=mean.device)
2523
var_all = torch.empty(world_size, var.size(0), dtype=var.dtype, device=var.device)
2624
mean_l = [mean_all.narrow(0, i, 1) for i in range(world_size)]

apex/parallel/sync_batchnorm.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from torch.nn import functional as F
44

55
from .sync_batchnorm_kernel import SyncBatchnormFunction
6-
from apex.parallel import group_creator, ReduceOp
6+
from apex.parallel import ReduceOp
77

88

99
class SyncBatchNorm(_BatchNorm):
@@ -63,11 +63,9 @@ def forward(self, input):
6363
else:
6464
process_group = self.process_group
6565
world_size = 0
66-
if self.process_group:
67-
world_size = torch.distributed.get_world_size(process_group)
68-
else:
69-
process_group = group_creator()
70-
world_size = torch.distributed.get_world_size()
66+
if not self.process_group:
67+
process_group = torch.distributed.group.WORLD
68+
world_size = torch.distributed.get_world_size(process_group)
7169
self.num_batches_tracked += 1
7270
with torch.no_grad():
7371
channel_first_input = input.transpose(0, 1).contiguous()

apex/parallel/sync_batchnorm_kernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
from torch.autograd.function import Function
33

4-
from apex.parallel import group_creator, ReduceOp
4+
from apex.parallel import ReduceOp
55

66

77
class SyncBatchnormFunction(Function):

0 commit comments

Comments
 (0)