2323# -------
2424from __future__ import annotations
2525
26+ import math
2627from typing import TYPE_CHECKING
28+ from typing import Any
2729
2830import torch
2931from torch import Tensor
@@ -117,6 +119,106 @@ def kl_div_forward(
117119 return final_loss
118120
119121
122+ @helion .kernel
123+ def kl_div_backward (
124+ grad_out : Tensor ,
125+ y_pred : Tensor , # input predictions in log-space, shape (BT, V)
126+ y_true : Tensor , # target values, shape (BT, V)
127+ log_target : hl .constexpr ,
128+ reduction : hl .constexpr ,
129+ eps : hl .constexpr ,
130+ compute_y_true_grad : hl .constexpr ,
131+ ) -> tuple [Tensor , Tensor | None ]:
132+ BT , V = y_pred .shape
133+ assert y_true .shape == y_pred .shape , (
134+ f"Shape mismatch: { y_true .shape } != { y_pred .shape } "
135+ )
136+
137+ grad_y_pred = torch .empty_like (y_pred )
138+ if compute_y_true_grad :
139+ grad_y_true = torch .empty_like (y_true )
140+ else :
141+ grad_y_true = None
142+
143+ if reduction == "none" :
144+ grad_out_expanded = grad_out
145+ else :
146+ grad_out_expanded = grad_out .expand (y_true .shape )
147+
148+ log_eps = math .log (eps )
149+ for tile_bt in hl .tile (BT ):
150+ for tile_v in hl .tile (V ):
151+ grad_out_val = grad_out_expanded [tile_bt , tile_v ]
152+ y_true_val = y_true [tile_bt , tile_v ]
153+
154+ if log_target :
155+ y_true_exp = torch .exp (y_true_val )
156+
157+ if reduction == "batchmean" :
158+ div = BT
159+ elif reduction == "mean" :
160+ div = BT * V
161+ else : # reduction == "sum" or "none"
162+ div = 1.0
163+
164+ if log_target :
165+ grad_y_pred [tile_bt , tile_v ] = - grad_out_val * y_true_exp / div # type: ignore
166+ else :
167+ grad_y_pred [tile_bt , tile_v ] = - grad_out_val * y_true_val / div
168+
169+ if compute_y_true_grad :
170+ y_pred_val = y_pred [tile_bt , tile_v ]
171+ if log_target :
172+ tmp = y_true_exp * (y_true_val - y_pred_val + 1 ) # type: ignore
173+ else :
174+ lt_eps = log_eps - y_pred_val
175+ gt_eps = torch .log (y_true_val ) - y_pred_val + 1
176+ tmp = torch .where (y_true_val < eps , lt_eps , gt_eps )
177+
178+ grad_y_true [tile_bt , tile_v ] = grad_out_val * tmp / div # type: ignore[index]
179+
180+ return grad_y_pred , grad_y_true
181+
182+
183+ class KLDivFunction (torch .autograd .Function ):
184+ @staticmethod
185+ def forward (
186+ ctx : Any , # noqa: ANN401
187+ y_pred : Tensor , # input predictions in log-space, shape (BT, V)
188+ y_true : Tensor , # target values, shape (BT, V)
189+ log_target : bool ,
190+ reduction : str ,
191+ eps : float ,
192+ ) -> Tensor :
193+ """Forward pass for KL divergence."""
194+ loss = kl_div_forward (y_pred , y_true , log_target , reduction , eps )
195+ ctx .save_for_backward (y_pred , y_true ) # type: ignore[arg-type]
196+ ctx .log_target = log_target
197+ ctx .reduction = reduction
198+ ctx .eps = eps
199+ return loss
200+
201+ @staticmethod
202+ def backward ( # type: ignore[override]
203+ ctx : Any , # noqa: ANN401
204+ grad_out : Tensor ,
205+ ) -> tuple [Tensor , Tensor | None , None , None , None ]:
206+ """Backward pass for KL divergence."""
207+ y_pred , y_true = ctx .saved_tensors # type: ignore[attr-defined]
208+
209+ grad_y_pred , grad_y_true = kl_div_backward (
210+ grad_out ,
211+ y_pred ,
212+ y_true ,
213+ ctx .log_target ,
214+ ctx .reduction ,
215+ ctx .eps ,
216+ y_true .requires_grad ,
217+ )
218+
219+ return grad_y_pred , grad_y_true , None , None , None
220+
221+
120222# %%
121223# KL Divergence Loss Module
122224# -------------------------
@@ -154,7 +256,7 @@ def forward(self, input_tensor: Tensor, target_tensor: Tensor) -> Tensor:
154256 Returns:
155257 KL divergence loss
156258 """
157- return kl_div_forward (
259+ return KLDivFunction . apply ( # type: ignore[no-any-return]
158260 input_tensor , target_tensor , self .log_target , self .reduction , self .eps
159261 )
160262
@@ -181,16 +283,26 @@ def check_kl_div_kernel(
181283 log_target: Whether target is in log-space
182284 eps: Small value for numerical stability
183285 """
184- # Create test tensors following tritonbench pattern
185- input_tensor = torch .randn (B * T , V , requires_grad = True , device = "cuda" ).log_softmax (
186- dim = - 1
187- )
188286
189- target_tensor = torch .randn (B * T , V , device = "cuda" ).softmax (dim = - 1 )
190-
191- # Test forward pass
287+ # Create test tensors following tritonbench pattern
288+ def create_inputs () -> tuple [Tensor , Tensor ]:
289+ input_tensor = torch .randn (
290+ B * T , V , requires_grad = True , device = "cuda"
291+ ).log_softmax (dim = - 1 )
292+ input_tensor .retain_grad ()
293+
294+ target_tensor = torch .randn (B * T , V , requires_grad = True , device = "cuda" )
295+ if log_target :
296+ target_tensor = target_tensor .log_softmax (dim = - 1 )
297+ else :
298+ target_tensor = target_tensor .softmax (dim = - 1 )
299+ target_tensor .retain_grad ()
300+
301+ return input_tensor , target_tensor
302+
303+ # Test forward + backward pass
192304 helion_kl = HelionKLDivLoss (reduction = reduction , log_target = log_target , eps = eps )
193- torch_kl_div = torch .nn .KLDivLoss (reduction = "batchmean" , log_target = log_target ).to (
305+ torch_kl_div = torch .nn .KLDivLoss (reduction = reduction , log_target = log_target ).to (
194306 "cuda"
195307 )
196308
@@ -200,7 +312,8 @@ def helion_wrapper(input_tensor: Tensor, target_tensor: Tensor) -> Tensor:
200312 def baseline_wrapper (input_tensor : Tensor , target_tensor : Tensor ) -> Tensor :
201313 return torch_kl_div (input_tensor , target_tensor )
202314
203- run_example (helion_wrapper , baseline_wrapper , (input_tensor , target_tensor ))
315+ run_example (helion_wrapper , baseline_wrapper , create_inputs ())
316+ run_example (helion_wrapper , baseline_wrapper , create_inputs (), bwd = True )
204317
205318
206319# %%
@@ -240,17 +353,17 @@ def main() -> None:
240353 print ("Testing KL divergence kernel..." )
241354 B = 8
242355 T = 512
243- reduction = "batchmean"
244- log_target = False
245356 eps = 1e-10
246357
247358 # Test with vocabulary sizes from tritonbench (2^12 to 2^17)
248- for V in [2 ** i for i in range (12 , 18 )]:
249- print (
250- f"Testing KL Div: B={ B } , T={ T } , V={ V } , reduction={ reduction } , log_target={ log_target } "
251- )
252- check_kl_div_kernel (B , T , V , reduction , log_target , eps )
253- print ("✓ KL Div passed" )
359+ for log_target in (True , False ):
360+ for reduction in ("batchmean" , "mean" , "sum" ):
361+ for V in [2 ** i for i in range (12 , 17 )]:
362+ print (
363+ f"Testing KL Div: B={ B } , T={ T } , V={ V } , reduction={ reduction } , log_target={ log_target } "
364+ )
365+ check_kl_div_kernel (B , T , V , reduction , log_target , eps )
366+ print ("✓ KL Div passed" )
254367
255368
256369# %%
0 commit comments