@@ -5952,7 +5952,7 @@ import triton.language as tl
59525952from helion.runtime import default_launcher as _default_launcher
59535953
59545954@triton.jit
5955- def _helion_swiglu (a_flat, b_flat, out_flat, _BLOCK_SIZE_0: tl.constexpr):
5955+ def _helion_swiglu_fwd (a_flat, b_flat, out_flat, _BLOCK_SIZE_0: tl.constexpr):
59565956 # src[swiglu.py:N]: for tile_idx in hl.tile(total_elements):
59575957 pid_0 = tl.program_id(0)
59585958 offset_0 = pid_0 * _BLOCK_SIZE_0
@@ -5972,7 +5972,7 @@ def _helion_swiglu(a_flat, b_flat, out_flat, _BLOCK_SIZE_0: tl.constexpr):
59725972 # src[swiglu.py:N]: out_flat[tile_idx] = result
59735973 tl.store(out_flat + indices_0 * 1, v_4, None)
59745974
5975- def swiglu (a: Tensor, b: Tensor, *, _launcher=_default_launcher):
5975+ def swiglu_fwd (a: Tensor, b: Tensor, *, _launcher=_default_launcher):
59765976 """
59775977 Performs SwiGLU operation: SiLU(a) * b where SiLU is the Swish activation.
59785978
@@ -6006,10 +6006,86 @@ def swiglu(a: Tensor, b: Tensor, *, _launcher=_default_launcher):
60066006 # src[swiglu.py:N]: # Load input values and convert to float32 for computation
60076007 # src[swiglu.py:N]: a_vals = a_flat[tile_idx].to(torch.float32)
60086008 # src[swiglu.py:N-N]: ...
6009- _launcher(_helion_swiglu , (triton.cdiv(1048576, _BLOCK_SIZE_0),), a_flat, b_flat, out_flat, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
6009+ _launcher(_helion_swiglu_fwd , (triton.cdiv(1048576, _BLOCK_SIZE_0),), a_flat, b_flat, out_flat, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
60106010 # src[swiglu.py:N]: return out
60116011 return out
60126012
6013+ --- assertExpectedJournal(TestExamples.test_swiglu_bwd)
6014+ from __future__ import annotations
6015+
6016+ import torch
6017+ import triton
6018+ import triton.language as tl
6019+ from helion.runtime import default_launcher as _default_launcher
6020+
6021+ @triton.jit
6022+ def _helion_swiglu_bwd(x1_flat, gout_flat, dx2_flat, x2_flat, dx1_flat, _BLOCK_SIZE_0: tl.constexpr):
6023+ # src[swiglu.py:N]: for tile in hl.tile(x1.numel()):
6024+ pid_0 = tl.program_id(0)
6025+ offset_0 = pid_0 * _BLOCK_SIZE_0
6026+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
6027+ # src[swiglu.py:N]: x1_vals = x1_flat[tile].to(torch.float32)
6028+ load = tl.load(x1_flat + indices_0 * 1, None)
6029+ v_0 = tl.cast(load, tl.float32)
6030+ # src[swiglu.py:N]: gout_vals = gout_flat[tile].to(torch.float32)
6031+ load_1 = tl.load(gout_flat + indices_0 * 1, None)
6032+ v_1 = tl.cast(load_1, tl.float32)
6033+ # src[swiglu.py:N]: dx2_vals = x1_vals * torch.sigmoid(x1_vals) * gout_vals
6034+ v_2 = tl.sigmoid(tl.cast(v_0, tl.float32))
6035+ v_3 = v_0 * v_2
6036+ v_4 = v_3 * v_1
6037+ # src[swiglu.py:N]: dx2_flat[tile] = dx2_vals.to(x2.dtype)
6038+ v_5 = tl.cast(v_4, tl.bfloat16)
6039+ tl.store(dx2_flat + indices_0 * 1, v_5, None)
6040+ # src[swiglu.py:N]: x2_vals = x2_flat[tile].to(torch.float32)
6041+ load_2 = tl.load(x2_flat + indices_0 * 1, None)
6042+ v_6 = tl.cast(load_2, tl.float32)
6043+ # src[swiglu.py:N]: x1_exp = torch.exp(x1_vals)
6044+ v_7 = libdevice.exp(v_0)
6045+ # src[swiglu.py:N]: x1_exp_plus1 = x1_exp + 1
6046+ v_8 = 1.0
6047+ v_9 = v_7 + v_8
6048+ # src[swiglu.py:N]: dextra = x1_exp / x1_exp_plus1 + x1_vals * x1_exp / x1_exp_plus1 / x1_exp_plus1
6049+ v_10 = v_7 / v_9
6050+ v_11 = v_0 * v_7
6051+ v_12 = v_11 / v_9
6052+ v_13 = v_12 / v_9
6053+ v_14 = v_10 + v_13
6054+ # src[swiglu.py:N]: dx1_vals = gout_vals * x2_vals * dextra
6055+ v_15 = v_1 * v_6
6056+ v_16 = v_15 * v_14
6057+ # src[swiglu.py:N]: dx1_flat[tile] = dx1_vals.to(x1.dtype)
6058+ v_17 = tl.cast(v_16, tl.bfloat16)
6059+ tl.store(dx1_flat + indices_0 * 1, v_17, None)
6060+
6061+ def swiglu_bwd(gout: Tensor, x1: Tensor, x2: Tensor, *, _launcher=_default_launcher):
6062+ """
6063+ Implement the backward formula for swiglu.
6064+ """
6065+ # src[swiglu.py:N]: dx1 = torch.empty_like(x1)
6066+ dx1 = torch.empty_like(x1)
6067+ # src[swiglu.py:N]: dx2 = torch.empty_like(x2)
6068+ dx2 = torch.empty_like(x2)
6069+ # src[swiglu.py:N]: gout_flat = gout.view(-1)
6070+ gout_flat = gout.view(-1)
6071+ # src[swiglu.py:N]: x1_flat = x1.view(-1)
6072+ x1_flat = x1.view(-1)
6073+ # src[swiglu.py:N]: x2_flat = x2.view(-1)
6074+ x2_flat = x2.view(-1)
6075+ # src[swiglu.py:N]: dx1_flat = dx1.view(-1)
6076+ dx1_flat = dx1.view(-1)
6077+ # src[swiglu.py:N]: dx2_flat = dx2.view(-1)
6078+ dx2_flat = dx2.view(-1)
6079+ # src[swiglu.py:N]: for tile in hl.tile(x1.numel()):
6080+ _BLOCK_SIZE_0 = 1024
6081+ # src[swiglu.py:N]: for tile in hl.tile(x1.numel()):
6082+ # src[swiglu.py:N]: x1_vals = x1_flat[tile].to(torch.float32)
6083+ # src[swiglu.py:N]: gout_vals = gout_flat[tile].to(torch.float32)
6084+ # src[swiglu.py:N-N]: ...
6085+ _launcher(_helion_swiglu_bwd, (triton.cdiv(1024, _BLOCK_SIZE_0),), x1_flat, gout_flat, dx2_flat, x2_flat, dx1_flat, _BLOCK_SIZE_0, num_warps=4, num_stages=2)
6086+ # src[swiglu.py:N]: return dx1, dx2
6087+ return (dx1, dx2)
6088+
60136089--- assertExpectedJournal(TestExamples.test_template_via_closure0)
60146090from __future__ import annotations
60156091
0 commit comments