-
Notifications
You must be signed in to change notification settings - Fork 9
Fix init_weights handling for param/buffer assignment #52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
6141ee7 to
76abb61
Compare
|
this does not work yet. @ezyang It is working for my test but not a real model run. I think i am probably not updating the right copy of the parameters on 'parallel_mod'. Does parallel_mod have both a flat set of parameters and aliases that are stuck on orig-mod fqns? |
|
Here is the relevant code. It doesn't seem to me like what you described could be happening: |
autoparallel/init_weights.py
Outdated
| orig_value = parallel_model.get_parameter(fqn) | ||
| new_value = DTensor.from_local( | ||
| value, device_mesh=orig_value.device_mesh | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IDK man, this doesn't look right? The value passed in here is going to be a Replicate, and so if the sharding of the parameter doesn't match up you need to do collectives right??
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is an intentional mismatch in sharding placement here.
new_value: it is created as a replica on purpose, since it is consumed from the value that is assigned in init_weights, which is stipulated to be written in a 'single-gpu' style.
orig_value: it's sharded, or replicated, or whatever. But its copy operator will take in the replicated DTensor and do the necessary redistribution to match the destination.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I didn't realize from_local defaults to replica. In fact the docs don't say lol
|
I fixed this. The bug was that I was only creating getter/setter properties for the first parameter in a submodule, becuase I was using this cls key cache i copied from the fsdp getter/setter hooks without thinking about what it was for. This is why for my local test things worked but for any model that had more than one param per module the latter params did not get initialized. Having fixed that, I also cleaned up the code a bit and verified local debug llama works. Here is a new mast job to verify correctness: |
autoparallel/init_weights.py
Outdated
| namespace[b_name] = build_property(_getter, _setter) | ||
| cls = mod.__class__ | ||
| param_properties_key = "#".join(sorted(namespace.keys())) | ||
| new_cls = hooked_classes.get((cls, param_properties_key), None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this was where i went wrong. I actually don't know why the FSDP hook code wants this behavior, but I sure don't.
0ba10b6 to
7674f89
Compare
| self.model, {**sharded_param_dict, **sharded_buffer_dict} | ||
| ): | ||
| self.model.init_weights(*args, **kwargs) | ||
| def init_weights(_self, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this the cleanest way of doing this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, I mean, if you didn't need access to _self, you don't have to turn it into a method?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can clean this up in a follow-up PR if needed
autoparallel/init_weights.py
Outdated
| param = parallel_model.get_parameter(_fqn) | ||
| return param | ||
|
|
||
| def setter(self, value): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be good to explicitly comment what you told me in CR which is that value is the replicated local value (since the init code is written with single device in mind).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think i broke this actually, i had a .copy_ in one of my versions but i lost it. I need to add a test for the sharded setter case and fix the handling here. i'll ping for a re-review when i do that part.
autoparallel/init_weights.py
Outdated
| new_value = DTensor.from_local(value, device_mesh=orig_value.device_mesh) | ||
| if isinstance(orig_value, torch.nn.Parameter): | ||
| new_value = torch.nn.Parameter(new_value) | ||
| _submod_setattr(parallel_model, fqn, new_value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we warn once when the user does this? Explicit setter is anti-recommended because you'll allocate everything locally before distributing it out (if it's sharded).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea, perhaps in just the case where the orig param is sharded we can warn-once, but if the orig param is replicated we can be silent.
| def init_weights(self): | ||
| self.linear.weight = torch.nn.Parameter(torch.ones(dim, dim) * 9.0) | ||
| with torch.no_grad(): | ||
| self.linear.bias.fill_(98.6) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't really relevant to your PR but it would be good to test random fill works correctly too. (If you directly normal_() IIRC it's not equal to the eager version, so maybe just something like, you didn't actually have all the RNGs produce the same thing?)
ezyang
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
fmassa
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
| self.model, {**sharded_param_dict, **sharded_buffer_dict} | ||
| ): | ||
| self.model.init_weights(*args, **kwargs) | ||
| def init_weights(_self, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can clean this up in a follow-up PR if needed
init_weights is a method a user module could supply, which initializes all the parameters and buffers on the module. Currently, we handle init_weights by reparametrization: - we replace the original module's state_dict with a version containing parallel_module's states - then we run init_weights, mutating these states But if init_weights does something like `self.buf = _init_buf()` instead of doing something like `self.buf.copy_(_init_buf())`, we fail to capture this update. This PR attempts to find these missing updates and then copy them back to the parallel_module's states. 1) assuming that if init_weights did an assignment, it would not create a DTensor, becuase init_weights and orig module are supposed to be written in 'single gpu' style. 2) finding any non-DTensors in the updated state_dict and converting them to new Replicate() DTensors, following the semantic that the new assigned value should represent the global value for the state 3) copy_ into the original state DTensor on the parallel_module, since this handles the case of converting Replicate() to Shard() if needed. TODO: - verify this fixes the current init correctness problem with llama - support params (currently only implemented buffers) - support nested names (a.b.c), currently only flat names work - see if there is a better way to detect the assignment (e.g. #1 above) Make hooked setter work for initializing params/buffers ghstack-source-id: 1c0670b Pull Request resolved: #66
init_weights is a method a user module could supply, which initializes all the parameters and buffers on the module.
Currently, we handle init_weights by reparametrization:
But if init_weights does something like
self.buf = _init_buf()instead of doing something likeself.buf.copy_(_init_buf()), we fail to capture this update.This PR interposes the init_weights function differently.
Verification run:
tbm FSDP_eager:torchtitan-64-whc-p3s1bn autop_initweights_eager:torchtitan-64-whc-qthbz6 autop_initweights_eager_rerun:torchtitan-64-whc-d2bddftorchtitan-64-whc-qthbz6