Skip to content

Commit ee4f1d9

Browse files
committed
rebase and rerun EXPECTACCEPT=1
1 parent 6ce91c6 commit ee4f1d9

File tree

3 files changed

+19
-20
lines changed

3 files changed

+19
-20
lines changed

examples/matmul.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,16 @@
99
# %%
1010
from __future__ import annotations
1111

12-
from typing import Any, TYPE_CHECKING
13-
14-
import helion
15-
import helion.language as hl
12+
from typing import TYPE_CHECKING
13+
from typing import Any
1614

1715
import torch
18-
from helion._testing import run_example
1916
from torch import Tensor
2017

18+
import helion
19+
from helion._testing import run_example
20+
import helion.language as hl
21+
2122
if TYPE_CHECKING:
2223
from collections.abc import Callable
2324

test/test_examples.expected

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def addmm_bwd(grad_out: Tensor, bias: Tensor, mat1: Tensor, mat2: Tensor, alpha:
146146
_BLOCK_SIZE_5 = 16
147147
_BLOCK_SIZE_6 = 16
148148
_BLOCK_SIZE_7 = 16
149-
_launcher(_helion_addmm_bwd, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1) + triton.cdiv(m, _BLOCK_SIZE_2) * triton.cdiv(k, _BLOCK_SIZE_3) + triton.cdiv(k, _BLOCK_SIZE_5) * triton.cdiv(n, _BLOCK_SIZE_6),), grad_out, grad_input, mat2, grad_mat1, mat1, grad_mat2, grad_input.stride(0), grad_input.stride(1), grad_mat1.stride(0), grad_mat1.stride(1), grad_mat2.stride(0), grad_mat2.stride(1), grad_out.stride(0), grad_out.stride(1), mat1.stride(0), mat1.stride(1), mat2.stride(0), mat2.stride(1), m, n, beta, k, alpha, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, _BLOCK_SIZE_5, _BLOCK_SIZE_6, _BLOCK_SIZE_7, num_warps=4, num_stages=3)
149+
_launcher(_helion_addmm_bwd, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1) + triton.cdiv(m, _BLOCK_SIZE_2) * triton.cdiv(k, _BLOCK_SIZE_3) + triton.cdiv(k, _BLOCK_SIZE_5) * triton.cdiv(n, _BLOCK_SIZE_6),), grad_out, grad_input, mat2, grad_mat1, mat1, grad_mat2, grad_input.stride(0), grad_input.stride(1), grad_mat1.stride(0), grad_mat1.stride(1), grad_mat2.stride(0), grad_mat2.stride(1), grad_out.stride(0), grad_out.stride(1), mat1.stride(0), mat1.stride(1), mat2.stride(0), mat2.stride(1), m, n, beta, k, alpha, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, _BLOCK_SIZE_5, _BLOCK_SIZE_6, _BLOCK_SIZE_7, num_warps=4, num_stages=2)
150150
return (grad_input, grad_mat1, grad_mat2)
151151

152152
--- assertExpectedJournal(TestExamples.test_attention_block_pointer)
@@ -3050,7 +3050,7 @@ def matmul_bwd(grad_out: Tensor, mat1: Tensor, mat2: Tensor, *, _launcher=_defau
30503050
_BLOCK_SIZE_3 = 16
30513051
_BLOCK_SIZE_4 = 16
30523052
_BLOCK_SIZE_5 = 16
3053-
_launcher(_helion_matmul_bwd, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(k, _BLOCK_SIZE_1) + triton.cdiv(k, _BLOCK_SIZE_3) * triton.cdiv(n, _BLOCK_SIZE_4),), grad_out, mat2, grad_mat1, mat1, grad_mat2, grad_mat1.stride(0), grad_mat1.stride(1), grad_mat2.stride(0), grad_mat2.stride(1), grad_out.stride(0), grad_out.stride(1), mat1.stride(0), mat1.stride(1), mat2.stride(0), mat2.stride(1), m, k, n, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, _BLOCK_SIZE_5, num_warps=4, num_stages=3)
3053+
_launcher(_helion_matmul_bwd, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(k, _BLOCK_SIZE_1) + triton.cdiv(k, _BLOCK_SIZE_3) * triton.cdiv(n, _BLOCK_SIZE_4),), grad_out, mat2, grad_mat1, mat1, grad_mat2, grad_mat1.stride(0), grad_mat1.stride(1), grad_mat2.stride(0), grad_mat2.stride(1), grad_out.stride(0), grad_out.stride(1), mat1.stride(0), mat1.stride(1), mat2.stride(0), mat2.stride(1), m, k, n, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, _BLOCK_SIZE_5, num_warps=4, num_stages=2)
30543054
return (grad_mat1, grad_mat2)
30553055

30563056
--- assertExpectedJournal(TestExamples.test_matmul_layernorm_dynamic_shapes)

test/test_examples.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,19 @@
22

33
import unittest
44

5-
import helion
5+
from packaging import version
66
import torch
7-
from helion._testing import (
8-
check_example,
9-
DEVICE,
10-
EXAMPLES_DIR,
11-
import_path,
12-
RefEagerTestBase,
13-
skipIfRefEager,
14-
skipIfRocm,
15-
skipIfXPU,
16-
TestCase,
17-
)
187

19-
from packaging import version
8+
import helion
9+
from helion._testing import DEVICE
10+
from helion._testing import EXAMPLES_DIR
11+
from helion._testing import RefEagerTestBase
12+
from helion._testing import TestCase
13+
from helion._testing import check_example
14+
from helion._testing import import_path
15+
from helion._testing import skipIfRefEager
16+
from helion._testing import skipIfRocm
17+
from helion._testing import skipIfXPU
2018

2119
torch.backends.cuda.matmul.fp32_precision = "tf32"
2220
torch.backends.cudnn.conv.fp32_precision = "tf32"

0 commit comments

Comments
 (0)