Skip to content

Commit 157be88

Browse files
authored
Fix GRPO loss example unit tests (#1079)
1 parent dee9f57 commit 157be88

File tree

3 files changed

+25
-14
lines changed

3 files changed

+25
-14
lines changed

examples/grpo_loss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def torch_grpo_loss(
135135

136136

137137
@helion.kernel(
138-
ignore_warnings=[helion.exc.TensorOperationInWrapper], autotune_effort="quick"
138+
ignore_warnings=[helion.exc.TensorOperationInWrapper],
139139
)
140140
def grpo_loss_forward(
141141
logits: torch.Tensor, # [B, L+1, V] input logits
@@ -227,7 +227,7 @@ def grpo_loss_forward(
227227

228228

229229
@helion.kernel(
230-
ignore_warnings=[helion.exc.TensorOperationInWrapper], autotune_effort="quick"
230+
ignore_warnings=[helion.exc.TensorOperationInWrapper],
231231
)
232232
def grpo_loss_backward(
233233
grad_output: torch.Tensor, # [B, L] gradient from downstream

test/test_examples.expected

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2218,17 +2218,18 @@ def _helion_grpo_loss_backward(completion_ids, lse, selected_logits, old_logp, a
22182218
pid_1 = tl.program_id(0) // num_blocks_0
22192219
offset_0 = pid_0 * _BLOCK_SIZE_0
22202220
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
2221+
mask_0 = indices_0 < 2
22212222
offset_1 = pid_1 * _BLOCK_SIZE_1
22222223
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
22232224
# src[grpo_loss.py:N]: completion_id = completion_ids[tile_b, tile_l]
2224-
completion_id = tl.load(completion_ids + (indices_0[:, None] * 64 + indices_1[None, :] * 1), None)
2225+
completion_id = tl.load(completion_ids + (indices_0[:, None] * 64 + indices_1[None, :] * 1), mask_0[:, None], other=0)
22252226
# src[grpo_loss.py:N]: log_sum_exp = lse[tile_b, tile_l]
2226-
log_sum_exp = tl.load(lse + (indices_0[:, None] * 64 + indices_1[None, :] * 1), None)
2227+
log_sum_exp = tl.load(lse + (indices_0[:, None] * 64 + indices_1[None, :] * 1), mask_0[:, None], other=0)
22272228
# src[grpo_loss.py:N]: logp = selected_logits[tile_b, tile_l] - log_sum_exp
2228-
load_2 = tl.load(selected_logits + (indices_0[:, None] * 64 + indices_1[None, :] * 1), None)
2229+
load_2 = tl.load(selected_logits + (indices_0[:, None] * 64 + indices_1[None, :] * 1), mask_0[:, None], other=0)
22292230
v_0 = load_2 - log_sum_exp
22302231
# src[grpo_loss.py:N]: old_logp_val = old_logp[tile_b, tile_l]
2231-
old_logp_val = tl.load(old_logp + (indices_0[:, None] * 64 + indices_1[None, :] * 1), None)
2232+
old_logp_val = tl.load(old_logp + (indices_0[:, None] * 64 + indices_1[None, :] * 1), mask_0[:, None], other=0)
22322233
# src[grpo_loss.py:N]: coef_1 = torch.exp(logp - old_logp_val)
22332234
v_1 = v_0 - old_logp_val
22342235
v_2 = libdevice.exp(v_1)
@@ -2238,7 +2239,7 @@ def _helion_grpo_loss_backward(completion_ids, lse, selected_logits, old_logp, a
22382239
v_3 = triton_helpers.maximum(v_2, sub_2)
22392240
v_4 = triton_helpers.minimum(v_3, add)
22402241
# src[grpo_loss.py:N]: advantage = advantages[tile_b]
2241-
advantage = tl.load(advantages + indices_0 * 1, None)
2242+
advantage = tl.load(advantages + indices_0 * 1, mask_0, other=0)
22422243
# src[grpo_loss.py:N]: per_token_loss1 = coef_1 * advantage[:, None]
22432244
subscript = advantage[:, None]
22442245
v_5 = v_2 * subscript
@@ -2263,7 +2264,7 @@ def _helion_grpo_loss_backward(completion_ids, lse, selected_logits, old_logp, a
22632264
v_0_copy_0 = v_0_copy
22642265
v_10_copy_0 = v_10_copy
22652266
# src[grpo_loss.py:N]: ref_logp_val = ref_logp[tile_b, tile_l]
2266-
ref_logp_val = tl.load(ref_logp + (indices_0[:, None] * 64 + indices_1[None, :] * 1), None)
2267+
ref_logp_val = tl.load(ref_logp + (indices_0[:, None] * 64 + indices_1[None, :] * 1), mask_0[:, None], other=0)
22672268
# src[grpo_loss.py:N]: dlogp += beta * (1 - torch.exp(ref_logp_val - logp))
22682269
v_11 = ref_logp_val - v_0_copy_0
22692270
v_12 = libdevice.exp(v_11)
@@ -2272,11 +2273,11 @@ def _helion_grpo_loss_backward(completion_ids, lse, selected_logits, old_logp, a
22722273
v_15 = v_14 * beta
22732274
v_10 = v_10_copy_0 + v_15
22742275
# src[grpo_loss.py:N]: dlogp = dlogp * grad_output[tile_b, tile_l] / temperature
2275-
load_5 = tl.load(grad_output + (indices_0[:, None] * 64 + indices_1[None, :] * 1), None)
2276+
load_5 = tl.load(grad_output + (indices_0[:, None] * 64 + indices_1[None, :] * 1), mask_0[:, None], other=0)
22762277
v_17 = v_10 * load_5
22772278
v_18 = v_17 / temperature
22782279
# src[grpo_loss.py:N]: mask_val = completion_mask[tile_b, tile_l]
2279-
mask_val = tl.load(completion_mask + (indices_0[:, None] * 64 + indices_1[None, :] * 1), None)
2280+
mask_val = tl.load(completion_mask + (indices_0[:, None] * 64 + indices_1[None, :] * 1), mask_0[:, None], other=0)
22802281
# src[grpo_loss.py:N]: dlogp *= mask_val
22812282
v_19 = v_18 * mask_val
22822283
# src[grpo_loss.py:N]: for tile_v in hl.tile(V):
@@ -2292,7 +2293,7 @@ def _helion_grpo_loss_backward(completion_ids, lse, selected_logits, old_logp, a
22922293
completion_id_copy_0 = completion_id_copy
22932294
v_19_copy_0 = v_19_copy
22942295
# src[grpo_loss.py:N]: logits_fwd[tile_b, tile_l, tile_v].to(torch.float32) / temperature
2295-
load = tl.load(logits_fwd + (indices_0[:, None, None] * 8320 + indices_1[None, :, None] * 128 + indices_2[None, None, :] * 1), None)
2296+
load = tl.load(logits_fwd + (indices_0[:, None, None] * 8320 + indices_1[None, :, None] * 128 + indices_2[None, None, :] * 1), mask_0[:, None, None], other=0)
22962297
v_20 = tl.cast(load, tl.float32)
22972298
v_21 = v_20 / temperature
22982299
# src[grpo_loss.py:N]: probs = torch.exp(logits_tile - log_sum_exp[:, :, None])
@@ -2320,7 +2321,7 @@ def _helion_grpo_loss_backward(completion_ids, lse, selected_logits, old_logp, a
23202321
v_31 = tl.where(v_25, v_28, v_30)
23212322
# src[grpo_loss.py:N]: grad_logits[tile_b, tile_l, tile_v] = grad_logits_tile
23222323
v_32 = tl.cast(v_31, tl.bfloat16)
2323-
tl.store(grad_logits + (indices_0[:, None, None] * 8320 + indices_1[None, :, None] * 128 + indices_2[None, None, :] * 1), v_32, None)
2324+
tl.store(grad_logits + (indices_0[:, None, None] * 8320 + indices_1[None, :, None] * 128 + indices_2[None, None, :] * 1), v_32, mask_0[:, None, None])
23242325

23252326
def grpo_loss_backward(grad_output: torch.Tensor, logits: torch.Tensor, selected_logits: torch.Tensor, completion_ids: torch.Tensor, old_logp: torch.Tensor | None, ref_logp: torch.Tensor | None, advantages: torch.Tensor, completion_mask: torch.Tensor | None, lse: torch.Tensor, temperature: float, beta: float, eps_low: float, eps_high: float, *, _launcher=_default_launcher):
23262327
"""
@@ -2351,7 +2352,7 @@ def grpo_loss_backward(grad_output: torch.Tensor, logits: torch.Tensor, selected
23512352
# src[grpo_loss.py:N]: grad_logits = torch.zeros_like(logits)
23522353
grad_logits = torch.zeros_like(logits)
23532354
# src[grpo_loss.py:N]: for tile_b, tile_l in hl.tile([B, L]):
2354-
_BLOCK_SIZE_0 = 2
2355+
_BLOCK_SIZE_0 = 4
23552356
_BLOCK_SIZE_1 = 16
23562357
# src[grpo_loss.py:N]: for tile_v in hl.tile(V):
23572358
# src[grpo_loss.py:N]: logits_tile = (

test/test_examples.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1723,6 +1723,7 @@ def test_grpo_loss_fwd(self):
17231723
fn_name="grpo_loss_forward",
17241724
rtol=1e-2,
17251725
atol=1e-1,
1726+
block_sizes=[4, 16, 16],
17261727
)
17271728
)
17281729

@@ -1748,11 +1749,13 @@ def test_grpo_loss_bwd(self):
17481749
from examples.grpo_loss import extract_selected_logits_pytorch
17491750
from examples.grpo_loss import grpo_loss_forward
17501751

1752+
from helion._testing import code_and_output
1753+
17511754
selected_logits = extract_selected_logits_pytorch(
17521755
logits[:, :-1, :], completion_ids, temperature
17531756
)
17541757

1755-
_, _, _, lse = grpo_loss_forward(
1758+
forward_args = (
17561759
logits,
17571760
selected_logits,
17581761
old_logp,
@@ -1765,6 +1768,12 @@ def test_grpo_loss_bwd(self):
17651768
eps_high,
17661769
)
17671770

1771+
_, (_, _, _, lse) = code_and_output(
1772+
grpo_loss_forward,
1773+
forward_args,
1774+
block_sizes=[4, 16, 16],
1775+
)
1776+
17681777
grad_output = torch.randn(B, L, device=DEVICE, dtype=torch.float32)
17691778

17701779
logits_torch = logits.detach().clone().float().requires_grad_(True)
@@ -1809,6 +1818,7 @@ def test_grpo_loss_bwd(self):
18091818
fn_name="grpo_loss_backward",
18101819
rtol=1e-2,
18111820
atol=1e-1,
1821+
block_sizes=[4, 16, 16],
18121822
)
18131823
)
18141824

0 commit comments

Comments
 (0)