Skip to content

Conversation

@fmassa
Copy link
Contributor

@fmassa fmassa commented Jun 12, 2025

This PR reformulates the optimization that we are performing in AutoParallel to model the single-gpu runtime after sharding. This is in contrast to modelling only the communication cost.

This change brings us closer to what we effectively want: shard the model in a way that minimize the time the user sees.

Some noteworthy things to keep in mind:

  • torch._inductor.utils.get_device_tflops seems to be giving weird TFlops results, so I'm not using it

fmassa added 2 commits June 12, 2025 09:45
This brings us closer to what we effectively what -- minimize the runtime (compute + comms) per GPU, instead of minimizing the comms only
@fmassa fmassa requested review from bdhirsh and wconstab June 12, 2025 10:59
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 12, 2025
}
for ss, ssi in enumerate(s.strategies):
compute_cost = estimate_strategy_runtime_cost(node, ssi)
for argi, xxi in enumerate(ssi.redistribute_cost):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more for my learning, but it would be nice to document the variable names for these loops a bit better.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah, definitely! I had started some cleanup a few weeks ago, but this part is still in the same state as when I initially started this implementation, so the naming is definitely bad

Copy link
Contributor

@wconstab wconstab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stamp to unblock, not a careful review

@fmassa fmassa merged commit 0e9b716 into main Jun 13, 2025
1 check passed
@fmassa fmassa deleted the fmassa/compute_model branch June 13, 2025 09:01
if node.op != "call_function":
return 0
# suppose only matmul-like ops
if not isinstance(node.target, torch._ops.OpOverload):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just thinking of issues that could bite us later - we have had at least one instance of OpOverloadPackets sneaking into the FX graph. Maybe worth an assert not isinstance(node.target, torch._ops.OpOverloadPacket) for sanity, so we get a hard crash if that happens and not a bad optimization decision?


args = tree_map_only(torch.fx.Node, lambda x: x.meta["val"], node.args)
kwargs = tree_map_only(torch.fx.Node, lambda x: x.meta["val"], node.kwargs)
fake_mode = next(arg.fake_mode for arg in args if isinstance(arg, torch._subclasses.fake_tensor.FakeTensor))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

none of these comments are reviewing blocking but just to call out - this won't work for factory functions. We should make sure that we can rely on there being a global FakeTensorMode active inside auto-parallel.

You can use this helper to get that for free (it will automatically work on all pytrees of tensors and check for a global fake mode):

mode = torch._guards.detect_fake_mode(args)

torch.float64: 67,
# NOTE: NVIDIA gives all numbers "with 2:4 sparsity"
# but we want the full GEMM numbers
torch.float32: 989 // 2,
Copy link
Contributor

@bdhirsh bdhirsh Jun 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also thinking out loud, will torch.backends.cuda.matmul.allow_tf32 bite us here? (we probably want to know if the user is planning on enabling tf32 at runtime during our estimation, if there are any float32 matmuls in the model)

the inductor version that you pointed out is buggy looks like it has handling for tf32 (link), but maybe we should add logic to error out if the tf32 flag is different at compile time vs runtime

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, for fp32 / tf32 matmuls this estimate will be off. This is part of the theme of "how good should our estimates be", and I think we should at least improve those further in inductor

return new_tensor_shape


def estimate_strategy_runtime_cost(node, strategy):
Copy link
Contributor

@bdhirsh bdhirsh Jun 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, more of a directional question - do you think that over time the goal would be to make the runtime estimation here more and more fine-grained/accurate? Or do you envision that the goal is more for this estimation code to be simple and "good enough", and that instead that users rely more heavily on taking the output of auto-parallel and customizing it for their needs.

I guess I don't have a very good sense of "how accurate does it need to be 'good enough'?"

Copy link
Contributor

@bdhirsh bdhirsh Jun 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to give an example - Looking at the comms estimation, I was reading through the redistribute cost estimation code, and I see some comments about how it makes a few assumptions:

(1) only considering redistribute costs on the same mesh https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/_collective_utils.py#L329 (I don't actually understand how DeviceMesh is used today well enough to know if "cross-device-mesh" usage is important)

(2) assuming all collectives use ring-based algo https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/_collective_utils.py#L256 (this feels very related to the NCCL estimator here? pytorch/pytorch#149343)

And then there are harder ones like: how important is it to accurately model the runtime of every op in the graph (taking into account inductor's fusion decisions)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also (sorry, just dumping more thoughts here....) - I'm sure you've thought about this and i'm just thinking out loud, but right now auto-parallel knows nothing about what the partitioner will do, which is interesting because on some level both are trying to reduce peak memory.

Do you think this will eventually be important to look into more? One option could be something like: if we are willing to iteratively run both, we could run auto-AC, get back a list of nodes that auto-AC wants to recompute, and then tell the solver to 2x the "cost" of every redistribute that we know we will recompute later (although this won't really be fully accurate since changing what the solver outputs will again change what auto-AC sees, so you might need to run this multiple times?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those are all good questions, and I think it also links to how important it is to inductor to have good runtime / memory estimation.

I think that the runtime estimation being of the same order of magnitude of what we will expect is a must, but if it's off by a factor of 2 I believe it should be fine.

The high-level ask here is that even for reorder_for_compute_comms_overlap, we will need to have reasonable estimates for the runtime and the memory (which inductor currently doesn't have), so I think the effort for improving those estimates will benefit more than just autoparallel.

Just to give an example - Looking at the comms estimation, I was reading through the redistribute cost estimation code, and I see some comments about how it makes a few assumptions

Yes, comms estimation is also not ideal and should be improved (maybe by using NCCL's runtime estimation that was recently added to PyTorch? pytorch/pytorch#149343 )

And then there are harder ones like: how important is it to accurately model the runtime of every op in the graph (taking into account inductor's fusion decisions)

I think for the majority of the "big models", just modelling the compute-heavy ops (that use TensorCores) might be good enough for what we need. But I think there is value in having a good runtime estimate that can be used across inductor, e.g., to assess if the generated kernels deviate too much from the idealized runtime (indicating room for improvement)

auto-parallel knows nothing about what the partitioner will do

Yes, that is true, and when I say that we have FSDP that gets automatically figured out in AutoParallel, this is assuming that the partitioner will decide to recompute the relevant all-gather nodes in the backward.

The work we have been doing in the SimpleFSDP workstream is indeed bringing us there, so they both will combine in a good way together.

Formulating everything jointly would ultimately be best indeed, but I'm not sure how feasible that is and I don't currently have ideas on how to do it.

I think the points you brought are pretty good, I think we should capture those in a document

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that is true, and when I say that we have FSDP that gets automatically figured out in AutoParallel, this is assuming that the partitioner will decide to recompute the relevant all-gather nodes in the backward.

how does this actually work out today? I guess naively, I would imagine that auto-parallel would not assume that anything is recomputed, and so it might end up thinking that FSDP-style sharding and allgathering of weights would not be a useful sharding scheme to return. Are you explicitly telling auto-parallel which collectives are expected to be computed somewhere?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh thinking about this more- IIRC, the optimization problem tries to minimize e2e runtime subject to “sum of param sizes < MAX_SIZE”. So is the answer here that the solver will automatically recover a version of FSDP with no recompute of the allgathers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh thinking about this more- IIRC, the optimization problem tries to minimize e2e runtime subject to “sum of param sizes < MAX_SIZE”. So is the answer here that the solver will automatically recover a version of FSDP with no recompute of the allgathers?

Yes, that's correct. The solver will just act on the joint graph, where there is no notion of recomputation of the all-gather yet. So it will give us the joint that that you would obtain if you were doing things with SimpleFSDP, but you just need to tag somewhere the nodes that need to be recomputed (which we have in SimpleFSDP)

@fmassa
Copy link
Contributor Author

fmassa commented Jun 17, 2025

@bdhirsh I've addressed your comments in #9

wconstab added a commit that referenced this pull request Jul 26, 2025
init_weights is a method a user module could supply, which initializes
all the parameters and buffers on the module.

Currently, we handle init_weights by reparametrization:
 - we replace the original module's state_dict with a version containing
   parallel_module's states
 - then we run init_weights, mutating these states

But if init_weights does something like `self.buf = _init_buf()` instead
of doing something like `self.buf.copy_(_init_buf())`, we fail to
capture this update.

This PR attempts to find these missing updates and then copy them back
to the parallel_module's states.
1) assuming that if init_weights did an assignment, it would not create
   a DTensor, becuase init_weights and orig module are supposed to be
   written in 'single gpu' style.
2) finding any non-DTensors in the updated state_dict and converting
   them to new Replicate() DTensors, following the semantic that the
   new assigned value should represent the global value for the state
3) copy_ into the original state DTensor on the parallel_module, since
   this handles the case of converting Replicate() to Shard() if
   needed.

TODO:
- verify this fixes the current init correctness problem with
llama
- support params (currently only implemented buffers)
- support nested names (a.b.c), currently only flat names work
- see if there is a better way to detect the assignment (e.g. #1 above)
wconstab added a commit that referenced this pull request Jul 26, 2025
init_weights is a method a user module could supply, which initializes
all the parameters and buffers on the module.

Currently, we handle init_weights by reparametrization:
 - we replace the original module's state_dict with a version containing
   parallel_module's states
 - then we run init_weights, mutating these states

But if init_weights does something like `self.buf = _init_buf()` instead
of doing something like `self.buf.copy_(_init_buf())`, we fail to
capture this update.

This PR attempts to find these missing updates and then copy them back
to the parallel_module's states.
1) assuming that if init_weights did an assignment, it would not create
   a DTensor, becuase init_weights and orig module are supposed to be
   written in 'single gpu' style.
2) finding any non-DTensors in the updated state_dict and converting
   them to new Replicate() DTensors, following the semantic that the
   new assigned value should represent the global value for the state
3) copy_ into the original state DTensor on the parallel_module, since
   this handles the case of converting Replicate() to Shard() if
   needed.

TODO:
- verify this fixes the current init correctness problem with
llama
- support params (currently only implemented buffers)
- support nested names (a.b.c), currently only flat names work
- see if there is a better way to detect the assignment (e.g. #1 above)
wconstab added a commit that referenced this pull request Jul 31, 2025
init_weights is a method a user module could supply, which initializes
all the parameters and buffers on the module.

Currently, we handle init_weights by reparametrization:
 - we replace the original module's state_dict with a version containing
   parallel_module's states
 - then we run init_weights, mutating these states

But if init_weights does something like `self.buf = _init_buf()` instead
of doing something like `self.buf.copy_(_init_buf())`, we fail to
capture this update.

This PR attempts to find these missing updates and then copy them back
to the parallel_module's states.
1) assuming that if init_weights did an assignment, it would not create
   a DTensor, becuase init_weights and orig module are supposed to be
   written in 'single gpu' style.
2) finding any non-DTensors in the updated state_dict and converting
   them to new Replicate() DTensors, following the semantic that the
   new assigned value should represent the global value for the state
3) copy_ into the original state DTensor on the parallel_module, since
   this handles the case of converting Replicate() to Shard() if
   needed.

TODO:
- verify this fixes the current init correctness problem with
llama
- support params (currently only implemented buffers)
- support nested names (a.b.c), currently only flat names work
- see if there is a better way to detect the assignment (e.g. #1 above)
wconstab added a commit that referenced this pull request Aug 1, 2025
init_weights is a method a user module could supply, which initializes
all the parameters and buffers on the module.

Currently, we handle init_weights by reparametrization:
 - we replace the original module's state_dict with a version containing
   parallel_module's states
 - then we run init_weights, mutating these states

But if init_weights does something like `self.buf = _init_buf()` instead
of doing something like `self.buf.copy_(_init_buf())`, we fail to
capture this update.

This PR attempts to find these missing updates and then copy them back
to the parallel_module's states.
1) assuming that if init_weights did an assignment, it would not create
   a DTensor, becuase init_weights and orig module are supposed to be
   written in 'single gpu' style.
2) finding any non-DTensors in the updated state_dict and converting
   them to new Replicate() DTensors, following the semantic that the
   new assigned value should represent the global value for the state
3) copy_ into the original state DTensor on the parallel_module, since
   this handles the case of converting Replicate() to Shard() if
   needed.

TODO:
- verify this fixes the current init correctness problem with
llama
- support params (currently only implemented buffers)
- support nested names (a.b.c), currently only flat names work
- see if there is a better way to detect the assignment (e.g. #1 above)

Make hooked setter work for initializing params/buffers

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Aug 1, 2025
init_weights is a method a user module could supply, which initializes
all the parameters and buffers on the module.

Currently, we handle init_weights by reparametrization:
 - we replace the original module's state_dict with a version containing
   parallel_module's states
 - then we run init_weights, mutating these states

But if init_weights does something like `self.buf = _init_buf()` instead
of doing something like `self.buf.copy_(_init_buf())`, we fail to
capture this update.

This PR attempts to find these missing updates and then copy them back
to the parallel_module's states.
1) assuming that if init_weights did an assignment, it would not create
   a DTensor, becuase init_weights and orig module are supposed to be
   written in 'single gpu' style.
2) finding any non-DTensors in the updated state_dict and converting
   them to new Replicate() DTensors, following the semantic that the
   new assigned value should represent the global value for the state
3) copy_ into the original state DTensor on the parallel_module, since
   this handles the case of converting Replicate() to Shard() if
   needed.

TODO:
- verify this fixes the current init correctness problem with
llama
- support params (currently only implemented buffers)
- support nested names (a.b.c), currently only flat names work
- see if there is a better way to detect the assignment (e.g. #1 above)

Make hooked setter work for initializing params/buffers

ghstack-source-id: 1c0670b
Pull Request resolved: #66
wconstab added a commit that referenced this pull request Aug 1, 2025
init_weights is a method a user module could supply, which initializes
all the parameters and buffers on the module.

Currently, we handle init_weights by reparametrization:
 - we replace the original module's state_dict with a version containing
   parallel_module's states
 - then we run init_weights, mutating these states

But if init_weights does something like `self.buf = _init_buf()` instead
of doing something like `self.buf.copy_(_init_buf())`, we fail to
capture this update.

This PR attempts to find these missing updates and then copy them back
to the parallel_module's states.
1) assuming that if init_weights did an assignment, it would not create
   a DTensor, becuase init_weights and orig module are supposed to be
   written in 'single gpu' style.
2) finding any non-DTensors in the updated state_dict and converting
   them to new Replicate() DTensors, following the semantic that the
   new assigned value should represent the global value for the state
3) copy_ into the original state DTensor on the parallel_module, since
   this handles the case of converting Replicate() to Shard() if
   needed.

TODO:
- verify this fixes the current init correctness problem with
llama
- support params (currently only implemented buffers)
- support nested names (a.b.c), currently only flat names work
- see if there is a better way to detect the assignment (e.g. #1 above)

Make hooked setter work for initializing params/buffers

ghstack-source-id: 1c0670b
Pull Request resolved: #66
wconstab added a commit that referenced this pull request Aug 1, 2025
init_weights is a method a user module could supply, which initializes
all the parameters and buffers on the module.

Currently, we handle init_weights by reparametrization:
 - we replace the original module's state_dict with a version containing
   parallel_module's states
 - then we run init_weights, mutating these states

But if init_weights does something like `self.buf = _init_buf()` instead
of doing something like `self.buf.copy_(_init_buf())`, we fail to
capture this update.

This PR attempts to find these missing updates and then copy them back
to the parallel_module's states.
1) assuming that if init_weights did an assignment, it would not create
   a DTensor, becuase init_weights and orig module are supposed to be
   written in 'single gpu' style.
2) finding any non-DTensors in the updated state_dict and converting
   them to new Replicate() DTensors, following the semantic that the
   new assigned value should represent the global value for the state
3) copy_ into the original state DTensor on the parallel_module, since
   this handles the case of converting Replicate() to Shard() if
   needed.

TODO:
- verify this fixes the current init correctness problem with
llama
- support params (currently only implemented buffers)
- support nested names (a.b.c), currently only flat names work
- see if there is a better way to detect the assignment (e.g. #1 above)

Make hooked setter work for initializing params/buffers

ghstack-source-id: 1c0670b
Pull Request resolved: #66
wconstab added a commit that referenced this pull request Aug 1, 2025
init_weights is a method a user module could supply, which initializes
all the parameters and buffers on the module.

Currently, we handle init_weights by reparametrization:
 - we replace the original module's state_dict with a version containing
   parallel_module's states
 - then we run init_weights, mutating these states

But if init_weights does something like `self.buf = _init_buf()` instead
of doing something like `self.buf.copy_(_init_buf())`, we fail to
capture this update.

This PR attempts to find these missing updates and then copy them back
to the parallel_module's states.
1) assuming that if init_weights did an assignment, it would not create
   a DTensor, becuase init_weights and orig module are supposed to be
   written in 'single gpu' style.
2) finding any non-DTensors in the updated state_dict and converting
   them to new Replicate() DTensors, following the semantic that the
   new assigned value should represent the global value for the state
3) copy_ into the original state DTensor on the parallel_module, since
   this handles the case of converting Replicate() to Shard() if
   needed.

TODO:
- verify this fixes the current init correctness problem with
llama
- support params (currently only implemented buffers)
- support nested names (a.b.c), currently only flat names work
- see if there is a better way to detect the assignment (e.g. #1 above)

Make hooked setter work for initializing params/buffers

ghstack-source-id: 1c0670b
Pull Request resolved: #66
wconstab added a commit that referenced this pull request Aug 3, 2025
init_weights is a method a user module could supply, which initializes
all the parameters and buffers on the module.

Currently, we handle init_weights by reparametrization:
 - we replace the original module's state_dict with a version containing
   parallel_module's states
 - then we run init_weights, mutating these states

But if init_weights does something like `self.buf = _init_buf()` instead
of doing something like `self.buf.copy_(_init_buf())`, we fail to
capture this update.

This PR attempts to find these missing updates and then copy them back
to the parallel_module's states.
1) assuming that if init_weights did an assignment, it would not create
   a DTensor, becuase init_weights and orig module are supposed to be
   written in 'single gpu' style.
2) finding any non-DTensors in the updated state_dict and converting
   them to new Replicate() DTensors, following the semantic that the
   new assigned value should represent the global value for the state
3) copy_ into the original state DTensor on the parallel_module, since
   this handles the case of converting Replicate() to Shard() if
   needed.

TODO:
- verify this fixes the current init correctness problem with
llama
- support params (currently only implemented buffers)
- support nested names (a.b.c), currently only flat names work
- see if there is a better way to detect the assignment (e.g. #1 above)

Make hooked setter work for initializing params/buffers

ghstack-source-id: 1c0670b
Pull Request resolved: #66
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants