Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions benchmarks/ops/benchmark_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,15 @@ def benchmark(T, provider):
requires_grad = True
B, H, D, M = 16, 4, 128, 64

q = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
k = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
v = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
if provider.startswith('flash'):
q = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
k = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
v = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
q = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
k = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
v = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)

if provider.startswith('gla'):
g = F.logsigmoid(torch.randn(B, H, T, D, device=device, dtype=dtype))
g = F.logsigmoid(torch.randn(B, T, H, D, device=device, dtype=dtype))
g = g.clamp_min(-5).requires_grad_(requires_grad)
if provider.startswith('abc'):
s = torch.randn(B, H, T, M, device=device, requires_grad=requires_grad, dtype=dtype)
s = torch.randn(B, T, H, M, device=device, requires_grad=requires_grad, dtype=dtype)

do = torch.ones_like(v, dtype=dtype)

Expand Down
7 changes: 5 additions & 2 deletions benchmarks/ops/benchmark_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,15 @@ def benchmark(T, provider):
q = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
k = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
v = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
else:
elif provider in ('torch', 'torch_bwd', 'parallel_chunk_bwd', 'parallel_chunk'):
q = torch.randn(B, H, T, 16, device=device, requires_grad=requires_grad, dtype=dtype)
k = torch.randn(B, H, T, 16, device=device, requires_grad=requires_grad, dtype=dtype)
v = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
else:
q = torch.randn(B, T, H, 16, device=device, requires_grad=requires_grad, dtype=dtype)
k = torch.randn(B, T, H, 16, device=device, requires_grad=requires_grad, dtype=dtype)
v = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
do = torch.ones_like(v, dtype=dtype)

quantiles = [0.5, 0.2, 0.8]
results = 0, 0, 0
if provider == 'torch':
Expand Down
20 changes: 10 additions & 10 deletions benchmarks/ops/benchmark_delta_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from benchmark import benchmark_backward, benchmark_combined, benchmark_forward
from torch.nn import functional as F

from fla.ops.delta_rule import chunk_delta_rule, fused_chunk_delta_rule
from fla.ops.delta_rule import chunk_delta_rule
from fla.utils import device


Expand Down Expand Up @@ -35,7 +35,7 @@ def time_bwd(func, *args, **kwargs):
dropout_p = 0.0


methods = (["chunk_delta_rule", "fused_chunk_delta_rule"])
methods = (["chunk_delta_rule"])
time_f = {}
time_b = {}
time_f_b = {}
Expand All @@ -47,10 +47,10 @@ def time_bwd(func, *args, **kwargs):
for B, seqlen in bs_seqlen_vals:
config = (causal, headdim, B, seqlen)
H = dim // headdim
q = torch.randn(B, H, seqlen, headdim, device=device, requires_grad=True, dtype=dtype)
k = F.normalize(torch.randn(B, H, seqlen, headdim, device=device, dtype=dtype), p=2, dim=-1).requires_grad_(True)
v = torch.randn(B, H, seqlen, headdim, device=device, requires_grad=True, dtype=dtype)
beta = torch.rand(B, H, seqlen, device=device, dtype=dtype).sigmoid().requires_grad_(True)
q = torch.randn(B, seqlen, H, headdim, device=device, requires_grad=True, dtype=dtype)
k = F.normalize(torch.randn(B, seqlen, H, headdim, device=device, dtype=dtype), p=2, dim=-1).requires_grad_(True)
v = torch.randn(B, seqlen, H, headdim, device=device, requires_grad=True, dtype=dtype)
beta = torch.rand(B, seqlen, H, device=device, dtype=dtype).sigmoid().requires_grad_(True)
o1, _ = chunk_delta_rule(q, k, v, beta)
o1.sum().backward(retain_graph=True)
f_b = time_fwd_bwd(
Expand All @@ -61,10 +61,10 @@ def time_bwd(func, *args, **kwargs):
# q = torch.randn(B, seqlen, H, headdim, device=device, requires_grad=True, dtype=dtype)
# k = torch.randn(B, seqlen, H, headdim, device=device, requires_grad=True, dtype=dtype)
# v = torch.randn(B, seqlen, H, headdim, device=device, requires_grad=True, dtype=dtype)
f_b = time_fwd_bwd(
fused_chunk_delta_rule, q, k, v, beta, verbose=False
)
time_f_b[config, "fused_chunk_delta_rule"] = f_b
# f_b = time_fwd_bwd(
# fused_chunk_delta_rule, q, k, v, beta, verbose=False
# )
# time_f_b[config, "fused_chunk_delta_rule"] = f_b

print(f"### causal={causal}, headdim={headdim}, B={B}, seqlen={seqlen} ###")
for method in methods:
Expand Down
28 changes: 12 additions & 16 deletions benchmarks/ops/benchmark_fla.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,23 +43,19 @@ def benchmark(T, provider):
requires_grad = True
B, H, D = 16, 8, 128

if provider == 'flash':
if "based" in provider:
q = torch.randn(B, T, H, 16, device=device, requires_grad=requires_grad, dtype=dtype)
k = torch.randn(B, T, H, 16, device=device, requires_grad=requires_grad, dtype=dtype)
v = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
elif "gla" in provider:
q = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
k = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
v = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
elif "based" in provider:
q = torch.randn(B, H, T, 16, device=device, requires_grad=requires_grad, dtype=dtype)
k = torch.randn(B, H, T, 16, device=device, requires_grad=requires_grad, dtype=dtype)
v = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
elif "gla" in provider:
q = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
k = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
v = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
g = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
g = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
else:
q = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
k = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
v = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
q = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
k = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
v = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)

do = torch.rand_like(v, dtype=dtype)

Expand All @@ -68,13 +64,13 @@ def benchmark(T, provider):
if provider == 'flash':
results = triton.testing.do_bench(lambda: flash_attn_func(q, k, v).backward(do), quantiles=quantiles)
elif provider == 'retention_parallel':
results = triton.testing.do_bench(lambda: parallel_retention(q, k, v).backward(do), quantiles=quantiles)
results = triton.testing.do_bench(lambda: parallel_retention(q, k, v)[0].backward(do), quantiles=quantiles)
elif provider == 'retention_fused_chunk':
results = triton.testing.do_bench(lambda: fused_chunk_retention(q, k, v).backward(do), quantiles=quantiles)
results = triton.testing.do_bench(lambda: fused_chunk_retention(q, k, v)[0].backward(do), quantiles=quantiles)
elif provider == 'based_parallel':
results = triton.testing.do_bench(lambda: parallel_based(q, k, v).backward(do), quantiles=quantiles)
elif provider == 'gla_fused_chunk':
results = triton.testing.do_bench(lambda: fused_chunk_gla(q, k, v, g).backward(do), quantiles=quantiles)
results = triton.testing.do_bench(lambda: fused_chunk_gla(q, k, v, g)[0].backward(do), quantiles=quantiles)

return results

Expand Down
17 changes: 12 additions & 5 deletions benchmarks/ops/benchmark_gla.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,20 @@
def benchmark(T, provider):
from fla.utils import device
dtype = torch.bfloat16
# dtype = torch.float32
requires_grad = True
B, H, D = 16, 8, 128

q = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
k = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
g = F.logsigmoid(torch.randn(B, H, T, D, device=device, dtype=dtype)).clamp_min(-5).requires_grad_(requires_grad)
v = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
if provider in ("fused_chunk_gla", "fused_chunk_gla_bwd"):
q = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
k = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
v = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
g = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
else:
q = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
k = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
g = F.logsigmoid(torch.randn(B, T, H, D, device=device, dtype=dtype)).clamp_min(-5).requires_grad_(requires_grad)
v = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)

do = torch.ones_like(q, dtype=dtype)

Expand All @@ -67,7 +74,7 @@ def benchmark(T, provider):
elif provider == 'chunk_retention_bwd':
results = triton.testing.do_bench(lambda: chunk_retention(q, k, v)[0].backward(do), quantiles=quantiles)
elif provider == 'recurrent_gla_bwd':
results = triton.testing.do_bench(lambda: fused_recurrent_gla(q, k, v, g)[0].backward(do), quantiles=quantiles)
results = triton.testing.do_bench(lambda: fused_recurrent_gla(q, k, v, gk=g)[0].backward(do), quantiles=quantiles)
elif provider == 'fused_chunk_gla_bwd':
results = triton.testing.do_bench(lambda: fused_chunk_gla(q, k, v, g)[0].backward(do), quantiles=quantiles)
elif provider == 'chunk_gla_bwd':
Expand Down
15 changes: 6 additions & 9 deletions benchmarks/ops/benchmark_gsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,16 @@ def benchmark(T, provider):
requires_grad = True
B, H, D, M = 16, 4, 128, 64

q = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
k = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
v = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
q = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
k = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
v = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
if provider.startswith('gsa'):
f = F.logsigmoid(torch.randn(B, H, T, M, device=device, dtype=dtype))
f = F.logsigmoid(torch.randn(B, T, H, M, device=device, dtype=dtype))
s = (1 - f.exp()).to(f.dtype)
if provider.startswith('gla'):
g = F.logsigmoid(torch.randn(B, H, T, D, device=device, dtype=dtype))
g = F.logsigmoid(torch.randn(B, T, H, D, device=device, dtype=dtype))
g = g.clamp_min(-5).requires_grad_(requires_grad)
if provider.startswith('flash'):
q = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
k = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
v = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)

do = torch.ones_like(v, dtype=dtype)

quantiles = [0.5, 0.2, 0.8]
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/ops/benchmark_nsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,12 @@ def benchmark(T, provider):
results = 0, 0, 0
if provider == 'nsa':
results = triton.testing.do_bench(
lambda: parallel_nsa(q, k, v, indices, block_size),
lambda: parallel_nsa(q, k, v, block_indices=indices, block_size=block_size),
quantiles=quantiles
)
elif provider == 'nsa_bwd':
results = triton.testing.do_bench(
lambda: parallel_nsa(q, k, v, indices, block_size).backward(do),
lambda: parallel_nsa(q, k, v, block_indices=indices, block_size=block_size).backward(do),
quantiles=quantiles
)
elif provider == 'flash':
Expand Down
11 changes: 3 additions & 8 deletions benchmarks/ops/benchmark_retention.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,9 @@ def benchmark(T, provider):
B, H, D = 4, 8, 256
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

if provider == 'flash' or provider == 'flash_bwd':
q = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
k = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
v = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
else:
q = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
k = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
v = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
q = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
k = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
v = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
do = torch.ones_like(q, dtype=dtype)

quantiles = [0.5, 0.2, 0.8]
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/ops/benchmark_titans.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def time_bwd(func, *args, **kwargs):
v = torch.randn(
B, H, seqlen, headdim, device=device, requires_grad=True, dtype=dtype
)
w = torch.randn(H, headdim, device=device, requires_grad=True, dtype=dtype)
b = torch.randn(H, headdim, device=device, requires_grad=True, dtype=dtype)
w = torch.randn(seqlen, headdim, device=device, requires_grad=True, dtype=dtype)
b = torch.randn(seqlen, headdim, device=device, requires_grad=True, dtype=dtype)
theta = torch.rand(
B, H, seqlen, 1, dtype=dtype, device=device, requires_grad=True
)
Expand Down
30 changes: 15 additions & 15 deletions benchmarks/ops/benchmark_ttt.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,31 +54,31 @@ def time_bwd(func, *args, **kwargs):
for B, seqlen in bs_seqlen_vals:
config = (causal, headdim, B, seqlen)
H = dim // headdim
q = torch.randn(B, H, seqlen, headdim, device=device, requires_grad=True, dtype=dtype)
k = torch.randn(B, H, seqlen, headdim, device=device, requires_grad=True, dtype=dtype)
v = torch.randn(B, H, seqlen, headdim, device=device, requires_grad=True, dtype=dtype)
g = torch.randn(B, H, seqlen, headdim, device=device, dtype=dtype).sigmoid().requires_grad_(True) / 16
q = torch.randn(B, seqlen, H, headdim, device=device, requires_grad=True, dtype=dtype)
k = torch.randn(B, seqlen, H, headdim, device=device, requires_grad=True, dtype=dtype)
v = torch.randn(B, seqlen, H, headdim, device=device, requires_grad=True, dtype=dtype)
g = torch.randn(B, seqlen, H, headdim, device=device, dtype=dtype).sigmoid().requires_grad_(True) / 16
o1, _ = chunk_gla(q, k, v, g)
o1.sum().backward(retain_graph=True)
f_b = time_fwd_bwd(
chunk_gla, q, k, v, g, verbose=False
)
time_f_b[config, "chunk_gla"] = f_b

q = torch.randn(B, H, seqlen, headdim, device=device, requires_grad=True, dtype=dtype)
k = F.normalize(torch.randn(B, H, seqlen, headdim, device=device, dtype=dtype), p=2, dim=-1).requires_grad_(True)
v = torch.randn(B, H, seqlen, headdim, device=device, requires_grad=True, dtype=dtype)
beta = torch.rand(B, H, seqlen, device=device, dtype=dtype).sigmoid().requires_grad_(True)
q = torch.randn(B, seqlen, H, headdim, device=device, requires_grad=True, dtype=dtype)
k = F.normalize(torch.randn(B, seqlen, H, headdim, device=device, dtype=dtype), p=2, dim=-1).requires_grad_(True)
v = torch.randn(B, seqlen, H, headdim, device=device, requires_grad=True, dtype=dtype)
beta = torch.rand(B, seqlen, H, device=device, dtype=dtype).sigmoid().requires_grad_(True)
o2, _ = chunk_delta_rule(q, k, v, beta)
o2.sum().backward(retain_graph=True)
f_b = time_fwd_bwd(
chunk_delta_rule, q, k, v, beta, verbose=False
)
time_f_b[config, "chunk_delta_rule"] = f_b

q = torch.randn(B, H, seqlen, headdim, device=device, requires_grad=True, dtype=dtype)
k = F.normalize(torch.randn(B, H, seqlen, headdim, device=device, dtype=dtype), p=2, dim=-1).requires_grad_(True)
v = torch.randn(B, H, seqlen, headdim, device=device, requires_grad=True, dtype=dtype)
q = torch.randn(B, seqlen, H, headdim, device=device, requires_grad=True, dtype=dtype)
k = F.normalize(torch.randn(B, seqlen, H, headdim, device=device, dtype=dtype), p=2, dim=-1).requires_grad_(True)
v = torch.randn(B, seqlen, H, headdim, device=device, requires_grad=True, dtype=dtype)
w = torch.randn(H, headdim, device=device, requires_grad=True, dtype=dtype)
b = torch.randn(H, headdim, device=device, requires_grad=True, dtype=dtype)
eta = torch.rand(B, H, seqlen, 1, device=device, requires_grad=True, dtype=dtype) * 5e-3
Expand All @@ -89,12 +89,12 @@ def time_bwd(func, *args, **kwargs):
)
time_f_b[config, "chunk_ttt_linear"] = f_b

q = torch.randn(B, H, seqlen, headdim, device=device, requires_grad=True, dtype=dtype)
k = F.normalize(torch.randn(B, H, seqlen, headdim, device=device, dtype=dtype), p=2, dim=-1).requires_grad_(True)
v = torch.randn(B, H, seqlen, headdim, device=device, requires_grad=True, dtype=dtype)
q = torch.randn(B, seqlen, H, headdim, device=device, requires_grad=True, dtype=dtype)
k = F.normalize(torch.randn(B, seqlen, H, headdim, device=device, dtype=dtype), p=2, dim=-1).requires_grad_(True)
v = torch.randn(B, seqlen, H, headdim, device=device, requires_grad=True, dtype=dtype)
w = torch.randn(H, headdim, device=device, requires_grad=True, dtype=dtype)
b = torch.randn(H, headdim, device=device, requires_grad=True, dtype=dtype)
eta = torch.rand(B, H, seqlen, 1, device=device, requires_grad=True, dtype=dtype) * 5e-3
eta = torch.rand(B, seqlen, H, 1, device=device, requires_grad=True, dtype=dtype) * 5e-3
o4, _, _ = fused_chunk_ttt_linear(q, k, v, w, b, eta, chunk_size=16)
o4.sum().backward(retain_graph=True)
f_b = time_fwd_bwd(
Expand Down
Loading