Skip to content

Commit 14e0247

Browse files
committed
[Titans] Update format
1 parent 12d0cbb commit 14e0247

File tree

2 files changed

+11
-11
lines changed

2 files changed

+11
-11
lines changed

fla/ops/titans/log_impl.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def cal_n_log(log_theta, log_eta, seq_len):
1616
log_n[..., j, i] = log_theta[..., j]
1717
else:
1818
log_n[..., j, i] = log_theta[..., j] + torch.sum(
19-
log_eta[..., j + 1 : i + 1], dim=-1
19+
log_eta[..., j + 1: i + 1], dim=-1
2020
)
2121

2222
return log_n
@@ -34,7 +34,7 @@ def cal_f_log(log_beta, seq_len, log_m):
3434
# f[..., t] += torch.exp(log_beta[..., t] - log_beta[..., i] + log_m[..., i])
3535
log_f = torch.zeros_like(log_beta)
3636
for t in range(seq_len):
37-
a_i = log_beta[..., t : t + 1] - log_beta[..., : t + 1] + log_m[..., : t + 1]
37+
a_i = log_beta[..., t: t + 1] - log_beta[..., : t + 1] + log_m[..., : t + 1]
3838
log_f[..., t] = torch.logsumexp(a_i, dim=-1)
3939
f = torch.exp(log_f)
4040

@@ -74,9 +74,9 @@ def cal_G_log(log_beta, log_n, seq_len):
7474
for i in range(seq_len): # row
7575
for j in range(i + 1): # column
7676
terms = (
77-
log_beta[..., i : i + 1]
78-
- log_beta[..., j : i + 1]
79-
+ log_n[..., j : j + 1, j : i + 1].squeeze(-2)
77+
log_beta[..., i: i + 1]
78+
- log_beta[..., j: i + 1]
79+
+ log_n[..., j: j + 1, j: i + 1].squeeze(-2)
8080
)
8181
# use logsumexp to avoid overflow
8282
log_G[..., i, j] = torch.logsumexp(terms, dim=-1)

fla/ops/titans/naive.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -154,12 +154,12 @@ def titans_linear(
154154
# Process sequence step by step
155155
for t in range(T):
156156
# Get current step inputs
157-
q_t = q[:, :, t : t + 1, :] # (batch_size, num_heads, 1, dim)
158-
k_t = k[:, :, t : t + 1, :] # (batch_size, num_heads, 1, dim)
159-
v_t = v[:, :, t : t + 1, :] # (batch_size, num_heads, 1, dim)
160-
theta_t = theta[:, :, t : t + 1, :] # (batch_size, num_heads, 1, dim)
161-
alpha_t = alpha[:, :, t : t + 1, :] # (batch_size, num_heads, 1, dim)
162-
eta_t = eta[:, :, t : t + 1, :] # (batch_size, num_heads, 1, dim)
157+
q_t = q[:, :, t: t + 1, :] # (batch_size, num_heads, 1, dim)
158+
k_t = k[:, :, t: t + 1, :] # (batch_size, num_heads, 1, dim)
159+
v_t = v[:, :, t: t + 1, :] # (batch_size, num_heads, 1, dim)
160+
theta_t = theta[:, :, t: t + 1, :] # (batch_size, num_heads, 1, dim)
161+
alpha_t = alpha[:, :, t: t + 1, :] # (batch_size, num_heads, 1, dim)
162+
eta_t = eta[:, :, t: t + 1, :] # (batch_size, num_heads, 1, dim)
163163

164164
# Compute gradient
165165
km = k_t @ M_prev_nabla # (batch_size, num_heads, 1, dim)

0 commit comments

Comments
 (0)