-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Description
🐛 Describe the bug
This is an issue I talked to @datumbox
According to #4540, it looks like IntermediateLayerGetter will be replaced with FX-based feature extractor.
My ML OSS, torchdistill, is built on PyTorch / torchvision and heavily dependent on forward hook in PyTorch for knowledge distillation without modifying a model implementation to extract its intermediate feature.
However, the FX-based feature extractor disables some nn.Module features of its covered modules such as forward hook.
Here is a minimal example to demonstrate the issue.
import torch
from torch import nn
from torchvision import models
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.models._utils import IntermediateLayerGetter
def printnorm(self, input, output):
# input is a tuple of packed inputs
# output is a Tensor. output.data is the Tensor we are interested
print('Inside ' + self.__class__.__name__ + ' forward')
print('')
print('input: ', type(input))
print('input[0]: ', type(input[0]))
print('output: ', type(output))
print('')
print('input size:', input[0].size())
print('output size:', output.data.size())
print('output norm:', output.data.norm())
resnet50 = models.resnet50()
return_layers = {'layer4': 'out'}
backbone_fx = create_feature_extractor(resnet50, return_layers)
backbone_org = IntermediateLayerGetter(resnet50, return_layers)
backbone_fx.layer1.register_forward_hook(printnorm)
backbone_org.layer1.register_forward_hook(printnorm)
x = torch.rand(2, 3, 224, 224)
# FX-based feature extractor doesn't call forward hook
z_fx = backbone_fx(x)
# IntermediateLayerGetter calls forward hook and print messages
z_org = backbone_org(x)
'''
Inside Sequential forward
input: <class 'tuple'>
input[0]: <class 'torch.Tensor'>
output: <class 'torch.Tensor'>
input size: torch.Size([2, 64, 56, 56])
output size: torch.Size([2, 256, 56, 56])
output norm: tensor(1973.4232)
'''
Could you please enable the features for those wrapped by FX-based feature extractor?
Thank you!
Versions
Collecting environment information...
PyTorch version: 1.10.0+cu102
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.3 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.31
Python version: 3.8.10 (default, Nov 26 2021, 20:14:08) [GCC 9.3.0] (64-bit runtime)
Python platform: Linux-5.4.0-94-generic-x86_64-with-glibc2.29
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Versions of relevant libraries:
[pip3] numpy==1.19.2
[pip3] pytorch-msssim==0.2.1
[pip3] torch==1.10.0
[pip3] torchdistill==0.1.4
[pip3] torchvision==0.11.1
[conda] Could not collect