Skip to content

Updated video classification ref example with new transforms #2935

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Nov 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions references/video_classification/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Video Classification

TODO: Add some info about the context, dataset we use etc
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bjuncek can you enhance this with information that you used to train on the datasets? We could potentially refer to submitit


## Data preparation

If you already have downloaded [Kinetics400 dataset](https://deepmind.com/research/open-source/kinetics),
please proceed directly to the next section.

To download videos, one can use https://github.com/Showmax/kinetics-downloader

## Training

We assume the training and validation AVI videos are stored at `/data/kinectics400/train` and
`/data/kinectics400/val`.

### Multiple GPUs

Run the training on a single node with 8 GPUs:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe @bjuncek trained those models on 64 GPUs, it might be good to add a mention here in a follow-up PR

```bash
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py --data-path=/data/kinectics400 --train-dir=train --val-dir=val --batch-size=16 --cache-dataset --sync-bn --apex
```



### Single GPU

**Note:** training on a single gpu can be extremely slow.


```bash
python train.py --data-path=/data/kinectics400 --train-dir=train --val-dir=val --batch-size=8 --cache-dataset
```


24 changes: 15 additions & 9 deletions references/video_classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
from torch import nn
import torchvision
import torchvision.datasets.video_utils
from torchvision import transforms
from torchvision import transforms as T
from torchvision.datasets.samplers import DistributedSampler, UniformClipSampler, RandomClipSampler

import utils

from scheduler import WarmupMultiStepLR
import transforms as T
from transforms import ConvertBHWCtoBCHW, ConvertBCHWtoCBHW

try:
from apex import amp
Expand Down Expand Up @@ -119,11 +119,13 @@ def main(args):
st = time.time()
cache_path = _get_cache_path(traindir)
transform_train = torchvision.transforms.Compose([
T.ToFloatTensorInZeroOne(),
ConvertBHWCtoBCHW(),
T.ConvertImageDtype(torch.float32),
T.Resize((128, 171)),
T.RandomHorizontalFlip(),
normalize,
T.RandomCrop((112, 112))
T.RandomCrop((112, 112)),
ConvertBCHWtoCBHW()
])

if args.cache_dataset and os.path.exists(cache_path):
Expand All @@ -139,7 +141,8 @@ def main(args):
frames_per_clip=args.clip_len,
step_between_clips=1,
transform=transform_train,
frame_rate=15
frame_rate=15,
extensions=('avi', 'mp4', )
)
if args.cache_dataset:
print("Saving dataset_train to {}".format(cache_path))
Expand All @@ -152,10 +155,12 @@ def main(args):
cache_path = _get_cache_path(valdir)

transform_test = torchvision.transforms.Compose([
T.ToFloatTensorInZeroOne(),
ConvertBHWCtoBCHW(),
T.ConvertImageDtype(torch.float32),
T.Resize((128, 171)),
normalize,
T.CenterCrop((112, 112))
T.CenterCrop((112, 112)),
ConvertBCHWtoCBHW()
])

if args.cache_dataset and os.path.exists(cache_path):
Expand All @@ -171,7 +176,8 @@ def main(args):
frames_per_clip=args.clip_len,
step_between_clips=1,
transform=transform_test,
frame_rate=15
frame_rate=15,
extensions=('avi', 'mp4',)
)
if args.cache_dataset:
print("Saving dataset_test to {}".format(cache_path))
Expand Down Expand Up @@ -265,7 +271,7 @@ def main(args):

def parse_args():
import argparse
parser = argparse.ArgumentParser(description='PyTorch Classification Training')
parser = argparse.ArgumentParser(description='PyTorch Video Classification Training')

parser.add_argument('--data-path', default='/datasets01_101/kinetics/070618/', help='dataset')
parser.add_argument('--train-dir', default='train_avi-480p', help='name of train dir')
Expand Down
126 changes: 11 additions & 115 deletions references/video_classification/transforms.py
Original file line number Diff line number Diff line change
@@ -1,122 +1,18 @@
import torch
import random
import torch.nn as nn


def crop(vid, i, j, h, w):
return vid[..., i:(i + h), j:(j + w)]
class ConvertBHWCtoBCHW(nn.Module):
"""Convert tensor from (B, H, W, C) to (B, C, H, W)
"""

def forward(self, vid: torch.Tensor) -> torch.Tensor:
return vid.permute(0, 3, 1, 2)

def center_crop(vid, output_size):
h, w = vid.shape[-2:]
th, tw = output_size

i = int(round((h - th) / 2.))
j = int(round((w - tw) / 2.))
return crop(vid, i, j, th, tw)
class ConvertBCHWtoCBHW(nn.Module):
"""Convert tensor from (B, C, H, W) to (C, B, H, W)
"""


def hflip(vid):
return vid.flip(dims=(-1,))


# NOTE: for those functions, which generally expect mini-batches, we keep them
# as non-minibatch so that they are applied as if they were 4d (thus image).
# this way, we only apply the transformation in the spatial domain
def resize(vid, size, interpolation='bilinear'):
# NOTE: using bilinear interpolation because we don't work on minibatches
# at this level
scale = None
if isinstance(size, int):
scale = float(size) / min(vid.shape[-2:])
size = None
return torch.nn.functional.interpolate(
vid, size=size, scale_factor=scale, mode=interpolation, align_corners=False)


def pad(vid, padding, fill=0, padding_mode="constant"):
# NOTE: don't want to pad on temporal dimension, so let as non-batch
# (4d) before padding. This works as expected
return torch.nn.functional.pad(vid, padding, value=fill, mode=padding_mode)


def to_normalized_float_tensor(vid):
return vid.permute(3, 0, 1, 2).to(torch.float32) / 255


def normalize(vid, mean, std):
shape = (-1,) + (1,) * (vid.dim() - 1)
mean = torch.as_tensor(mean).reshape(shape)
std = torch.as_tensor(std).reshape(shape)
return (vid - mean) / std


# Class interface

class RandomCrop(object):
def __init__(self, size):
self.size = size

@staticmethod
def get_params(vid, output_size):
"""Get parameters for ``crop`` for a random crop.
"""
h, w = vid.shape[-2:]
th, tw = output_size
if w == tw and h == th:
return 0, 0, h, w
i = random.randint(0, h - th)
j = random.randint(0, w - tw)
return i, j, th, tw

def __call__(self, vid):
i, j, h, w = self.get_params(vid, self.size)
return crop(vid, i, j, h, w)


class CenterCrop(object):
def __init__(self, size):
self.size = size

def __call__(self, vid):
return center_crop(vid, self.size)


class Resize(object):
def __init__(self, size):
self.size = size

def __call__(self, vid):
return resize(vid, self.size)


class ToFloatTensorInZeroOne(object):
def __call__(self, vid):
return to_normalized_float_tensor(vid)


class Normalize(object):
def __init__(self, mean, std):
self.mean = mean
self.std = std

def __call__(self, vid):
return normalize(vid, self.mean, self.std)


class RandomHorizontalFlip(object):
def __init__(self, p=0.5):
self.p = p

def __call__(self, vid):
if random.random() < self.p:
return hflip(vid)
return vid


class Pad(object):
def __init__(self, padding, fill=0):
self.padding = padding
self.fill = fill

def __call__(self, vid):
return pad(vid, self.padding, self.fill)
def forward(self, vid: torch.Tensor) -> torch.Tensor:
return vid.permute(1, 0, 2, 3)