Skip to content

Commit badffa7

Browse files
authored
Fix init_weights handling for param/buffer assignment (#52)
init_weights is a method a user module could supply, which initializes all the parameters and buffers on the module. Currently, we handle init_weights by reparametrization: we replace the original module's state_dict with a version containing parallel_module's states then we run init_weights, mutating these states But if init_weights does something like self.buf = _init_buf() instead of doing something like self.buf.copy_(_init_buf()), we fail to capture this update. This PR interposes the init_weights function differently. 1. it adds a new deepcopied 'init_weights_model' to AutoParallel, so we can freely mutate the class without affecting orig model 2. it mutates the class of init_weights_model, to * add property objects (getter+setter) for each parameter fqn in the module tree * get rid of nn.Module.setattr so its property.setters work 3. each getter returns a corresponding paramter from the parallel module instead of the orig module 4. each setter additionally wraps 'value' in a new replicated DTensor and copies it in the existing DTensor in the parallel module Verification run: `tbm FSDP_eager:torchtitan-64-whc-p3s1bn autop_initweights_eager:torchtitan-64-whc-qthbz6 autop_initweights_eager_rerun:torchtitan-64-whc-d2bddf torchtitan-64-whc-qthbz6`
1 parent cd3ecba commit badffa7

File tree

3 files changed

+183
-13
lines changed

3 files changed

+183
-13
lines changed

autoparallel/api.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import copy
77
import itertools
88
from contextlib import ExitStack
9+
from types import MethodType
910
from typing import Optional
1011

1112
import torch
@@ -22,10 +23,10 @@
2223
from torch.distributed.tensor import DeviceMesh
2324
from torch.export._unlift import _assign_attr
2425
from torch.export.unflatten import _AttrKind
25-
from torch.nn.utils import stateless
2626

2727
from .apply_sharding import apply_sharding_to_model
2828
from .cast_parametrization import apply_dtype_cast, canonicalize_mp, set_dtype_cast
29+
from .init_weights import hook_params_setters
2930
from .optimize_sharding import ShardingOptimizer
3031
from .utils import _get_device_from_mesh
3132

@@ -175,6 +176,11 @@ def __init__(
175176
# in dtype casting and move_to_fake
176177
model = copy.deepcopy(model)
177178

179+
# keep a separate copy of the fake orig model to customize for supporting init_weights
180+
self.init_weights_model = move_to_fake(
181+
copy.deepcopy(model), self.fake_mode, device
182+
)
183+
178184
if self.mp_policy is not None:
179185
apply_dtype_cast(model, self.mp_policy)
180186

@@ -429,6 +435,9 @@ def forward(self, *args):
429435

430436
self.parallel_model = AutoParallelModule()
431437

438+
# We construct an unflattened structure on parallel_mod,
439+
# e.g. _assign_attr(v, parallel_model, k="layers.0.weight") will literally
440+
# create empty nn.Modules recursively and then stash 'v' so it shows up in the right spot
432441
for k, v in sharded_param_dict.items():
433442
_assign_attr(v, self.parallel_model, k, attr_kind=_AttrKind.PARAMETER)
434443

@@ -437,20 +446,18 @@ def forward(self, *args):
437446

438447
# Right now we require a convention that the user model provides an init_weights method,
439448
# although we could snoop for other methods too.
449+
hook_params_setters(self.init_weights_model, self.parallel_model)
440450
if hasattr(self.model, "init_weights"):
441451

442-
def init_weights(*args, **kwargs):
443-
with stateless._reparametrize_module(
444-
self.model, {**sharded_param_dict, **sharded_buffer_dict}
445-
):
446-
self.model.init_weights(*args, **kwargs)
452+
def init_weights(_self, *args, **kwargs):
453+
# this is now a deep-fake-copy of orig mod, so we don't have to use reparametrize
454+
return self.init_weights_model.init_weights(*args, **kwargs)
447455

448-
else:
449-
init_weights = None
450-
451-
# assign an init_weights method onto the output mod.
452-
# all it does is sneakily run the original user mod's init_weights method,
453-
# but with our new DTensor sharded params attached to the user module.
454-
self.parallel_model.init_weights = init_weights
456+
# assign an init_weights method onto the output mod.
457+
# all it does is sneakily run the original user mod's init_weights method,
458+
# but with our new DTensor sharded params attached to the user module.
459+
self.parallel_model.init_weights = MethodType(
460+
init_weights, self.parallel_model
461+
)
455462

456463
return self.parallel_model

autoparallel/init_weights.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
#
3+
# This source code is licensed under the BSD license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from typing import Any, Union
6+
7+
import torch
8+
from torch._dynamo.utils import warn_once
9+
from torch.distributed.tensor import DTensor
10+
11+
12+
def _submod_setattr(model: torch.nn.Module, fqn: str, value: Any):
13+
module_path, _, buffer_name = fqn.rpartition(".")
14+
submod: torch.nn.Module = model.get_submodule(module_path)
15+
setattr(submod, buffer_name, value)
16+
17+
18+
def _copy_set_value_to_dtensor(
19+
fqn: str, parallel_value: DTensor, set_value: torch.Tensor
20+
):
21+
# We expect the user wrote their module's init_weights in terms of a single-gpu model, so we do not expect
22+
# set_value to be a DTensor already (since this would imply init_weights was written in a 'distributed' way),
23+
# and we interpret it as a global tensor which we map to a Replicated DTensor.
24+
assert not isinstance(
25+
set_value, DTensor
26+
), "Expected local/full tensor from setattr in init_weights, not DTensor."
27+
28+
# This creates a replicated DTensor
29+
new_parallel_value = DTensor.from_local(
30+
set_value, device_mesh=parallel_value.device_mesh
31+
)
32+
if parallel_value.placements != new_parallel_value.placements:
33+
# no harm done if the parallel value is replicated, e.g. freqs_cis in llama3, but it would be
34+
# noticeably wasteful if we do this for all the sharded parameters.
35+
warn_once(
36+
f"init_weights set a new value for {fqn}, "
37+
f"but the existing value is already sharded ({parallel_value.placements=}, "
38+
"and it is wasteful to materialize the new value as a global tensor. "
39+
"Change init_weights to perform an inplace initialization instead if possible."
40+
)
41+
with torch.no_grad():
42+
# This ensures that we faithfully redistribute the replicated new_parallel_value into whatever placement
43+
# the autoparallel engine decided for parallel_value. Note: this should in general be comm free, since it
44+
# would be going from Replicate -> Shard.
45+
parallel_value.copy_(new_parallel_value)
46+
47+
48+
def _build_param_property(parallel_model: torch.nn.Module, fqn: str):
49+
def getter(self) -> torch.nn.Parameter:
50+
param = parallel_model.get_parameter(fqn)
51+
return param
52+
53+
def setter(self, value: Union[torch.Tensor, torch.nn.Parameter]) -> None:
54+
parallel_value = parallel_model.get_parameter(fqn)
55+
assert isinstance(
56+
parallel_value, DTensor
57+
), "Expected parallel_module params to be DTensors"
58+
_copy_set_value_to_dtensor(fqn, parallel_value, value)
59+
60+
return property(getter, setter)
61+
62+
63+
def _build_buffer_property(parallel_model: torch.nn.Module, fqn: str):
64+
def getter(self) -> torch.Tensor:
65+
return parallel_model.get_buffer(fqn)
66+
67+
def setter(self, value: torch.Tensor) -> None:
68+
parallel_value = parallel_model.get_buffer(fqn)
69+
assert isinstance(
70+
parallel_value, DTensor
71+
), "Expected parallel_module params to be DTensors"
72+
_copy_set_value_to_dtensor(fqn, parallel_value, value)
73+
74+
return property(getter, setter)
75+
76+
77+
def hook_params_setters(
78+
init_weights_model: torch.nn.Module, parallel_model: torch.nn.Module
79+
) -> None:
80+
"""
81+
Replaces init_weights_model's parameters with hooked properties that let us
82+
(a) return a new parameter (from our parallel_mod) instead of the one on the original model,
83+
similar to using stateless.reparametrize
84+
(b) also, detect if anyone tries to assign a new value to the parameter, e.g.
85+
self.layer.weight = nn.Parameter(torch.randn(10, 10))
86+
would not be properly captured if relying on parametrization alone
87+
88+
Assumes init_weights_model is a deepcopy of the user's original model, with all fake params. This way we can
89+
modify the model to enable init_weights to work, without affecting the user's original model.
90+
91+
Adds one 'property' (e.g. getter+setter) obj for each parameter name at the right spot in
92+
the module hierarchy. For self.layer.weight, this would install a 'weight' property on the self.layer
93+
submodule.
94+
"""
95+
for mod_name, mod in sorted(init_weights_model.named_modules()):
96+
params_dict = dict(mod.named_parameters(recurse=False))
97+
buffers_dict = dict(mod.named_buffers(recurse=False))
98+
99+
namespace = {}
100+
for p_name in params_dict:
101+
fqn = mod_name + "." + p_name
102+
namespace[p_name] = _build_param_property(parallel_model, fqn)
103+
104+
for b_name in buffers_dict:
105+
fqn = mod_name + "." + b_name
106+
namespace[b_name] = _build_buffer_property(parallel_model, fqn)
107+
108+
cls = mod.__class__
109+
# nn.Module.__setattr__ gets in the way
110+
namespace["__setattr__"] = object.__setattr__
111+
mod.__class__ = type(f"HookedInit{cls.__name__}", (cls,), namespace)

tests/test_api.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pytest
77
import torch
88
from torch import nn
9+
from torch.distributed.tensor.placement_types import Shard
910
from torch.testing._internal.distributed.fake_pg import FakeStore
1011

1112
from autoparallel.api import AutoParallel
@@ -62,3 +63,54 @@ def input_fn():
6263
auto_p.model.get_parameter("linear.weight"), torch._subclasses.FakeTensor
6364
)
6465
assert isinstance(auto_p.model.get_buffer("buf"), torch._subclasses.FakeTensor)
66+
67+
68+
def test_init(device_mesh_1d):
69+
dim = 128
70+
71+
class Model(nn.Module):
72+
def __init__(self, dim):
73+
super().__init__()
74+
self.linear = nn.Linear(dim, dim)
75+
self.register_buffer("buf", torch.empty(dim))
76+
77+
def forward(self, x):
78+
return self.linear(x) + self.buf
79+
80+
def init_weights(self):
81+
self.linear.weight = torch.nn.Parameter(torch.ones(dim, dim) * 9.0)
82+
with torch.no_grad():
83+
self.linear.bias.fill_(98.6)
84+
self.buf = torch.arange(dim)
85+
86+
def input_fn():
87+
b = 512
88+
inputs = (torch.rand(b, dim, device="cuda"),)
89+
return inputs
90+
91+
with torch.device("meta"):
92+
model = Model(dim)
93+
with AutoParallel(
94+
model,
95+
input_fn,
96+
device_mesh_1d,
97+
) as autop:
98+
x_sharding = (Shard(0),)
99+
autop.add_input_constraints([x_sharding])
100+
sharding_placement = autop.optimize_placement()
101+
102+
# AutoParallel produces a module with meta-DTensor parameters that need to be initialized
103+
parallel_mod = autop.apply_placement(sharding_placement)
104+
parallel_mod.to_empty(device="cuda")
105+
parallel_mod.init_weights()
106+
assert torch.equal(
107+
parallel_mod.get_parameter("linear.weight").full_tensor(),
108+
torch.full((dim, dim), 9.0, device="cuda"),
109+
)
110+
assert torch.equal(
111+
parallel_mod.get_parameter("linear.bias").full_tensor(),
112+
torch.full((dim,), 98.6, device="cuda"),
113+
)
114+
assert torch.equal(
115+
parallel_mod.get_buffer("buf").full_tensor(), torch.arange(dim, device="cuda")
116+
)

0 commit comments

Comments
 (0)