1+ """Fused linear cross entropy implementation for Helion.
2+
3+ This implementation uses Liger's chunking strategy to reduce memory usage
4+ while staying within Helion's constraints.
5+ """
6+
7+ from __future__ import annotations
8+
9+ import torch
10+
11+ import helion
12+ from helion ._testing import run_example
13+ import helion .language as hl
14+ from helion .utils import get_gpu_memory_info
15+
16+ # TritonBench configuration - adjust based on available GPU memory
17+ if get_gpu_memory_info ()[0 ] < 16.0 :
18+ # Low memory configuration for GPUs with less than 16GB
19+ TRITONBENCH_ARGS = {"hidden_size" : 2048 , "vocab_size" : 32000 }
20+
21+
22+ # Simple matmul kernel for the linear layer
23+ @helion .kernel (static_shapes = True , dot_precision = "ieee" )
24+ def linear (input : torch .Tensor , weight : torch .Tensor ) -> torch .Tensor :
25+ n , h = input .shape
26+ v , h2 = weight .shape
27+ assert h == h2 , f"Hidden size mismatch: { h } != { h2 } "
28+
29+ logits = torch .empty ([n , v ], dtype = torch .float32 , device = input .device )
30+
31+ for tile_n , tile_v in hl .tile ([n , v ]):
32+ acc = hl .zeros ([tile_n , tile_v ], dtype = torch .float32 )
33+ for tile_h in hl .tile (h ):
34+ acc = torch .addmm (acc , input [tile_n , tile_h ], weight [tile_v , tile_h ].T )
35+ logits [tile_n , tile_v ] = acc
36+
37+ return logits
38+
39+
40+ # Cross entropy loss kernel (based on existing cross_entropy.py)
41+ @helion .kernel (ignore_warnings = [helion .exc .TensorOperationInWrapper ])
42+ def cross_entropy_loss (logits : torch .Tensor , labels : torch .Tensor ) -> torch .Tensor :
43+ n , v = logits .shape
44+ losses = torch .zeros ([n ], dtype = torch .float32 , device = logits .device )
45+
46+ # Pre-compute base indices
47+ base_indices = torch .arange (n , device = logits .device ) * v
48+ logits_flat = logits .view (- 1 )
49+
50+ for tile_n in hl .tile (n ):
51+ labels_tile = labels [tile_n ]
52+ base_indices_tile = base_indices [tile_n ]
53+
54+ # Get logits at target indices
55+ flat_indices = base_indices_tile + labels_tile
56+ logits_at_target = hl .load (logits_flat , [flat_indices ])
57+
58+ # Load the full rows for this tile
59+ logits_rows = logits [tile_n , :]
60+
61+ # Compute log-sum-exp
62+ max_logits = torch .amax (logits_rows , dim = - 1 , keepdim = True )
63+ shifted = logits_rows - max_logits
64+ exp_shifted = torch .exp (shifted )
65+ sum_exp = torch .sum (exp_shifted , dim = - 1 , keepdim = True )
66+ log_sum_exp = max_logits .squeeze (- 1 ) + torch .log (sum_exp .squeeze (- 1 ))
67+
68+ # Cross entropy loss
69+ losses [tile_n ] = log_sum_exp - logits_at_target
70+
71+ return losses .mean ()
72+
73+
74+ def calculate_chunk_size (batch_size : int , hidden_size : int , vocab_size : int ) -> int :
75+ """Calculate optimal chunk size following Liger's approach."""
76+ # Following Liger's logic for chunk size calculation
77+ inc_factor = (vocab_size + hidden_size - 1 ) // hidden_size
78+ chunk_size = max (1 , batch_size // inc_factor )
79+
80+ # Make chunk_size a power of 2 for better performance
81+ if chunk_size > 0 :
82+ chunk_size = 2 ** (chunk_size .bit_length () - 1 )
83+ else :
84+ chunk_size = 1
85+
86+ # Ensure chunk_size doesn't exceed batch_size
87+ chunk_size = min (chunk_size , batch_size )
88+
89+ # Cap at a reasonable maximum to avoid too small chunks
90+ chunk_size = min (chunk_size , 256 )
91+
92+ return chunk_size
93+
94+
95+ # Fused version that uses chunking to reduce memory
96+ def fused_linear_cross_entropy (
97+ input : torch .Tensor ,
98+ weight : torch .Tensor ,
99+ labels : torch .Tensor ,
100+ ) -> torch .Tensor :
101+ """Fused linear + cross entropy using Liger's chunking strategy."""
102+ batch_size , hidden_size = input .shape
103+ vocab_size = weight .shape [0 ]
104+
105+ # Calculate optimal chunk size
106+ chunk_size = calculate_chunk_size (batch_size , hidden_size , vocab_size )
107+
108+ # If chunk size equals batch size, no chunking needed
109+ if chunk_size >= batch_size :
110+ logits = linear (input , weight )
111+ return cross_entropy_loss (logits , labels )
112+
113+ # Process in chunks to reduce memory usage
114+ num_chunks = (batch_size + chunk_size - 1 ) // chunk_size
115+ total_loss = 0.0
116+
117+ for chunk_id in range (num_chunks ):
118+ start_idx = chunk_id * chunk_size
119+ end_idx = min ((chunk_id + 1 ) * chunk_size , batch_size )
120+ actual_chunk_size = end_idx - start_idx
121+
122+ # Extract chunk
123+ input_chunk = input [start_idx :end_idx ]
124+ labels_chunk = labels [start_idx :end_idx ]
125+
126+ # Compute logits for this chunk (only chunk_size x vocab_size memory)
127+ logits_chunk = linear (input_chunk , weight )
128+
129+ # Compute loss for this chunk
130+ chunk_loss = cross_entropy_loss (logits_chunk , labels_chunk )
131+
132+ # Accumulate weighted by chunk size
133+ total_loss += chunk_loss * actual_chunk_size
134+
135+ # Return average loss
136+ return total_loss / batch_size
137+
138+
139+ def fused_linear_cross_entropy_pytorch (
140+ input : torch .Tensor ,
141+ weight : torch .Tensor ,
142+ labels : torch .Tensor
143+ ) -> torch .Tensor :
144+ """PyTorch reference implementation for fused linear cross entropy."""
145+ # Compute logits
146+ logits = torch .matmul (input , weight .T )
147+ # Compute cross entropy
148+ return torch .nn .functional .cross_entropy (logits , labels )
149+
150+
151+ def main () -> None :
152+ """Run fused linear cross entropy benchmark with different input sizes."""
153+ # Test with moderate size
154+ n , h , v = 128 , 512 , 1000
155+ input = torch .randn (n , h , device = "cuda" , dtype = torch .float32 )
156+ weight = torch .randn (v , h , device = "cuda" , dtype = torch .float32 )
157+ labels = torch .randint (0 , v , (n ,), device = "cuda" , dtype = torch .long )
158+
159+ run_example (
160+ fused_linear_cross_entropy ,
161+ fused_linear_cross_entropy_pytorch ,
162+ (input , weight , labels ),
163+ kernel_name = "helion" ,
164+ baseline_name = "torch" ,
165+ rtol = 1e-3 ,
166+ atol = 1e-3 ,
167+ )
168+
169+
170+ if __name__ == "__main__" :
171+ main ()
0 commit comments