Skip to content

Commit f7f20e9

Browse files
authored
Lint fixes for kernel folder (#1487)
1 parent c59bce5 commit f7f20e9

File tree

5 files changed

+43
-40
lines changed

5 files changed

+43
-40
lines changed

ruff.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ include = [
1010
"torchao/profiler/**/*.py",
1111
"torchao/testing/**/*.py",
1212
"torchao/_models/**/*.py",
13+
"torchao/kernel/**/*.py",
1314
"torchao/prototype/low_bit_optim/**.py",
1415
"torchao/utils.py",
1516
"torchao/ops.py",

torchao/kernel/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from torchao.kernel.intmm import int_scaled_matmul
2-
from torchao.kernel.intmm import safe_int_mm
1+
from torchao.kernel.intmm import int_scaled_matmul, safe_int_mm
32

43
__all__ = [
54
"safe_int_mm",

torchao/kernel/autotuner.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import logging
22
import os
33
import pathlib
4-
import pickle
54

65
import torch
76
import triton
@@ -173,7 +172,7 @@ def wrapped_fn():
173172
# Run it once and skip if it crashes or is 100x slower
174173
try:
175174
time = do_bench_basic(wrapped_fn, 1)
176-
except RuntimeError as e:
175+
except RuntimeError:
177176
time = None
178177
except triton.runtime.OutOfResources:
179178
time = None

torchao/kernel/intmm.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import itertools
21
import os
2+
33
import torch
44

55
from torchao.utils import TORCH_VERSION_AT_LEAST_2_2, TORCH_VERSION_AT_LEAST_2_6
@@ -21,6 +21,7 @@
2121
if TORCH_VERSION_AT_LEAST_2_2:
2222
from torch._dynamo import is_compiling as dynamo_is_compiling
2323
from torch._higher_order_ops.out_dtype import out_dtype
24+
2425
def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
2526
"""
2627
Performs a safe integer matrix multiplication, considering different paths for
@@ -40,7 +41,9 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
4041
if dynamo_is_compiling() or "FakeTensor" in input.__repr__():
4142
if input.device.type == "cpu":
4243
# Matmul in int32 is slow on CPU and not supported well by Inductor cpp backend
43-
return out_dtype(torch.ops.aten.mm.default, torch.int32, input.float(), mat2.float())
44+
return out_dtype(
45+
torch.ops.aten.mm.default, torch.int32, input.float(), mat2.float()
46+
)
4447
return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2)
4548

4649
# error checking for cublas path
@@ -60,9 +63,9 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
6063

6164
if device_cpu or bad_dimensions_for_cublas:
6265
# fallback path
63-
return torch.matmul(input.cpu().to(torch.int32), mat2.cpu().to(torch.int32)).to(
64-
input.device.type
65-
)
66+
return torch.matmul(
67+
input.cpu().to(torch.int32), mat2.cpu().to(torch.int32)
68+
).to(input.device.type)
6669

6770
# cublas paths
6871
if not mat2.is_contiguous(): # silently gives incorrect result without this
@@ -78,8 +81,11 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
7881
except Exception:
7982
# fallback path, would run on H100 for float8 dtypes
8083
# Exception on H100 float8 dtype : "addmm_cuda" not implemented for 'Float8_e4m3fn'
81-
return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to(torch.int32)
84+
return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to(
85+
torch.int32
86+
)
8287
else:
88+
8389
def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
8490
"""
8591
Performs a fallback integer matrix multiplication for torch versions before 2.2.
@@ -93,7 +99,9 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
9399
"""
94100
# We can improve on this by writing Triton code that works for older versions of Triton
95101
# that ship with 2.1 or 2.0.
96-
return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to(torch.int32)
102+
return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to(
103+
torch.int32
104+
)
97105

98106

99107
def int_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
@@ -113,7 +121,9 @@ def int_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
113121
return safe_int_mm(a, b)
114122

115123

116-
def int_scaled_matmul(a: torch.Tensor, b: torch.Tensor, scales1: torch.Tensor) -> torch.Tensor:
124+
def int_scaled_matmul(
125+
a: torch.Tensor, b: torch.Tensor, scales1: torch.Tensor
126+
) -> torch.Tensor:
117127
"""
118128
Performs scaled integer matrix multiplication.
119129

torchao/kernel/intmm_triton.py

Lines changed: 22 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,36 @@
11
import itertools
2-
import os
32

43
import torch
5-
64
import triton
75
import triton.language as tl
86

97
from torchao.kernel.autotuner import get_best_config_fn
108
from torchao.utils import TORCH_VERSION_AFTER_2_5
119

1210
# TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_SEARCH_SPACE=EXHAUSTIVE to enable exhaustive option
13-
int8_mm_kernel_configs = (
14-
sum(
11+
int8_mm_kernel_configs = sum(
12+
[
13+
# "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps"
1514
[
16-
# "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps"
17-
[
18-
(i, j, k, 1, 1),
19-
(i, j, k, 1, 2),
20-
(i, j, k, 2, 2),
21-
(i, j, k, 1, 4),
22-
(i, j, k, 2, 4),
23-
(i, j, k, 3, 4),
24-
(i, j, k, 4, 4),
25-
(i, j, k, 1, 8),
26-
(i, j, k, 2, 8),
27-
(i, j, k, 3, 8),
28-
(i, j, k, 4, 8),
29-
(i, j, k, 5, 8),
30-
(i, j, k, 6, 8),
31-
(i, j, k, 7, 8),
32-
(i, j, k, 8, 8),
33-
]
34-
for (i, j, k) in itertools.product(
35-
[32, 64, 128, 256], repeat=3
36-
)
37-
],
38-
[]
39-
)
15+
(i, j, k, 1, 1),
16+
(i, j, k, 1, 2),
17+
(i, j, k, 2, 2),
18+
(i, j, k, 1, 4),
19+
(i, j, k, 2, 4),
20+
(i, j, k, 3, 4),
21+
(i, j, k, 4, 4),
22+
(i, j, k, 1, 8),
23+
(i, j, k, 2, 8),
24+
(i, j, k, 3, 8),
25+
(i, j, k, 4, 8),
26+
(i, j, k, 5, 8),
27+
(i, j, k, 6, 8),
28+
(i, j, k, 7, 8),
29+
(i, j, k, 8, 8),
30+
]
31+
for (i, j, k) in itertools.product([32, 64, 128, 256], repeat=3)
32+
],
33+
[],
4034
)
4135

4236
if TORCH_VERSION_AFTER_2_5:

0 commit comments

Comments
 (0)