Skip to content
149 changes: 149 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,16 @@
import os
import fnmatch

try:
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names, NodePathTracer
has_fx_feature_extraction = True
except ImportError:
has_fx_feature_extraction = False

import timm
from timm import list_models, create_model, set_scriptable, has_model_default_key, is_model_default_key, \
get_model_default_value
from timm.models.fx_features import _leaf_modules, _autowrap_functions

if hasattr(torch._C, '_jit_set_profiling_executor'):
# legacy executor is too slow to compile large models for unit tests
Expand Down Expand Up @@ -297,3 +304,145 @@ def test_model_forward_features(model_name, batch_size):
assert e == o.shape[1]
assert o.shape[0] == batch_size
assert not torch.isnan(o).any()


@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward_fx(model_name, batch_size):
"""
Symbolically trace each model and run single forward pass through the resulting GraphModule
Also check that the output of a forward pass through the GraphModule is the same as that from the original Module
"""
if not has_fx_feature_extraction:
pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required")

model = create_model(model_name, pretrained=False)
model.eval()

input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE)
if max(input_size) > MAX_FWD_SIZE:
pytest.skip("Fixed input size model > limit.")

# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
# So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output
# node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names
tracer = NodePathTracer(leaf_modules=list(_leaf_modules), autowrap_functions=list(_autowrap_functions))
graph = tracer.trace(model)
graph_nodes = list(reversed(graph.nodes))
output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()]
graph_node_names = [n.name for n in graph_nodes]
output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names]
train_nodes, eval_nodes = get_graph_node_names(
model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
eval_return_nodes = [eval_nodes[ix] for ix in output_node_indices]

fx_model = create_feature_extractor(
model, train_return_nodes=[train_nodes[-1]], eval_return_nodes=eval_return_nodes,
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})

inputs = torch.randn((batch_size, *input_size))
outputs = model(inputs)
if isinstance(outputs, tuple):
outputs = torch.cat(outputs)
fx_outputs = tuple(fx_model(inputs).values())
if isinstance(fx_outputs, tuple):
fx_outputs = torch.cat(fx_outputs)

assert torch.all(fx_outputs == outputs)
assert outputs.shape[0] == batch_size
assert not torch.isnan(outputs).any(), 'Output included NaNs'


@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS, name_matches_cfg=True))
@pytest.mark.parametrize('batch_size', [2])
def test_model_backward_fx(model_name, batch_size):
"""Symbolically trace each model and run single backward pass through the resulting GraphModule"""
if not has_fx_feature_extraction:
pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required")

input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_SIZE)
if max(input_size) > MAX_BWD_SIZE:
pytest.skip("Fixed input size model > limit.")

model = create_model(model_name, pretrained=False, num_classes=42)
model.train()

num_params = sum([x.numel() for x in model.parameters()])

input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE)
if max(input_size) > MAX_FWD_SIZE:
pytest.skip("Fixed input size model > limit.")

# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
# So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output
# node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names
tracer = NodePathTracer(leaf_modules=list(_leaf_modules), autowrap_functions=list(_autowrap_functions))
graph = tracer.trace(model)
graph_nodes = list(reversed(graph.nodes))
output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()]
graph_node_names = [n.name for n in graph_nodes]
output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names]
train_nodes, eval_nodes = get_graph_node_names(
model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
train_return_nodes = [train_nodes[ix] for ix in output_node_indices]

model = create_feature_extractor(
model, train_return_nodes=train_return_nodes, eval_return_nodes=[eval_nodes[-1]],
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})

inputs = torch.randn((batch_size, *input_size))
outputs = tuple(model(inputs).values())
if isinstance(outputs, tuple):
outputs = torch.cat(outputs)
outputs.mean().backward()
for n, x in model.named_parameters():
assert x.grad is not None, f'No gradient for {n}'
num_grad = sum([x.grad.numel() for x in model.parameters() if x.grad is not None])

assert outputs.shape[-1] == 42
assert num_params == num_grad, 'Some parameters are missing gradients'
assert not torch.isnan(outputs).any(), 'Output included NaNs'

# reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow
EXCLUDE_FX_JIT_FILTERS = [
'beit_*',
'deit_*_distilled_patch16_224',
'levit*',
'pit_*_distilled_224',
]

@pytest.mark.timeout(120)
@pytest.mark.parametrize(
'model_name', list_models(
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS + EXCLUDE_FX_JIT_FILTERS, name_matches_cfg=True))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward_fx_torchscript(model_name, batch_size):
"""Symbolically trace each model, script it, and run single forward pass"""
if not has_fx_feature_extraction:
pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required")

input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE)
if max(input_size) > MAX_JIT_SIZE:
pytest.skip("Fixed input size model > limit.")

with set_scriptable(True):
model = create_model(model_name, pretrained=False)
model.eval()

input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE)
if max(input_size) > MAX_FWD_SIZE:
pytest.skip("Fixed input size model > limit.")

train_nodes, eval_nodes = get_graph_node_names(
model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
model = create_feature_extractor(
model, train_return_nodes=[train_nodes[-1]], eval_return_nodes=[eval_nodes[-1]],
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})

model = torch.jit.script(model)
outputs = model(torch.randn((batch_size, *input_size)))[train_nodes[-1]]

assert outputs.shape[0] == batch_size
assert not torch.isnan(outputs).any(), 'Output included NaNs'
7 changes: 4 additions & 3 deletions timm/models/coat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .helpers import build_model_with_cfg, overlay_external_default_cfg
from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_
from .registry import register_model
from .layers import _assert


__all__ = [
Expand Down Expand Up @@ -105,7 +106,7 @@ def __init__(self, Ch, h, window):
def forward(self, q, v, size: Tuple[int, int]):
B, h, N, Ch = q.shape
H, W = size
assert N == 1 + H * W
_assert(N == 1 + H * W, '')

# Convolutional relative position encoding.
q_img = q[:, :, 1:, :] # [B, h, H*W, Ch]
Expand Down Expand Up @@ -177,7 +178,7 @@ def __init__(self, dim, k=3):
def forward(self, x, size: Tuple[int, int]):
B, N, C = x.shape
H, W = size
assert N == 1 + H * W
_assert(N == 1 + H * W, '')

# Extract CLS token and image tokens.
cls_token, img_tokens = x[:, :1], x[:, 1:] # [B, 1, C], [B, H*W, C]
Expand Down Expand Up @@ -275,7 +276,7 @@ def interpolate(self, x, scale_factor: float, size: Tuple[int, int]):
""" Feature map interpolation. """
B, N, C = x.shape
H, W = size
assert N == 1 + H * W
_assert(N == 1 + H * W, '')

cls_token = x[:, :1, :]
img_tokens = x[:, 1:, :]
Expand Down
2 changes: 2 additions & 0 deletions timm/models/convit.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from .layers import DropPath, to_2tuple, trunc_normal_, PatchEmbed, Mlp
from .registry import register_model
from .vision_transformer_hybrid import HybridEmbed
from .fx_features import register_notrace_module

import torch
import torch.nn as nn
Expand All @@ -56,6 +57,7 @@ def _cfg(url='', **kwargs):
}


@register_notrace_module # reason: FX can't symbolically trace control flow in forward method
class GPSA(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.,
locality_strength=1.):
Expand Down
40 changes: 30 additions & 10 deletions timm/models/crossvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
Modifed from Timm. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py

"""
from typing import Tuple

import torch
import torch.nn as nn
Expand All @@ -31,8 +32,9 @@
from typing import List

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_notrace_function
from .helpers import build_model_with_cfg
from .layers import DropPath, to_2tuple, trunc_normal_
from .layers import DropPath, to_2tuple, trunc_normal_, _assert
from .registry import register_model
from .vision_transformer import Mlp, Block

Expand Down Expand Up @@ -116,8 +118,10 @@ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, multi
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
_assert(H == self.img_size[0],
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
_assert(W == self.img_size[1],
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
x = self.proj(x).flatten(2).transpose(1, 2)
return x

Expand Down Expand Up @@ -255,6 +259,27 @@ def _compute_num_patches(img_size, patches):
return [i[0] // p * i[1] // p for i, p in zip(img_size, patches)]


@register_notrace_function
def scale_image(x, ss: Tuple[int, int], crop_scale: bool = False): # annotations for torchscript
"""
Pulled out of CrossViT.forward_features to bury conditional logic in a leaf node for FX tracing.
Args:
x (Tensor): input image
ss (tuple[int, int]): height and width to scale to
crop_scale (bool): whether to crop instead of interpolate to achieve the desired scale. Defaults to False
Returns:
Tensor: the "scaled" image batch tensor
"""
H, W = x.shape[-2:]
if H != ss[0] or W != ss[1]:
if crop_scale and ss[0] <= H and ss[1] <= W:
cu, cl = int(round((H - ss[0]) / 2.)), int(round((W - ss[1]) / 2.))
x = x[:, :, cu:cu + ss[0], cl:cl + ss[1]]
else:
x = torch.nn.functional.interpolate(x, size=ss, mode='bicubic', align_corners=False)
return x


class CrossViT(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
Expand Down Expand Up @@ -342,17 +367,12 @@ def reset_classifier(self, num_classes, global_pool=''):
range(self.num_branches)])

def forward_features(self, x):
B, C, H, W = x.shape
B = x.shape[0]
xs = []
for i, patch_embed in enumerate(self.patch_embed):
x_ = x
ss = self.img_size_scaled[i]
if H != ss[0] or W != ss[1]:
if self.crop_scale and ss[0] <= H and ss[1] <= W:
cu, cl = int(round((H - ss[0]) / 2.)), int(round((W - ss[1]) / 2.))
x_ = x_[:, :, cu:cu + ss[0], cl:cl + ss[1]]
else:
x_ = torch.nn.functional.interpolate(x_, size=ss, mode='bicubic', align_corners=False)
x_ = scale_image(x_, ss, self.crop_scale)
x_ = patch_embed(x_)
cls_tokens = self.cls_token_0 if i == 0 else self.cls_token_1 # hard-coded for torch jit script
cls_tokens = cls_tokens.expand(B, -1, -1)
Expand Down
74 changes: 74 additions & 0 deletions timm/models/fx_features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
""" PyTorch FX Based Feature Extraction Helpers
Using https://pytorch.org/vision/stable/feature_extraction.html
"""
from typing import Callable
from torch import nn

from .features import _get_feature_info

try:
from torchvision.models.feature_extraction import create_feature_extractor
has_fx_feature_extraction = True
except ImportError:
has_fx_feature_extraction = False

# Layers we went to treat as leaf modules
from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame, DropPath
from .layers.non_local_attn import BilinearAttnTransform
from .layers.pool2d_same import MaxPool2dSame, AvgPool2dSame

# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here
# BUT modules from timm.models should use the registration mechanism below
_leaf_modules = {
BatchNormAct2d, # reason: flow control for jit scripting
BilinearAttnTransform, # reason: flow control t <= 1
BlurPool2d, # reason: TypeError: F.conv2d received Proxy in groups=x.shape[1]
# Reason: get_same_padding has a max which raises a control flow error
Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame,
CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0])
DropPath, # reason: TypeError: rand recieved Proxy in `size` argument
}

try:
from .layers import InplaceAbn
_leaf_modules.add(InplaceAbn)
except ImportError:
pass


def register_notrace_module(module: nn.Module):
"""
Any module not under timm.models.layers should get this decorator if we don't want to trace through it.
"""
_leaf_modules.add(module)
return module


# Functions we want to autowrap (treat them as leaves)
_autowrap_functions = set()


def register_notrace_function(func: Callable):
"""
Decorator for functions which ought not to be traced through
"""
_autowrap_functions.add(func)
return func


class FeatureGraphNet(nn.Module):
def __init__(self, model, out_indices, out_map=None):
super().__init__()
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
self.feature_info = _get_feature_info(model, out_indices)
if out_map is not None:
assert len(out_map) == len(out_indices)
return_nodes = {info['module']: out_map[i] if out_map is not None else info['module']
for i, info in enumerate(self.feature_info) if i in out_indices}
self.graph_module = create_feature_extractor(
model, return_nodes,
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})

def forward(self, x):
return list(self.graph_module(x).values())

3 changes: 3 additions & 0 deletions timm/models/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
from .fx_features import FeatureGraphNet
from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf, load_state_dict_from_url
from .layers import Conv2dSame, Linear

Expand Down Expand Up @@ -477,6 +478,8 @@ def build_model_with_cfg(
feature_cls = feature_cls.lower()
if 'hook' in feature_cls:
feature_cls = FeatureHookNet
elif feature_cls == 'fx':
feature_cls = FeatureGraphNet
else:
assert False, f'Unknown feature class {feature_cls}'
model = feature_cls(model, **feature_cfg)
Expand Down
Loading