diff --git a/references/video_classification/README.md b/references/video_classification/README.md new file mode 100644 index 00000000000..525cfddd414 --- /dev/null +++ b/references/video_classification/README.md @@ -0,0 +1,35 @@ +# Video Classification + +TODO: Add some info about the context, dataset we use etc + +## Data preparation + +If you already have downloaded [Kinetics400 dataset](https://deepmind.com/research/open-source/kinetics), +please proceed directly to the next section. + +To download videos, one can use https://github.com/Showmax/kinetics-downloader + +## Training + +We assume the training and validation AVI videos are stored at `/data/kinectics400/train` and +`/data/kinectics400/val`. + +### Multiple GPUs + +Run the training on a single node with 8 GPUs: +```bash +python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py --data-path=/data/kinectics400 --train-dir=train --val-dir=val --batch-size=16 --cache-dataset --sync-bn --apex +``` + + + +### Single GPU + +**Note:** training on a single gpu can be extremely slow. + + +```bash +python train.py --data-path=/data/kinectics400 --train-dir=train --val-dir=val --batch-size=8 --cache-dataset +``` + + diff --git a/references/video_classification/train.py b/references/video_classification/train.py index e71c03f174f..3b5d8d8d206 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -7,13 +7,13 @@ from torch import nn import torchvision import torchvision.datasets.video_utils -from torchvision import transforms +from torchvision import transforms as T from torchvision.datasets.samplers import DistributedSampler, UniformClipSampler, RandomClipSampler import utils from scheduler import WarmupMultiStepLR -import transforms as T +from transforms import ConvertBHWCtoBCHW, ConvertBCHWtoCBHW try: from apex import amp @@ -119,11 +119,13 @@ def main(args): st = time.time() cache_path = _get_cache_path(traindir) transform_train = torchvision.transforms.Compose([ - T.ToFloatTensorInZeroOne(), + ConvertBHWCtoBCHW(), + T.ConvertImageDtype(torch.float32), T.Resize((128, 171)), T.RandomHorizontalFlip(), normalize, - T.RandomCrop((112, 112)) + T.RandomCrop((112, 112)), + ConvertBCHWtoCBHW() ]) if args.cache_dataset and os.path.exists(cache_path): @@ -139,7 +141,8 @@ def main(args): frames_per_clip=args.clip_len, step_between_clips=1, transform=transform_train, - frame_rate=15 + frame_rate=15, + extensions=('avi', 'mp4', ) ) if args.cache_dataset: print("Saving dataset_train to {}".format(cache_path)) @@ -152,10 +155,12 @@ def main(args): cache_path = _get_cache_path(valdir) transform_test = torchvision.transforms.Compose([ - T.ToFloatTensorInZeroOne(), + ConvertBHWCtoBCHW(), + T.ConvertImageDtype(torch.float32), T.Resize((128, 171)), normalize, - T.CenterCrop((112, 112)) + T.CenterCrop((112, 112)), + ConvertBCHWtoCBHW() ]) if args.cache_dataset and os.path.exists(cache_path): @@ -171,7 +176,8 @@ def main(args): frames_per_clip=args.clip_len, step_between_clips=1, transform=transform_test, - frame_rate=15 + frame_rate=15, + extensions=('avi', 'mp4',) ) if args.cache_dataset: print("Saving dataset_test to {}".format(cache_path)) @@ -265,7 +271,7 @@ def main(args): def parse_args(): import argparse - parser = argparse.ArgumentParser(description='PyTorch Classification Training') + parser = argparse.ArgumentParser(description='PyTorch Video Classification Training') parser.add_argument('--data-path', default='/datasets01_101/kinetics/070618/', help='dataset') parser.add_argument('--train-dir', default='train_avi-480p', help='name of train dir') diff --git a/references/video_classification/transforms.py b/references/video_classification/transforms.py index 9435450c4b3..27f6c75450a 100644 --- a/references/video_classification/transforms.py +++ b/references/video_classification/transforms.py @@ -1,122 +1,18 @@ import torch -import random +import torch.nn as nn -def crop(vid, i, j, h, w): - return vid[..., i:(i + h), j:(j + w)] +class ConvertBHWCtoBCHW(nn.Module): + """Convert tensor from (B, H, W, C) to (B, C, H, W) + """ + def forward(self, vid: torch.Tensor) -> torch.Tensor: + return vid.permute(0, 3, 1, 2) -def center_crop(vid, output_size): - h, w = vid.shape[-2:] - th, tw = output_size - i = int(round((h - th) / 2.)) - j = int(round((w - tw) / 2.)) - return crop(vid, i, j, th, tw) +class ConvertBCHWtoCBHW(nn.Module): + """Convert tensor from (B, C, H, W) to (C, B, H, W) + """ - -def hflip(vid): - return vid.flip(dims=(-1,)) - - -# NOTE: for those functions, which generally expect mini-batches, we keep them -# as non-minibatch so that they are applied as if they were 4d (thus image). -# this way, we only apply the transformation in the spatial domain -def resize(vid, size, interpolation='bilinear'): - # NOTE: using bilinear interpolation because we don't work on minibatches - # at this level - scale = None - if isinstance(size, int): - scale = float(size) / min(vid.shape[-2:]) - size = None - return torch.nn.functional.interpolate( - vid, size=size, scale_factor=scale, mode=interpolation, align_corners=False) - - -def pad(vid, padding, fill=0, padding_mode="constant"): - # NOTE: don't want to pad on temporal dimension, so let as non-batch - # (4d) before padding. This works as expected - return torch.nn.functional.pad(vid, padding, value=fill, mode=padding_mode) - - -def to_normalized_float_tensor(vid): - return vid.permute(3, 0, 1, 2).to(torch.float32) / 255 - - -def normalize(vid, mean, std): - shape = (-1,) + (1,) * (vid.dim() - 1) - mean = torch.as_tensor(mean).reshape(shape) - std = torch.as_tensor(std).reshape(shape) - return (vid - mean) / std - - -# Class interface - -class RandomCrop(object): - def __init__(self, size): - self.size = size - - @staticmethod - def get_params(vid, output_size): - """Get parameters for ``crop`` for a random crop. - """ - h, w = vid.shape[-2:] - th, tw = output_size - if w == tw and h == th: - return 0, 0, h, w - i = random.randint(0, h - th) - j = random.randint(0, w - tw) - return i, j, th, tw - - def __call__(self, vid): - i, j, h, w = self.get_params(vid, self.size) - return crop(vid, i, j, h, w) - - -class CenterCrop(object): - def __init__(self, size): - self.size = size - - def __call__(self, vid): - return center_crop(vid, self.size) - - -class Resize(object): - def __init__(self, size): - self.size = size - - def __call__(self, vid): - return resize(vid, self.size) - - -class ToFloatTensorInZeroOne(object): - def __call__(self, vid): - return to_normalized_float_tensor(vid) - - -class Normalize(object): - def __init__(self, mean, std): - self.mean = mean - self.std = std - - def __call__(self, vid): - return normalize(vid, self.mean, self.std) - - -class RandomHorizontalFlip(object): - def __init__(self, p=0.5): - self.p = p - - def __call__(self, vid): - if random.random() < self.p: - return hflip(vid) - return vid - - -class Pad(object): - def __init__(self, padding, fill=0): - self.padding = padding - self.fill = fill - - def __call__(self, vid): - return pad(vid, self.padding, self.fill) + def forward(self, vid: torch.Tensor) -> torch.Tensor: + return vid.permute(1, 0, 2, 3)