-
Notifications
You must be signed in to change notification settings - Fork 10
Add compute cost in optimization problem #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This brings us closer to what we effectively what -- minimize the runtime (compute + comms) per GPU, instead of minimizing the comms only
| } | ||
| for ss, ssi in enumerate(s.strategies): | ||
| compute_cost = estimate_strategy_runtime_cost(node, ssi) | ||
| for argi, xxi in enumerate(ssi.redistribute_cost): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
more for my learning, but it would be nice to document the variable names for these loops a bit better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yeah, definitely! I had started some cleanup a few weeks ago, but this part is still in the same state as when I initially started this implementation, so the naming is definitely bad
wconstab
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stamp to unblock, not a careful review
| if node.op != "call_function": | ||
| return 0 | ||
| # suppose only matmul-like ops | ||
| if not isinstance(node.target, torch._ops.OpOverload): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just thinking of issues that could bite us later - we have had at least one instance of OpOverloadPackets sneaking into the FX graph. Maybe worth an assert not isinstance(node.target, torch._ops.OpOverloadPacket) for sanity, so we get a hard crash if that happens and not a bad optimization decision?
|
|
||
| args = tree_map_only(torch.fx.Node, lambda x: x.meta["val"], node.args) | ||
| kwargs = tree_map_only(torch.fx.Node, lambda x: x.meta["val"], node.kwargs) | ||
| fake_mode = next(arg.fake_mode for arg in args if isinstance(arg, torch._subclasses.fake_tensor.FakeTensor)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
none of these comments are reviewing blocking but just to call out - this won't work for factory functions. We should make sure that we can rely on there being a global FakeTensorMode active inside auto-parallel.
You can use this helper to get that for free (it will automatically work on all pytrees of tensors and check for a global fake mode):
mode = torch._guards.detect_fake_mode(args)
| torch.float64: 67, | ||
| # NOTE: NVIDIA gives all numbers "with 2:4 sparsity" | ||
| # but we want the full GEMM numbers | ||
| torch.float32: 989 // 2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also thinking out loud, will torch.backends.cuda.matmul.allow_tf32 bite us here? (we probably want to know if the user is planning on enabling tf32 at runtime during our estimation, if there are any float32 matmuls in the model)
the inductor version that you pointed out is buggy looks like it has handling for tf32 (link), but maybe we should add logic to error out if the tf32 flag is different at compile time vs runtime
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, for fp32 / tf32 matmuls this estimate will be off. This is part of the theme of "how good should our estimates be", and I think we should at least improve those further in inductor
| return new_tensor_shape | ||
|
|
||
|
|
||
| def estimate_strategy_runtime_cost(node, strategy): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also, more of a directional question - do you think that over time the goal would be to make the runtime estimation here more and more fine-grained/accurate? Or do you envision that the goal is more for this estimation code to be simple and "good enough", and that instead that users rely more heavily on taking the output of auto-parallel and customizing it for their needs.
I guess I don't have a very good sense of "how accurate does it need to be 'good enough'?"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to give an example - Looking at the comms estimation, I was reading through the redistribute cost estimation code, and I see some comments about how it makes a few assumptions:
(1) only considering redistribute costs on the same mesh https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/_collective_utils.py#L329 (I don't actually understand how DeviceMesh is used today well enough to know if "cross-device-mesh" usage is important)
(2) assuming all collectives use ring-based algo https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/_collective_utils.py#L256 (this feels very related to the NCCL estimator here? pytorch/pytorch#149343)
And then there are harder ones like: how important is it to accurately model the runtime of every op in the graph (taking into account inductor's fusion decisions)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also (sorry, just dumping more thoughts here....) - I'm sure you've thought about this and i'm just thinking out loud, but right now auto-parallel knows nothing about what the partitioner will do, which is interesting because on some level both are trying to reduce peak memory.
Do you think this will eventually be important to look into more? One option could be something like: if we are willing to iteratively run both, we could run auto-AC, get back a list of nodes that auto-AC wants to recompute, and then tell the solver to 2x the "cost" of every redistribute that we know we will recompute later (although this won't really be fully accurate since changing what the solver outputs will again change what auto-AC sees, so you might need to run this multiple times?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Those are all good questions, and I think it also links to how important it is to inductor to have good runtime / memory estimation.
I think that the runtime estimation being of the same order of magnitude of what we will expect is a must, but if it's off by a factor of 2 I believe it should be fine.
The high-level ask here is that even for reorder_for_compute_comms_overlap, we will need to have reasonable estimates for the runtime and the memory (which inductor currently doesn't have), so I think the effort for improving those estimates will benefit more than just autoparallel.
Just to give an example - Looking at the comms estimation, I was reading through the redistribute cost estimation code, and I see some comments about how it makes a few assumptions
Yes, comms estimation is also not ideal and should be improved (maybe by using NCCL's runtime estimation that was recently added to PyTorch? pytorch/pytorch#149343 )
And then there are harder ones like: how important is it to accurately model the runtime of every op in the graph (taking into account inductor's fusion decisions)
I think for the majority of the "big models", just modelling the compute-heavy ops (that use TensorCores) might be good enough for what we need. But I think there is value in having a good runtime estimate that can be used across inductor, e.g., to assess if the generated kernels deviate too much from the idealized runtime (indicating room for improvement)
auto-parallel knows nothing about what the partitioner will do
Yes, that is true, and when I say that we have FSDP that gets automatically figured out in AutoParallel, this is assuming that the partitioner will decide to recompute the relevant all-gather nodes in the backward.
The work we have been doing in the SimpleFSDP workstream is indeed bringing us there, so they both will combine in a good way together.
Formulating everything jointly would ultimately be best indeed, but I'm not sure how feasible that is and I don't currently have ideas on how to do it.
I think the points you brought are pretty good, I think we should capture those in a document
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that is true, and when I say that we have FSDP that gets automatically figured out in AutoParallel, this is assuming that the partitioner will decide to recompute the relevant all-gather nodes in the backward.
how does this actually work out today? I guess naively, I would imagine that auto-parallel would not assume that anything is recomputed, and so it might end up thinking that FSDP-style sharding and allgathering of weights would not be a useful sharding scheme to return. Are you explicitly telling auto-parallel which collectives are expected to be computed somewhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh thinking about this more- IIRC, the optimization problem tries to minimize e2e runtime subject to “sum of param sizes < MAX_SIZE”. So is the answer here that the solver will automatically recover a version of FSDP with no recompute of the allgathers?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh thinking about this more- IIRC, the optimization problem tries to minimize e2e runtime subject to “sum of param sizes < MAX_SIZE”. So is the answer here that the solver will automatically recover a version of FSDP with no recompute of the allgathers?
Yes, that's correct. The solver will just act on the joint graph, where there is no notion of recomputation of the all-gather yet. So it will give us the joint that that you would obtain if you were doing things with SimpleFSDP, but you just need to tag somewhere the nodes that need to be recomputed (which we have in SimpleFSDP)
init_weights is a method a user module could supply, which initializes all the parameters and buffers on the module. Currently, we handle init_weights by reparametrization: - we replace the original module's state_dict with a version containing parallel_module's states - then we run init_weights, mutating these states But if init_weights does something like `self.buf = _init_buf()` instead of doing something like `self.buf.copy_(_init_buf())`, we fail to capture this update. This PR attempts to find these missing updates and then copy them back to the parallel_module's states. 1) assuming that if init_weights did an assignment, it would not create a DTensor, becuase init_weights and orig module are supposed to be written in 'single gpu' style. 2) finding any non-DTensors in the updated state_dict and converting them to new Replicate() DTensors, following the semantic that the new assigned value should represent the global value for the state 3) copy_ into the original state DTensor on the parallel_module, since this handles the case of converting Replicate() to Shard() if needed. TODO: - verify this fixes the current init correctness problem with llama - support params (currently only implemented buffers) - support nested names (a.b.c), currently only flat names work - see if there is a better way to detect the assignment (e.g. #1 above)
init_weights is a method a user module could supply, which initializes all the parameters and buffers on the module. Currently, we handle init_weights by reparametrization: - we replace the original module's state_dict with a version containing parallel_module's states - then we run init_weights, mutating these states But if init_weights does something like `self.buf = _init_buf()` instead of doing something like `self.buf.copy_(_init_buf())`, we fail to capture this update. This PR attempts to find these missing updates and then copy them back to the parallel_module's states. 1) assuming that if init_weights did an assignment, it would not create a DTensor, becuase init_weights and orig module are supposed to be written in 'single gpu' style. 2) finding any non-DTensors in the updated state_dict and converting them to new Replicate() DTensors, following the semantic that the new assigned value should represent the global value for the state 3) copy_ into the original state DTensor on the parallel_module, since this handles the case of converting Replicate() to Shard() if needed. TODO: - verify this fixes the current init correctness problem with llama - support params (currently only implemented buffers) - support nested names (a.b.c), currently only flat names work - see if there is a better way to detect the assignment (e.g. #1 above)
init_weights is a method a user module could supply, which initializes all the parameters and buffers on the module. Currently, we handle init_weights by reparametrization: - we replace the original module's state_dict with a version containing parallel_module's states - then we run init_weights, mutating these states But if init_weights does something like `self.buf = _init_buf()` instead of doing something like `self.buf.copy_(_init_buf())`, we fail to capture this update. This PR attempts to find these missing updates and then copy them back to the parallel_module's states. 1) assuming that if init_weights did an assignment, it would not create a DTensor, becuase init_weights and orig module are supposed to be written in 'single gpu' style. 2) finding any non-DTensors in the updated state_dict and converting them to new Replicate() DTensors, following the semantic that the new assigned value should represent the global value for the state 3) copy_ into the original state DTensor on the parallel_module, since this handles the case of converting Replicate() to Shard() if needed. TODO: - verify this fixes the current init correctness problem with llama - support params (currently only implemented buffers) - support nested names (a.b.c), currently only flat names work - see if there is a better way to detect the assignment (e.g. #1 above)
init_weights is a method a user module could supply, which initializes all the parameters and buffers on the module. Currently, we handle init_weights by reparametrization: - we replace the original module's state_dict with a version containing parallel_module's states - then we run init_weights, mutating these states But if init_weights does something like `self.buf = _init_buf()` instead of doing something like `self.buf.copy_(_init_buf())`, we fail to capture this update. This PR attempts to find these missing updates and then copy them back to the parallel_module's states. 1) assuming that if init_weights did an assignment, it would not create a DTensor, becuase init_weights and orig module are supposed to be written in 'single gpu' style. 2) finding any non-DTensors in the updated state_dict and converting them to new Replicate() DTensors, following the semantic that the new assigned value should represent the global value for the state 3) copy_ into the original state DTensor on the parallel_module, since this handles the case of converting Replicate() to Shard() if needed. TODO: - verify this fixes the current init correctness problem with llama - support params (currently only implemented buffers) - support nested names (a.b.c), currently only flat names work - see if there is a better way to detect the assignment (e.g. #1 above) Make hooked setter work for initializing params/buffers [ghstack-poisoned]
init_weights is a method a user module could supply, which initializes all the parameters and buffers on the module. Currently, we handle init_weights by reparametrization: - we replace the original module's state_dict with a version containing parallel_module's states - then we run init_weights, mutating these states But if init_weights does something like `self.buf = _init_buf()` instead of doing something like `self.buf.copy_(_init_buf())`, we fail to capture this update. This PR attempts to find these missing updates and then copy them back to the parallel_module's states. 1) assuming that if init_weights did an assignment, it would not create a DTensor, becuase init_weights and orig module are supposed to be written in 'single gpu' style. 2) finding any non-DTensors in the updated state_dict and converting them to new Replicate() DTensors, following the semantic that the new assigned value should represent the global value for the state 3) copy_ into the original state DTensor on the parallel_module, since this handles the case of converting Replicate() to Shard() if needed. TODO: - verify this fixes the current init correctness problem with llama - support params (currently only implemented buffers) - support nested names (a.b.c), currently only flat names work - see if there is a better way to detect the assignment (e.g. #1 above) Make hooked setter work for initializing params/buffers ghstack-source-id: 1c0670b Pull Request resolved: #66
init_weights is a method a user module could supply, which initializes all the parameters and buffers on the module. Currently, we handle init_weights by reparametrization: - we replace the original module's state_dict with a version containing parallel_module's states - then we run init_weights, mutating these states But if init_weights does something like `self.buf = _init_buf()` instead of doing something like `self.buf.copy_(_init_buf())`, we fail to capture this update. This PR attempts to find these missing updates and then copy them back to the parallel_module's states. 1) assuming that if init_weights did an assignment, it would not create a DTensor, becuase init_weights and orig module are supposed to be written in 'single gpu' style. 2) finding any non-DTensors in the updated state_dict and converting them to new Replicate() DTensors, following the semantic that the new assigned value should represent the global value for the state 3) copy_ into the original state DTensor on the parallel_module, since this handles the case of converting Replicate() to Shard() if needed. TODO: - verify this fixes the current init correctness problem with llama - support params (currently only implemented buffers) - support nested names (a.b.c), currently only flat names work - see if there is a better way to detect the assignment (e.g. #1 above) Make hooked setter work for initializing params/buffers ghstack-source-id: 1c0670b Pull Request resolved: #66
init_weights is a method a user module could supply, which initializes all the parameters and buffers on the module. Currently, we handle init_weights by reparametrization: - we replace the original module's state_dict with a version containing parallel_module's states - then we run init_weights, mutating these states But if init_weights does something like `self.buf = _init_buf()` instead of doing something like `self.buf.copy_(_init_buf())`, we fail to capture this update. This PR attempts to find these missing updates and then copy them back to the parallel_module's states. 1) assuming that if init_weights did an assignment, it would not create a DTensor, becuase init_weights and orig module are supposed to be written in 'single gpu' style. 2) finding any non-DTensors in the updated state_dict and converting them to new Replicate() DTensors, following the semantic that the new assigned value should represent the global value for the state 3) copy_ into the original state DTensor on the parallel_module, since this handles the case of converting Replicate() to Shard() if needed. TODO: - verify this fixes the current init correctness problem with llama - support params (currently only implemented buffers) - support nested names (a.b.c), currently only flat names work - see if there is a better way to detect the assignment (e.g. #1 above) Make hooked setter work for initializing params/buffers ghstack-source-id: 1c0670b Pull Request resolved: #66
init_weights is a method a user module could supply, which initializes all the parameters and buffers on the module. Currently, we handle init_weights by reparametrization: - we replace the original module's state_dict with a version containing parallel_module's states - then we run init_weights, mutating these states But if init_weights does something like `self.buf = _init_buf()` instead of doing something like `self.buf.copy_(_init_buf())`, we fail to capture this update. This PR attempts to find these missing updates and then copy them back to the parallel_module's states. 1) assuming that if init_weights did an assignment, it would not create a DTensor, becuase init_weights and orig module are supposed to be written in 'single gpu' style. 2) finding any non-DTensors in the updated state_dict and converting them to new Replicate() DTensors, following the semantic that the new assigned value should represent the global value for the state 3) copy_ into the original state DTensor on the parallel_module, since this handles the case of converting Replicate() to Shard() if needed. TODO: - verify this fixes the current init correctness problem with llama - support params (currently only implemented buffers) - support nested names (a.b.c), currently only flat names work - see if there is a better way to detect the assignment (e.g. #1 above) Make hooked setter work for initializing params/buffers ghstack-source-id: 1c0670b Pull Request resolved: #66
init_weights is a method a user module could supply, which initializes all the parameters and buffers on the module. Currently, we handle init_weights by reparametrization: - we replace the original module's state_dict with a version containing parallel_module's states - then we run init_weights, mutating these states But if init_weights does something like `self.buf = _init_buf()` instead of doing something like `self.buf.copy_(_init_buf())`, we fail to capture this update. This PR attempts to find these missing updates and then copy them back to the parallel_module's states. 1) assuming that if init_weights did an assignment, it would not create a DTensor, becuase init_weights and orig module are supposed to be written in 'single gpu' style. 2) finding any non-DTensors in the updated state_dict and converting them to new Replicate() DTensors, following the semantic that the new assigned value should represent the global value for the state 3) copy_ into the original state DTensor on the parallel_module, since this handles the case of converting Replicate() to Shard() if needed. TODO: - verify this fixes the current init correctness problem with llama - support params (currently only implemented buffers) - support nested names (a.b.c), currently only flat names work - see if there is a better way to detect the assignment (e.g. #1 above) Make hooked setter work for initializing params/buffers ghstack-source-id: 1c0670b Pull Request resolved: #66
This PR reformulates the optimization that we are performing in AutoParallel to model the single-gpu runtime after sharding. This is in contrast to modelling only the communication cost.
This change brings us closer to what we effectively want: shard the model in a way that minimize the time the user sees.
Some noteworthy things to keep in mind:
torch._inductor.utils.get_device_tflopsseems to be giving weird TFlops results, so I'm not using it