@@ -433,7 +433,7 @@ A scalar version of FlashAttention.
433433 scores = q_tile[:, None ] * k_tile[None , :]
434434
435435 # Find max for numerical stability
436- batch_max = torch.max (scores, dim = 1 )[ 0 ]
436+ batch_max = torch.amax (scores, dim = 1 )
437437 new_max = torch.maximum(max_val, batch_max)
438438
439439 # Scale old accumulations
@@ -468,7 +468,7 @@ A batched 2D convolution.
468468.. code-block :: python
469469
470470 def conv2d_spec (x : Float32[Tensor, " 4 8 8" ], k : Float32[Tensor, " 4 4" ]) -> Float32[Tensor, " 4 8 8" ]:
471- z = torch.zeros(4 , 8 , 8 )
471+ z = torch.zeros(4 , 8 , 8 , device = x.device )
472472 x = torch.nn.functional.pad(x, (0 , 4 , 0 , 4 , 0 , 0 ), value = 0.0 )
473473 for i in range (8 ):
474474 for j in range (8 ):
@@ -495,7 +495,7 @@ A batched 2D convolution.
495495 # Extract the patch
496496 patch = x_padded[tile_batch, i:i+ kh, j:j+ kw]
497497 # Apply the kernel
498- out[tile_batch, i, j] = (k[tile_batch] * patch).sum([1 , 2 ])
498+ out[tile_batch, i, j] = (k[tile_batch,:,: ] * patch).sum([1 , 2 ])
499499
500500 return out
501501
0 commit comments