Skip to content

Commit c7b8e13

Browse files
Revert D82355346
Differential Revision: D84169566 Pull Request resolved: #3132
1 parent a52a64a commit c7b8e13

File tree

3 files changed

+4
-27
lines changed

3 files changed

+4
-27
lines changed

test/quantization/test_da8w4_cpu.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import unittest
99

1010
import torch
11-
from torch._dynamo.utils import counters
1211
from torch.testing._internal import common_utils
1312
from torch.testing._internal.common_utils import (
1413
TestCase,
@@ -121,6 +120,7 @@ def test_8da4w_cpu(self, dtype, x_dim, bias, bs, sym_quant_a):
121120
@common_utils.parametrize("x_dim", [2, 3])
122121
@common_utils.parametrize("bias", [True, False])
123122
def test_8da4w_concat_linear_cpu(self, x_dim, bias):
123+
self.skipTest("Disabled for now")
124124
N, K = 64, 128
125125

126126
class Mod(torch.nn.Module):
@@ -163,15 +163,6 @@ def forward(self, x):
163163
# ensure the expected op occurs only once in the code after fusion
164164
# The trailing "(" is to avoid matching the op in the comment
165165
assert code[0].count("torch.ops.torchao.da8w4_linear_cpu.default(") == 1
166-
167-
# Ensure that when concat linear is enabled, fxgraph cache works
168-
# without being bypassed (fxgraph_cache_bypass = 0), indicating that
169-
# DA8W4ConcatLinearCPUPass properly implements the CustomGraphPass
170-
# interface and uuid() function, allowing fxgraph to be saved and hit
171-
# on subsequent runs (fxgraph_cache_hit > 0).
172-
fx_cache_bypass_count = counters["inductor"]["fxgraph_cache_bypass"]
173-
assert fx_cache_bypass_count == 0
174-
175166
with torch._inductor.config.patch(
176167
{"freezing": True, "cpp.enable_concat_linear": False}
177168
):
@@ -181,10 +172,6 @@ def forward(self, x):
181172
)
182173
assert torch.allclose(y, y_ref)
183174

184-
# Ensure that the fxgraph cache is also not bypassed when concat linear is disabled
185-
fx_cache_bypass_count = counters["inductor"]["fxgraph_cache_bypass"]
186-
assert fx_cache_bypass_count == 0
187-
188175

189176
common_utils.instantiate_parametrized_tests(TestDa8w4Cpu)
190177

torchao/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,6 @@ def _linear_int8_act_int4_weight_cpu_impl(input_tensor, weight_tensor, bias):
314314

315315

316316
# Register the concat linear fusion pass
317-
from ...prototype.inductor.fx_passes import register_da8w4_concat_linear_cpu_pass
317+
# from ...prototype.inductor.fx_passes import register_da8w4_concat_linear_cpu_pass
318318

319-
register_da8w4_concat_linear_cpu_pass()
319+
# register_da8w4_concat_linear_cpu_pass()

torchao/prototype/inductor/fx_passes/da8w4_concat_linear_fusion_cpu.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,6 @@
77
import operator
88

99
import torch
10-
from torch._inductor.custom_graph_pass import CustomGraphPass, get_hash_for_files
11-
12-
13-
class DA8W4ConcatLinearCPUPass(CustomGraphPass):
14-
def __call__(self, graph: torch.fx.Graph):
15-
_concat_linear_dq8w4_cpu(graph)
16-
17-
def uuid(self):
18-
return get_hash_for_files((__file__,))
1910

2011

2112
# Inductor FX passes for concat linear for DA8W4
@@ -222,5 +213,4 @@ def ...
222213
def register_da8w4_concat_linear_cpu_pass():
223214
from torch._inductor import config as inductor_config
224215

225-
da8w4_concat_linear_cpu_pass = DA8W4ConcatLinearCPUPass()
226-
inductor_config.post_grad_custom_post_pass = da8w4_concat_linear_cpu_pass
216+
inductor_config.post_grad_custom_post_pass = _concat_linear_dq8w4_cpu

0 commit comments

Comments
 (0)