Skip to content

Commit b899918

Browse files
committed
[WIP] Fix init_weights handling for param/buffer assignment
init_weights is a method a user module could supply, which initializes all the parameters and buffers on the module. Currently, we handle init_weights by reparametrization: - we replace the original module's state_dict with a version containing parallel_module's states - then we run init_weights, mutating these states But if init_weights does something like `self.buf = _init_buf()` instead of doing something like `self.buf.copy_(_init_buf())`, we fail to capture this update. This PR attempts to find these missing updates and then copy them back to the parallel_module's states. 1) assuming that if init_weights did an assignment, it would not create a DTensor, becuase init_weights and orig module are supposed to be written in 'single gpu' style. 2) finding any non-DTensors in the updated state_dict and converting them to new Replicate() DTensors, following the semantic that the new assigned value should represent the global value for the state 3) copy_ into the original state DTensor on the parallel_module, since this handles the case of converting Replicate() to Shard() if needed. TODO: - verify this fixes the current init correctness problem with llama - support params (currently only implemented buffers) - support nested names (a.b.c), currently only flat names work - see if there is a better way to detect the assignment (e.g. #1 above) Make hooked setter work for initializing params/buffers ghstack-source-id: 1c0670b Pull Request resolved: #66
1 parent 385d06e commit b899918

File tree

3 files changed

+178
-13
lines changed

3 files changed

+178
-13
lines changed

autoparallel/api.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import copy
77
import itertools
88
from contextlib import ExitStack
9+
from types import MethodType
910
from typing import Optional
1011

1112
import torch
@@ -22,10 +23,10 @@
2223
from torch.distributed.tensor import DeviceMesh
2324
from torch.export._unlift import _assign_attr
2425
from torch.export.unflatten import _AttrKind
25-
from torch.nn.utils import stateless
2626

2727
from .apply_sharding import apply_sharding_to_model
2828
from .cast_parametrization import apply_dtype_cast, canonicalize_mp, set_dtype_cast
29+
from .init_weights import hook_params_setters
2930
from .optimize_sharding import ShardingOptimizer
3031
from .utils import _get_device_from_mesh
3132

@@ -175,6 +176,11 @@ def __init__(
175176
# in dtype casting and move_to_fake
176177
model = copy.deepcopy(model)
177178

179+
# keep a separate copy of the fake orig model to customize for supporting init_weights
180+
self.init_weights_model = move_to_fake(
181+
copy.deepcopy(model), self.fake_mode, device
182+
)
183+
178184
if self.mp_policy is not None:
179185
apply_dtype_cast(model, self.mp_policy)
180186

@@ -431,6 +437,9 @@ def forward(self, *args):
431437

432438
self.parallel_model = AutoParallelModule()
433439

440+
# We construct an unflattened structure on parallel_mod,
441+
# e.g. _assign_attr(v, parallel_model, k="layers.0.weight") will literally
442+
# create empty nn.Modules recursively and then stash 'v' so it shows up in the right spot
434443
for k, v in sharded_param_dict.items():
435444
_assign_attr(v, self.parallel_model, k, attr_kind=_AttrKind.PARAMETER)
436445

@@ -439,20 +448,18 @@ def forward(self, *args):
439448

440449
# Right now we require a convention that the user model provides an init_weights method,
441450
# although we could snoop for other methods too.
451+
hook_params_setters(self.init_weights_model, self.parallel_model)
442452
if hasattr(self.model, "init_weights"):
443453

444-
def init_weights(*args, **kwargs):
445-
with stateless._reparametrize_module(
446-
self.model, {**sharded_param_dict, **sharded_buffer_dict}
447-
):
448-
self.model.init_weights(*args, **kwargs)
454+
def init_weights(_self, *args, **kwargs):
455+
# this is now a deep-fake-copy of orig mod, so we don't have to use reparametrize
456+
return self.init_weights_model.init_weights(*args, **kwargs)
449457

450-
else:
451-
init_weights = None
452-
453-
# assign an init_weights method onto the output mod.
454-
# all it does is sneakily run the original user mod's init_weights method,
455-
# but with our new DTensor sharded params attached to the user module.
456-
self.parallel_model.init_weights = init_weights
458+
# assign an init_weights method onto the output mod.
459+
# all it does is sneakily run the original user mod's init_weights method,
460+
# but with our new DTensor sharded params attached to the user module.
461+
self.parallel_model.init_weights = MethodType(
462+
init_weights, self.parallel_model
463+
)
457464

458465
return self.parallel_model

autoparallel/init_weights.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
#
3+
# This source code is licensed under the BSD license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from typing import Any, Union
6+
7+
import torch
8+
from torch._dynamo.utils import warn_once
9+
from torch.distributed.tensor import DTensor
10+
11+
12+
def _submod_setattr(model: torch.nn.Module, fqn: str, value: Any):
13+
module_path, _, buffer_name = fqn.rpartition(".")
14+
submod: torch.nn.Module = model.get_submodule(module_path)
15+
setattr(submod, buffer_name, value)
16+
17+
18+
def _copy_set_value_to_dtensor(
19+
fqn: str, parallel_value: DTensor, set_value: torch.Tensor
20+
):
21+
# We expect the user wrote their module's init_weights in terms of a single-gpu model, so we do not expect
22+
# set_value to be a DTensor already (since this would imply init_weights was written in a 'distributed' way),
23+
# and we interpret it as a global tensor which we map to a Replicated DTensor.
24+
assert not isinstance(
25+
set_value, DTensor
26+
), "Expected local/full tensor from setattr in init_weights, not DTensor."
27+
28+
# This creates a replicated DTensor
29+
new_parallel_value = DTensor.from_local(
30+
set_value, device_mesh=parallel_value.device_mesh
31+
)
32+
if parallel_value.placements != new_parallel_value.placements:
33+
warn_once(
34+
f"init_weights set a new value for {fqn}, "
35+
f"but the existing value is already sharded ({parallel_value.placements=}, "
36+
"and it is wasteful to materialize the new value as a global tensor. "
37+
"Change init_weights to perform an inplace initialization instead if possible."
38+
)
39+
with torch.no_grad():
40+
parallel_value.copy_(new_parallel_value)
41+
42+
43+
def _build_param_property(parallel_model: torch.nn.Module, fqn: str):
44+
def getter(self) -> torch.nn.Parameter:
45+
param = parallel_model.get_parameter(fqn)
46+
return param
47+
48+
def setter(self, value: Union[torch.Tensor, torch.nn.Parameter]) -> None:
49+
parallel_value = parallel_model.get_parameter(fqn)
50+
assert isinstance(
51+
parallel_value, DTensor
52+
), "Expected parallel_module params to be DTensors"
53+
_copy_set_value_to_dtensor(fqn, parallel_value, value)
54+
55+
return property(getter, setter)
56+
57+
58+
def _build_buffer_property(parallel_model: torch.nn.Module, fqn: str):
59+
def getter(self) -> torch.Tensor:
60+
return parallel_model.get_buffer(fqn)
61+
62+
def setter(self, value: torch.Tensor) -> None:
63+
parallel_value = parallel_model.get_buffer(fqn)
64+
assert isinstance(
65+
parallel_value, DTensor
66+
), "Expected parallel_module params to be DTensors"
67+
_copy_set_value_to_dtensor(fqn, parallel_value, value)
68+
69+
return property(getter, setter)
70+
71+
72+
def hook_params_setters(
73+
init_weights_model: torch.nn.Module, parallel_model: torch.nn.Module
74+
) -> None:
75+
"""
76+
Replaces init_weights_model's parameters with hooked properties that let us
77+
(a) return a new parameter (from our parallel_mod) instead of the one on the original model,
78+
similar to using stateless.reparametrize
79+
(b) also, detect if anyone tries to assign a new value to the parameter, e.g.
80+
self.layer.weight = nn.Parameter(torch.randn(10, 10))
81+
would not be properly captured if relying on parametrization alone
82+
83+
Assumes init_weights_model is a deepcopy of the user's original model, with all fake params. This way we can
84+
modify the model to enable init_weights to work, without affecting the user's original model.
85+
86+
Adds one 'property' (e.g. getter+setter) obj for each parameter name at the right spot in
87+
the module hierarchy. For self.layer.weight, this would install a 'weight' property on the self.layer
88+
submodule.
89+
"""
90+
for mod_name, mod in sorted(init_weights_model.named_modules()):
91+
params_dict = dict(mod.named_parameters(recurse=False))
92+
buffers_dict = dict(mod.named_buffers(recurse=False))
93+
94+
namespace = {}
95+
for p_name in params_dict:
96+
fqn = mod_name + "." + p_name
97+
namespace[p_name] = _build_param_property(parallel_model, fqn)
98+
99+
for b_name in buffers_dict:
100+
fqn = mod_name + "." + b_name
101+
namespace[b_name] = _build_buffer_property(parallel_model, fqn)
102+
103+
cls = mod.__class__
104+
# nn.Module.__setattr__ gets in the way
105+
namespace["__setattr__"] = object.__setattr__
106+
mod.__class__ = type(f"HookedInit{cls.__name__}", (cls,), namespace)

tests/test_api.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pytest
77
import torch
88
from torch import nn
9+
from torch.distributed.tensor.placement_types import Shard
910
from torch.testing._internal.distributed.fake_pg import FakeStore
1011

1112
from autoparallel.api import AutoParallel
@@ -62,3 +63,54 @@ def input_fn():
6263
auto_p.model.get_parameter("linear.weight"), torch._subclasses.FakeTensor
6364
)
6465
assert isinstance(auto_p.model.get_buffer("buf"), torch._subclasses.FakeTensor)
66+
67+
68+
def test_init(device_mesh_1d):
69+
dim = 128
70+
71+
class Model(nn.Module):
72+
def __init__(self, dim):
73+
super().__init__()
74+
self.linear = nn.Linear(dim, dim)
75+
self.register_buffer("buf", torch.empty(dim))
76+
77+
def forward(self, x):
78+
return self.linear(x) + self.buf
79+
80+
def init_weights(self):
81+
self.linear.weight = torch.nn.Parameter(torch.ones(dim, dim) * 9.0)
82+
with torch.no_grad():
83+
self.linear.bias.fill_(98.6)
84+
self.buf = torch.arange(dim)
85+
86+
def input_fn():
87+
b = 512
88+
inputs = (torch.rand(b, dim, device="cuda"),)
89+
return inputs
90+
91+
with torch.device("meta"):
92+
model = Model(dim)
93+
with AutoParallel(
94+
model,
95+
input_fn,
96+
device_mesh_1d,
97+
) as autop:
98+
x_sharding = (Shard(0),)
99+
autop.add_input_constraints([x_sharding])
100+
sharding_placement = autop.optimize_placement()
101+
102+
# AutoParallel produces a module with meta-DTensor parameters that need to be initialized
103+
parallel_mod = autop.apply_placement(sharding_placement)
104+
parallel_mod.to_empty(device="cuda")
105+
parallel_mod.init_weights()
106+
assert torch.equal(
107+
parallel_mod.get_parameter("linear.weight").full_tensor(),
108+
torch.full((dim, dim), 9.0, device="cuda"),
109+
)
110+
assert torch.equal(
111+
parallel_mod.get_parameter("linear.bias").full_tensor(),
112+
torch.full((dim,), 98.6, device="cuda"),
113+
)
114+
assert torch.equal(
115+
parallel_mod.get_buffer("buf").full_tensor(), torch.arange(dim, device="cuda")
116+
)

0 commit comments

Comments
 (0)