Skip to content

Commit 05bd290

Browse files
committed
[WIP] fused_linear_cross_entropy
stack-info: PR: #302, branch: yf225/stack/19
1 parent 727d52a commit 05bd290

File tree

2 files changed

+176
-0
lines changed

2 files changed

+176
-0
lines changed

benchmarks/run.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,11 @@
7979
"examples.cross_entropy",
8080
"cross_entropy",
8181
),
82+
"fused_linear_cross_entropy": (
83+
"tritonbench.operators.fused_linear_cross_entropy.operator",
84+
"examples.fused_linear_cross_entropy",
85+
"fused_linear_cross_entropy",
86+
),
8287
}
8388

8489

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
"""Fused linear cross entropy implementation for Helion.
2+
3+
This implementation uses Liger's chunking strategy to reduce memory usage
4+
while staying within Helion's constraints.
5+
"""
6+
7+
from __future__ import annotations
8+
9+
import torch
10+
11+
import helion
12+
from helion._testing import run_example
13+
import helion.language as hl
14+
from helion.utils import get_gpu_memory_info
15+
16+
# TritonBench configuration - adjust based on available GPU memory
17+
if get_gpu_memory_info()[0] < 16.0:
18+
# Low memory configuration for GPUs with less than 16GB
19+
TRITONBENCH_ARGS = {"hidden_size": 2048, "vocab_size": 32000}
20+
21+
22+
# Simple matmul kernel for the linear layer
23+
@helion.kernel(static_shapes=True, dot_precision="ieee")
24+
def linear(input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
25+
n, h = input.shape
26+
v, h2 = weight.shape
27+
assert h == h2, f"Hidden size mismatch: {h} != {h2}"
28+
29+
logits = torch.empty([n, v], dtype=torch.float32, device=input.device)
30+
31+
for tile_n, tile_v in hl.tile([n, v]):
32+
acc = hl.zeros([tile_n, tile_v], dtype=torch.float32)
33+
for tile_h in hl.tile(h):
34+
acc = torch.addmm(acc, input[tile_n, tile_h], weight[tile_v, tile_h].T)
35+
logits[tile_n, tile_v] = acc
36+
37+
return logits
38+
39+
40+
# Cross entropy loss kernel (based on existing cross_entropy.py)
41+
@helion.kernel(ignore_warnings=[helion.exc.TensorOperationInWrapper])
42+
def cross_entropy_loss(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
43+
n, v = logits.shape
44+
losses = torch.zeros([n], dtype=torch.float32, device=logits.device)
45+
46+
# Pre-compute base indices
47+
base_indices = torch.arange(n, device=logits.device) * v
48+
logits_flat = logits.view(-1)
49+
50+
for tile_n in hl.tile(n):
51+
labels_tile = labels[tile_n]
52+
base_indices_tile = base_indices[tile_n]
53+
54+
# Get logits at target indices
55+
flat_indices = base_indices_tile + labels_tile
56+
logits_at_target = hl.load(logits_flat, [flat_indices])
57+
58+
# Load the full rows for this tile
59+
logits_rows = logits[tile_n, :]
60+
61+
# Compute log-sum-exp
62+
max_logits = torch.amax(logits_rows, dim=-1, keepdim=True)
63+
shifted = logits_rows - max_logits
64+
exp_shifted = torch.exp(shifted)
65+
sum_exp = torch.sum(exp_shifted, dim=-1, keepdim=True)
66+
log_sum_exp = max_logits.squeeze(-1) + torch.log(sum_exp.squeeze(-1))
67+
68+
# Cross entropy loss
69+
losses[tile_n] = log_sum_exp - logits_at_target
70+
71+
return losses.mean()
72+
73+
74+
def calculate_chunk_size(batch_size: int, hidden_size: int, vocab_size: int) -> int:
75+
"""Calculate optimal chunk size following Liger's approach."""
76+
# Following Liger's logic for chunk size calculation
77+
inc_factor = (vocab_size + hidden_size - 1) // hidden_size
78+
chunk_size = max(1, batch_size // inc_factor)
79+
80+
# Make chunk_size a power of 2 for better performance
81+
if chunk_size > 0:
82+
chunk_size = 2 ** (chunk_size.bit_length() - 1)
83+
else:
84+
chunk_size = 1
85+
86+
# Ensure chunk_size doesn't exceed batch_size
87+
chunk_size = min(chunk_size, batch_size)
88+
89+
# Cap at a reasonable maximum to avoid too small chunks
90+
chunk_size = min(chunk_size, 256)
91+
92+
return chunk_size
93+
94+
95+
# Fused version that uses chunking to reduce memory
96+
def fused_linear_cross_entropy(
97+
input: torch.Tensor,
98+
weight: torch.Tensor,
99+
labels: torch.Tensor,
100+
) -> torch.Tensor:
101+
"""Fused linear + cross entropy using Liger's chunking strategy."""
102+
batch_size, hidden_size = input.shape
103+
vocab_size = weight.shape[0]
104+
105+
# Calculate optimal chunk size
106+
chunk_size = calculate_chunk_size(batch_size, hidden_size, vocab_size)
107+
108+
# If chunk size equals batch size, no chunking needed
109+
if chunk_size >= batch_size:
110+
logits = linear(input, weight)
111+
return cross_entropy_loss(logits, labels)
112+
113+
# Process in chunks to reduce memory usage
114+
num_chunks = (batch_size + chunk_size - 1) // chunk_size
115+
total_loss = 0.0
116+
117+
for chunk_id in range(num_chunks):
118+
start_idx = chunk_id * chunk_size
119+
end_idx = min((chunk_id + 1) * chunk_size, batch_size)
120+
actual_chunk_size = end_idx - start_idx
121+
122+
# Extract chunk
123+
input_chunk = input[start_idx:end_idx]
124+
labels_chunk = labels[start_idx:end_idx]
125+
126+
# Compute logits for this chunk (only chunk_size x vocab_size memory)
127+
logits_chunk = linear(input_chunk, weight)
128+
129+
# Compute loss for this chunk
130+
chunk_loss = cross_entropy_loss(logits_chunk, labels_chunk)
131+
132+
# Accumulate weighted by chunk size
133+
total_loss += chunk_loss * actual_chunk_size
134+
135+
# Return average loss
136+
return total_loss / batch_size
137+
138+
139+
def fused_linear_cross_entropy_pytorch(
140+
input: torch.Tensor,
141+
weight: torch.Tensor,
142+
labels: torch.Tensor
143+
) -> torch.Tensor:
144+
"""PyTorch reference implementation for fused linear cross entropy."""
145+
# Compute logits
146+
logits = torch.matmul(input, weight.T)
147+
# Compute cross entropy
148+
return torch.nn.functional.cross_entropy(logits, labels)
149+
150+
151+
def main() -> None:
152+
"""Run fused linear cross entropy benchmark with different input sizes."""
153+
# Test with moderate size
154+
n, h, v = 128, 512, 1000
155+
input = torch.randn(n, h, device="cuda", dtype=torch.float32)
156+
weight = torch.randn(v, h, device="cuda", dtype=torch.float32)
157+
labels = torch.randint(0, v, (n,), device="cuda", dtype=torch.long)
158+
159+
run_example(
160+
fused_linear_cross_entropy,
161+
fused_linear_cross_entropy_pytorch,
162+
(input, weight, labels),
163+
kernel_name="helion",
164+
baseline_name="torch",
165+
rtol=1e-3,
166+
atol=1e-3,
167+
)
168+
169+
170+
if __name__ == "__main__":
171+
main()

0 commit comments

Comments
 (0)