Skip to content

Commit 2c155fd

Browse files
committed
Add backward pass for softmax kernel
stack-info: PR: #744, branch: karthickai/stack/4
1 parent c7fa936 commit 2c155fd

File tree

4 files changed

+192
-3
lines changed

4 files changed

+192
-3
lines changed

benchmarks/run.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,12 @@ class RunResult:
159159
"softmax": (
160160
"tritonbench.operators.softmax.operator",
161161
"examples.softmax",
162-
"softmax",
162+
"softmax_tritonbench",
163+
),
164+
"softmax-bwd": (
165+
"tritonbench.operators.softmax.operator",
166+
"examples.softmax",
167+
"softmax_tritonbench",
163168
),
164169
"jagged_mean": (
165170
"tritonbench.operators.jagged_mean.operator",
@@ -325,8 +330,17 @@ class RunResult:
325330
"triton_softmax-accuracy": "triton_accuracy",
326331
"torch_compile_softmax-speedup": "torch_compile_speedup",
327332
"torch_compile_softmax-accuracy": "torch_compile_accuracy",
328-
"helion_softmax-speedup": "helion_speedup",
329-
"helion_softmax-accuracy": "helion_accuracy",
333+
"helion_softmax_tritonbench-speedup": "helion_speedup",
334+
"helion_softmax_tritonbench-accuracy": "helion_accuracy",
335+
},
336+
"softmax-bwd": {
337+
"naive_softmax": "baseline",
338+
"triton_softmax-speedup": "triton_speedup",
339+
"triton_softmax-accuracy": "triton_accuracy",
340+
"torch_compile_softmax-speedup": "torch_compile_speedup",
341+
"torch_compile_softmax-accuracy": "torch_compile_accuracy",
342+
"helion_softmax_tritonbench-speedup": "helion_speedup",
343+
"helion_softmax_tritonbench-accuracy": "helion_accuracy",
330344
},
331345
"rms_norm": {
332346
"llama_rms": "baseline",

examples/softmax.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# %%
1212
from __future__ import annotations
1313

14+
from typing import Any
15+
1416
import torch
1517

1618
import helion
@@ -89,6 +91,80 @@ def softmax_two_pass(x: torch.Tensor) -> torch.Tensor:
8991
return out
9092

9193

94+
@helion.kernel()
95+
def softmax_bwd(
96+
grad_output: torch.Tensor, softmax_output: torch.Tensor
97+
) -> torch.Tensor:
98+
"""
99+
Helion kernel implementing softmax backward pass.
100+
101+
dy/dx = softmax_output * (grad_output - sum(softmax_output * grad_output))
102+
103+
Args:
104+
grad_output (torch.Tensor): Gradient from downstream layers of shape [m, n]
105+
softmax_output (torch.Tensor): Output from forward softmax pass of shape [m, n]
106+
107+
Returns:
108+
torch.Tensor: Gradient with respect to input of shape [m, n]
109+
"""
110+
m, n = grad_output.size()
111+
grad_input = torch.empty_like(grad_output)
112+
113+
for tile_m in hl.tile(m):
114+
sum_per_row = hl.zeros([tile_m], dtype=torch.float32)
115+
for tile_n in hl.tile(n):
116+
sum_per_row += torch.sum(
117+
softmax_output[tile_m, tile_n] * grad_output[tile_m, tile_n], dim=1
118+
)
119+
for tile_n in hl.tile(n):
120+
grad_input[tile_m, tile_n] = softmax_output[tile_m, tile_n] * (
121+
grad_output[tile_m, tile_n] - sum_per_row[:, None]
122+
)
123+
124+
return grad_input
125+
126+
127+
class SoftmaxFunction(torch.autograd.Function):
128+
@staticmethod
129+
def forward(
130+
ctx: Any, # noqa: ANN401
131+
x: torch.Tensor,
132+
) -> torch.Tensor:
133+
y = softmax_two_pass(x)
134+
ctx.save_for_backward(y)
135+
return y
136+
137+
@staticmethod
138+
def backward( # type: ignore[override]
139+
ctx: Any, # noqa: ANN401
140+
grad_output: torch.Tensor,
141+
) -> tuple[torch.Tensor | None]:
142+
(softmax_output,) = ctx.saved_tensors
143+
grad_x = softmax_bwd(grad_output, softmax_output)
144+
return (grad_x,)
145+
146+
147+
def softmax_fwd_bwd(
148+
x: torch.Tensor,
149+
) -> torch.Tensor:
150+
"""Softmax with forward + backward support."""
151+
return SoftmaxFunction.apply(x) # type: ignore[no-any-return]
152+
153+
154+
def softmax_tritonbench(tb_op: object, x: torch.Tensor) -> callable[[], torch.Tensor]:
155+
"""
156+
Wrapper for tritonbench that returns softmax with backward support.
157+
158+
Args:
159+
tb_op: TritonBench operator instance
160+
x: Input tensor
161+
162+
Returns:
163+
Callable that returns the output tensor
164+
"""
165+
return lambda: softmax_fwd_bwd(x)
166+
167+
92168
# %%
93169
def check(m: int, n: int) -> None:
94170
"""
@@ -105,6 +181,17 @@ def check(m: int, n: int) -> None:
105181
}
106182
run_example(kernels, lambda x: torch.nn.functional.softmax(x, dim=1), (x,))
107183

184+
print("\n\n=== Forward + Backward Pass Test ===")
185+
x_grad = torch.randn([m, n], device="cuda", dtype=torch.float16, requires_grad=True)
186+
run_example(
187+
softmax_fwd_bwd,
188+
torch.nn.functional.softmax,
189+
(x_grad,),
190+
rtol=1e-3,
191+
atol=1e-3,
192+
bwd=True,
193+
)
194+
108195

109196
# %%
110197
def main() -> None:

test/test_examples.expected

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3420,6 +3420,68 @@ def softmax(x: torch.Tensor, *, _launcher=_default_launcher):
34203420
_launcher(_helion_softmax, (n,), x, out, out.size(0), out.size(1), x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), _m, _RDIM_SIZE_1, 1, num_warps=4, num_stages=1)
34213421
return out
34223422

3423+
--- assertExpectedJournal(TestExamples.test_softmax_bwd)
3424+
from __future__ import annotations
3425+
3426+
import torch
3427+
import triton
3428+
import triton.language as tl
3429+
from helion.runtime import default_launcher as _default_launcher
3430+
3431+
@triton.jit
3432+
def _helion_softmax_bwd(softmax_output, grad_output, grad_input, grad_input_stride_0, grad_input_stride_1, grad_output_stride_0, grad_output_stride_1, softmax_output_stride_0, softmax_output_stride_1, m, n, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
3433+
pid_0 = tl.program_id(0)
3434+
offset_0 = pid_0 * _BLOCK_SIZE_0
3435+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
3436+
mask_0 = indices_0 < m
3437+
sum_per_row = tl.full([_BLOCK_SIZE_0], 0.0, tl.float32)
3438+
for offset_1 in tl.range(0, n.to(tl.int32), _BLOCK_SIZE_1):
3439+
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
3440+
mask_1 = indices_1 < n
3441+
sum_per_row_copy = sum_per_row
3442+
sum_per_row_copy_0 = sum_per_row_copy
3443+
load = tl.load(softmax_output + (indices_0[:, None] * softmax_output_stride_0 + indices_1[None, :] * softmax_output_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
3444+
load_1 = tl.load(grad_output + (indices_0[:, None] * grad_output_stride_0 + indices_1[None, :] * grad_output_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
3445+
v_0 = load * load_1
3446+
sum_1 = tl.cast(tl.sum(v_0, 1), tl.float16)
3447+
v_1 = tl.cast(sum_1, tl.float32)
3448+
sum_per_row = sum_per_row_copy_0 + v_1
3449+
for offset_2 in tl.range(0, n.to(tl.int32), _BLOCK_SIZE_2):
3450+
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
3451+
mask_2 = indices_2 < n
3452+
sum_per_row_copy_1 = sum_per_row
3453+
sum_per_row_copy_1_0 = sum_per_row_copy_1
3454+
load_2 = tl.load(softmax_output + (indices_0[:, None] * softmax_output_stride_0 + indices_2[None, :] * softmax_output_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
3455+
load_3 = tl.load(grad_output + (indices_0[:, None] * grad_output_stride_0 + indices_2[None, :] * grad_output_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
3456+
subscript = sum_per_row_copy_1_0[:, None]
3457+
v_3 = tl.cast(load_3, tl.float32)
3458+
v_4 = v_3 - subscript
3459+
v_5 = tl.cast(load_2, tl.float32)
3460+
v_6 = v_5 * v_4
3461+
v_7 = tl.cast(v_6, tl.float16)
3462+
tl.store(grad_input + (indices_0[:, None] * grad_input_stride_0 + indices_2[None, :] * grad_input_stride_1), v_7, mask_0[:, None] & mask_2[None, :])
3463+
3464+
def softmax_bwd(grad_output: torch.Tensor, softmax_output: torch.Tensor, *, _launcher=_default_launcher):
3465+
"""
3466+
Helion kernel implementing softmax backward pass.
3467+
3468+
dy/dx = softmax_output * (grad_output - sum(softmax_output * grad_output))
3469+
3470+
Args:
3471+
grad_output (torch.Tensor): Gradient from downstream layers of shape [m, n]
3472+
softmax_output (torch.Tensor): Output from forward softmax pass of shape [m, n]
3473+
3474+
Returns:
3475+
torch.Tensor: Gradient with respect to input of shape [m, n]
3476+
"""
3477+
m, n = grad_output.size()
3478+
grad_input = torch.empty_like(grad_output)
3479+
_BLOCK_SIZE_0 = 16
3480+
_BLOCK_SIZE_1 = 16
3481+
_BLOCK_SIZE_2 = 16
3482+
_launcher(_helion_softmax_bwd, (triton.cdiv(m, _BLOCK_SIZE_0),), softmax_output, grad_output, grad_input, grad_input.stride(0), grad_input.stride(1), grad_output.stride(0), grad_output.stride(1), softmax_output.stride(0), softmax_output.stride(1), m, n, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
3483+
return grad_input
3484+
34233485
--- assertExpectedJournal(TestExamples.test_softmax_decomposed)
34243486
from __future__ import annotations
34253487

test/test_examples.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -881,6 +881,32 @@ def test_layernorm_bwd_dx(self):
881881
)
882882
)
883883

884+
def test_softmax_bwd(self):
885+
m, n = 2048, 2048
886+
x = torch.randn([m, n], device=DEVICE, dtype=torch.float16, requires_grad=True)
887+
grad_out = torch.randn([m, n], device=DEVICE, dtype=torch.float16)
888+
889+
from examples.softmax import softmax_two_pass
890+
891+
config = helion.Config(block_size=[128, 128], num_warps=4, num_stages=3)
892+
configured_kernel = helion.kernel(softmax_two_pass.fn, config=config)
893+
y = configured_kernel(x)
894+
895+
x_torch = x.detach().clone().requires_grad_(True)
896+
y_torch = torch.nn.functional.softmax(x_torch, dim=-1)
897+
y_torch.backward(grad_out)
898+
899+
self.assertExpectedJournal(
900+
check_example(
901+
"softmax",
902+
(grad_out, y),
903+
x_torch.grad,
904+
fn_name="softmax_bwd",
905+
rtol=1e-3,
906+
atol=1e-3,
907+
)
908+
)
909+
884910
def test_layernorm_without_bias(self):
885911
x = -2.3 + 0.5 * torch.randn([32, 64], device=DEVICE, dtype=torch.float16)
886912
weight = torch.randn([64], device=DEVICE, dtype=torch.float16)

0 commit comments

Comments
 (0)