Skip to content

Commit cdfb492

Browse files
authored
Merge pull request ggml-org#37 from awgu/pt2
Have DDP ignore `freqs_cis` to avoid broadcast
2 parents 9055766 + 25494f9 commit cdfb492

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

train.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,10 @@
191191

192192
# wrap model into DDP container
193193
if ddp:
194+
# Ignore the `freqs_cis` buffer so that DDP does not broadcast it at
195+
# construction time since NCCL does not support `ComplexFloat`
196+
prefix = "_orig_mod." if compile else ""
197+
model._ddp_params_and_buffers_to_ignore = {prefix + "freqs_cis"}
194198
model = DDP(model, device_ids=[ddp_local_rank])
195199

196200
# helps estimate an arbitrarily accurate loss over either split using many batches

0 commit comments

Comments
 (0)