Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions torchtitan/components/ft/diloco/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch.nn as nn
from torchtitan.components.ft.config import FaultTolerance as FTConfig
from torchtitan.distributed.pipeline_parallel import generate_llm_fqn_per_model_part
from torchtitan.tools.logging import logger


def module_split(
Expand Down Expand Up @@ -98,7 +99,9 @@ def _build_fragment_from_modules(
fragment_idx,
module_names,
)
print(f"building fragment_idx {fragment_idx} " f"with modules {module_names}")
logger.info(
f"building fragment_idx {fragment_idx} " f"with modules {module_names}"
)
model_fragments.append(model_fragment)

return model_fragments
Expand All @@ -118,13 +121,14 @@ def fragment_llm(

if module_fqns_per_model_fragment == []:
if ft_config.num_fragments == 1:
logger.info("Created 1 model fragments")
return [model]

module_fqns_per_model_fragment = generate_llm_fqn_per_model_part(
ft_config.num_fragments, n_layers, input_weight, output_weight
)

model_fragments = module_split(model, module_fqns_per_model_fragment)
print(f"Created {len(model_fragments)} model fragments")
logger.info(f"Created {len(model_fragments)} model fragments")

return model_fragments
4 changes: 4 additions & 0 deletions torchtitan/components/ft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torch.distributed._composable.fsdp.fully_shard import FSDPModule
from torch.distributed.distributed_c10d import ReduceOp
from torchtitan.components.ft.config import FaultTolerance as FTConfig
from torchtitan.tools.logging import logger

if importlib.util.find_spec("torchft") is not None:
import torchft as ft
Expand Down Expand Up @@ -125,6 +126,9 @@ def maybe_semi_sync_training(
assert (
ft_manager._manager is not None
), "FTManager must be enabled to use semi-sync training."
logger.info(
f"using fragment function to split model: {fragment_fn is not None}"
)
if semi_sync_method.lower() == "diloco":
if fragment_fn:
model_parts = fragment_fn(model, ft_config, n_layers)
Expand Down
Loading