Skip to content

Conversation

@wconstab
Copy link
Contributor

@wconstab wconstab commented Aug 7, 2025

Work in progress: copied the latest code over from whc/hack_aot and tweaked the way it gets hooked up a bit, haven't tested yet. Likely need to discuss whether we want the AC pass to be popped back off inductor's passes earlier or keep it at exit from AutoParallel.

This is now rebased on whc/compile (#77) - lets land that first and then focus on this.

I kicked off a job with both compile+AC enabled to verify memory looks good e2e:
https://www.internalfb.com/mlhub/pipelines/runs/mast/torchtitan-64-whc-k3rn2fc

tbm FSDP_eager:torchtitan-64-whc-p3s1bn compile_pr:torchtitan-64-whc-kf1llhnr compile_noac_from_update_post:torchtitan-64-fmassa-r4rvfnf6 ac_pr:torchtitan-64-whc-k3rn2fc compile_ac_from_update_post:torchtitan-64-fmassa-rg4nxx

hmm- seems like we are still missing some secret sauce from the runs from the post:
this PR: 27GB
the equivalent config from the update post: 21GB
image

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 7, 2025
Copy link
Contributor

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

I have some questions / comments, I think it might be good to discuss how we want to expose this (I'd prefer to keep the constructor of AutoParallel reasonably small, and I was already not very happy passing mp_policy inside)

input_fn,
mesh: DeviceMesh,
mp_policy: Optional[MixedPrecisionPolicy] = None,
enable_ac: bool = True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I wouldn't pass this as a boolean flag in the constructor. We might need to add other AC tags as well (like save some matmuls, which is what I did for Llama3 8B). So I'd maybe keep this outside of the constructor somehow

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so, i thought a bit about this, and I don't have a strong opinion about where the knob goes, but I think we need the knob, at least for debugging purposes. I'd also prefer it not to be something hacky like an ENV or a global state, because these can be harder to control from submitting mast CLI jobs. It's pretty nice to be able to wire up all the things we care about to torchtitan CLI.

would you feel better about making a method configure_ac(enable:bool=True, size:optional[float]=None, ...) that users can call on the autop object?

Comment on lines +46 to +57
# TODO: this needs to be improved, its firing in autoparallel "2D" case where input to AG is wait,
# maybe just 2D FSDP
# ag_node = node.args[0]
# assert is_graph_input(ag_node.args[0]) or (
# ag_node.args[0].op == "call_function"
# and ag_node.args[0].target == torch.ops.prims.convert_element_type.default
# and is_graph_input(ag_node.args[0].args[0])
# ), (
# "Assume all_gather_into_tensor input is either graph input "
# + f"or dtype conversion of graph input, but got {ag_node.args[0]}"
# )
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably will want to check if we want to fix this, or just always recompute all-gathers in general?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, would it ever be profitable to recompute an allgather for e.g. context parallelism? And would it ever be better to not recompute an allgather for fsdp?

If we want to make this all about FSDP for now, i think we could add some tagging to the nodes when we do the my_redistribute_tensor for param nodes and snapshot the fact that it's for FSDP.

@fmassa
Copy link
Contributor

fmassa commented Aug 7, 2025

BTW, we don't need to register a joint pass for this -- we can just call this function once we added the collectives in the joint graph. This should make some things safer wrt properly tagging at the right moment

Work in progress: copied the latest code over from `whc/hack_aot` and
tweaked the way it gets hooked up a bit, haven't tested yet. Likely need
to discuss whether we want the AC pass to be popped back off inductor's
passes earlier or keep it at __exit__ from AutoParallel.
based by periodically saving itermediate nodes, keeping peak memory usage below.
Args:
joint_graph: The joint graph containing both forward and backward nodes
stage_size_in_GiB: Target memory size per stage in GiB (-1 to disable staging)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI this is missing the option to disable the AutoAC. Before -1 would early return, now it just errors?

Comment on lines +285 to +301
save_list = {
torch.ops.aten.mm.default,
# torch.ops.aten._scaled_dot_product_efficient_attention.default,
# torch.ops.aten._scaled_dot_product_flash_attention.default,
}
must_save_nodes = []
counter = 0
for node in fwd_recomputable_nodes:
if node.target in save_list:
if node.target == torch.ops.aten.mm.default:
if counter % 2 == 0:
counter += 1
else:
counter += 1
continue
must_save_nodes.append(node)
mark_nodes_as_must_save(must_save_nodes)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI I think this needs to be moved out of this function, as it was just a specific tagging for Llama and we need to add user control of it

@fmassa
Copy link
Contributor

fmassa commented Aug 8, 2025

The difference in memory is indeed just the bucket sizes for all-gather / reduce-scatter being off. I've prepared a PR fixing it in pytorch/torchtitan#1545

Here is a run showing the behavior after fixing those bucket sizes:

tbm compile_pr:torchtitan-64-whc-kf1llhnr ac_pr:torchtitan-64-whc-k3rn2fc compile_ac_from_update_post:torchtitan-64-fmassa-rg4nxx compile_ac_bucket_update:torchtitan-64-fmassa-pkfg4q

Copy link
Contributor

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, let's merge this PR now and I'll send a follow-up PR with some minor cleanups.

Run which gives comparable results (after the torchtitan PR) is in torchtitan-64-fmassa-pkfg4q

@fmassa fmassa merged commit c65facd into main Aug 8, 2025
6 checks passed
@fmassa fmassa deleted the whc/act branch August 8, 2025 09:18
if is_wait_tensor_from_fsdp(node):
ag_node = node.args[0]
force_recompute_node(ag_node) # all_gather
force_recompute_node(node) # wait_tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm if I'm following, right now the logic here kind of needs to mirror the logic in is_wait_tensor_from_fsdp - the first function attempts to looks for allgathers that are specifically "FSDP" allgathers (the allgather is done directly on a param, and or the result of a downcast if using AMP), and if that condition is true we try to mark the nodes as must recompute.

It might be cleaner just to "force recompute every node in the chain between the current wait tensor and its param" here, that way we don't have to worry about this code and the code in is_wait_tensor_from_fsdp going out of sync. WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think it would be a good approach to improve this part. Ideally though we would need to have some heuristic like all the downstream users of the parameter up until the wait only have a single input node, otherwise we might end-up mixing two-operand operations which are not from FSDP-like ops.

mesh: DeviceMesh,
mp_policy: Optional[MixedPrecisionPolicy] = None,
compile: bool = False,
enable_ac: bool = True,
Copy link
Contributor

@bdhirsh bdhirsh Aug 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm a bit confused about the API design. How are we reconciling these 3 things:

(1) the user might want to use compiler-driven AC. I think that this is.. effectively what this PR tries to support through the enable_ac flag through the custom graph pass, is that right? What is the relationship between the above graph pass and the autoAC code that lives directly in the partitioner - do we plan to reconcile the two in the future?

(2) the user might be manually applying AC in their single-gpu model code (my understanding is that we actually do support this, independently from this PR)

(3) we generally want to ensure that AGs for FSDP are always recomputed. This PR also does this. One thing I'm not clear about, though: don't we generally always want to recompute FSDP allgathers, and if so, why make this part of the flag at all?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we need to document it a bit better.

The way it works (iirc) is

  • fsdp allgathers are always recomputed, flag doesn't change this
  • "auto": implement the 'stage-size heuristic' to determine stage size for the autoac in this PR
  • Int: use autoac in this PR, but don't compute the stage size automatically, use this size
  • None: passthrough user annotations (useful if they passed in ac hints)

note: this commeent should apply to what's on main, not whats in this PR. The delta being None vs Auto, which I merged together in this PR but @fmassa later unmerged

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants