1- from torch_tensorrt .dynamo . partitioning import partition
1+ from torch_tensorrt .dynamo import partitioning
22from torch .testing ._internal .common_utils import run_tests , TestCase
33from utils import lower_graph_testing
44import torch
55from copy import deepcopy
66import numpy as np
77
88
9- class TestPartitioning (TestCase ):
9+ class TestFastPartitioning (TestCase ):
1010 def test_partition_fully_supported_one_op (self ):
1111 class FullySupportedOneOp (torch .nn .Module ):
1212 def __init__ (self , * args , ** kwargs ) -> None :
@@ -16,7 +16,7 @@ def forward(self, x, y):
1616 return torch .ops .aten .add .Tensor (x , y )
1717
1818 fx_graph = torch .fx .symbolic_trace (FullySupportedOneOp ())
19- partitioned_graph = partition (deepcopy (fx_graph ))
19+ partitioned_graph = partitioning . fast_partition (deepcopy (fx_graph ))
2020 self .assertEquals (
2121 len (
2222 [
@@ -42,7 +42,9 @@ def forward(self, x, y):
4242 return pow_
4343
4444 fx_graph = torch .fx .symbolic_trace (FullySupportedMultiOp ())
45- partitioned_graph = partition (deepcopy (fx_graph ), min_block_size = 2 )
45+ partitioned_graph = partitioning .fast_partition (
46+ deepcopy (fx_graph ), min_block_size = 2
47+ )
4648 self .assertEquals (
4749 len (
4850 [
@@ -69,7 +71,9 @@ def forward(self, x, y):
6971 return pow_
7072
7173 fx_graph = torch .fx .symbolic_trace (PartiallySupportedMultiOp ())
72- partitioned_graph = partition (deepcopy (fx_graph ), min_block_size = 2 )
74+ partitioned_graph = partitioning .fast_partition (
75+ deepcopy (fx_graph ), min_block_size = 2
76+ )
7377 self .assertEquals (
7478 len (
7579 [
@@ -118,6 +122,7 @@ def forward(self, x, y):
118122 min_block_size = 2 ,
119123 torch_executed_ops = {"torch.ops.aten.add.Tensor" },
120124 testing_partitioning = True ,
125+ use_fast_partitioner = True ,
121126 )
122127
123128 self .assertEquals (
@@ -144,5 +149,124 @@ def forward(self, x, y):
144149 )
145150
146151
152+ class TestGlobalPartitioning (TestCase ):
153+ def test_partition_fully_supported_one_op (self ):
154+ class FullySupportedOneOp (torch .nn .Module ):
155+ def __init__ (self , * args , ** kwargs ) -> None :
156+ super ().__init__ (* args , ** kwargs )
157+
158+ def forward (self , x , y ):
159+ return torch .ops .aten .add .Tensor (x , y )
160+
161+ fx_graph = torch .fx .symbolic_trace (FullySupportedOneOp ())
162+ partitioned_graph = partitioning .global_partition (deepcopy (fx_graph ))
163+ self .assertEquals (
164+ len (list (partitioned_graph .named_children ())),
165+ 0 ,
166+ "Single operators should not be segmented" ,
167+ )
168+
169+ def test_partition_fully_supported_multi_op (self ):
170+ class FullySupportedMultiOp (torch .nn .Module ):
171+ def __init__ (self , * args , ** kwargs ) -> None :
172+ super ().__init__ (* args , ** kwargs )
173+
174+ def forward (self , x , y ):
175+ sum_ = torch .ops .aten .sub .Tensor (x , y )
176+ concat_ = torch .ops .aten .cat .default (x , sum_ )
177+ relu_ = torch .ops .aten .relu .default (concat_ )
178+ pow_ = torch .ops .aten .pow .Tensor_Scalar (relu_ , 2 )
179+ return pow_
180+
181+ fx_graph = torch .fx .symbolic_trace (FullySupportedMultiOp ())
182+ partitioned_graph = partitioning .global_partition (
183+ deepcopy (fx_graph ), min_block_size = 2
184+ )
185+ self .assertEquals (
186+ len (list (partitioned_graph .named_children ())),
187+ 1 ,
188+ "All operators are supported, there should be one segment" ,
189+ )
190+
191+ def test_partition_partially_supported_multi_op (self ):
192+ class PartiallySupportedMultiOp (torch .nn .Module ):
193+ def __init__ (self , * args , ** kwargs ) -> None :
194+ super ().__init__ (* args , ** kwargs )
195+
196+ def forward (self , x , y ):
197+ sum_1 = torch .ops .aten .add .Tensor (x , y )
198+ sum_2 = torch .ops .aten .add .Tensor (x , sum_1 )
199+ sum_ = np .sum (sum_1 ) + np .sum (sum_2 )
200+ relu_ = torch .ops .aten .relu .default (sum_ )
201+ pow_ = torch .ops .aten .pow .Tensor_Scalar (relu_ , 2 )
202+ return pow_
203+
204+ fx_graph = torch .fx .symbolic_trace (PartiallySupportedMultiOp ())
205+ partitioned_graph = partitioning .global_partition (
206+ deepcopy (fx_graph ), min_block_size = 2
207+ )
208+ self .assertEquals (
209+ len (list (partitioned_graph .named_children ())),
210+ 2 ,
211+ "Unsupported operators interleave supported ones, expected 2 segments" ,
212+ )
213+
214+ def test_partition_partially_supported_with_torch_executed_ops (self ):
215+ class PartiallySupportedMultiOp (torch .nn .Module ):
216+ def __init__ (self , * args , ** kwargs ) -> None :
217+ super ().__init__ (* args , ** kwargs )
218+
219+ def forward (self , x , y ):
220+ sum_1 = torch .ops .aten .add .Tensor (x , y )
221+ sum_2 = torch .ops .aten .add .Tensor (x , sum_1 )
222+ sum_ = torch .ops .aten .add .Tensor (sum_1 , sum_2 )
223+ relu_ = torch .ops .aten .relu .default (sum_ )
224+ pow_ = torch .ops .aten .pow .Tensor_Scalar (relu_ , 2 )
225+ return pow_
226+
227+ unexpected_ops = {torch .ops .aten .add .Tensor }
228+
229+ inputs = [
230+ torch .randint (
231+ 1 ,
232+ 10 ,
233+ (5 ,),
234+ ),
235+ torch .randint (
236+ 1 ,
237+ 10 ,
238+ (5 ,),
239+ ),
240+ ]
241+
242+ fx_graph = torch .fx .symbolic_trace (PartiallySupportedMultiOp ())
243+ (unexpected_ops_seen , _ , partitioned_graphs ,) = lower_graph_testing (
244+ fx_graph ,
245+ inputs ,
246+ unexpected_ops = unexpected_ops ,
247+ min_block_size = 2 ,
248+ torch_executed_ops = {"torch.ops.aten.add.Tensor" },
249+ testing_partitioning = True ,
250+ use_fast_partitioner = False ,
251+ )
252+
253+ self .assertEquals (
254+ len (unexpected_ops_seen ),
255+ 0 ,
256+ f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
257+ )
258+
259+ self .assertEquals (
260+ len (partitioned_graphs ),
261+ 1 ,
262+ "Without control flow breaks, there should only be a single graph" ,
263+ )
264+ self .assertEquals (
265+ len (list (partitioned_graphs [0 ].named_children ())),
266+ 1 ,
267+ "Certain operators are set to run in Torch, expected 1 segment" ,
268+ )
269+
270+
147271if __name__ == "__main__" :
148272 run_tests ()
0 commit comments