@@ -3420,6 +3420,68 @@ def softmax(x: torch.Tensor, *, _launcher=_default_launcher):
34203420 _launcher(_helion_softmax, (n,), x, out, out.size(0), out.size(1), x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), _m, _RDIM_SIZE_1, 1, num_warps=4, num_stages=1)
34213421 return out
34223422
3423+ --- assertExpectedJournal(TestExamples.test_softmax_bwd)
3424+ from __future__ import annotations
3425+
3426+ import torch
3427+ import triton
3428+ import triton.language as tl
3429+ from helion.runtime import default_launcher as _default_launcher
3430+
3431+ @triton.jit
3432+ def _helion_softmax_bwd(softmax_output, grad_output, grad_input, grad_input_stride_0, grad_input_stride_1, grad_output_stride_0, grad_output_stride_1, softmax_output_stride_0, softmax_output_stride_1, m, n, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
3433+ pid_0 = tl.program_id(0)
3434+ offset_0 = pid_0 * _BLOCK_SIZE_0
3435+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
3436+ mask_0 = indices_0 < m
3437+ sum_per_row = tl.full([_BLOCK_SIZE_0], 0.0, tl.float32)
3438+ for offset_1 in tl.range(0, n.to(tl.int32), _BLOCK_SIZE_1):
3439+ indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
3440+ mask_1 = indices_1 < n
3441+ sum_per_row_copy = sum_per_row
3442+ sum_per_row_copy_0 = sum_per_row_copy
3443+ load = tl.load(softmax_output + (indices_0[:, None] * softmax_output_stride_0 + indices_1[None, :] * softmax_output_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
3444+ load_1 = tl.load(grad_output + (indices_0[:, None] * grad_output_stride_0 + indices_1[None, :] * grad_output_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
3445+ v_0 = load * load_1
3446+ sum_1 = tl.cast(tl.sum(v_0, 1), tl.float16)
3447+ v_1 = tl.cast(sum_1, tl.float32)
3448+ sum_per_row = sum_per_row_copy_0 + v_1
3449+ for offset_2 in tl.range(0, n.to(tl.int32), _BLOCK_SIZE_2):
3450+ indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
3451+ mask_2 = indices_2 < n
3452+ sum_per_row_copy_1 = sum_per_row
3453+ sum_per_row_copy_1_0 = sum_per_row_copy_1
3454+ load_2 = tl.load(softmax_output + (indices_0[:, None] * softmax_output_stride_0 + indices_2[None, :] * softmax_output_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
3455+ load_3 = tl.load(grad_output + (indices_0[:, None] * grad_output_stride_0 + indices_2[None, :] * grad_output_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
3456+ subscript = sum_per_row_copy_1_0[:, None]
3457+ v_3 = tl.cast(load_3, tl.float32)
3458+ v_4 = v_3 - subscript
3459+ v_5 = tl.cast(load_2, tl.float32)
3460+ v_6 = v_5 * v_4
3461+ v_7 = tl.cast(v_6, tl.float16)
3462+ tl.store(grad_input + (indices_0[:, None] * grad_input_stride_0 + indices_2[None, :] * grad_input_stride_1), v_7, mask_0[:, None] & mask_2[None, :])
3463+
3464+ def softmax_bwd(grad_output: torch.Tensor, softmax_output: torch.Tensor, *, _launcher=_default_launcher):
3465+ """
3466+ Helion kernel implementing softmax backward pass.
3467+
3468+ dy/dx = softmax_output * (grad_output - sum(softmax_output * grad_output))
3469+
3470+ Args:
3471+ grad_output (torch.Tensor): Gradient from downstream layers of shape [m, n]
3472+ softmax_output (torch.Tensor): Output from forward softmax pass of shape [m, n]
3473+
3474+ Returns:
3475+ torch.Tensor: Gradient with respect to input of shape [m, n]
3476+ """
3477+ m, n = grad_output.size()
3478+ grad_input = torch.empty_like(grad_output)
3479+ _BLOCK_SIZE_0 = 16
3480+ _BLOCK_SIZE_1 = 16
3481+ _BLOCK_SIZE_2 = 16
3482+ _launcher(_helion_softmax_bwd, (triton.cdiv(m, _BLOCK_SIZE_0),), softmax_output, grad_output, grad_input, grad_input.stride(0), grad_input.stride(1), grad_output.stride(0), grad_output.stride(1), softmax_output.stride(0), softmax_output.stride(1), m, n, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
3483+ return grad_input
3484+
34233485--- assertExpectedJournal(TestExamples.test_softmax_decomposed)
34243486from __future__ import annotations
34253487
0 commit comments