Skip to content

Commit 2ca721c

Browse files
SherlockNoMadpytorchmergebot
authored andcommitted
An improved version of subgraph matcher (pytorch#82090)
This new version of subgraph matcher further supports - optionally match with pattern's placeholder and output nodes - patterns with multiple outputs - filtering out non-containing matches - filtering out overlapping matches TODOs: - [x] Update replace_pattern() to use this matcher - [x] Fix cases with identical anchor - [x] Introduce wildcard matching, such Any, OneOf - [ ] Improve node comparer to match args and kwargs values Pull Request resolved: pytorch#82090 Approved by: https://github.com/ezyang
1 parent 59b1c4e commit 2ca721c

File tree

4 files changed

+621
-3
lines changed

4 files changed

+621
-3
lines changed

test/test_fx_passes.py

Lines changed: 370 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Owner(s): ["module: fx.passes"]
22

3+
from dataclasses import dataclass
34
import operator
45
import logging
56

@@ -9,6 +10,7 @@
910
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
1011
from torch.fx.passes.operator_support import OperatorSupport
1112
from torch.fx.passes.utils.fuser_utils import fuse_by_partitions
13+
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
1214

1315
from torch.testing._internal.common_utils import run_tests, parametrize, instantiate_parametrized_tests
1416
from torch.testing._internal.jit_utils import JitTestCase
@@ -163,6 +165,8 @@ class MockOperatorSupport(OperatorSupport):
163165
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
164166
return node.op == "call_function" and node.target in {operator.add}
165167

168+
169+
@instantiate_parametrized_tests
166170
class TestFXGraphPasses(JitTestCase):
167171

168172
@parametrize("fn, expected_partition", [
@@ -270,7 +274,372 @@ def test_fuser_util_xfail(self, partition):
270274
with self.assertRaises(Exception):
271275
fuse_by_partitions(gm, partitions)
272276

273-
instantiate_parametrized_tests(TestFXGraphPasses)
277+
@dataclass
278+
class TestCase:
279+
match_output: bool
280+
match_placeholder: bool
281+
num_matches: int
282+
remove_overlapping_matches: bool = True
283+
284+
class SingleNodePattern:
285+
@staticmethod
286+
def forward(x):
287+
val = torch.neg(x)
288+
return torch.add(val, val)
289+
290+
@staticmethod
291+
def pattern(a):
292+
return torch.neg(a)
293+
294+
test_cases = [
295+
# match_output, match_placeholder, num_matches
296+
TestCase(False, False, 1),
297+
TestCase(True, False, 0),
298+
TestCase(False, True, 1),
299+
TestCase(True, True, 0)
300+
]
301+
class SimplePattern:
302+
@staticmethod
303+
def forward(x, w1, w2):
304+
m1 = torch.cat([w1, w2]).sum()
305+
m2 = torch.cat([w2, w1]).sum()
306+
m3 = torch.cat([m1, m2]).sum()
307+
return x + torch.max(m1) + torch.max(m2) + m3
308+
309+
@staticmethod
310+
def pattern(a, b):
311+
return torch.cat([a, b]).sum()
312+
313+
test_cases = [
314+
# match_output, match_placeholder, num_matches
315+
TestCase(False, False, 3),
316+
TestCase(True, False, 0),
317+
TestCase(False, True, 2),
318+
TestCase(True, True, 0)
319+
]
320+
321+
class SimpleFullGraphMatching:
322+
@staticmethod
323+
def forward(x):
324+
a = torch.neg(x)
325+
return torch.add(a, a)
326+
327+
@staticmethod
328+
def pattern(x):
329+
a = torch.neg(x)
330+
return torch.add(a, a)
331+
332+
test_cases = [
333+
# match_output, match_placeholder, num_matches
334+
TestCase(False, False, 1),
335+
TestCase(True, False, 1),
336+
TestCase(False, True, 1),
337+
TestCase(True, True, 1)
338+
]
339+
340+
class DiamondShapePatternTestCase:
341+
@staticmethod
342+
def forward(x):
343+
a = torch.neg(x)
344+
345+
a = a.relu()
346+
left = a.sigmoid()
347+
right = a.relu()
348+
out = left + right
349+
350+
return out
351+
352+
@staticmethod
353+
def pattern(a):
354+
a = a.relu()
355+
left = a.sigmoid()
356+
right = a.relu()
357+
out = left + right
358+
return out
359+
360+
test_cases = [
361+
# match_output, match_placeholder, num_matches
362+
TestCase(False, False, 1),
363+
TestCase(True, False, 1),
364+
TestCase(False, True, 0),
365+
TestCase(True, True, 0)
366+
]
367+
368+
class NonFullyContainedMatches:
369+
@staticmethod
370+
def forward(x, w1, w2, b1, b2):
371+
# fully contained matched subgraph
372+
m1 = torch.cat([w1, w2])
373+
m2 = torch.cat([x, b2])
374+
t0 = torch.addmm(b1, m1, m2.t())
375+
t0_sum = torch.sum(t0) # use of t0 is not leaking
376+
377+
# leaking matched subgraph, m3 is leaked
378+
m3 = torch.cat([w1, w2])
379+
m4 = torch.cat([x, b2])
380+
t1 = torch.addmm(b1, m3, m4.t())
381+
m3_sum = torch.sum(m3)
382+
383+
return t0_sum, m3_sum
384+
385+
@staticmethod
386+
def pattern(x, w1, w2, b1, b2):
387+
m1 = torch.cat([w1, w2])
388+
m2 = torch.cat([x, b2])
389+
return torch.addmm(b1, m1, m2.t())
390+
391+
test_cases = [
392+
# match_output, match_placeholder, num_matches
393+
TestCase(False, False, 1),
394+
395+
TestCase(True, False, 0),
396+
397+
TestCase(False, True, 1), # leaked used of placeholder is not leaking
398+
]
399+
400+
class ChainRepeatedPattern:
401+
@staticmethod
402+
def forward(x):
403+
x = torch.sigmoid(x)
404+
x = torch.sigmoid(x)
405+
x = torch.sigmoid(x)
406+
return torch.sigmoid(x)
407+
408+
@staticmethod
409+
def pattern(x):
410+
return torch.sigmoid(torch.sigmoid(x))
411+
412+
test_cases = [
413+
# match_output, match_placeholder, num_matches
414+
TestCase(False, False, 3, remove_overlapping_matches=False),
415+
TestCase(False, False, 2, remove_overlapping_matches=True),
416+
TestCase(True, False, 1),
417+
TestCase(False, True, 1),
418+
TestCase(True, True, 0)
419+
]
420+
421+
class QuantizationModel:
422+
@staticmethod
423+
def forward(x):
424+
x += 3
425+
x = x.dequantize()
426+
x = torch.sigmoid(x)
427+
x = x.to(torch.float16)
428+
return x
429+
430+
@staticmethod
431+
def pattern(x):
432+
x = x.dequantize()
433+
x = torch.sigmoid(x)
434+
x = x.to(torch.float16)
435+
return x
436+
437+
test_cases = [
438+
# match_output, match_placeholder, num_matches
439+
TestCase(False, False, 1),
440+
TestCase(True, False, 1),
441+
TestCase(False, True, 0),
442+
TestCase(True, True, 0)
443+
]
444+
445+
class MultipleOutputsWithDependency:
446+
@staticmethod
447+
def forward(x):
448+
y = x.relu()
449+
z = y.sigmoid()
450+
return z, y
451+
452+
@staticmethod
453+
def pattern(a):
454+
b = a.relu()
455+
c = b.sigmoid()
456+
return b, c # outputs have data dependency
457+
458+
test_cases = [
459+
# match_output, match_placeholder, num_matches
460+
TestCase(False, False, 1),
461+
TestCase(True, False, 0),
462+
TestCase(False, True, 1),
463+
TestCase(True, True, 0)
464+
]
465+
466+
class MultipleOutputsWithoutDependency:
467+
@staticmethod
468+
def forward(x):
469+
x = x + 1
470+
471+
# target subgraph to match
472+
x = x.relu()
473+
z = x.sum()
474+
y = x.sigmoid()
475+
476+
out = y.sigmoid() + z.sum()
477+
return out
478+
479+
@staticmethod
480+
def pattern(a):
481+
a = a.relu()
482+
b = a.sigmoid()
483+
c = a.sum()
484+
return b, c
485+
486+
test_cases = [
487+
# match_output, match_placeholder, num_matches
488+
TestCase(False, False, 1),
489+
TestCase(True, False, 0),
490+
TestCase(False, True, 0),
491+
TestCase(True, True, 0)
492+
]
493+
494+
class MultipleOutputsMultipleOverlappingMatches:
495+
@staticmethod
496+
def forward(x):
497+
x = x + 1
498+
499+
# target subgraph to match
500+
x = x.relu()
501+
z = x.sum()
502+
z1 = x.sum()
503+
y = x.sigmoid()
504+
y1 = x.sigmoid()
505+
506+
return z + z1 + y + y1
507+
508+
@staticmethod
509+
def pattern(a):
510+
a = a.relu()
511+
b = a.sigmoid()
512+
c = a.sum()
513+
return a, b, c
514+
515+
test_cases = [
516+
# match_output, match_placeholder, num_matches
517+
TestCase(False, False, 4, remove_overlapping_matches=False),
518+
TestCase(False, False, 1, remove_overlapping_matches=True),
519+
]
520+
521+
class MultipleOutputsMultipleNonOverlappingMatches:
522+
@staticmethod
523+
def forward(x):
524+
x = x + 1
525+
526+
# target subgraph to match
527+
x = x.relu()
528+
z = x.sum()
529+
y = x.sigmoid()
530+
531+
x = x.relu()
532+
z1 = x.sum()
533+
y1 = x.sigmoid()
534+
535+
return z + z1 + y + y1
536+
537+
@staticmethod
538+
def pattern(a):
539+
a = a.relu()
540+
b = a.sigmoid()
541+
c = a.sum()
542+
return b, c
543+
544+
test_cases = [
545+
# match_output, match_placeholder, num_matches
546+
TestCase(False, False, 1),
547+
]
548+
549+
class MultipleOutputsIdenticalAnchor:
550+
@staticmethod
551+
def forward(x):
552+
x = x + 1
553+
554+
# target subgraph to match
555+
x = x.relu()
556+
y = x.sigmoid()
557+
y1 = x.sigmoid()
558+
559+
return y, y1
560+
561+
@staticmethod
562+
def pattern(a):
563+
a = a.relu()
564+
b = a.sigmoid()
565+
b1 = a.sigmoid()
566+
return b, b1
567+
568+
test_cases = [
569+
# match_output, match_placeholder, num_matches
570+
# (False, False, 2), # FIXME: currently still matches to 2, should fix to 1
571+
TestCase(True, False, 1),
572+
TestCase(False, True, 0),
573+
]
574+
575+
576+
class MultipleOutputsHorizontalPattern:
577+
@staticmethod
578+
def forward(x):
579+
x = x + 1
580+
581+
# target subgraph to match
582+
y1 = x.relu()
583+
y2 = x.sigmoid()
584+
585+
return y1, y2
586+
587+
@staticmethod
588+
def pattern(a):
589+
b1 = a.relu()
590+
b2 = a.sigmoid()
591+
592+
return b1, b2
593+
594+
test_cases = [
595+
# match_output, match_placeholder, num_matches
596+
TestCase(False, False, 1),
597+
TestCase(True, False, 1),
598+
TestCase(False, True, 0),
599+
TestCase(True, True, 0)
600+
]
601+
602+
603+
@instantiate_parametrized_tests
604+
class TestFXMatcherUtils(JitTestCase):
605+
606+
@parametrize("test_model", [
607+
SingleNodePattern,
608+
SimplePattern,
609+
SimpleFullGraphMatching,
610+
DiamondShapePatternTestCase,
611+
NonFullyContainedMatches,
612+
ChainRepeatedPattern,
613+
QuantizationModel,
614+
MultipleOutputsWithDependency,
615+
MultipleOutputsWithoutDependency,
616+
MultipleOutputsMultipleOverlappingMatches,
617+
MultipleOutputsMultipleNonOverlappingMatches,
618+
MultipleOutputsIdenticalAnchor,
619+
MultipleOutputsHorizontalPattern
620+
])
621+
def test_subgraph_matcher(self, test_model):
622+
traced = symbolic_trace(test_model.forward)
623+
pattern_traced = symbolic_trace(test_model.pattern)
624+
625+
for test_case in test_model.test_cases:
626+
627+
matcher = SubgraphMatcher(pattern_traced.graph,
628+
match_output=test_case.match_output,
629+
match_placeholder=test_case.match_placeholder,
630+
remove_overlapping_matches=test_case.remove_overlapping_matches)
631+
matches = matcher.match(traced.graph)
632+
633+
assert len(matches) == test_case.num_matches
634+
635+
for match in matches:
636+
for node in pattern_traced.graph.nodes:
637+
if not test_case.match_placeholder and node.op == "placeholder":
638+
continue
639+
if not test_case.match_output and node.op == "output":
640+
continue
641+
assert node in match.nodes_map
642+
274643

275644
if __name__ == "__main__":
276645
run_tests()

0 commit comments

Comments
 (0)