Skip to content

some compile-related improvements #443

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,17 @@ def build_test_list():
"1D compile",
"1d_compile",
),
OverrideDefinitions(
[
[
"--training.compile",
"--activation_checkpoint.mode selective",
"--activation_checkpoint.selective_ac_option op",
],
],
"1D compile with selective op AC",
"1d_compile_sac_op",
),
OverrideDefinitions(
[
[
Expand Down
13 changes: 8 additions & 5 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,12 +441,15 @@ def apply_ac(model: nn.Module, ac_config: JobConfig):

def apply_compile(model: nn.Module):
"""Apply torch.compile to each transformer block."""

# the following flag can be used to to accelarate per-block compilation
# TODO(bdhirsh): turning it off because it's currently not working with 2D
# TODO(anijain): remove it after it's enabled in pytorch by default
# torch._dynamo.config.inline_inbuilt_nn_modules = True

for layer_id, transformer_block in model.layers.named_children():
# TODO: dynamic shape have some issues so we turn it off for now.
# TODO: inline inbuilt nn modules does not work yet, enable it to accelarate
# compile time.
# torch._dynamo.config.inline_inbuilt_nn_modules = True
transformer_block = torch.compile(transformer_block, dynamic=False)
# turn on per-transformer block compile after AC wrapping and before FSDP
transformer_block = torch.compile(transformer_block, fullgraph=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To check my understanding, is it true that fullgraph=True does not change the actual compiled function, only that it errors if it cannot acquire a full graph for this function?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep that's right - if there's a graph break inside of the tranformer block, fullgraph=True will force a compile-time error instead of allowing the graph break to run

model.layers.register_module(layer_id, transformer_block)

logger.info("Compiled each TransformerBlock with torch.compile")
Expand Down
Loading