Skip to content

Conversation

@wconstab
Copy link
Contributor

@wconstab wconstab commented Jul 26, 2025

init_weights is a method a user module could supply, which initializes all the parameters and buffers on the module.

Currently, we handle init_weights by reparametrization:

  • we replace the original module's state_dict with a version containing parallel_module's states
  • then we run init_weights, mutating these states

But if init_weights does something like self.buf = _init_buf() instead of doing something like self.buf.copy_(_init_buf()), we fail to capture this update.

This PR interposes the init_weights function differently.

  1. it adds a new deepcopied 'init_weights_model' to AutoParallel, so we can freely mutate the class without affecting orig model
  2. it mutates the class of init_weights_model, to
  • add property objects (getter+setter) for each parameter fqn in the module tree
  • get rid of nn.Module.setattr so its property.setters work
  1. each getter returns a corresponding paramter from the parallel module instead of the orig module
  2. each setter additionally wraps 'value' in a new replicated DTensor and copies it in the existing DTensor in the parallel module

Verification run:
tbm FSDP_eager:torchtitan-64-whc-p3s1bn autop_initweights_eager:torchtitan-64-whc-qthbz6 autop_initweights_eager_rerun:torchtitan-64-whc-d2bddf
torchtitan-64-whc-qthbz6

@wconstab wconstab requested a review from fmassa July 26, 2025 01:12
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jul 26, 2025
@wconstab wconstab requested a review from ezyang July 26, 2025 01:12
@wconstab wconstab force-pushed the whc/init branch 2 times, most recently from 6141ee7 to 76abb61 Compare July 31, 2025 00:43
@wconstab wconstab changed the title [WIP] Fix init_weights handling for param/buffer assignment Fix init_weights handling for param/buffer assignment Jul 31, 2025
@wconstab
Copy link
Contributor Author

this does not work yet. @ezyang It is working for my test but not a real model run. I think i am probably not updating the right copy of the parameters on 'parallel_mod'. Does parallel_mod have both a flat set of parameters and aliases that are stuck on orig-mod fqns?

@ezyang
Copy link
Contributor

ezyang commented Aug 1, 2025

Here is the relevant code. It doesn't seem to me like what you described could be happening:

        self.parallel_model_fn = parallel_model_fn = aot_compile_joint_with_descriptors(
            self.joint_with_descriptors
        )

        # TODO: this probably belongs in the AOTAutograd API
        # TODO: pytree handling
        class AutoParallelModule(torch.nn.Module):
            def forward(self, *args):
                # NB: don't close over the parameters/buffers, as the user may
                # reassign the module!
                # TODO: It's this to just exactly match
                # prepare_aot_module_simplified, this seems like an API gap
                params = [
                    v.to_local()
                    for k, v in
                    # TODO: this is very slow
                    itertools.chain(
                        dict(self.named_parameters(remove_duplicate=False)).items(),
                        dict(self.named_buffers(remove_duplicate=False)).items(),
                    )
                ]
                boxed_args = [*params, *args]
                del params
                # NB: don't do self.parallel_model_fn work around Dynamo bug
                out = parallel_model_fn(boxed_args)
                return out

        self.parallel_model = AutoParallelModule()

        for k, v in sharded_param_dict.items():
            _assign_attr(v, self.parallel_model, k, attr_kind=_AttrKind.PARAMETER)

        for k, v in sharded_buffer_dict.items():
            _assign_attr(v, self.parallel_model, k, attr_kind=_AttrKind.BUFFER)

        # Right now we require a convention that the user model provides an init_weights method,
        # although we could snoop for other methods too.
        if hasattr(self.model, "init_weights"):

            def init_weights(*args, **kwargs):
                with stateless._reparametrize_module(
                    self.model, {**sharded_param_dict, **sharded_buffer_dict}
                ):
                    self.model.init_weights(*args, **kwargs)

        else:
            init_weights = None

        # assign an init_weights method onto the output mod.
        # all it does is sneakily run the original user mod's init_weights method,
        # but with our new DTensor sharded params attached to the user module.
        self.parallel_model.init_weights = init_weights

orig_value = parallel_model.get_parameter(fqn)
new_value = DTensor.from_local(
value, device_mesh=orig_value.device_mesh
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IDK man, this doesn't look right? The value passed in here is going to be a Replicate, and so if the sharding of the parameter doesn't match up you need to do collectives right??

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is an intentional mismatch in sharding placement here.

new_value: it is created as a replica on purpose, since it is consumed from the value that is assigned in init_weights, which is stipulated to be written in a 'single-gpu' style.

orig_value: it's sharded, or replicated, or whatever. But its copy operator will take in the replicated DTensor and do the necessary redistribution to match the destination.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I didn't realize from_local defaults to replica. In fact the docs don't say lol

@wconstab
Copy link
Contributor Author

wconstab commented Aug 1, 2025

I fixed this. The bug was that I was only creating getter/setter properties for the first parameter in a submodule, becuase I was using this cls key cache i copied from the fsdp getter/setter hooks without thinking about what it was for.

This is why for my local test things worked but for any model that had more than one param per module the latter params did not get initialized.

Having fixed that, I also cleaned up the code a bit and verified local debug llama works. Here is a new mast job to verify correctness:
https://www.internalfb.com/mlhub/pipelines/runs/mast/torchtitan-64-whc-d7zz2wk

namespace[b_name] = build_property(_getter, _setter)
cls = mod.__class__
param_properties_key = "#".join(sorted(namespace.keys()))
new_cls = hooked_classes.get((cls, param_properties_key), None)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was where i went wrong. I actually don't know why the FSDP hook code wants this behavior, but I sure don't.

@wconstab wconstab force-pushed the whc/init branch 2 times, most recently from 0ba10b6 to 7674f89 Compare August 1, 2025 04:38
self.model, {**sharded_param_dict, **sharded_buffer_dict}
):
self.model.init_weights(*args, **kwargs)
def init_weights(_self, *args, **kwargs):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this the cleanest way of doing this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, I mean, if you didn't need access to _self, you don't have to turn it into a method?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can clean this up in a follow-up PR if needed

param = parallel_model.get_parameter(_fqn)
return param

def setter(self, value):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to explicitly comment what you told me in CR which is that value is the replicated local value (since the init code is written with single device in mind).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think i broke this actually, i had a .copy_ in one of my versions but i lost it. I need to add a test for the sharded setter case and fix the handling here. i'll ping for a re-review when i do that part.

new_value = DTensor.from_local(value, device_mesh=orig_value.device_mesh)
if isinstance(orig_value, torch.nn.Parameter):
new_value = torch.nn.Parameter(new_value)
_submod_setattr(parallel_model, fqn, new_value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we warn once when the user does this? Explicit setter is anti-recommended because you'll allocate everything locally before distributing it out (if it's sharded).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, perhaps in just the case where the orig param is sharded we can warn-once, but if the orig param is replicated we can be silent.

def init_weights(self):
self.linear.weight = torch.nn.Parameter(torch.ones(dim, dim) * 9.0)
with torch.no_grad():
self.linear.bias.fill_(98.6)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't really relevant to your PR but it would be good to test random fill works correctly too. (If you directly normal_() IIRC it's not equal to the eager version, so maybe just something like, you didn't actually have all the RNGs produce the same thing?)

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

self.model, {**sharded_param_dict, **sharded_buffer_dict}
):
self.model.init_weights(*args, **kwargs)
def init_weights(_self, *args, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can clean this up in a follow-up PR if needed

init_weights is a method a user module could supply, which initializes
all the parameters and buffers on the module.

Currently, we handle init_weights by reparametrization:
 - we replace the original module's state_dict with a version containing
   parallel_module's states
 - then we run init_weights, mutating these states

But if init_weights does something like `self.buf = _init_buf()` instead
of doing something like `self.buf.copy_(_init_buf())`, we fail to
capture this update.

This PR attempts to find these missing updates and then copy them back
to the parallel_module's states.
1) assuming that if init_weights did an assignment, it would not create
   a DTensor, becuase init_weights and orig module are supposed to be
   written in 'single gpu' style.
2) finding any non-DTensors in the updated state_dict and converting
   them to new Replicate() DTensors, following the semantic that the
   new assigned value should represent the global value for the state
3) copy_ into the original state DTensor on the parallel_module, since
   this handles the case of converting Replicate() to Shard() if
   needed.

TODO:
- verify this fixes the current init correctness problem with
llama
- support params (currently only implemented buffers)
- support nested names (a.b.c), currently only flat names work
- see if there is a better way to detect the assignment (e.g. #1 above)

Make hooked setter work for initializing params/buffers

ghstack-source-id: 1c0670b
Pull Request resolved: #66
@wconstab wconstab merged commit badffa7 into main Aug 6, 2025
6 of 7 checks passed
@wconstab wconstab deleted the whc/init branch August 6, 2025 17:05
wconstab added a commit to pytorch/torchtitan that referenced this pull request Aug 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants