1
- import itertools
2
1
import os
2
+
3
3
import torch
4
4
5
5
from torchao .utils import TORCH_VERSION_AT_LEAST_2_2 , TORCH_VERSION_AT_LEAST_2_6
21
21
if TORCH_VERSION_AT_LEAST_2_2 :
22
22
from torch ._dynamo import is_compiling as dynamo_is_compiling
23
23
from torch ._higher_order_ops .out_dtype import out_dtype
24
+
24
25
def safe_int_mm (input : torch .Tensor , mat2 : torch .Tensor ) -> torch .Tensor :
25
26
"""
26
27
Performs a safe integer matrix multiplication, considering different paths for
@@ -40,7 +41,9 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
40
41
if dynamo_is_compiling () or "FakeTensor" in input .__repr__ ():
41
42
if input .device .type == "cpu" :
42
43
# Matmul in int32 is slow on CPU and not supported well by Inductor cpp backend
43
- return out_dtype (torch .ops .aten .mm .default , torch .int32 , input .float (), mat2 .float ())
44
+ return out_dtype (
45
+ torch .ops .aten .mm .default , torch .int32 , input .float (), mat2 .float ()
46
+ )
44
47
return out_dtype (torch .ops .aten .mm .default , torch .int32 , input , mat2 )
45
48
46
49
# error checking for cublas path
@@ -60,9 +63,9 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
60
63
61
64
if device_cpu or bad_dimensions_for_cublas :
62
65
# fallback path
63
- return torch .matmul (input . cpu (). to ( torch . int32 ), mat2 . cpu (). to ( torch . int32 )). to (
64
- input .device . type
65
- )
66
+ return torch .matmul (
67
+ input .cpu (). to ( torch . int32 ), mat2 . cpu (). to ( torch . int32 )
68
+ ). to ( input . device . type )
66
69
67
70
# cublas paths
68
71
if not mat2 .is_contiguous (): # silently gives incorrect result without this
@@ -78,8 +81,11 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
78
81
except Exception :
79
82
# fallback path, would run on H100 for float8 dtypes
80
83
# Exception on H100 float8 dtype : "addmm_cuda" not implemented for 'Float8_e4m3fn'
81
- return torch .matmul (input .to (torch .float32 ), mat2 .to (torch .float32 )).to (torch .int32 )
84
+ return torch .matmul (input .to (torch .float32 ), mat2 .to (torch .float32 )).to (
85
+ torch .int32
86
+ )
82
87
else :
88
+
83
89
def safe_int_mm (input : torch .Tensor , mat2 : torch .Tensor ) -> torch .Tensor :
84
90
"""
85
91
Performs a fallback integer matrix multiplication for torch versions before 2.2.
@@ -93,7 +99,9 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
93
99
"""
94
100
# We can improve on this by writing Triton code that works for older versions of Triton
95
101
# that ship with 2.1 or 2.0.
96
- return torch .matmul (input .to (torch .float32 ), mat2 .to (torch .float32 )).to (torch .int32 )
102
+ return torch .matmul (input .to (torch .float32 ), mat2 .to (torch .float32 )).to (
103
+ torch .int32
104
+ )
97
105
98
106
99
107
def int_matmul (a : torch .Tensor , b : torch .Tensor ) -> torch .Tensor :
@@ -113,7 +121,9 @@ def int_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
113
121
return safe_int_mm (a , b )
114
122
115
123
116
- def int_scaled_matmul (a : torch .Tensor , b : torch .Tensor , scales1 : torch .Tensor ) -> torch .Tensor :
124
+ def int_scaled_matmul (
125
+ a : torch .Tensor , b : torch .Tensor , scales1 : torch .Tensor
126
+ ) -> torch .Tensor :
117
127
"""
118
128
Performs scaled integer matrix multiplication.
119
129
0 commit comments