Skip to content

Commit 3e31a31

Browse files
committed
Add backward pass for softmax kernel
stack-info: PR: #744, branch: karthickai/stack/4
1 parent c7fa936 commit 3e31a31

File tree

3 files changed

+163
-0
lines changed

3 files changed

+163
-0
lines changed

examples/softmax.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# %%
1212
from __future__ import annotations
1313

14+
from typing import Any
15+
1416
import torch
1517

1618
import helion
@@ -89,6 +91,69 @@ def softmax_two_pass(x: torch.Tensor) -> torch.Tensor:
8991
return out
9092

9193

94+
@helion.kernel()
95+
def softmax_bwd(
96+
grad_output: torch.Tensor, softmax_output: torch.Tensor
97+
) -> torch.Tensor:
98+
"""
99+
Helion kernel implementing softmax backward pass.
100+
101+
dy/dx = softmax_output * (grad_output - sum(softmax_output * grad_output))
102+
103+
Args:
104+
grad_output (torch.Tensor): Gradient from downstream layers of shape [m, n]
105+
softmax_output (torch.Tensor): Output from forward softmax pass of shape [m, n]
106+
107+
Returns:
108+
torch.Tensor: Gradient with respect to input of shape [m, n]
109+
"""
110+
m, n = grad_output.size()
111+
grad_input = torch.empty_like(grad_output)
112+
113+
block_size_m = hl.register_block_size(m)
114+
block_size_n = hl.register_block_size(n)
115+
116+
for tile_m in hl.tile(m, block_size=block_size_m):
117+
sum_per_row = hl.zeros([tile_m], dtype=torch.float32)
118+
for tile_n in hl.tile(n, block_size=block_size_n):
119+
sum_per_row += torch.sum(
120+
softmax_output[tile_m, tile_n] * grad_output[tile_m, tile_n], dim=1
121+
)
122+
for tile_n in hl.tile(n, block_size=block_size_n):
123+
grad_input[tile_m, tile_n] = softmax_output[tile_m, tile_n] * (
124+
grad_output[tile_m, tile_n] - sum_per_row[:, None]
125+
)
126+
127+
return grad_input
128+
129+
130+
class SoftmaxFunction(torch.autograd.Function):
131+
@staticmethod
132+
def forward(
133+
ctx: Any, # noqa: ANN401
134+
x: torch.Tensor,
135+
) -> torch.Tensor:
136+
y = softmax_two_pass(x)
137+
ctx.save_for_backward(y)
138+
return y
139+
140+
@staticmethod
141+
def backward( # type: ignore[override]
142+
ctx: Any, # noqa: ANN401
143+
grad_output: torch.Tensor,
144+
) -> tuple[torch.Tensor | None]:
145+
(softmax_output,) = ctx.saved_tensors
146+
grad_x = softmax_bwd(grad_output, softmax_output)
147+
return (grad_x,)
148+
149+
150+
def softmax_fwd_bwd(
151+
x: torch.Tensor,
152+
) -> torch.Tensor:
153+
"""Softmax with forward + backward support."""
154+
return SoftmaxFunction.apply(x) # type: ignore[no-any-return]
155+
156+
92157
# %%
93158
def check(m: int, n: int) -> None:
94159
"""
@@ -105,6 +170,17 @@ def check(m: int, n: int) -> None:
105170
}
106171
run_example(kernels, lambda x: torch.nn.functional.softmax(x, dim=1), (x,))
107172

173+
print("\n\n=== Forward + Backward Pass Test ===")
174+
x_grad = torch.randn([m, n], device="cuda", dtype=torch.float16, requires_grad=True)
175+
run_example(
176+
softmax_fwd_bwd,
177+
torch.nn.functional.softmax,
178+
(x_grad,),
179+
rtol=1e-3,
180+
atol=1e-3,
181+
bwd=True,
182+
)
183+
108184

109185
# %%
110186
def main() -> None:

test/test_examples.expected

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3420,6 +3420,67 @@ def softmax(x: torch.Tensor, *, _launcher=_default_launcher):
34203420
_launcher(_helion_softmax, (n,), x, out, out.size(0), out.size(1), x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), _m, _RDIM_SIZE_1, 1, num_warps=4, num_stages=1)
34213421
return out
34223422

3423+
--- assertExpectedJournal(TestExamples.test_softmax_bwd)
3424+
from __future__ import annotations
3425+
3426+
import torch
3427+
import triton
3428+
import triton.language as tl
3429+
from helion.runtime import default_launcher as _default_launcher
3430+
3431+
@triton.jit
3432+
def _helion_softmax_bwd(softmax_output, grad_output, grad_input, grad_input_stride_0, grad_input_stride_1, grad_output_stride_0, grad_output_stride_1, softmax_output_stride_0, softmax_output_stride_1, m, n, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
3433+
pid_0 = tl.program_id(0)
3434+
offset_0 = pid_0 * _BLOCK_SIZE_0
3435+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
3436+
mask_0 = indices_0 < m
3437+
sum_per_row = tl.full([_BLOCK_SIZE_0], 0.0, tl.float32)
3438+
for offset_2 in tl.range(0, n.to(tl.int32), _BLOCK_SIZE_1):
3439+
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
3440+
mask_1 = indices_2 < n
3441+
sum_per_row_copy = sum_per_row
3442+
sum_per_row_copy_0 = sum_per_row_copy
3443+
load = tl.load(softmax_output + (indices_0[:, None] * softmax_output_stride_0 + indices_2[None, :] * softmax_output_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
3444+
load_1 = tl.load(grad_output + (indices_0[:, None] * grad_output_stride_0 + indices_2[None, :] * grad_output_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
3445+
v_0 = load * load_1
3446+
sum_1 = tl.cast(tl.sum(v_0, 1), tl.float16)
3447+
v_1 = tl.cast(sum_1, tl.float32)
3448+
sum_per_row = sum_per_row_copy_0 + v_1
3449+
for offset_2 in tl.range(0, n.to(tl.int32), _BLOCK_SIZE_1):
3450+
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
3451+
mask_2 = indices_2 < n
3452+
sum_per_row_copy_1 = sum_per_row
3453+
sum_per_row_copy_1_0 = sum_per_row_copy_1
3454+
load_2 = tl.load(softmax_output + (indices_0[:, None] * softmax_output_stride_0 + indices_2[None, :] * softmax_output_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
3455+
load_3 = tl.load(grad_output + (indices_0[:, None] * grad_output_stride_0 + indices_2[None, :] * grad_output_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
3456+
subscript = sum_per_row_copy_1_0[:, None]
3457+
v_3 = tl.cast(load_3, tl.float32)
3458+
v_4 = v_3 - subscript
3459+
v_5 = tl.cast(load_2, tl.float32)
3460+
v_6 = v_5 * v_4
3461+
v_7 = tl.cast(v_6, tl.float16)
3462+
tl.store(grad_input + (indices_0[:, None] * grad_input_stride_0 + indices_2[None, :] * grad_input_stride_1), v_7, mask_0[:, None] & mask_2[None, :])
3463+
3464+
def softmax_bwd(grad_output: torch.Tensor, softmax_output: torch.Tensor, *, _launcher=_default_launcher):
3465+
"""
3466+
Helion kernel implementing softmax backward pass.
3467+
3468+
dy/dx = softmax_output * (grad_output - sum(softmax_output * grad_output))
3469+
3470+
Args:
3471+
grad_output (torch.Tensor): Gradient from downstream layers of shape [m, n]
3472+
softmax_output (torch.Tensor): Output from forward softmax pass of shape [m, n]
3473+
3474+
Returns:
3475+
torch.Tensor: Gradient with respect to input of shape [m, n]
3476+
"""
3477+
m, n = grad_output.size()
3478+
grad_input = torch.empty_like(grad_output)
3479+
_BLOCK_SIZE_0 = 32
3480+
_BLOCK_SIZE_1 = 32
3481+
_launcher(_helion_softmax_bwd, (triton.cdiv(m, _BLOCK_SIZE_0),), softmax_output, grad_output, grad_input, grad_input.stride(0), grad_input.stride(1), grad_output.stride(0), grad_output.stride(1), softmax_output.stride(0), softmax_output.stride(1), m, n, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
3482+
return grad_input
3483+
34233484
--- assertExpectedJournal(TestExamples.test_softmax_decomposed)
34243485
from __future__ import annotations
34253486

test/test_examples.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -881,6 +881,32 @@ def test_layernorm_bwd_dx(self):
881881
)
882882
)
883883

884+
def test_softmax_bwd(self):
885+
m, n = 2048, 2048
886+
x = torch.randn([m, n], device=DEVICE, dtype=torch.float16, requires_grad=True)
887+
grad_out = torch.randn([m, n], device=DEVICE, dtype=torch.float16)
888+
889+
from examples.softmax import softmax_two_pass
890+
891+
config = helion.Config(block_size=[128, 128], num_warps=4, num_stages=3)
892+
configured_kernel = helion.kernel(softmax_two_pass.fn, config=config)
893+
y = configured_kernel(x)
894+
895+
x_torch = x.detach().clone().requires_grad_(True)
896+
y_torch = torch.nn.functional.softmax(x_torch, dim=-1)
897+
y_torch.backward(grad_out)
898+
899+
self.assertExpectedJournal(
900+
check_example(
901+
"softmax",
902+
(grad_out, y),
903+
x_torch.grad,
904+
fn_name="softmax_bwd",
905+
rtol=1e-3,
906+
atol=1e-3,
907+
)
908+
)
909+
884910
def test_layernorm_without_bias(self):
885911
x = -2.3 + 0.5 * torch.randn([32, 64], device=DEVICE, dtype=torch.float16)
886912
weight = torch.randn([64], device=DEVICE, dtype=torch.float16)

0 commit comments

Comments
 (0)