Skip to content

Conversation

@wconstab
Copy link
Contributor

@wconstab wconstab commented Aug 7, 2025

TODO:

  • test this
  • should we do something upstream to provide a nicer API for handling all of these updates in one shot?

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 7, 2025
Comment on lines +225 to +236
local_args = []
for placeholder, arg in zip(gm.graph.nodes, args):
assert placeholder.meta["val"].shape == arg.shape
local_arg = arg.to_local()
placeholder.meta["val"] = local_arg
# requires_grad is missing from val and local_arg, take it
# from original tensor_meta
requires_grad = placeholder.meta["tensor_meta"].requires_grad
placeholder.meta["tensor_meta"] = _extract_tensor_metadata(
local_arg.clone().requires_grad_(requires_grad)
)
local_args.append(local_arg)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is wrong, and we would need to implement this after we construct the parameters, otherwise we end-up sharding the model twice

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are correct. I confirmed that after running make_fx on sharding interpreter, the new parallel_gm0 has its metas updated. I am not sure why i (thought I) had to do this update in hack_aot.

@wconstab
Copy link
Contributor Author

wconstab commented Aug 7, 2025

abandoned because I had to move these changes into the compile PR to land together
#77

@wconstab wconstab closed this Aug 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants