Skip to content

Commit 909d09c

Browse files
authored
Implement verify_correctness pytorch#179 (pytorch#252)
* Wrapperbackend to enable verifying corretness of backends; set config.verify_correctness as True to enable it. * move testing.same() to utils.py
1 parent 83d4cab commit 909d09c

File tree

6 files changed

+322
-58
lines changed

6 files changed

+322
-58
lines changed

tests/test_verify_correctness.py

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
#!/usr/bin/env pytest
2+
import importlib
3+
import operator
4+
import unittest
5+
6+
import torch
7+
8+
import torchdynamo
9+
import torchdynamo.config as config
10+
from torchdynamo.optimizations import backends
11+
from torchdynamo.optimizations.inference import fixed_strategy1
12+
from torchdynamo.optimizations.inference import offline_autotuner
13+
from torchdynamo.testing import same
14+
15+
16+
def has_onnxruntime():
17+
try:
18+
importlib.import_module("onnxruntime")
19+
return True
20+
except ImportError:
21+
return False
22+
23+
24+
def has_ipex():
25+
try:
26+
importlib.import_module("intel_extension_for_pytorch")
27+
return True
28+
except ImportError:
29+
return False
30+
31+
32+
class Seq(torch.nn.Module):
33+
def __init__(self):
34+
super().__init__()
35+
self.layers = torch.nn.Sequential(
36+
torch.nn.Linear(10, 10),
37+
torch.nn.ReLU(),
38+
torch.nn.Linear(10, 10),
39+
torch.nn.Sigmoid(),
40+
)
41+
42+
def forward(self, x):
43+
return self.layers(x)
44+
45+
46+
class Conv_Bn_Relu(torch.nn.Module):
47+
def __init__(self, in_channels, out_channels, **kwargs):
48+
super(Conv_Bn_Relu, self).__init__()
49+
self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
50+
self.bn = torch.nn.BatchNorm2d(out_channels, eps=0.001)
51+
self.relu = torch.nn.ReLU()
52+
53+
def forward(self, x):
54+
return self.relu(self.bn(self.conv(x)))
55+
56+
57+
def toy_example(a, b):
58+
x = a / (torch.abs(a) + 1)
59+
if b.sum() < 0:
60+
b = b * -1
61+
return x * b
62+
63+
64+
def transform(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
65+
for node in gm.graph.nodes:
66+
# Checks if we're calling a function (i.e:
67+
# operator.add)
68+
if node.op == "call_function":
69+
# The target attribute is the function
70+
# that call_function calls.
71+
if node.target == operator.mul:
72+
node.target = operator.add
73+
74+
gm.graph.lint() # Does some checks to make sure the
75+
# Graph is well-formed.
76+
77+
gm.recompile()
78+
return gm
79+
80+
81+
config.verify_correctness = True
82+
83+
84+
class TestVerifyCorrectness(torchdynamo.testing.TestCase):
85+
def test_example_inputs(self):
86+
def fn(a, bc, d):
87+
b, c = bc
88+
return a / d - b / c
89+
90+
def compiler_fn(graph, example_inputs):
91+
nonlocal r1
92+
r1 = graph(*example_inputs)[0]
93+
return graph.forward
94+
95+
a = torch.empty(2).fill_(1)
96+
b = torch.empty(2).fill_(2)
97+
c = torch.empty(2).fill_(3)
98+
d = 4
99+
r1 = None
100+
r2 = fn(a, (b, c), d)
101+
with torchdynamo.optimize_assert(compiler_fn):
102+
r3 = fn(a, (b, c), d)
103+
104+
self.assertIsNotNone(r1)
105+
self.assertTrue(same(r1, r2))
106+
self.assertTrue(same(r1, r3))
107+
108+
def test_fixed_strategy1(self):
109+
s = Seq()
110+
i = torch.randn(10)
111+
r1 = s(i)
112+
with torchdynamo.optimize(fixed_strategy1):
113+
r2 = s(i)
114+
self.assertTrue(same(r1, r2))
115+
116+
def test_nnc(self):
117+
s = Seq()
118+
i = torch.randn(10)
119+
r1 = s(i)
120+
with torchdynamo.optimize("nnc"):
121+
r2 = s(i)
122+
self.assertTrue(same(r1, r2))
123+
124+
def test_incorrect_verify_true(self):
125+
"""
126+
Even the bad optimization return a graph that
127+
is not functionally equal to the original graph;
128+
When config.verify_correctness=True, it will
129+
check the correctness of outputs and fallback using
130+
the original graph
131+
"""
132+
i1 = torch.randn(10)
133+
i2 = torch.randn(10)
134+
135+
def incorrect_compile_fn(gm, example_inputs):
136+
return transform(gm).forward
137+
138+
r1 = toy_example(i1, i2)
139+
with torchdynamo.optimize(incorrect_compile_fn):
140+
r2 = toy_example(i1, i2)
141+
self.assertTrue(same(r1, r2))
142+
143+
def test_incorrect_verify_false(self):
144+
config.verify_correctness = False
145+
"""
146+
The bad optimization return a graph that
147+
is not functionally equal to the original graph;
148+
When config.verify_correctness=False, wrong outputs
149+
will return
150+
"""
151+
i1 = torch.randn(10)
152+
i2 = torch.randn(10)
153+
154+
def incorrect_compile_fn(gm, example_inputs):
155+
return transform(gm).forward
156+
157+
r1 = toy_example(i1, i2)
158+
with torchdynamo.optimize(incorrect_compile_fn):
159+
r2 = toy_example(i1, i2)
160+
self.assertTrue(not same(r1, r2))
161+
config.verify_correctness = True
162+
163+
@unittest.skipIf(not has_onnxruntime(), "requires onnxruntime")
164+
def test_export(self):
165+
s = Seq()
166+
i = torch.randn(10)
167+
r1 = s(i)
168+
with torchdynamo.optimize_assert(offline_autotuner):
169+
r2 = s(i)
170+
self.assertTrue(same(r1, r2))
171+
172+
@unittest.skipIf(not has_ipex(), "requires ipex")
173+
def test_ipex_fp32(self):
174+
model = Conv_Bn_Relu(3, 32, kernel_size=3, stride=1)
175+
model = model.to(memory_format=torch.channels_last)
176+
model = model.eval()
177+
input = torch.randn(8, 3, 64, 64).contiguous(memory_format=torch.channels_last)
178+
r1 = model(input)
179+
with torchdynamo.optimize(backends.ipex_fp32), torch.no_grad():
180+
r2 = model(input)
181+
self.assertTrue(same(r1, r2))
182+
self.assertEqual(r2.dtype, torch.float32)
183+
184+
@unittest.skipIf(not has_ipex(), "requires ipex")
185+
def test_ipex_bf16(self):
186+
model = Conv_Bn_Relu(3, 32, kernel_size=3, stride=1)
187+
model = model.to(memory_format=torch.channels_last)
188+
model = model.eval()
189+
input = torch.randn(8, 3, 64, 64).contiguous(memory_format=torch.channels_last)
190+
r1 = model(input)
191+
with torchdynamo.optimize(
192+
backends.ipex_bf16
193+
), torch.no_grad(), torch.cpu.amp.autocast():
194+
r2 = model(input)
195+
self.assertTrue(same(r1, r2.float(), tol=0.1))
196+
self.assertEqual(r2.dtype, torch.bfloat16)

torchdynamo/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
# print out lots of stuff
88
debug = False
99

10+
# verify the correctness of optimized backend
11+
verify_correctness = False
12+
1013
# an unreasonable amount of debug printouts
1114
trace = False
1215

torchdynamo/convert_frame.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .bytecode_analysis import remove_pointless_jumps
1818
from .bytecode_transformation import is_generator
1919
from .bytecode_transformation import transform_code_object
20+
from .eval_frame import WrapperBackend
2021
from .eval_frame import skip_code
2122
from .exc import InternalTorchDynamoError
2223
from .exc import TorchRuntimeError
@@ -59,7 +60,7 @@ def fx_forward_from_src_skip_result(*args, **kwargs):
5960
return result
6061

6162

62-
def wrap_compiler_fn(compiler_fn):
63+
def _wrap_compiler_fn(compiler_fn):
6364
"""Expand backend strings to functions"""
6465
if compiler_fn == "inductor":
6566
from torchinductor.compile_fx import compile_fx
@@ -73,6 +74,19 @@ def wrap_compiler_fn(compiler_fn):
7374
return compiler_fn
7475

7576

77+
def wrap_compiler_fn(compiler_fn):
78+
"""WrapperBackend if config.verify_correctness is True"""
79+
wrapped_compiler_fn = _wrap_compiler_fn(compiler_fn)
80+
81+
if config.verify_correctness:
82+
# wrap backend if verify_correctness is True
83+
wrapper_backend_compiler_fn = WrapperBackend(wrapped_compiler_fn)
84+
85+
return wrapper_backend_compiler_fn
86+
87+
return wrapped_compiler_fn
88+
89+
7690
def wrap_convert_context(fn):
7791
"""
7892
Context manager to:

torchdynamo/eval_frame.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,21 @@
11
import contextlib
2+
import copy
23
import functools
34
import logging
45
import threading
56

7+
import torch
8+
9+
from torchdynamo.utils import checkpoint_params
10+
from torchdynamo.utils import clone_inputs
11+
612
from . import config
713
from . import convert_frame
814
from . import skipfiles
915
from .mutation_guard import install_generation_tagging_new
16+
from .utils import same
17+
18+
log = logging.getLogger(__name__)
1019

1120
try:
1221
from . import _eval_frame
@@ -124,6 +133,47 @@ def _optimize_catch_errors(compile_fn, backend_ctx_ctor=null_context):
124133
)
125134

126135

136+
class WrapperBackend:
137+
def __init__(self, backend=None):
138+
self.backend = backend
139+
140+
@property
141+
def example_inputs(self):
142+
return clone_inputs(self.original_example_inputs)
143+
144+
def __call__(self, gm: torch.fx.GraphModule, example_inputs):
145+
146+
self.restore = checkpoint_params(gm)
147+
self.original_example_inputs = clone_inputs(example_inputs)
148+
self.gm = gm
149+
copy_gm = copy.deepcopy(self.gm)
150+
self.candidate = self.backend(copy_gm, self.original_example_inputs)
151+
152+
if self.candidate is None or self.candidate is self.gm.forward:
153+
return self.gm.forward
154+
155+
if not config.verify_correctness:
156+
return self.candidate
157+
158+
# if verify_correctness=True
159+
try:
160+
correct = self.gm.forward(*self.example_inputs)
161+
result = self.candidate(*self.example_inputs)
162+
163+
# TODO: replace `same` function with the one in testing
164+
if same(correct, result):
165+
return self.candidate
166+
167+
print(f"incorrect results of backend {self}")
168+
return self.gm.forward
169+
170+
except Exception:
171+
log.exception("error in verify_correctness")
172+
return self.gm.forward
173+
finally:
174+
self.restore()
175+
176+
127177
def optimize(backend, nopython=False):
128178
"""
129179
The main entrypoint of TorchDynamo. Do graph capture and call

torchdynamo/testing.py

Lines changed: 1 addition & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .bytecode_transformation import is_generator
1818
from .bytecode_transformation import transform_code_object
1919
from .guards import GuardedCode
20+
from .utils import same
2021

2122
unsupported = torchdynamo.eval_frame.unsupported
2223
three = 3
@@ -69,63 +70,6 @@ def reduce_to_scalar_loss(out):
6970
raise NotImplementedError("Don't know how to reduce")
7071

7172

72-
def same(a, b, cos_similarity=False, tol=1e-4, equal_nan=False):
73-
"""Check correctness to see if a and b match"""
74-
if isinstance(a, (list, tuple, torch.nn.ParameterList, torch.Size)):
75-
assert isinstance(b, (list, tuple)), f"type mismatch {type(a)} {type(b)}"
76-
return len(a) == len(b) and all(
77-
same(ai, bi, cos_similarity, tol, equal_nan) for ai, bi in zip(a, b)
78-
)
79-
elif isinstance(a, dict):
80-
assert isinstance(b, dict)
81-
assert set(a.keys()) == set(
82-
b.keys()
83-
), f"keys mismatch {set(a.keys())} == {set(b.keys())}"
84-
for k in a.keys():
85-
if not (same(a[k], b[k], cos_similarity, tol, equal_nan=equal_nan)):
86-
print("Accuracy failed for key name", k)
87-
return False
88-
return True
89-
elif isinstance(a, torch.Tensor):
90-
if a.is_sparse:
91-
assert b.is_sparse
92-
a = a.to_dense()
93-
b = b.to_dense()
94-
assert isinstance(b, torch.Tensor)
95-
if cos_similarity:
96-
# TRT will bring error loss larger than current threshold. Use cosine similarity as replacement
97-
a = a.flatten().to(torch.float32)
98-
b = b.flatten().to(torch.float32)
99-
res = torch.nn.functional.cosine_similarity(a, b, dim=0, eps=1e-6)
100-
if res < 0.99:
101-
print(f"Similarity score={res.cpu().numpy()}")
102-
return res >= 0.99
103-
else:
104-
return torch.allclose(a, b, atol=tol, rtol=tol, equal_nan=equal_nan)
105-
elif isinstance(a, (str, int, float, type(None), bool, torch.device)):
106-
return a == b
107-
elif type(a).__name__ in (
108-
"MaskedLMOutput",
109-
"Seq2SeqLMOutput",
110-
"CausalLMOutputWithCrossAttentions",
111-
"LongformerMaskedLMOutput",
112-
"Instances",
113-
"SquashedNormal",
114-
"Boxes",
115-
"Normal",
116-
"TanhTransform",
117-
"Foo",
118-
"Variable",
119-
):
120-
assert type(a) is type(b)
121-
return all(
122-
same(getattr(a, key), getattr(b, key), cos_similarity, tol, equal_nan)
123-
for key in a.__dict__.keys()
124-
)
125-
else:
126-
raise RuntimeError(f"unsupported type: {type(a).__name__}")
127-
128-
12973
def debug_dir():
13074
path = os.path.join(os.path.dirname(__file__), "../debug")
13175
if not os.path.exists(path):

0 commit comments

Comments
 (0)