Skip to content

Commit 5d0cd02

Browse files
authored
Fixes in helion puzzles (#1104)
1 parent 19b6cf9 commit 5d0cd02

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

docs/helion_puzzles.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ A scalar version of FlashAttention.
433433
scores = q_tile[:, None] * k_tile[None, :]
434434
435435
# Find max for numerical stability
436-
batch_max = torch.max(scores, dim=1)[0]
436+
batch_max = torch.amax(scores, dim=1)
437437
new_max = torch.maximum(max_val, batch_max)
438438
439439
# Scale old accumulations
@@ -468,7 +468,7 @@ A batched 2D convolution.
468468
.. code-block:: python
469469
470470
def conv2d_spec(x: Float32[Tensor, "4 8 8"], k: Float32[Tensor, "4 4"]) -> Float32[Tensor, "4 8 8"]:
471-
z = torch.zeros(4, 8, 8)
471+
z = torch.zeros(4, 8, 8, device=x.device)
472472
x = torch.nn.functional.pad(x, (0, 4, 0, 4, 0, 0), value=0.0)
473473
for i in range(8):
474474
for j in range(8):
@@ -495,7 +495,7 @@ A batched 2D convolution.
495495
# Extract the patch
496496
patch = x_padded[tile_batch, i:i+kh, j:j+kw]
497497
# Apply the kernel
498-
out[tile_batch, i, j] = (k[tile_batch] * patch).sum([1, 2])
498+
out[tile_batch, i, j] = (k[tile_batch,:,:] * patch).sum([1, 2])
499499
500500
return out
501501

0 commit comments

Comments
 (0)