Skip to content

Commit 88c4809

Browse files
committed
add KL divergence backward helion kernel
ghstack-source-id: 0edb72d Pull Request resolved: #802
1 parent 1716c26 commit 88c4809

File tree

1 file changed

+131
-18
lines changed

1 file changed

+131
-18
lines changed

examples/kl_div.py

Lines changed: 131 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323
# -------
2424
from __future__ import annotations
2525

26+
import math
2627
from typing import TYPE_CHECKING
28+
from typing import Any
2729

2830
import torch
2931
from torch import Tensor
@@ -117,6 +119,106 @@ def kl_div_forward(
117119
return final_loss
118120

119121

122+
@helion.kernel
123+
def kl_div_backward(
124+
grad_out: Tensor,
125+
y_pred: Tensor, # input predictions in log-space, shape (BT, V)
126+
y_true: Tensor, # target values, shape (BT, V)
127+
log_target: hl.constexpr,
128+
reduction: hl.constexpr,
129+
eps: hl.constexpr,
130+
compute_y_true_grad: hl.constexpr,
131+
) -> tuple[Tensor, Tensor | None]:
132+
BT, V = y_pred.shape
133+
assert y_true.shape == y_pred.shape, (
134+
f"Shape mismatch: {y_true.shape} != {y_pred.shape}"
135+
)
136+
137+
grad_y_pred = torch.empty_like(y_pred)
138+
if compute_y_true_grad:
139+
grad_y_true = torch.empty_like(y_true)
140+
else:
141+
grad_y_true = None
142+
143+
if reduction == "none":
144+
grad_out_expanded = grad_out
145+
else:
146+
grad_out_expanded = grad_out.expand(y_true.shape)
147+
148+
log_eps = math.log(eps)
149+
for tile_bt in hl.tile(BT):
150+
for tile_v in hl.tile(V):
151+
grad_out_val = grad_out_expanded[tile_bt, tile_v]
152+
y_true_val = y_true[tile_bt, tile_v]
153+
154+
if log_target:
155+
y_true_exp = torch.exp(y_true_val)
156+
157+
if reduction == "batchmean":
158+
div = BT
159+
elif reduction == "mean":
160+
div = BT * V
161+
else: # reduction == "sum" or "none"
162+
div = 1.0
163+
164+
if log_target:
165+
grad_y_pred[tile_bt, tile_v] = -grad_out_val * y_true_exp / div # type: ignore
166+
else:
167+
grad_y_pred[tile_bt, tile_v] = -grad_out_val * y_true_val / div
168+
169+
if compute_y_true_grad:
170+
y_pred_val = y_pred[tile_bt, tile_v]
171+
if log_target:
172+
tmp = y_true_exp * (y_true_val - y_pred_val + 1) # type: ignore
173+
else:
174+
lt_eps = log_eps - y_pred_val
175+
gt_eps = torch.log(y_true_val) - y_pred_val + 1
176+
tmp = torch.where(y_true_val < eps, lt_eps, gt_eps)
177+
178+
grad_y_true[tile_bt, tile_v] = grad_out_val * tmp / div # type: ignore[index]
179+
180+
return grad_y_pred, grad_y_true
181+
182+
183+
class KLDivFunction(torch.autograd.Function):
184+
@staticmethod
185+
def forward(
186+
ctx: Any, # noqa: ANN401
187+
y_pred: Tensor, # input predictions in log-space, shape (BT, V)
188+
y_true: Tensor, # target values, shape (BT, V)
189+
log_target: bool,
190+
reduction: str,
191+
eps: float,
192+
) -> Tensor:
193+
"""Forward pass for KL divergence."""
194+
loss = kl_div_forward(y_pred, y_true, log_target, reduction, eps)
195+
ctx.save_for_backward(y_pred, y_true) # type: ignore[arg-type]
196+
ctx.log_target = log_target
197+
ctx.reduction = reduction
198+
ctx.eps = eps
199+
return loss
200+
201+
@staticmethod
202+
def backward( # type: ignore[override]
203+
ctx: Any, # noqa: ANN401
204+
grad_out: Tensor,
205+
) -> tuple[Tensor, Tensor | None, None, None, None]:
206+
"""Backward pass for KL divergence."""
207+
y_pred, y_true = ctx.saved_tensors # type: ignore[attr-defined]
208+
209+
grad_y_pred, grad_y_true = kl_div_backward(
210+
grad_out,
211+
y_pred,
212+
y_true,
213+
ctx.log_target,
214+
ctx.reduction,
215+
ctx.eps,
216+
y_true.requires_grad,
217+
)
218+
219+
return grad_y_pred, grad_y_true, None, None, None
220+
221+
120222
# %%
121223
# KL Divergence Loss Module
122224
# -------------------------
@@ -154,7 +256,7 @@ def forward(self, input_tensor: Tensor, target_tensor: Tensor) -> Tensor:
154256
Returns:
155257
KL divergence loss
156258
"""
157-
return kl_div_forward(
259+
return KLDivFunction.apply( # type: ignore[no-any-return]
158260
input_tensor, target_tensor, self.log_target, self.reduction, self.eps
159261
)
160262

@@ -181,16 +283,26 @@ def check_kl_div_kernel(
181283
log_target: Whether target is in log-space
182284
eps: Small value for numerical stability
183285
"""
184-
# Create test tensors following tritonbench pattern
185-
input_tensor = torch.randn(B * T, V, requires_grad=True, device="cuda").log_softmax(
186-
dim=-1
187-
)
188286

189-
target_tensor = torch.randn(B * T, V, device="cuda").softmax(dim=-1)
190-
191-
# Test forward pass
287+
# Create test tensors following tritonbench pattern
288+
def create_inputs() -> tuple[Tensor, Tensor]:
289+
input_tensor = torch.randn(
290+
B * T, V, requires_grad=True, device="cuda"
291+
).log_softmax(dim=-1)
292+
input_tensor.retain_grad()
293+
294+
target_tensor = torch.randn(B * T, V, requires_grad=True, device="cuda")
295+
if log_target:
296+
target_tensor = target_tensor.log_softmax(dim=-1)
297+
else:
298+
target_tensor = target_tensor.softmax(dim=-1)
299+
target_tensor.retain_grad()
300+
301+
return input_tensor, target_tensor
302+
303+
# Test forward + backward pass
192304
helion_kl = HelionKLDivLoss(reduction=reduction, log_target=log_target, eps=eps)
193-
torch_kl_div = torch.nn.KLDivLoss(reduction="batchmean", log_target=log_target).to(
305+
torch_kl_div = torch.nn.KLDivLoss(reduction=reduction, log_target=log_target).to(
194306
"cuda"
195307
)
196308

@@ -200,7 +312,8 @@ def helion_wrapper(input_tensor: Tensor, target_tensor: Tensor) -> Tensor:
200312
def baseline_wrapper(input_tensor: Tensor, target_tensor: Tensor) -> Tensor:
201313
return torch_kl_div(input_tensor, target_tensor)
202314

203-
run_example(helion_wrapper, baseline_wrapper, (input_tensor, target_tensor))
315+
run_example(helion_wrapper, baseline_wrapper, create_inputs())
316+
run_example(helion_wrapper, baseline_wrapper, create_inputs(), bwd=True)
204317

205318

206319
# %%
@@ -240,17 +353,17 @@ def main() -> None:
240353
print("Testing KL divergence kernel...")
241354
B = 8
242355
T = 512
243-
reduction = "batchmean"
244-
log_target = False
245356
eps = 1e-10
246357

247358
# Test with vocabulary sizes from tritonbench (2^12 to 2^17)
248-
for V in [2**i for i in range(12, 18)]:
249-
print(
250-
f"Testing KL Div: B={B}, T={T}, V={V}, reduction={reduction}, log_target={log_target}"
251-
)
252-
check_kl_div_kernel(B, T, V, reduction, log_target, eps)
253-
print("✓ KL Div passed")
359+
for log_target in (True, False):
360+
for reduction in ("batchmean", "mean", "sum"):
361+
for V in [2**i for i in range(12, 17)]:
362+
print(
363+
f"Testing KL Div: B={B}, T={T}, V={V}, reduction={reduction}, log_target={log_target}"
364+
)
365+
check_kl_div_kernel(B, T, V, reduction, log_target, eps)
366+
print("✓ KL Div passed")
254367

255368

256369
# %%

0 commit comments

Comments
 (0)