@@ -62,6 +62,128 @@ def reduce_kernel(x: torch.Tensor, fn: Callable[[torch.Tensor], torch.Tensor], o
6262 _launcher(_reduce_kernel_kernel, (n,), x, out, out.size(0), x.size(0), x.size(1), out.stride(0), x.stride(0), x.stride(1), _m, _REDUCTION_BLOCK_1, num_warps=4, num_stages=3)
6363 return out
6464
65+ --- assertExpectedJournal(TestReductions.test_fp16_var_mean)
66+ from __future__ import annotations
67+
68+ import torch
69+ import triton
70+ import triton.language as tl
71+ from torch._inductor.runtime.triton_compat import libdevice
72+ from helion.runtime import default_launcher as _default_launcher
73+
74+ @triton.jit
75+ def _layer_norm_fwd_repro_kernel(x, weight, bias, out, eps, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr):
76+ pid_0 = tl.program_id(0)
77+ offset_0 = pid_0 * _BLOCK_SIZE_0
78+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
79+ indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
80+ x_part = tl.load(x + (indices_0[:, None] * 64 + indices_1[None, :] * 1), None)
81+ v_0 = x_part.to(tl.float32)
82+ var_mean_extra = tl.reshape(tl.sum(v_0, 1), [_BLOCK_SIZE_0, 1])
83+ v_1 = 64
84+ v_2 = var_mean_extra / v_1.to(tl.float32)
85+ v_3 = x_part.to(tl.float32)
86+ v_4 = v_3 - v_2
87+ v_5 = v_4 * v_4
88+ var_mean_extra_2 = tl.reshape(tl.sum(v_5, 1), [_BLOCK_SIZE_0, 1])
89+ v_6 = 64
90+ v_7 = var_mean_extra_2 / v_6.to(tl.float32)
91+ v_8 = v_7.to(tl.float16)
92+ v_9 = v_2.to(tl.float16)
93+ v_10 = x_part - v_9
94+ v_11 = v_8.to(tl.float32)
95+ v_12 = v_11 + eps
96+ v_13 = libdevice.rsqrt(v_12)
97+ v_14 = v_10.to(tl.float32)
98+ v_15 = v_14 * v_13
99+ load_1 = tl.load(weight + indices_1 * 1, None)
100+ v_16 = load_1.to(tl.float32)
101+ v_17 = v_16[None, :]
102+ v_18 = v_15 * v_17
103+ load_2 = tl.load(bias + indices_1 * 1, None)
104+ v_19 = load_2.to(tl.float32)
105+ v_20 = v_19[None, :]
106+ v_21 = v_18 + v_20
107+ v_22 = v_21.to(tl.float16)
108+ tl.store(out + (indices_0[:, None] * 64 + indices_1[None, :] * 1), v_22, None)
109+
110+ def layer_norm_fwd_repro(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float=1e-05, *, _launcher=_default_launcher):
111+ m, n = x.size()
112+ out = torch.empty([m, n], dtype=torch.float16, device=x.device)
113+ _BLOCK_SIZE_0 = 32
114+ _RDIM_SIZE_1 = 64
115+ _launcher(_layer_norm_fwd_repro_kernel, (triton.cdiv(32, _BLOCK_SIZE_0),), x, weight, bias, out, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3)
116+ return out
117+
118+ --- assertExpectedJournal(TestReductions.test_fp16_var_mean)
119+ from __future__ import annotations
120+
121+ import torch
122+ import triton
123+ import triton.language as tl
124+ from torch._inductor.runtime.triton_compat import libdevice
125+ from helion.runtime import default_launcher as _default_launcher
126+
127+ @triton.jit
128+ def _layer_norm_fwd_repro_kernel(x, weight, bias, out, eps, _BLOCK_SIZE_0: tl.constexpr, _REDUCTION_BLOCK_1: tl.constexpr):
129+ pid_0 = tl.program_id(0)
130+ offset_0 = pid_0 * _BLOCK_SIZE_0
131+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
132+ var_mean_extra_acc = tl.full([_BLOCK_SIZE_0, _REDUCTION_BLOCK_1], 0, tl.float32)
133+ for roffset_1 in tl.range(0, 64, _REDUCTION_BLOCK_1):
134+ rindex_1 = roffset_1 + tl.arange(0, _REDUCTION_BLOCK_1).to(tl.int32)
135+ x_part = tl.load(x + (indices_0[:, None] * 64 + rindex_1[None, :] * 1), None)
136+ v_0 = x_part.to(tl.float32)
137+ v_1 = var_mean_extra_acc + v_0
138+ var_mean_extra_acc = v_1
139+ var_mean_extra = tl.reshape(tl.sum(var_mean_extra_acc, 1), [_BLOCK_SIZE_0, 1])
140+ v_2 = 64
141+ v_3 = var_mean_extra / v_2.to(tl.float32)
142+ var_mean_extra_2_acc = tl.full([_BLOCK_SIZE_0, _REDUCTION_BLOCK_1], 0, tl.float32)
143+ for roffset_1 in tl.range(0, 64, _REDUCTION_BLOCK_1):
144+ rindex_1 = roffset_1 + tl.arange(0, _REDUCTION_BLOCK_1).to(tl.int32)
145+ v_3_copy = v_3
146+ x_part_1 = tl.load(x + (indices_0[:, None] * 64 + rindex_1[None, :] * 1), None)
147+ v_4 = x_part_1.to(tl.float32)
148+ v_5 = v_4 - v_3_copy
149+ v_6 = v_5 * v_5
150+ v_7 = var_mean_extra_2_acc + v_6
151+ var_mean_extra_2_acc = v_7
152+ var_mean_extra_2 = tl.reshape(tl.sum(var_mean_extra_2_acc, 1), [_BLOCK_SIZE_0, 1])
153+ v_8 = 64
154+ v_9 = var_mean_extra_2 / v_8.to(tl.float32)
155+ v_10 = v_9.to(tl.float16)
156+ v_11 = v_3.to(tl.float16)
157+ v_12 = v_10.to(tl.float32)
158+ v_13 = v_12 + eps
159+ v_14 = libdevice.rsqrt(v_13)
160+ for roffset_1 in tl.range(0, 64, _REDUCTION_BLOCK_1):
161+ rindex_1 = roffset_1 + tl.arange(0, _REDUCTION_BLOCK_1).to(tl.int32)
162+ v_11_copy = v_11
163+ v_14_copy = v_14
164+ x_part_2 = tl.load(x + (indices_0[:, None] * 64 + rindex_1[None, :] * 1), None)
165+ v_15 = x_part_2 - v_11_copy
166+ v_16 = v_15.to(tl.float32)
167+ v_17 = v_16 * v_14_copy
168+ load_1 = tl.load(weight + rindex_1 * 1, None)
169+ v_18 = load_1.to(tl.float32)
170+ v_19 = v_18[None, :]
171+ v_20 = v_17 * v_19
172+ load_2 = tl.load(bias + rindex_1 * 1, None)
173+ v_21 = load_2.to(tl.float32)
174+ v_22 = v_21[None, :]
175+ v_23 = v_20 + v_22
176+ v_24 = v_23.to(tl.float16)
177+ tl.store(out + (indices_0[:, None] * 64 + rindex_1[None, :] * 1), v_24, None)
178+
179+ def layer_norm_fwd_repro(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float=1e-05, *, _launcher=_default_launcher):
180+ m, n = x.size()
181+ out = torch.empty([m, n], dtype=torch.float16, device=x.device)
182+ _BLOCK_SIZE_0 = 32
183+ _REDUCTION_BLOCK_1 = 8
184+ _launcher(_layer_norm_fwd_repro_kernel, (triton.cdiv(32, _BLOCK_SIZE_0),), x, weight, bias, out, eps, _BLOCK_SIZE_0, _REDUCTION_BLOCK_1, num_warps=4, num_stages=3)
185+ return out
186+
65187--- assertExpectedJournal(TestReductions.test_mean)
66188def reduce_kernel(x: torch.Tensor, fn: Callable[[torch.Tensor], torch.Tensor], out_dtype=torch.float32):
67189 # Call: SequenceType((SymIntType(s77), SymIntType(s27))) SourceOrigin(location=<SourceLocation test_reductions.py:48>)
0 commit comments