3
3
import time
4
4
import warnings
5
5
6
+ import datasets
6
7
import presets
7
8
import torch
8
9
import torch .utils .data
11
12
import utils
12
13
from torch import nn
13
14
from torch .utils .data .dataloader import default_collate
14
- from torchvision .datasets .samplers import DistributedSampler , UniformClipSampler , RandomClipSampler
15
+ from torchvision .datasets .samplers import DistributedSampler , RandomClipSampler , UniformClipSampler
15
16
16
17
17
18
def train_one_epoch (model , criterion , optimizer , lr_scheduler , data_loader , device , epoch , print_freq , scaler = None ):
@@ -21,7 +22,7 @@ def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, devi
21
22
metric_logger .add_meter ("clips/s" , utils .SmoothedValue (window_size = 10 , fmt = "{value:.3f}" ))
22
23
23
24
header = f"Epoch: [{ epoch } ]"
24
- for video , target in metric_logger .log_every (data_loader , print_freq , header ):
25
+ for video , target , _ in metric_logger .log_every (data_loader , print_freq , header ):
25
26
start_time = time .time ()
26
27
video , target = video .to (device ), target .to (device )
27
28
with torch .cuda .amp .autocast (enabled = scaler is not None ):
@@ -52,13 +53,25 @@ def evaluate(model, criterion, data_loader, device):
52
53
metric_logger = utils .MetricLogger (delimiter = " " )
53
54
header = "Test:"
54
55
num_processed_samples = 0
56
+ # Group and aggregate output of a video
57
+ num_videos = len (data_loader .dataset .samples )
58
+ num_classes = len (data_loader .dataset .classes )
59
+ agg_preds = torch .zeros ((num_videos , num_classes ), dtype = torch .float32 , device = device )
60
+ agg_targets = torch .zeros ((num_videos ), dtype = torch .int32 , device = device )
55
61
with torch .inference_mode ():
56
- for video , target in metric_logger .log_every (data_loader , 100 , header ):
62
+ for video , target , video_idx in metric_logger .log_every (data_loader , 100 , header ):
57
63
video = video .to (device , non_blocking = True )
58
64
target = target .to (device , non_blocking = True )
59
65
output = model (video )
60
66
loss = criterion (output , target )
61
67
68
+ # Use softmax to convert output into prediction probability
69
+ preds = torch .softmax (output , dim = 1 )
70
+ for b in range (video .size (0 )):
71
+ idx = video_idx [b ].item ()
72
+ agg_preds [idx ] += preds [b ].detach ()
73
+ agg_targets [idx ] = target [b ].detach ().item ()
74
+
62
75
acc1 , acc5 = utils .accuracy (output , target , topk = (1 , 5 ))
63
76
# FIXME need to take into account that the datasets
64
77
# could have been padded in distributed setup
@@ -95,6 +108,11 @@ def evaluate(model, criterion, data_loader, device):
95
108
top1 = metric_logger .acc1 , top5 = metric_logger .acc5
96
109
)
97
110
)
111
+ # Reduce the agg_preds and agg_targets from all gpu and show result
112
+ agg_preds = utils .reduce_across_processes (agg_preds )
113
+ agg_targets = utils .reduce_across_processes (agg_targets , op = torch .distributed .ReduceOp .MAX )
114
+ agg_acc1 , agg_acc5 = utils .accuracy (agg_preds , agg_targets , topk = (1 , 5 ))
115
+ print (" * Video Acc@1 {acc1:.3f} Video Acc@5 {acc5:.3f}" .format (acc1 = agg_acc1 , acc5 = agg_acc5 ))
98
116
return metric_logger .acc1 .global_avg
99
117
100
118
@@ -110,7 +128,7 @@ def _get_cache_path(filepath, args):
110
128
111
129
def collate_fn (batch ):
112
130
# remove audio from the batch
113
- batch = [(d [0 ], d [2 ]) for d in batch ]
131
+ batch = [(d [0 ], d [2 ], d [ 3 ] ) for d in batch ]
114
132
return default_collate (batch )
115
133
116
134
@@ -146,7 +164,7 @@ def main(args):
146
164
else :
147
165
if args .distributed :
148
166
print ("It is recommended to pre-compute the dataset cache on a single-gpu first, as it will be faster" )
149
- dataset = torchvision . datasets .Kinetics (
167
+ dataset = datasets .KineticsWithVideoId (
150
168
args .data_path ,
151
169
frames_per_clip = args .clip_len ,
152
170
num_classes = args .kinetics_version ,
@@ -183,7 +201,7 @@ def main(args):
183
201
else :
184
202
if args .distributed :
185
203
print ("It is recommended to pre-compute the dataset cache on a single-gpu first, as it will be faster" )
186
- dataset_test = torchvision . datasets .Kinetics (
204
+ dataset_test = datasets .KineticsWithVideoId (
187
205
args .data_path ,
188
206
frames_per_clip = args .clip_len ,
189
207
num_classes = args .kinetics_version ,
@@ -313,10 +331,10 @@ def main(args):
313
331
print (f"Training time { total_time_str } " )
314
332
315
333
316
- def parse_args ( ):
334
+ def get_args_parser ( add_help = True ):
317
335
import argparse
318
336
319
- parser = argparse .ArgumentParser (description = "PyTorch Video Classification Training" )
337
+ parser = argparse .ArgumentParser (description = "PyTorch Video Classification Training" , add_help = add_help )
320
338
321
339
parser .add_argument ("--data-path" , default = "/datasets01_101/kinetics/070618/" , type = str , help = "dataset path" )
322
340
parser .add_argument (
@@ -387,11 +405,9 @@ def parse_args():
387
405
# Mixed precision training parameters
388
406
parser .add_argument ("--amp" , action = "store_true" , help = "Use torch.cuda.amp for mixed precision training" )
389
407
390
- args = parser .parse_args ()
391
-
392
- return args
408
+ return parser
393
409
394
410
395
411
if __name__ == "__main__" :
396
- args = parse_args ()
412
+ args = get_args_parser (). parse_args ()
397
413
main (args )
0 commit comments