4
4
import os
5
5
import fnmatch
6
6
7
+ from torchvision .models .feature_extraction import create_feature_extractor , get_graph_node_names , NodePathTracer
8
+
7
9
import timm
8
10
from timm import list_models , create_model , set_scriptable , has_model_default_key , is_model_default_key , \
9
11
get_model_default_value
10
- from timm .models .fx_features import NodePathTracer
12
+ from timm .models .fx_features import _leaf_modules , _autowrap_functions
11
13
12
14
if hasattr (torch ._C , '_jit_set_profiling_executor' ):
13
15
# legacy executor is too slow to compile large models for unit tests
@@ -312,12 +314,14 @@ def test_model_forward_fx(model_name, batch_size):
312
314
if max (input_size ) > MAX_FWD_SIZE :
313
315
pytest .skip ("Fixed input size model > limit." )
314
316
315
- tracer = NodePathTracer ()
316
- graph = tracer .trace (model )
317
- model = torch .fx .GraphModule (model , graph )
317
+ train_nodes , eval_nodes = get_graph_node_names (
318
+ model , tracer_kwargs = {'leaf_modules' : list (_leaf_modules ), 'autowrap_functions' : list (_autowrap_functions )})
319
+ model = create_feature_extractor (
320
+ model , train_return_nodes = [train_nodes [- 1 ]], eval_return_nodes = [eval_nodes [- 1 ]],
321
+ tracer_kwargs = {'leaf_modules' : list (_leaf_modules ), 'autowrap_functions' : list (_autowrap_functions )})
318
322
319
323
inputs = torch .randn ((batch_size , * input_size ))
320
- outputs = model (inputs )
324
+ outputs = model (inputs )[ eval_nodes [ - 1 ]]
321
325
322
326
assert outputs .shape [0 ] == batch_size
323
327
assert not torch .isnan (outputs ).any (), 'Output included NaNs'
@@ -336,12 +340,30 @@ def test_model_backward_fx(model_name, batch_size):
336
340
model .train ()
337
341
num_params = sum ([x .numel () for x in model .parameters ()])
338
342
339
- tracer = NodePathTracer ()
343
+ input_size = _get_input_size (model = model , target = TARGET_FWD_SIZE )
344
+ if max (input_size ) > MAX_FWD_SIZE :
345
+ pytest .skip ("Fixed input size model > limit." )
346
+
347
+ # This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
348
+ # If so, we need to return all of them in order to check all grads
349
+ # So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output
350
+ # node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names
351
+ tracer = NodePathTracer (leaf_modules = list (_leaf_modules ), autowrap_functions = list (_autowrap_functions ))
340
352
graph = tracer .trace (model )
341
- model = torch .fx .GraphModule (model , graph )
353
+ graph_nodes = list (reversed (graph .nodes ))
354
+ output_node_names = [n .name for n in graph_nodes [0 ]._input_nodes .keys ()]
355
+ graph_node_names = [n .name for n in graph_nodes ]
356
+ output_node_indices = [- graph_node_names .index (node_name ) for node_name in output_node_names ]
357
+ train_nodes , eval_nodes = get_graph_node_names (
358
+ model , tracer_kwargs = {'leaf_modules' : list (_leaf_modules ), 'autowrap_functions' : list (_autowrap_functions )})
359
+ train_return_nodes = [train_nodes [ix ] for ix in output_node_indices ]
360
+
361
+ model = create_feature_extractor (
362
+ model , train_return_nodes = train_return_nodes , eval_return_nodes = [eval_nodes [- 1 ]],
363
+ tracer_kwargs = {'leaf_modules' : list (_leaf_modules ), 'autowrap_functions' : list (_autowrap_functions )})
342
364
343
365
inputs = torch .randn ((batch_size , * input_size ))
344
- outputs = model (inputs )
366
+ outputs = tuple ( model (inputs ). values () )
345
367
if isinstance (outputs , tuple ):
346
368
outputs = torch .cat (outputs )
347
369
outputs .mean ().backward ()
@@ -354,9 +376,14 @@ def test_model_backward_fx(model_name, batch_size):
354
376
assert not torch .isnan (outputs ).any (), 'Output included NaNs'
355
377
356
378
379
+ EXCLUDE_FX_JIT_FILTERS = [
380
+ 'beit_*' # reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow
381
+ ]
382
+
357
383
@pytest .mark .timeout (120 )
358
384
@pytest .mark .parametrize (
359
- 'model_name' , list_models (exclude_filters = EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS , name_matches_cfg = True ))
385
+ 'model_name' , list_models (
386
+ exclude_filters = EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS + EXCLUDE_FX_JIT_FILTERS , name_matches_cfg = True ))
360
387
@pytest .mark .parametrize ('batch_size' , [1 ])
361
388
def test_model_forward_fx_torchscript (model_name , batch_size ):
362
389
"""Symbolically trace each model, script it, and run single forward pass"""
@@ -368,12 +395,18 @@ def test_model_forward_fx_torchscript(model_name, batch_size):
368
395
model = create_model (model_name , pretrained = False )
369
396
model .eval ()
370
397
371
- tracer = NodePathTracer ()
372
- graph = tracer .trace (model )
373
- model = torch .fx .GraphModule (model , graph )
398
+ input_size = _get_input_size (model = model , target = TARGET_FWD_SIZE )
399
+ if max (input_size ) > MAX_FWD_SIZE :
400
+ pytest .skip ("Fixed input size model > limit." )
401
+
402
+ train_nodes , eval_nodes = get_graph_node_names (
403
+ model , tracer_kwargs = {'leaf_modules' : list (_leaf_modules ), 'autowrap_functions' : list (_autowrap_functions )})
404
+ model = create_feature_extractor (
405
+ model , train_return_nodes = [train_nodes [- 1 ]], eval_return_nodes = [eval_nodes [- 1 ]],
406
+ tracer_kwargs = {'leaf_modules' : list (_leaf_modules ), 'autowrap_functions' : list (_autowrap_functions )})
374
407
375
408
model = torch .jit .script (model )
376
- outputs = model (torch .randn ((batch_size , * input_size )))
409
+ outputs = model (torch .randn ((batch_size , * input_size )))[ train_nodes [ - 1 ]]
377
410
378
411
assert outputs .shape [0 ] == batch_size
379
- assert not torch .isnan (outputs ).any (), 'Output included NaNs'
412
+ assert not torch .isnan (outputs ).any (), 'Output included NaNs'
0 commit comments