Skip to content

Commit c0c273c

Browse files
committed
pass deterministic.fill_uninitialized_memory to HF model
1 parent 9be95da commit c0c273c

File tree

3 files changed

+6
-2
lines changed

3 files changed

+6
-2
lines changed

torchtitan/distributed/utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,6 @@ def set_determinism(
111111
)
112112
torch.backends.cudnn.deterministic = True
113113
torch.backends.cudnn.benchmark = False
114-
# Otherwise, Huggignface modeling register buffer for ROPE (inv_freq) and this will be by default be initialized to Nan
115-
torch.utils.deterministic.fill_uninitialized_memory = False
116114
# env var for deterministic CuBLAS
117115
# https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html
118116
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

torchtitan/experiments/transformers_backend/model/args.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,8 @@ def update_from_config(self, job_config: JobConfig):
171171

172172
self.max_seq_len = job_config.training.seq_len
173173

174+
self.deterministic = job_config.debug.deterministic
175+
174176
# Configure HF-specific settings to match TorchTitan settings
175177
# TODO: false ?
176178
self.attention_bias = False

torchtitan/experiments/transformers_backend/model/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ class HFTransformerModel(nn.Module):
5050
def __init__(self, model_args: HFTransformerModelArgs):
5151
super().__init__()
5252

53+
#NOTE(3outeille): This prevents Hugging Face modeling from initializing ROPE (inv_freq) buffers to NaN. Usefull when loading from seed checkpoint.
54+
if hasattr(model_args, 'deterministic') and model_args.deterministic:
55+
torch.utils.deterministic.fill_uninitialized_memory = False
56+
5357
# Try to import the model class dynamically from the transformers library if not found in globals
5458
model_class_name = model_args.architectures[0]
5559
model_cls = globals().get(model_class_name, None)

0 commit comments

Comments
 (0)