-
Notifications
You must be signed in to change notification settings - Fork 277
[Deltaformer] kernel improvement; if-else optimization; change w to fp32; add 1e-9 to avoid nan #603
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Deltaformer] kernel improvement; if-else optimization; change w to fp32; add 1e-9 to avoid nan #603
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -167,7 +167,7 @@ def parallel_deltaformer_fwd_kernel( | |
| ) | ||
| q = tl.load(q_blk_ptr, boundary_check=(0,)) | ||
|
|
||
| for kv_i in range(0, T, BLOCK_T): | ||
| for kv_i in range(0, T-C, BLOCK_T): | ||
| k_blk_ptr = tl.make_block_ptr( | ||
| base=k_ptr + pid_h * D, | ||
| shape=(D, T), | ||
|
|
@@ -179,10 +179,6 @@ def parallel_deltaformer_fwd_kernel( | |
| k = tl.load(k_blk_ptr, boundary_check=(1,)) | ||
| qk = tl.dot(q, k) * qk_scale | ||
|
|
||
| if kv_i >= T - C: | ||
| mask = (T - C - kv_i + rowid_block[:, None] - colid_block[None, :] < 1) | ||
| qk = tl.where(mask, -1e6, qk) | ||
|
|
||
| rowmax_i = tl.maximum(rowmax, tl.max(qk, axis=1)) | ||
| qk -= rowmax_i[:, None] | ||
| p = tl.math.exp2(qk) | ||
|
|
@@ -193,17 +189,41 @@ def parallel_deltaformer_fwd_kernel( | |
| acc = acc * alpha[:, None] | ||
| rowmax = rowmax_i | ||
|
|
||
| if kv_i < T - C: | ||
| u_blk_ptr = tl.make_block_ptr( | ||
| base=u_ptr + pid_h * D, | ||
| shape=(T, D), | ||
| strides=(H * D, 1), | ||
| offsets=(kv_i, 0), | ||
| block_shape=(BLOCK_T, D), | ||
| order=(1, 0), | ||
| ) | ||
| u = tl.load(u_blk_ptr, boundary_check=(0,)) | ||
| acc = tl.dot(p.to(u_ptr.dtype.element_ty), u, acc) | ||
| u_blk_ptr = tl.make_block_ptr( | ||
| base=u_ptr + pid_h * D, | ||
| shape=(T, D), | ||
| strides=(H * D, 1), | ||
| offsets=(kv_i, 0), | ||
| block_shape=(BLOCK_T, D), | ||
| order=(1, 0), | ||
| ) | ||
| u = tl.load(u_blk_ptr, boundary_check=(0,)) | ||
| acc = tl.dot(p.to(u_ptr.dtype.element_ty), u, acc) | ||
|
|
||
| for kv_i in range(T-C, T, BLOCK_T): | ||
| k_blk_ptr = tl.make_block_ptr( | ||
| base=k_ptr + pid_h * D, | ||
| shape=(D, T), | ||
| strides=(1, H * D), | ||
| offsets=(0, kv_i), | ||
| block_shape=(D, BLOCK_T), | ||
| order=(0, 1), | ||
| ) | ||
| k = tl.load(k_blk_ptr, boundary_check=(1,)) | ||
| qk = tl.dot(q, k) * qk_scale | ||
|
|
||
| mask = (T - C - kv_i + rowid_block[:, None] - colid_block[None, :] < 1) | ||
| qk = tl.where(mask, -1e6, qk) | ||
|
|
||
| rowmax_i = tl.maximum(rowmax, tl.max(qk, axis=1)) | ||
| qk -= rowmax_i[:, None] | ||
| p = tl.math.exp2(qk) | ||
|
|
||
| rowsum_i = tl.sum(p, axis=1) | ||
| alpha = tl.math.exp2(rowmax - rowmax_i) | ||
| rowsum = rowsum * alpha + rowsum_i | ||
| acc = acc * alpha[:, None] | ||
| rowmax = rowmax_i | ||
|
Comment on lines
+203
to
+226
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This new loop, part of the loop-splitting optimization, introduces a significant amount of code duplication from the preceding loop (lines 170-201). Specifically, the logic for creating |
||
|
|
||
| lse = rowmax + tl.math.log2(rowsum) | ||
| lse_block_ptr = lse_ptr + pid_h + rowid_block * H | ||
|
|
@@ -218,7 +238,7 @@ def parallel_deltaformer_fwd_kernel( | |
| block_shape=(BLOCK_C, D), | ||
| order=(1, 0), | ||
| ) | ||
| acc = acc / rowsum[:, None] | ||
| acc = acc / (rowsum[:, None] + 1e-9) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chainGood addition of epsilon for numerical stability. Adding Verify that similar normalization divisions elsewhere in the codebase (e.g., in backward passes or other kernels) also include epsilon guards for consistency: 🏁 Script executed: #!/bin/bash
# Description: Search for other normalization divisions that might need epsilon
rg -n "/ *(rowsum|sum|norm)" --type=py -C 2 -g "*.py"Length of output: 408 Add epsilon to the The division at line 280 p = tl.math.exp2(qk) / rowsum[:, None]lacks the - p = tl.math.exp2(qk) / rowsum[:, None]
+ p = tl.math.exp2(qk) / (rowsum[:, None] + 1e-9)🤖 Prompt for AI Agents |
||
|
|
||
| beta_ptr = tl.make_block_ptr( | ||
| base=beta_ptr + pid_h, | ||
|
|
@@ -861,7 +881,7 @@ def _forward_impl( | |
| betai = beta_full[b, i:i + Ci, :] | ||
|
|
||
| w, lse_chunk = parallel_deltaformer_chunk_fwd(qi, ki, vi, ui_prev, fa_scale, betai) | ||
| w = w * betai.unsqueeze(-1) | ||
| w = w * betai.unsqueeze(-1).to(torch.float32) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Casting There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ensure consistent dtype casting across all code paths. The explicit w = w * betai.unsqueeze(-1) # line 944, missing .to(torch.float32)This inconsistency could lead to different numeric behavior or dtype mismatches between the standard and Apply this diff to maintain consistency: - w = w * betai.unsqueeze(-1)
+ w = w * betai.unsqueeze(-1).to(torch.float32)🤖 Prompt for AI Agents |
||
| if need_aux: | ||
| wpad = torch.zeros(C, H, C, device=ko.device, dtype=ko.dtype) | ||
| wpad[:Ci, :, :Ci].copy_(w) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove trailing whitespace.
The pre-commit hook detected trailing whitespace on this line, which should be removed to pass the linting check.
🤖 Prompt for AI Agents