Skip to content

Commit ad3a2a9

Browse files
authored
Add backward pass for softmax kernel (#744)
1 parent 46f6aa6 commit ad3a2a9

File tree

4 files changed

+192
-3
lines changed

4 files changed

+192
-3
lines changed

benchmarks/run.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,12 @@ class RunResult:
176176
"softmax": (
177177
"tritonbench.operators.softmax.operator",
178178
"examples.softmax",
179-
"softmax",
179+
"softmax_tritonbench",
180+
),
181+
"softmax-bwd": (
182+
"tritonbench.operators.softmax.operator",
183+
"examples.softmax",
184+
"softmax_tritonbench",
180185
),
181186
"jagged_mean": (
182187
"tritonbench.operators.jagged_mean.operator",
@@ -378,8 +383,17 @@ class RunResult:
378383
"triton_softmax-accuracy": "triton_accuracy",
379384
"torch_compile_softmax-speedup": "torch_compile_speedup",
380385
"torch_compile_softmax-accuracy": "torch_compile_accuracy",
381-
"helion_softmax-speedup": "helion_speedup",
382-
"helion_softmax-accuracy": "helion_accuracy",
386+
"helion_softmax_tritonbench-speedup": "helion_speedup",
387+
"helion_softmax_tritonbench-accuracy": "helion_accuracy",
388+
},
389+
"softmax-bwd": {
390+
"naive_softmax": "baseline",
391+
"triton_softmax-speedup": "triton_speedup",
392+
"triton_softmax-accuracy": "triton_accuracy",
393+
"torch_compile_softmax-speedup": "torch_compile_speedup",
394+
"torch_compile_softmax-accuracy": "torch_compile_accuracy",
395+
"helion_softmax_tritonbench-speedup": "helion_speedup",
396+
"helion_softmax_tritonbench-accuracy": "helion_accuracy",
383397
},
384398
"rms_norm": {
385399
"llama_rms": "baseline",

examples/softmax.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
# %%
1212
from __future__ import annotations
1313

14+
from typing import Any
15+
from typing import Callable
16+
1417
import torch
1518

1619
import helion
@@ -90,6 +93,79 @@ def softmax_two_pass(x: torch.Tensor) -> torch.Tensor:
9093
return out
9194

9295

96+
@helion.kernel()
97+
def softmax_bwd(
98+
grad_output: torch.Tensor, softmax_output: torch.Tensor
99+
) -> torch.Tensor:
100+
"""
101+
Helion kernel implementing softmax backward pass.
102+
103+
dy/dx = softmax_output * (grad_output - sum(softmax_output * grad_output))
104+
105+
Args:
106+
grad_output (torch.Tensor): Gradient from downstream layers of shape [m, n]
107+
softmax_output (torch.Tensor): Output from forward softmax pass of shape [m, n]
108+
109+
Returns:
110+
torch.Tensor: Gradient with respect to input of shape [m, n]
111+
"""
112+
m, n = grad_output.size()
113+
grad_input = torch.empty_like(grad_output)
114+
115+
for tile_m in hl.tile(m):
116+
sum_per_row = hl.zeros([tile_m], dtype=torch.float32)
117+
for tile_n in hl.tile(n):
118+
sum_per_row += torch.sum(
119+
softmax_output[tile_m, tile_n] * grad_output[tile_m, tile_n], dim=1
120+
)
121+
for tile_n in hl.tile(n):
122+
grad_input[tile_m, tile_n] = softmax_output[tile_m, tile_n] * (
123+
grad_output[tile_m, tile_n] - sum_per_row[:, None]
124+
)
125+
126+
return grad_input
127+
128+
129+
class SoftmaxFunction(torch.autograd.Function):
130+
@staticmethod
131+
def forward(
132+
ctx: Any, # noqa: ANN401
133+
x: torch.Tensor,
134+
) -> torch.Tensor:
135+
y = softmax_two_pass(x)
136+
ctx.save_for_backward(y)
137+
return y
138+
139+
@staticmethod
140+
def backward( # type: ignore[override]
141+
ctx: Any, # noqa: ANN401
142+
grad_output: torch.Tensor,
143+
) -> tuple[torch.Tensor | None]:
144+
(softmax_output,) = ctx.saved_tensors
145+
grad_x = softmax_bwd(grad_output, softmax_output)
146+
return (grad_x,)
147+
148+
149+
def softmax_fwd_bwd(
150+
x: torch.Tensor,
151+
) -> torch.Tensor:
152+
"""Softmax with forward + backward support."""
153+
return SoftmaxFunction.apply(x) # type: ignore[no-any-return]
154+
155+
156+
def softmax_tritonbench(tb_op: object, x: torch.Tensor) -> Callable[[], torch.Tensor]:
157+
"""
158+
Wrapper for tritonbench that returns softmax with backward support.
159+
Args:
160+
tb_op: TritonBench operator instance
161+
x: Input tensor
162+
163+
Returns:
164+
Callable that returns the output tensor
165+
"""
166+
return lambda: softmax_fwd_bwd(x)
167+
168+
93169
# %%
94170
def check(m: int, n: int) -> None:
95171
"""
@@ -106,6 +182,17 @@ def check(m: int, n: int) -> None:
106182
}
107183
run_example(kernels, lambda x: torch.nn.functional.softmax(x, dim=1), (x,))
108184

185+
print("\n\n=== Forward + Backward Pass Test ===")
186+
x_grad = torch.randn([m, n], device="cuda", dtype=torch.float16, requires_grad=True)
187+
run_example(
188+
softmax_fwd_bwd,
189+
torch.nn.functional.softmax,
190+
(x_grad,),
191+
rtol=1e-3,
192+
atol=1e-3,
193+
bwd=True,
194+
)
195+
109196

110197
# %%
111198
def main() -> None:

test/test_examples.expected

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5576,6 +5576,68 @@ def softmax(x: torch.Tensor, *, _launcher=_default_launcher):
55765576
# src[softmax.py:N]: return out
55775577
return out
55785578

5579+
--- assertExpectedJournal(TestExamples.test_softmax_bwd)
5580+
from __future__ import annotations
5581+
5582+
import torch
5583+
import triton
5584+
import triton.language as tl
5585+
from helion.runtime import default_launcher as _default_launcher
5586+
5587+
@triton.jit
5588+
def _helion_softmax_bwd(softmax_output, grad_output, grad_input, grad_input_stride_0, grad_input_stride_1, grad_output_stride_0, grad_output_stride_1, softmax_output_stride_0, softmax_output_stride_1, m, n, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
5589+
pid_0 = tl.program_id(0)
5590+
offset_0 = pid_0 * _BLOCK_SIZE_0
5591+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
5592+
mask_0 = indices_0 < m
5593+
sum_per_row = tl.full([_BLOCK_SIZE_0], 0.0, tl.float32)
5594+
for offset_1 in tl.range(0, n.to(tl.int32), _BLOCK_SIZE_1):
5595+
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
5596+
mask_1 = indices_1 < n
5597+
sum_per_row_copy = sum_per_row
5598+
sum_per_row_copy_0 = sum_per_row_copy
5599+
load = tl.load(softmax_output + (indices_0[:, None] * softmax_output_stride_0 + indices_1[None, :] * softmax_output_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
5600+
load_1 = tl.load(grad_output + (indices_0[:, None] * grad_output_stride_0 + indices_1[None, :] * grad_output_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
5601+
v_0 = load * load_1
5602+
sum_1 = tl.cast(tl.sum(v_0, 1), tl.float16)
5603+
v_1 = tl.cast(sum_1, tl.float32)
5604+
sum_per_row = sum_per_row_copy_0 + v_1
5605+
for offset_2 in tl.range(0, n.to(tl.int32), _BLOCK_SIZE_2):
5606+
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
5607+
mask_2 = indices_2 < n
5608+
sum_per_row_copy_1 = sum_per_row
5609+
sum_per_row_copy_1_0 = sum_per_row_copy_1
5610+
load_2 = tl.load(softmax_output + (indices_0[:, None] * softmax_output_stride_0 + indices_2[None, :] * softmax_output_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
5611+
load_3 = tl.load(grad_output + (indices_0[:, None] * grad_output_stride_0 + indices_2[None, :] * grad_output_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
5612+
subscript = sum_per_row_copy_1_0[:, None]
5613+
v_3 = tl.cast(load_3, tl.float32)
5614+
v_4 = v_3 - subscript
5615+
v_5 = tl.cast(load_2, tl.float32)
5616+
v_6 = v_5 * v_4
5617+
v_7 = tl.cast(v_6, tl.float16)
5618+
tl.store(grad_input + (indices_0[:, None] * grad_input_stride_0 + indices_2[None, :] * grad_input_stride_1), v_7, mask_0[:, None] & mask_2[None, :])
5619+
5620+
def softmax_bwd(grad_output: torch.Tensor, softmax_output: torch.Tensor, *, _launcher=_default_launcher):
5621+
"""
5622+
Helion kernel implementing softmax backward pass.
5623+
5624+
dy/dx = softmax_output * (grad_output - sum(softmax_output * grad_output))
5625+
5626+
Args:
5627+
grad_output (torch.Tensor): Gradient from downstream layers of shape [m, n]
5628+
softmax_output (torch.Tensor): Output from forward softmax pass of shape [m, n]
5629+
5630+
Returns:
5631+
torch.Tensor: Gradient with respect to input of shape [m, n]
5632+
"""
5633+
m, n = grad_output.size()
5634+
grad_input = torch.empty_like(grad_output)
5635+
_BLOCK_SIZE_0 = 16
5636+
_BLOCK_SIZE_1 = 16
5637+
_BLOCK_SIZE_2 = 16
5638+
_launcher(_helion_softmax_bwd, (triton.cdiv(m, _BLOCK_SIZE_0),), softmax_output, grad_output, grad_input, grad_input.stride(0), grad_input.stride(1), grad_output.stride(0), grad_output.stride(1), softmax_output.stride(0), softmax_output.stride(1), m, n, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
5639+
return grad_input
5640+
55795641
--- assertExpectedJournal(TestExamples.test_softmax_decomposed)
55805642
from __future__ import annotations
55815643

test/test_examples.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -973,6 +973,32 @@ def test_layernorm_bwd(self):
973973
if idx == 0:
974974
self.assertExpectedJournal(journal)
975975

976+
def test_softmax_bwd(self):
977+
m, n = 2048, 2048
978+
x = torch.randn([m, n], device=DEVICE, dtype=torch.float16, requires_grad=True)
979+
grad_out = torch.randn([m, n], device=DEVICE, dtype=torch.float16)
980+
981+
from examples.softmax import softmax_two_pass
982+
983+
config = helion.Config(block_size=[128, 128], num_warps=4, num_stages=3)
984+
configured_kernel = helion.kernel(softmax_two_pass.fn, config=config)
985+
y = configured_kernel(x)
986+
987+
x_torch = x.detach().clone().requires_grad_(True)
988+
y_torch = torch.nn.functional.softmax(x_torch, dim=-1)
989+
y_torch.backward(grad_out)
990+
991+
self.assertExpectedJournal(
992+
check_example(
993+
"softmax",
994+
(grad_out, y),
995+
x_torch.grad,
996+
fn_name="softmax_bwd",
997+
rtol=1e-3,
998+
atol=1e-3,
999+
)
1000+
)
1001+
9761002
def test_layernorm_without_bias(self):
9771003
x = -2.3 + 0.5 * torch.randn([32, 64], device=DEVICE, dtype=torch.float16)
9781004
weight = torch.randn([64], device=DEVICE, dtype=torch.float16)

0 commit comments

Comments
 (0)