@@ -2218,17 +2218,18 @@ def _helion_grpo_loss_backward(completion_ids, lse, selected_logits, old_logp, a
22182218 pid_1 = tl.program_id(0) // num_blocks_0
22192219 offset_0 = pid_0 * _BLOCK_SIZE_0
22202220 indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
2221+ mask_0 = indices_0 < 2
22212222 offset_1 = pid_1 * _BLOCK_SIZE_1
22222223 indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
22232224 # src[grpo_loss.py:N]: completion_id = completion_ids[tile_b, tile_l]
2224- completion_id = tl.load(completion_ids + (indices_0[:, None] * 64 + indices_1[None, :] * 1), None)
2225+ completion_id = tl.load(completion_ids + (indices_0[:, None] * 64 + indices_1[None, :] * 1), mask_0[:, None], other=0 )
22252226 # src[grpo_loss.py:N]: log_sum_exp = lse[tile_b, tile_l]
2226- log_sum_exp = tl.load(lse + (indices_0[:, None] * 64 + indices_1[None, :] * 1), None)
2227+ log_sum_exp = tl.load(lse + (indices_0[:, None] * 64 + indices_1[None, :] * 1), mask_0[:, None], other=0 )
22272228 # src[grpo_loss.py:N]: logp = selected_logits[tile_b, tile_l] - log_sum_exp
2228- load_2 = tl.load(selected_logits + (indices_0[:, None] * 64 + indices_1[None, :] * 1), None)
2229+ load_2 = tl.load(selected_logits + (indices_0[:, None] * 64 + indices_1[None, :] * 1), mask_0[:, None], other=0 )
22292230 v_0 = load_2 - log_sum_exp
22302231 # src[grpo_loss.py:N]: old_logp_val = old_logp[tile_b, tile_l]
2231- old_logp_val = tl.load(old_logp + (indices_0[:, None] * 64 + indices_1[None, :] * 1), None)
2232+ old_logp_val = tl.load(old_logp + (indices_0[:, None] * 64 + indices_1[None, :] * 1), mask_0[:, None], other=0 )
22322233 # src[grpo_loss.py:N]: coef_1 = torch.exp(logp - old_logp_val)
22332234 v_1 = v_0 - old_logp_val
22342235 v_2 = libdevice.exp(v_1)
@@ -2238,7 +2239,7 @@ def _helion_grpo_loss_backward(completion_ids, lse, selected_logits, old_logp, a
22382239 v_3 = triton_helpers.maximum(v_2, sub_2)
22392240 v_4 = triton_helpers.minimum(v_3, add)
22402241 # src[grpo_loss.py:N]: advantage = advantages[tile_b]
2241- advantage = tl.load(advantages + indices_0 * 1, None )
2242+ advantage = tl.load(advantages + indices_0 * 1, mask_0, other=0 )
22422243 # src[grpo_loss.py:N]: per_token_loss1 = coef_1 * advantage[:, None]
22432244 subscript = advantage[:, None]
22442245 v_5 = v_2 * subscript
@@ -2263,7 +2264,7 @@ def _helion_grpo_loss_backward(completion_ids, lse, selected_logits, old_logp, a
22632264 v_0_copy_0 = v_0_copy
22642265 v_10_copy_0 = v_10_copy
22652266 # src[grpo_loss.py:N]: ref_logp_val = ref_logp[tile_b, tile_l]
2266- ref_logp_val = tl.load(ref_logp + (indices_0[:, None] * 64 + indices_1[None, :] * 1), None)
2267+ ref_logp_val = tl.load(ref_logp + (indices_0[:, None] * 64 + indices_1[None, :] * 1), mask_0[:, None], other=0 )
22672268 # src[grpo_loss.py:N]: dlogp += beta * (1 - torch.exp(ref_logp_val - logp))
22682269 v_11 = ref_logp_val - v_0_copy_0
22692270 v_12 = libdevice.exp(v_11)
@@ -2272,11 +2273,11 @@ def _helion_grpo_loss_backward(completion_ids, lse, selected_logits, old_logp, a
22722273 v_15 = v_14 * beta
22732274 v_10 = v_10_copy_0 + v_15
22742275 # src[grpo_loss.py:N]: dlogp = dlogp * grad_output[tile_b, tile_l] / temperature
2275- load_5 = tl.load(grad_output + (indices_0[:, None] * 64 + indices_1[None, :] * 1), None)
2276+ load_5 = tl.load(grad_output + (indices_0[:, None] * 64 + indices_1[None, :] * 1), mask_0[:, None], other=0 )
22762277 v_17 = v_10 * load_5
22772278 v_18 = v_17 / temperature
22782279 # src[grpo_loss.py:N]: mask_val = completion_mask[tile_b, tile_l]
2279- mask_val = tl.load(completion_mask + (indices_0[:, None] * 64 + indices_1[None, :] * 1), None)
2280+ mask_val = tl.load(completion_mask + (indices_0[:, None] * 64 + indices_1[None, :] * 1), mask_0[:, None], other=0 )
22802281 # src[grpo_loss.py:N]: dlogp *= mask_val
22812282 v_19 = v_18 * mask_val
22822283 # src[grpo_loss.py:N]: for tile_v in hl.tile(V):
@@ -2292,7 +2293,7 @@ def _helion_grpo_loss_backward(completion_ids, lse, selected_logits, old_logp, a
22922293 completion_id_copy_0 = completion_id_copy
22932294 v_19_copy_0 = v_19_copy
22942295 # src[grpo_loss.py:N]: logits_fwd[tile_b, tile_l, tile_v].to(torch.float32) / temperature
2295- load = tl.load(logits_fwd + (indices_0[:, None, None] * 8320 + indices_1[None, :, None] * 128 + indices_2[None, None, :] * 1), None)
2296+ load = tl.load(logits_fwd + (indices_0[:, None, None] * 8320 + indices_1[None, :, None] * 128 + indices_2[None, None, :] * 1), mask_0[:, None, None], other=0 )
22962297 v_20 = tl.cast(load, tl.float32)
22972298 v_21 = v_20 / temperature
22982299 # src[grpo_loss.py:N]: probs = torch.exp(logits_tile - log_sum_exp[:, :, None])
@@ -2320,7 +2321,7 @@ def _helion_grpo_loss_backward(completion_ids, lse, selected_logits, old_logp, a
23202321 v_31 = tl.where(v_25, v_28, v_30)
23212322 # src[grpo_loss.py:N]: grad_logits[tile_b, tile_l, tile_v] = grad_logits_tile
23222323 v_32 = tl.cast(v_31, tl.bfloat16)
2323- tl.store(grad_logits + (indices_0[:, None, None] * 8320 + indices_1[None, :, None] * 128 + indices_2[None, None, :] * 1), v_32, None)
2324+ tl.store(grad_logits + (indices_0[:, None, None] * 8320 + indices_1[None, :, None] * 128 + indices_2[None, None, :] * 1), v_32, mask_0[:, None, None] )
23242325
23252326def grpo_loss_backward(grad_output: torch.Tensor, logits: torch.Tensor, selected_logits: torch.Tensor, completion_ids: torch.Tensor, old_logp: torch.Tensor | None, ref_logp: torch.Tensor | None, advantages: torch.Tensor, completion_mask: torch.Tensor | None, lse: torch.Tensor, temperature: float, beta: float, eps_low: float, eps_high: float, *, _launcher=_default_launcher):
23262327 """
@@ -2351,7 +2352,7 @@ def grpo_loss_backward(grad_output: torch.Tensor, logits: torch.Tensor, selected
23512352 # src[grpo_loss.py:N]: grad_logits = torch.zeros_like(logits)
23522353 grad_logits = torch.zeros_like(logits)
23532354 # src[grpo_loss.py:N]: for tile_b, tile_l in hl.tile([B, L]):
2354- _BLOCK_SIZE_0 = 2
2355+ _BLOCK_SIZE_0 = 4
23552356 _BLOCK_SIZE_1 = 16
23562357 # src[grpo_loss.py:N]: for tile_v in hl.tile(V):
23572358 # src[grpo_loss.py:N]: logits_tile = (
0 commit comments