Skip to content

Commit abfd789

Browse files
committed
[WIP] Fix init_weights handling for param/buffer assignment
init_weights is a method a user module could supply, which initializes all the parameters and buffers on the module. Currently, we handle init_weights by reparametrization: - we replace the original module's state_dict with a version containing parallel_module's states - then we run init_weights, mutating these states But if init_weights does something like `self.buf = _init_buf()` instead of doing something like `self.buf.copy_(_init_buf())`, we fail to capture this update. This PR attempts to find these missing updates and then copy them back to the parallel_module's states. 1) assuming that if init_weights did an assignment, it would not create a DTensor, becuase init_weights and orig module are supposed to be written in 'single gpu' style. 2) finding any non-DTensors in the updated state_dict and converting them to new Replicate() DTensors, following the semantic that the new assigned value should represent the global value for the state 3) copy_ into the original state DTensor on the parallel_module, since this handles the case of converting Replicate() to Shard() if needed. TODO: - verify this fixes the current init correctness problem with llama - support params (currently only implemented buffers) - support nested names (a.b.c), currently only flat names work - see if there is a better way to detect the assignment (e.g. #1 above) Make hooked setter work for initializing params/buffers ghstack-source-id: 1c0670b Pull Request resolved: #66
1 parent 385d06e commit abfd789

File tree

3 files changed

+146
-13
lines changed

3 files changed

+146
-13
lines changed

autoparallel/api.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import copy
77
import itertools
88
from contextlib import ExitStack
9+
from types import MethodType
910
from typing import Optional
1011

1112
import torch
@@ -22,10 +23,10 @@
2223
from torch.distributed.tensor import DeviceMesh
2324
from torch.export._unlift import _assign_attr
2425
from torch.export.unflatten import _AttrKind
25-
from torch.nn.utils import stateless
2626

2727
from .apply_sharding import apply_sharding_to_model
2828
from .cast_parametrization import apply_dtype_cast, canonicalize_mp, set_dtype_cast
29+
from .init_weights import hook_params_setters
2930
from .optimize_sharding import ShardingOptimizer
3031
from .utils import _get_device_from_mesh
3132

@@ -175,6 +176,11 @@ def __init__(
175176
# in dtype casting and move_to_fake
176177
model = copy.deepcopy(model)
177178

179+
# keep a separate copy of the fake orig model to customize for supporting init_weights
180+
self.init_weights_model = move_to_fake(
181+
copy.deepcopy(model), self.fake_mode, device
182+
)
183+
178184
if self.mp_policy is not None:
179185
apply_dtype_cast(model, self.mp_policy)
180186

@@ -431,6 +437,10 @@ def forward(self, *args):
431437

432438
self.parallel_model = AutoParallelModule()
433439

440+
# We construct an unflattened structure on parallel_mod,
441+
# e.g. _assign_attr(v, parallel_model, k="layers.0.weight") will literally
442+
# create empty nn.Modules recursively and then stash 'v' so it shows up in the right spot
443+
# though, what happened to the original 'flattned' parameters on parallel_mod, did we delete those?
434444
for k, v in sharded_param_dict.items():
435445
_assign_attr(v, self.parallel_model, k, attr_kind=_AttrKind.PARAMETER)
436446

@@ -439,20 +449,18 @@ def forward(self, *args):
439449

440450
# Right now we require a convention that the user model provides an init_weights method,
441451
# although we could snoop for other methods too.
452+
hook_params_setters(self.init_weights_model, self.parallel_model)
442453
if hasattr(self.model, "init_weights"):
443454

444-
def init_weights(*args, **kwargs):
445-
with stateless._reparametrize_module(
446-
self.model, {**sharded_param_dict, **sharded_buffer_dict}
447-
):
448-
self.model.init_weights(*args, **kwargs)
455+
def init_weights(_self, *args, **kwargs):
456+
# this is now a deep-fake-copy of orig mod, so we don't have to use reparametrize
457+
return self.init_weights_model.init_weights(*args, **kwargs)
449458

450-
else:
451-
init_weights = None
452-
453-
# assign an init_weights method onto the output mod.
454-
# all it does is sneakily run the original user mod's init_weights method,
455-
# but with our new DTensor sharded params attached to the user module.
456-
self.parallel_model.init_weights = init_weights
459+
# assign an init_weights method onto the output mod.
460+
# all it does is sneakily run the original user mod's init_weights method,
461+
# but with our new DTensor sharded params attached to the user module.
462+
self.parallel_model.init_weights = MethodType(
463+
init_weights, self.parallel_model
464+
)
457465

458466
return self.parallel_model

autoparallel/init_weights.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
#
3+
# This source code is licensed under the BSD license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
import torch
6+
from torch.distributed.tensor import DTensor
7+
8+
9+
def _submod_setattr(model, fqn, value):
10+
module_path, _, buffer_name = fqn.rpartition(".")
11+
submod: torch.nn.Module = model.get_submodule(module_path)
12+
setattr(submod, buffer_name, value)
13+
14+
15+
def _build_param_property(parallel_model, fqn):
16+
def getter(self, _fqn=fqn):
17+
param = parallel_model.get_parameter(_fqn)
18+
return param
19+
20+
def setter(self, value):
21+
orig_value = parallel_model.get_parameter(fqn)
22+
new_value = DTensor.from_local(value, device_mesh=orig_value.device_mesh)
23+
if isinstance(orig_value, torch.nn.Parameter):
24+
new_value = torch.nn.Parameter(new_value)
25+
_submod_setattr(parallel_model, fqn, new_value)
26+
27+
return property(getter, setter)
28+
29+
30+
def _build_buffer_property(parallel_model, fqn):
31+
def getter(self):
32+
return parallel_model.get_buffer(fqn)
33+
34+
def setter(self, value):
35+
orig_value = parallel_model.get_buffer(fqn)
36+
new_value = DTensor.from_local(value, device_mesh=orig_value.device_mesh)
37+
_submod_setattr(parallel_model, fqn, new_value)
38+
39+
return property(getter, setter)
40+
41+
42+
def hook_params_setters(model, parallel_model):
43+
"""
44+
Replaces model's parameters with hooked properties that let us
45+
(a) return a new parameter (from our parallel_mod) instead of the one on the original model,
46+
similar to using stateless.reparametrize
47+
(b) also, detect if anyone tries to assign a new value to the parameter, e.g.
48+
self.layer.weight = nn.Parameter(torch.randn(10, 10))
49+
would not be properly captured if relying on parametrization alone
50+
51+
Adds one 'property' (e.g. getter+setter) obj for each parameter name at the right spot in
52+
the module hierarchy. For self.layer.weight, this would install a 'weight' property on the self.layer
53+
submodule.
54+
"""
55+
for mod_name, mod in sorted(model.named_modules()):
56+
params_dict = dict(mod.named_parameters(recurse=False))
57+
buffers_dict = dict(mod.named_buffers(recurse=False))
58+
59+
namespace = {}
60+
for p_name in params_dict:
61+
fqn = mod_name + "." + p_name
62+
namespace[p_name] = _build_param_property(parallel_model, fqn)
63+
64+
for b_name in buffers_dict:
65+
fqn = mod_name + "." + b_name
66+
namespace[b_name] = _build_buffer_property(parallel_model, fqn)
67+
68+
cls = mod.__class__
69+
# nn.Module.__setattr__ gets in the way
70+
namespace["__setattr__"] = object.__setattr__
71+
mod.__class__ = type(f"HookedInit{cls.__name__}", (cls,), namespace)
72+
73+
return model

tests/test_api.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pytest
77
import torch
88
from torch import nn
9+
from torch.distributed.tensor.placement_types import Shard
910
from torch.testing._internal.distributed.fake_pg import FakeStore
1011

1112
from autoparallel.api import AutoParallel
@@ -62,3 +63,54 @@ def input_fn():
6263
auto_p.model.get_parameter("linear.weight"), torch._subclasses.FakeTensor
6364
)
6465
assert isinstance(auto_p.model.get_buffer("buf"), torch._subclasses.FakeTensor)
66+
67+
68+
def test_init(device_mesh_1d):
69+
dim = 128
70+
71+
class Model(nn.Module):
72+
def __init__(self, dim):
73+
super().__init__()
74+
self.linear = nn.Linear(dim, dim)
75+
self.register_buffer("buf", torch.empty(dim))
76+
77+
def forward(self, x):
78+
return self.linear(x) + self.buf
79+
80+
def init_weights(self):
81+
self.linear.weight = torch.nn.Parameter(torch.ones(dim, dim) * 9.0)
82+
with torch.no_grad():
83+
self.linear.bias.fill_(98.6)
84+
self.buf = torch.arange(dim)
85+
86+
def input_fn():
87+
b = 512
88+
inputs = (torch.rand(b, dim, device="cuda"),)
89+
return inputs
90+
91+
with torch.device("meta"):
92+
model = Model(dim)
93+
with AutoParallel(
94+
model,
95+
input_fn,
96+
device_mesh_1d,
97+
) as autop:
98+
x_sharding = (Shard(0),)
99+
autop.add_input_constraints([x_sharding])
100+
sharding_placement = autop.optimize_placement()
101+
102+
# AutoParallel produces a module with meta-DTensor parameters that need to be initialized
103+
parallel_mod = autop.apply_placement(sharding_placement)
104+
parallel_mod.to_empty(device="cuda")
105+
parallel_mod.init_weights()
106+
assert torch.equal(
107+
parallel_mod.get_parameter("linear.weight").full_tensor(),
108+
torch.full((dim, dim), 9.0, device="cuda"),
109+
)
110+
assert torch.equal(
111+
parallel_mod.get_parameter("linear.bias").full_tensor(),
112+
torch.full((dim,), 98.6, device="cuda"),
113+
)
114+
assert torch.equal(
115+
parallel_mod.get_buffer("buf").full_tensor(), torch.arange(dim, device="cuda")
116+
)

0 commit comments

Comments
 (0)