diff --git a/torchtitan/distributed/pipeline_parallel.py b/torchtitan/distributed/pipeline_parallel.py index 8ecc7df49a..c1db39225b 100644 --- a/torchtitan/distributed/pipeline_parallel.py +++ b/torchtitan/distributed/pipeline_parallel.py @@ -64,6 +64,29 @@ def forward(ctx, i, w): @staticmethod def backward(ctx, grad_output): (w,) = ctx.saved_tensors + """ + A[m,k] @ B[k,n] -> O[m,n] + grad_o[m,n] @ B.t()[n,k] -> grad_a[m,k] + looks right.. + getting +[rank4]:[rank4]: File "/data/users/whc/torchtitan/torchtitan/distributed/pipeline_parallel.py", line 67, in backward +[rank4]:[rank4]: grad_input = grad_output.mm(w.t()) +[rank4]:[rank4]: File "/data/users/whc/torchtitan/torchtitan/distributed/pipeline_parallel.py", line 88, in split_mm +[rank4]:[rank4]: return MmPassThrough.apply(i1, w1) +[rank4]:[rank4]: File "/data/users/whc/pytorch/torch/autograd/function.py", line 583, in apply +[rank4]:[rank4]: return super().apply(*args, **kwargs) # type: ignore[misc] +[rank4]:[rank4]: File "/data/users/whc/torchtitan/torchtitan/distributed/pipeline_parallel.py", line 74, in forward +[rank4]:[rank4]: return torch.mm(x, y) +[rank4]:[rank4]: RuntimeError: mat1 and mat2 shapes cannot be multiplied (4096x2816 and 2048x2816) + +[rank4]:[rank4]: RuntimeError: +[rank4]:[rank4]: Failed to run stage backward: +[rank4]:[rank4]: Stage output: ('Tensor(torch.Size([1, 4096, 2048]), grad=True, dtype=torch.bfloat16)',) +[rank4]:[rank4]: Output gradient: ('Tensor(torch.Size([1, 4096, 2048]), grad=False, dtype=torch.bfloat16)',) +[rank4]:[rank4]: Input: ['Tensor(torch.Size([1, 4096, 2048]), grad=True, dtype=torch.bfloat16)'] +[rank4]:[rank4]: + """ + logger.error(f"MmSeparateInputGrad backward: {grad_output.shape=}, {w.t().shape=}") grad_input = grad_output.mm(w.t()) return grad_input, None @@ -75,13 +98,14 @@ def forward(ctx, x, y): @staticmethod def backward(ctx, gO): + # TODO(whc) - claude first wrote it this way and later tried to return None, None, i'm not sure which is correct return gO, gO def split_mm(i, w): - print("split mul") # Apply the pass-through node. y is passed to this node so that it can be # saved for backward, but detach because we don't want to actually build # this edge of the graph + logger.error(f"split_mm forward: {i.shape=}, {w.shape=}") w1 = MmSeparateWeightGrad.apply(i.detach(), w) i1 = MmSeparateInputGrad.apply(i, w.detach()) return MmPassThrough.apply(i1, w1) @@ -138,7 +162,6 @@ def backward(ctx, gO): return gO, gO, gO, None, None def split_addmm(bias, mat1, mat2, *, beta=1, alpha=1): - print("split addmm") mat2_1 = AddmmSeparateMat2Grad.apply(mat1.detach(), mat2, alpha) mat1_1 = AddmmSeparateMat1Grad.apply(mat1, mat2.detach(), alpha) bias_1 = AddmmSeparateBiasGrad.apply(bias, beta) @@ -197,7 +220,6 @@ def backward(ctx, gO): return gO, None, gO, None def split_rms_norm(input, normalized_shape, weight=None, eps=None): - print("split rms_norm") weight_1 = RmsNormSeparateWeightGrad.apply( input.detach(), normalized_shape, weight, eps ) @@ -255,7 +277,6 @@ def backward(ctx, gO): return gO, gO, None, None, None def split_grouped_mm(input, mat2, offs=None, bias=None, out_dtype=None): - print("split grouped_mm") mat2_1 = GroupedMmSeparateMat2Grad.apply( input.detach(), mat2, offs, bias, out_dtype ) diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml index cb1fb30fe3..d8a38f80ae 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml @@ -61,7 +61,7 @@ export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem]" [activation_checkpoint] -mode = "selective" # ["none", "selective", "full"] +mode = "none" # ["none", "selective", "full"] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy [compile]