@@ -154,12 +154,12 @@ def titans_linear(
154154 # Process sequence step by step
155155 for t in range (T ):
156156 # Get current step inputs
157- q_t = q [:, :, t : t + 1 , :] # (batch_size, num_heads, 1, dim)
158- k_t = k [:, :, t : t + 1 , :] # (batch_size, num_heads, 1, dim)
159- v_t = v [:, :, t : t + 1 , :] # (batch_size, num_heads, 1, dim)
160- theta_t = theta [:, :, t : t + 1 , :] # (batch_size, num_heads, 1, dim)
161- alpha_t = alpha [:, :, t : t + 1 , :] # (batch_size, num_heads, 1, dim)
162- eta_t = eta [:, :, t : t + 1 , :] # (batch_size, num_heads, 1, dim)
157+ q_t = q [:, :, t : t + 1 , :] # (batch_size, num_heads, 1, dim)
158+ k_t = k [:, :, t : t + 1 , :] # (batch_size, num_heads, 1, dim)
159+ v_t = v [:, :, t : t + 1 , :] # (batch_size, num_heads, 1, dim)
160+ theta_t = theta [:, :, t : t + 1 , :] # (batch_size, num_heads, 1, dim)
161+ alpha_t = alpha [:, :, t : t + 1 , :] # (batch_size, num_heads, 1, dim)
162+ eta_t = eta [:, :, t : t + 1 , :] # (batch_size, num_heads, 1, dim)
163163
164164 # Compute gradient
165165 km = k_t @ M_prev_nabla # (batch_size, num_heads, 1, dim)
0 commit comments