11import os
22import threading
33import time
4+ import warnings
45from functools import wraps
56
67import torch
910import torch .distributed .rpc as rpc
1011import torch .multiprocessing as mp
1112import torch .optim as optim
12- from torch .distributed .optim import DistributedOptimizer
1313from torch .distributed .rpc import RRef
1414
1515from torchvision .models .resnet import Bottleneck
1616
17+ # Suppress warnings that can't be fixed from user code
18+ warnings .filterwarnings ("ignore" ,
19+ message = "You are using a Backend .* as a ProcessGroup. This usage is deprecated" ,
20+ category = UserWarning )
21+ warnings .filterwarnings ("ignore" ,
22+ message = "networkx backend defined more than once: nx-loopback" ,
23+ category = RuntimeWarning )
24+
1725
1826#########################################################
1927# Define Model Parallel ResNet50 #
@@ -185,15 +193,34 @@ def parameter_rrefs(self):
185193image_h = 128
186194
187195
196+ def create_optimizer_for_remote_params (worker_name , param_rrefs , lr = 0.05 ):
197+ """Create torch.compiled optimizers on each worker"""
198+ params = [p .to_here () for p in param_rrefs ]
199+ opt = optim .SGD (params , lr = lr )
200+ opt .step = torch .compile (opt .step )
201+ return opt
202+
203+
188204def run_master (split_size ):
189205
190206 # put the two model parts on worker1 and worker2 respectively
191207 model = DistResNet50 (split_size , ["worker1" , "worker2" ])
192208 loss_fn = nn .MSELoss ()
193- opt = DistributedOptimizer (
194- optim .SGD ,
195- model .parameter_rrefs (),
196- lr = 0.05 ,
209+
210+ # Get parameter RRefs for each model shard
211+ p1_param_rrefs = model .p1_rref .remote ().parameter_rrefs ().to_here ()
212+ p2_param_rrefs = model .p2_rref .remote ().parameter_rrefs ().to_here ()
213+
214+ # Create optimizers on remote workers
215+ opt1_rref = rpc .remote (
216+ "worker1" ,
217+ create_optimizer_for_remote_params ,
218+ args = ("worker1" , p1_param_rrefs )
219+ )
220+ opt2_rref = rpc .remote (
221+ "worker2" ,
222+ create_optimizer_for_remote_params ,
223+ args = ("worker2" , p2_param_rrefs )
197224 )
198225
199226 one_hot_indices = torch .LongTensor (batch_size ) \
@@ -213,7 +240,12 @@ def run_master(split_size):
213240 with dist_autograd .context () as context_id :
214241 outputs = model (inputs )
215242 dist_autograd .backward (context_id , [loss_fn (outputs , labels )])
216- opt .step (context_id )
243+
244+ opt1_rref .rpc_sync ().step ()
245+ opt2_rref .rpc_sync ().step ()
246+
247+ opt1_rref .rpc_sync ().zero_grad ()
248+ opt2_rref .rpc_sync ().zero_grad ()
217249
218250
219251def run_worker (rank , world_size , num_split ):
@@ -245,6 +277,9 @@ def run_worker(rank, world_size, num_split):
245277
246278
247279if __name__ == "__main__" :
280+ # Suppress torch compile profiler warnings
281+ os .environ ['TORCH_LOGS' ] = '-dynamo'
282+
248283 world_size = 3
249284 for num_split in [1 , 2 , 4 , 8 ]:
250285 tik = time .time ()
0 commit comments