Skip to content

Commit 8872467

Browse files
wip - pre-rebase
1 parent ca0e895 commit 8872467

25 files changed

+185
-734
lines changed

tests/test_models.py

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
import os
55
import fnmatch
66

7+
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names, NodePathTracer
8+
79
import timm
810
from timm import list_models, create_model, set_scriptable, has_model_default_key, is_model_default_key, \
911
get_model_default_value
10-
from timm.models.fx_features import NodePathTracer
12+
from timm.models.fx_features import _leaf_modules, _autowrap_functions
1113

1214
if hasattr(torch._C, '_jit_set_profiling_executor'):
1315
# legacy executor is too slow to compile large models for unit tests
@@ -312,12 +314,14 @@ def test_model_forward_fx(model_name, batch_size):
312314
if max(input_size) > MAX_FWD_SIZE:
313315
pytest.skip("Fixed input size model > limit.")
314316

315-
tracer = NodePathTracer()
316-
graph = tracer.trace(model)
317-
model = torch.fx.GraphModule(model, graph)
317+
train_nodes, eval_nodes = get_graph_node_names(
318+
model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
319+
model = create_feature_extractor(
320+
model, train_return_nodes=[train_nodes[-1]], eval_return_nodes=[eval_nodes[-1]],
321+
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
318322

319323
inputs = torch.randn((batch_size, *input_size))
320-
outputs = model(inputs)
324+
outputs = model(inputs)[eval_nodes[-1]]
321325

322326
assert outputs.shape[0] == batch_size
323327
assert not torch.isnan(outputs).any(), 'Output included NaNs'
@@ -336,12 +340,30 @@ def test_model_backward_fx(model_name, batch_size):
336340
model.train()
337341
num_params = sum([x.numel() for x in model.parameters()])
338342

339-
tracer = NodePathTracer()
343+
input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE)
344+
if max(input_size) > MAX_FWD_SIZE:
345+
pytest.skip("Fixed input size model > limit.")
346+
347+
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
348+
# If so, we need to return all of them in order to check all grads
349+
# So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output
350+
# node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names
351+
tracer = NodePathTracer(leaf_modules=list(_leaf_modules), autowrap_functions=list(_autowrap_functions))
340352
graph = tracer.trace(model)
341-
model = torch.fx.GraphModule(model, graph)
353+
graph_nodes = list(reversed(graph.nodes))
354+
output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()]
355+
graph_node_names = [n.name for n in graph_nodes]
356+
output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names]
357+
train_nodes, eval_nodes = get_graph_node_names(
358+
model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
359+
train_return_nodes = [train_nodes[ix] for ix in output_node_indices]
360+
361+
model = create_feature_extractor(
362+
model, train_return_nodes=train_return_nodes, eval_return_nodes=[eval_nodes[-1]],
363+
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
342364

343365
inputs = torch.randn((batch_size, *input_size))
344-
outputs = model(inputs)
366+
outputs = tuple(model(inputs).values())
345367
if isinstance(outputs, tuple):
346368
outputs = torch.cat(outputs)
347369
outputs.mean().backward()
@@ -354,9 +376,14 @@ def test_model_backward_fx(model_name, batch_size):
354376
assert not torch.isnan(outputs).any(), 'Output included NaNs'
355377

356378

379+
EXCLUDE_FX_JIT_FILTERS = [
380+
'beit_*' # reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow
381+
]
382+
357383
@pytest.mark.timeout(120)
358384
@pytest.mark.parametrize(
359-
'model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS, name_matches_cfg=True))
385+
'model_name', list_models(
386+
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS + EXCLUDE_FX_JIT_FILTERS, name_matches_cfg=True))
360387
@pytest.mark.parametrize('batch_size', [1])
361388
def test_model_forward_fx_torchscript(model_name, batch_size):
362389
"""Symbolically trace each model, script it, and run single forward pass"""
@@ -368,12 +395,18 @@ def test_model_forward_fx_torchscript(model_name, batch_size):
368395
model = create_model(model_name, pretrained=False)
369396
model.eval()
370397

371-
tracer = NodePathTracer()
372-
graph = tracer.trace(model)
373-
model = torch.fx.GraphModule(model, graph)
398+
input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE)
399+
if max(input_size) > MAX_FWD_SIZE:
400+
pytest.skip("Fixed input size model > limit.")
401+
402+
train_nodes, eval_nodes = get_graph_node_names(
403+
model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
404+
model = create_feature_extractor(
405+
model, train_return_nodes=[train_nodes[-1]], eval_return_nodes=[eval_nodes[-1]],
406+
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
374407

375408
model = torch.jit.script(model)
376-
outputs = model(torch.randn((batch_size, *input_size)))
409+
outputs = model(torch.randn((batch_size, *input_size)))[train_nodes[-1]]
377410

378411
assert outputs.shape[0] == batch_size
379-
assert not torch.isnan(outputs).any(), 'Output included NaNs'
412+
assert not torch.isnan(outputs).any(), 'Output included NaNs'

timm/models/cait.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,11 @@ def forward(self, x):
9595
q = q * self.scale
9696
v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
9797

98-
attn = torch.matmul(q, k.transpose(-2, -1))
98+
attn = (q @ k.transpose(-2, -1))
9999
attn = attn.softmax(dim=-1)
100100
attn = self.attn_drop(attn)
101101

102-
x_cls = torch.matmul(attn, v).transpose(1, 2).reshape(B, 1, C)
102+
x_cls = (attn @ v).transpose(1, 2).reshape(B, 1, C)
103103
x_cls = self.proj(x_cls)
104104
x_cls = self.proj_drop(x_cls)
105105

@@ -158,7 +158,7 @@ def forward(self, x):
158158
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
159159
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
160160

161-
attn = torch.matmul(q, k.transpose(-2, -1))
161+
attn = (q @ k.transpose(-2, -1))
162162

163163
attn = self.proj_l(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
164164

@@ -167,7 +167,7 @@ def forward(self, x):
167167
attn = self.proj_w(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
168168
attn = self.attn_drop(attn)
169169

170-
x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C)
170+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
171171
x = self.proj(x)
172172
x = self.proj_drop(x)
173173
return x

timm/models/coat.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .helpers import build_model_with_cfg, overlay_external_default_cfg
2020
from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_
2121
from .registry import register_model
22+
from .layers.trace_utils import _assert
2223

2324

2425
__all__ = [
@@ -105,7 +106,7 @@ def __init__(self, Ch, h, window):
105106
def forward(self, q, v, size: Tuple[int, int]):
106107
B, h, N, Ch = q.shape
107108
H, W = size
108-
torch._assert(N == 1 + H * W, '')
109+
_assert(N == 1 + H * W, '')
109110

110111
# Convolutional relative position encoding.
111112
q_img = q[:, :, 1:, :] # [B, h, H*W, Ch]
@@ -149,8 +150,8 @@ def forward(self, x, size: Tuple[int, int]):
149150

150151
# Factorized attention.
151152
k_softmax = k.softmax(dim=2)
152-
factor_att = torch.matmul(k_softmax.transpose(-1, -2), v)
153-
factor_att = torch.matmul(q, factor_att)
153+
factor_att = k_softmax.transpose(-1, -2) @ v
154+
factor_att = q @ factor_att
154155

155156
# Convolutional relative position encoding.
156157
crpe = self.crpe(q, v, size=size) # [B, h, N, Ch]
@@ -177,7 +178,7 @@ def __init__(self, dim, k=3):
177178
def forward(self, x, size: Tuple[int, int]):
178179
B, N, C = x.shape
179180
H, W = size
180-
torch._assert(N == 1 + H * W, '')
181+
_assert(N == 1 + H * W, '')
181182

182183
# Extract CLS token and image tokens.
183184
cls_token, img_tokens = x[:, :1], x[:, 1:] # [B, 1, C], [B, H*W, C]
@@ -275,7 +276,7 @@ def interpolate(self, x, scale_factor: float, size: Tuple[int, int]):
275276
""" Feature map interpolation. """
276277
B, N, C = x.shape
277278
H, W = size
278-
torch._assert(N == 1 + H * W, '')
279+
_assert(N == 1 + H * W, '')
279280

280281
cls_token = x[:, :1, :]
281282
img_tokens = x[:, 1:, :]

timm/models/convit.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def _cfg(url='', **kwargs):
5757
}
5858

5959

60-
@register_leaf_module # FX can't symbolically trace control flow in forward method
60+
@register_leaf_module # reason: FX can't symbolically trace control flow in forward method
6161
class GPSA(nn.Module):
6262
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.,
6363
locality_strength=1.):
@@ -84,7 +84,7 @@ def forward(self, x):
8484
self.rel_indices = self.get_rel_indices(N)
8585
attn = self.get_attention(x)
8686
v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
87-
x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C)
87+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
8888
x = self.proj(x)
8989
x = self.proj_drop(x)
9090
return x
@@ -95,7 +95,7 @@ def get_attention(self, x):
9595
q, k = qk[0], qk[1]
9696
pos_score = self.rel_indices.expand(B, -1, -1, -1)
9797
pos_score = self.pos_proj(pos_score).permute(0, 3, 1, 2)
98-
patch_score = torch.matmul(q, k.transpose(-2, -1)) * self.scale
98+
patch_score = (q @ k.transpose(-2, -1)) * self.scale
9999
patch_score = patch_score.softmax(dim=-1)
100100
pos_score = pos_score.softmax(dim=-1)
101101

@@ -180,11 +180,11 @@ def forward(self, x):
180180
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
181181
q, k, v = qkv[0], qkv[1], qkv[2]
182182

183-
attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
183+
attn = (q @ k.transpose(-2, -1)) * self.scale
184184
attn = attn.softmax(dim=-1)
185185
attn = self.attn_drop(attn)
186186

187-
x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C)
187+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
188188
x = self.proj(x)
189189
x = self.proj_drop(x)
190190
return x

timm/models/crossvit.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
Modifed from Timm. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
2323
2424
"""
25+
from typing import Tuple
2526

2627
import torch
2728
import torch.nn as nn
@@ -31,8 +32,9 @@
3132
from typing import List
3233

3334
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
35+
from .fx_features import register_autowrap_function
3436
from .helpers import build_model_with_cfg
35-
from .layers import DropPath, to_2tuple, trunc_normal_
37+
from .layers import DropPath, to_2tuple, trunc_normal_, _assert
3638
from .registry import register_model
3739
from .vision_transformer import Mlp, Block
3840

@@ -116,8 +118,10 @@ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, multi
116118
def forward(self, x):
117119
B, C, H, W = x.shape
118120
# FIXME look at relaxing size constraints
119-
assert H == self.img_size[0] and W == self.img_size[1], \
120-
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
121+
_assert(H == self.img_size[0],
122+
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
123+
_assert(W == self.img_size[1],
124+
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
121125
x = self.proj(x).flatten(2).transpose(1, 2)
122126
return x
123127

@@ -255,6 +259,27 @@ def _compute_num_patches(img_size, patches):
255259
return [i[0] // p * i[1] // p for i, p in zip(img_size, patches)]
256260

257261

262+
@register_autowrap_function
263+
def scale_image(x, ss: Tuple[int, int], crop_scale: bool = False): # annotations for torchscript
264+
"""
265+
Pulled out of CrossViT.forward_features to bury conditional logic in a leaf node for FX tracing.
266+
Args:
267+
x (Tensor): input image
268+
ss (tuple[int, int]): height and width to scale to
269+
crop_scale (bool): whether to crop instead of interpolate to achieve the desired scale. Defaults to False
270+
Returns:
271+
Tensor: the "scaled" image batch tensor
272+
"""
273+
H, W = x.shape[-2:]
274+
if H != ss[0] or W != ss[1]:
275+
if crop_scale and ss[0] <= H and ss[1] <= W:
276+
cu, cl = int(round((H - ss[0]) / 2.)), int(round((W - ss[1]) / 2.))
277+
x = x[:, :, cu:cu + ss[0], cl:cl + ss[1]]
278+
else:
279+
x = torch.nn.functional.interpolate(x, size=ss, mode='bicubic', align_corners=False)
280+
return x
281+
282+
258283
class CrossViT(nn.Module):
259284
""" Vision Transformer with support for patch or hybrid CNN input stage
260285
"""
@@ -342,17 +367,12 @@ def reset_classifier(self, num_classes, global_pool=''):
342367
range(self.num_branches)])
343368

344369
def forward_features(self, x):
345-
B, C, H, W = x.shape
370+
B = x.shape[0]
346371
xs = []
347372
for i, patch_embed in enumerate(self.patch_embed):
348373
x_ = x
349374
ss = self.img_size_scaled[i]
350-
if H != ss[0] or W != ss[1]:
351-
if self.crop_scale and ss[0] <= H and ss[1] <= W:
352-
cu, cl = int(round((H - ss[0]) / 2.)), int(round((W - ss[1]) / 2.))
353-
x_ = x_[:, :, cu:cu + ss[0], cl:cl + ss[1]]
354-
else:
355-
x_ = torch.nn.functional.interpolate(x_, size=ss, mode='bicubic', align_corners=False)
375+
x_ = scale_image(x_, ss, self.crop_scale)
356376
x_ = patch_embed(x_)
357377
cls_tokens = self.cls_token_0 if i == 0 else self.cls_token_1 # hard-coded for torch jit script
358378
cls_tokens = cls_tokens.expand(B, -1, -1)

0 commit comments

Comments
 (0)