Skip to content

Commit a1fdd7e

Browse files
authored
use logger in ft (#1539)
Summary: - wasn't seeing print statements getting printed - the statements show up using the logger - also added some logging to validate the model is being split for diloco
1 parent 36ec547 commit a1fdd7e

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

torchtitan/components/ft/diloco/utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch.nn as nn
88
from torchtitan.components.ft.config import FaultTolerance as FTConfig
99
from torchtitan.distributed.pipeline_parallel import generate_llm_fqn_per_model_part
10+
from torchtitan.tools.logging import logger
1011

1112

1213
def module_split(
@@ -98,7 +99,9 @@ def _build_fragment_from_modules(
9899
fragment_idx,
99100
module_names,
100101
)
101-
print(f"building fragment_idx {fragment_idx} " f"with modules {module_names}")
102+
logger.info(
103+
f"building fragment_idx {fragment_idx} " f"with modules {module_names}"
104+
)
102105
model_fragments.append(model_fragment)
103106

104107
return model_fragments
@@ -118,13 +121,14 @@ def fragment_llm(
118121

119122
if module_fqns_per_model_fragment == []:
120123
if ft_config.num_fragments == 1:
124+
logger.info("Created 1 model fragments")
121125
return [model]
122126

123127
module_fqns_per_model_fragment = generate_llm_fqn_per_model_part(
124128
ft_config.num_fragments, n_layers, input_weight, output_weight
125129
)
126130

127131
model_fragments = module_split(model, module_fqns_per_model_fragment)
128-
print(f"Created {len(model_fragments)} model fragments")
132+
logger.info(f"Created {len(model_fragments)} model fragments")
129133

130134
return model_fragments

torchtitan/components/ft/manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from torch.distributed._composable.fsdp.fully_shard import FSDPModule
1717
from torch.distributed.distributed_c10d import ReduceOp
1818
from torchtitan.components.ft.config import FaultTolerance as FTConfig
19+
from torchtitan.tools.logging import logger
1920

2021
if importlib.util.find_spec("torchft") is not None:
2122
import torchft as ft
@@ -125,6 +126,9 @@ def maybe_semi_sync_training(
125126
assert (
126127
ft_manager._manager is not None
127128
), "FTManager must be enabled to use semi-sync training."
129+
logger.info(
130+
f"using fragment function to split model: {fragment_fn is not None}"
131+
)
128132
if semi_sync_method.lower() == "diloco":
129133
if fragment_fn:
130134
model_parts = fragment_fn(model, ft_config, n_layers)

0 commit comments

Comments
 (0)