3939# Model evaluation
4040# ================
4141
42+ class TrainMetrics (NamedTuple ):
43+ class_labels : List [int ]
44+ acc_balanced : float
45+ f1_micro : float
46+ f1_macro : float
47+ prec_micro : float
48+ prec_macro : float
49+ rec_micro : float
50+ rec_macro : float
51+
52+
4253class ValidationMetrics (NamedTuple ):
4354 class_labels : List [int ]
4455 acc_balanced : float
@@ -63,7 +74,7 @@ class EarlyStopping:
6374 - https://github.com/Bjarten/early-stopping-pytorch
6475 """
6576
66- def __init__ (self , patience = 3 , min_delta = 1 , min_epochs = 50 ):
77+ def __init__ (self , patience : int = 3 , min_delta : float = 1 , min_epochs : int = 50 ):
6778 self .patience = patience
6879 self .min_delta = min_delta
6980 self .counter = 0
@@ -457,24 +468,25 @@ def get_target_class(cl: int) -> str:
457468 except Exception as e :
458469 log (f"Failed to add graph to tensorboard." )
459470
460- early_stopping = EarlyStopping (patience = 5 , min_delta = 5 , min_epochs = 50 )
471+ early_stopping = EarlyStopping (patience = 5 , min_delta = 0. 5 , min_epochs = 50 )
461472 try :
462473 for epoch in range (args .start_epoch , args .epochs ):
463474 if args .distributed :
464475 train_sampler .set_epoch (epoch )
465476
466477 # train for one epoch
467- train_loss = train (train_loader , model , criterion , optimizer , epoch , device , args )
478+ train_acc1 , train_loss , train_metrics = train (train_loader , model , criterion , optimizer , epoch , device ,
479+ args )
468480
469481 # evaluate on validation set
470- acc1 , val_loss , metrics = validate (val_loader , model , criterion , args )
482+ val_acc1 , val_loss , val_metrics = validate (val_loader , model , criterion , args )
471483 scheduler .step ()
472484 early_stopping (val_loss , epoch )
473485
474486 # remember best acc@1 and save checkpoint
475- is_best = acc1 > best_acc1
476- best_acc1 = max (acc1 , best_acc1 )
477- best_metrics = metrics if metrics .f1_micro > best_metrics .f1_micro else best_metrics
487+ is_best = val_acc1 > best_acc1
488+ best_acc1 = max (val_acc1 , best_acc1 )
489+ best_metrics = val_metrics if val_metrics .f1_micro > best_metrics .f1_micro else best_metrics
478490
479491 if not args .multiprocessing_distributed or \
480492 (args .multiprocessing_distributed and args .rank % ngpus_per_node == 0 ) or \
@@ -491,31 +503,49 @@ def get_target_class(cl: int) -> str:
491503 if tensorboard_writer :
492504 tensorboard_writer .add_scalars ('Loss' , dict (train = train_loss , val = val_loss ), epoch + 1 )
493505 tensorboard_writer .add_scalars ('Metrics/Accuracy' ,
494- dict (acc = acc1 / 100.0 , balanced_acc = metrics .acc_balanced ), epoch + 1 )
495- tensorboard_writer .add_scalars ('Metrics/F1' , dict (micro = metrics .f1_micro , macro = metrics .f1_macro ),
506+ dict (val_acc = val_acc1 / 100.0 ,
507+ val_bacc = val_metrics .acc_balanced ,
508+ train_acc = train_acc1 / 100.0 ,
509+ train_bacc = train_metrics .acc_balanced ),
510+ epoch + 1 )
511+ tensorboard_writer .add_scalars ('Metrics/F1' ,
512+ dict (val_micro = val_metrics .f1_micro ,
513+ val_macro = val_metrics .f1_macro ,
514+ train_micro = train_metrics .f1_micro ,
515+ train_macro = train_metrics .f1_macro ),
496516 epoch + 1 )
497517 tensorboard_writer .add_scalars ('Metrics/Precision' ,
498- dict (micro = metrics .prec_micro , macro = metrics .prec_macro ), epoch + 1 )
499- tensorboard_writer .add_scalars ('Metrics/Recall' , dict (micro = metrics .rec_micro , macro = metrics .rec_macro ),
518+ dict (val_micro = val_metrics .prec_micro ,
519+ val_macro = val_metrics .prec_macro ,
520+ train_micro = train_metrics .prec_micro ,
521+ train_macro = train_metrics .prec_macro ),
522+ epoch + 1 )
523+ tensorboard_writer .add_scalars ('Metrics/Recall' ,
524+ dict (val_micro = val_metrics .rec_micro ,
525+ val_macro = val_metrics .rec_macro ,
526+ train_micro = train_metrics .rec_micro ,
527+ train_macro = train_metrics .rec_macro ),
500528 epoch + 1 )
501529 tensorboard_writer .add_scalars ('Metrics/F1/class' ,
502- {get_target_class (cl ): f1 for cl , f1 in metrics .f1_per_class }, epoch + 1 )
530+ {get_target_class (cl ): f1 for cl , f1 in val_metrics .f1_per_class },
531+ epoch + 1 )
503532
504533 if epoch < 10 or epoch % 5 == 0 or epoch == args .epochs - 1 :
505- class_names = [get_target_class (cl ) for cl in list ({l for l in metrics .class_labels })]
506- fig_abs , _ = plot_confusion_matrix (metrics .conf_matrix , class_names = class_names , normalize = False )
507- fig_rel , _ = plot_confusion_matrix (metrics .conf_matrix , class_names = class_names , normalize = True )
534+ class_names = [get_target_class (cl ) for cl in list ({l for l in val_metrics .class_labels })]
535+ fig_abs , _ = plot_confusion_matrix (val_metrics .conf_matrix , class_names = class_names ,
536+ normalize = False )
537+ fig_rel , _ = plot_confusion_matrix (val_metrics .conf_matrix , class_names = class_names , normalize = True )
508538 tensorboard_writer .add_figure ('Confusion matrix' , fig_abs , epoch + 1 )
509539 tensorboard_writer .add_figure ('Confusion matrix/normalized' , fig_rel , epoch + 1 )
510540
511- for cl in metrics .class_labels :
541+ for cl in val_metrics .class_labels :
512542 class_index = int (cl )
513- labels_true = metrics .labels_true == class_index
514- pred_probs = metrics .labels_probs [:, class_index ]
543+ labels_true = val_metrics .labels_true == class_index
544+ pred_probs = val_metrics .labels_probs [:, class_index ]
515545 tensorboard_writer .add_pr_curve (f'PR curve/{ get_target_class (class_index )} ' ,
516546 labels_true , pred_probs , epoch + 1 )
517547
518- tensorboard_writer .add_figure ('PR curve' , metrics .fig_pr_curve_micro , epoch + 1 )
548+ tensorboard_writer .add_figure ('PR curve' , val_metrics .fig_pr_curve_micro , epoch + 1 )
519549
520550 if early_stopping .should_stop :
521551 log (f"Early stopping at epoch { epoch + 1 } " )
@@ -540,7 +570,7 @@ def get_target_class(cl: int) -> str:
540570 })
541571
542572
543- def train (train_loader , model , criterion , optimizer , epoch , device , args ) -> float :
573+ def train (train_loader , model , criterion , optimizer , epoch , device , args ) -> Tuple [ float , float , TrainMetrics ] :
544574 batch_time = AverageMeter ('Time' , ':6.3f' )
545575 data_time = AverageMeter ('Data' , ':6.3f' )
546576 losses = AverageMeter ('Loss' , ':.4e' )
@@ -555,6 +585,11 @@ def train(train_loader, model, criterion, optimizer, epoch, device, args) -> flo
555585 # switch to train mode
556586 model .train ()
557587
588+ # for train metrics
589+ labels_true = np .array ([], dtype = np .int64 )
590+ labels_pred = np .array ([], dtype = np .int64 )
591+ labels_probs = []
592+
558593 end = time .time ()
559594 for i , (images , target ) in enumerate (train_loader ):
560595 # measure data loading time
@@ -579,14 +614,29 @@ def train(train_loader, model, criterion, optimizer, epoch, device, args) -> flo
579614 loss .backward ()
580615 optimizer .step ()
581616
617+ with torch .no_grad ():
618+ predicted_values , predicted_indices = torch .max (output .data , 1 )
619+ labels_true = np .append (labels_true , target .cpu ().numpy ())
620+ labels_pred = np .append (labels_pred , predicted_indices .cpu ().numpy ())
621+
622+ class_probs_batch = [F .softmax (el , dim = 0 ) for el in output ]
623+ labels_probs .append (class_probs_batch )
624+
582625 # measure elapsed time
583626 batch_time .update (time .time () - end )
584627 end = time .time ()
585628
586629 if i % args .print_freq == 0 :
587630 progress .display (i + 1 )
588631
589- return loss .item ()
632+ if args .distributed :
633+ acc_top1 .all_reduce ()
634+ acc_top5 .all_reduce ()
635+
636+ labels_probs = torch .cat ([torch .stack (batch ) for batch in labels_probs ]).cpu ()
637+ metrics = calculate_train_metrics (labels_true , labels_pred , labels_probs )
638+
639+ return acc_top1 .avg , loss .item (), metrics
590640
591641
592642def validate (val_loader , model , criterion , args ) -> Tuple [float , float , "ValidationMetrics" ]:
@@ -635,7 +685,7 @@ def run_validate(loader, base_progress=0) -> ValidationMetrics:
635685
636686 labels_probs = torch .cat ([torch .stack (batch ) for batch in labels_probs ]).cpu ()
637687
638- return metrics_labels_true_pred (labels_true , labels_pred , labels_probs )
688+ return calculate_validation_metrics (labels_true , labels_pred , labels_probs )
639689
640690 batch_time = AverageMeter ('Time' , ':6.3f' , Summary .NONE )
641691 losses = AverageMeter ('Loss' , ':.4e' , Summary .NONE )
@@ -786,8 +836,29 @@ def accuracy(output, target, topk=(1,)):
786836 return res
787837
788838
789- def metrics_labels_true_pred (labels_true : np .array , labels_pred : np .array ,
790- labels_probs : torch .Tensor ) -> ValidationMetrics :
839+ def calculate_train_metrics (labels_true : np .array , labels_pred : np .array ,
840+ labels_probs : torch .Tensor ) -> TrainMetrics :
841+ unique_labels = list ({l for l in labels_true })
842+ f1_micro = f1_score (labels_true , labels_pred , average = "micro" )
843+ f1_macro = f1_score (labels_true , labels_pred , average = "macro" )
844+
845+ acc_balanced = balanced_accuracy_score (labels_true , labels_pred )
846+ prec_micro = precision_score (labels_true , labels_pred , average = "micro" )
847+ prec_macro = precision_score (labels_true , labels_pred , average = "macro" )
848+ rec_micro = recall_score (labels_true , labels_pred , average = "micro" )
849+ rec_macro = recall_score (labels_true , labels_pred , average = "macro" )
850+
851+ return TrainMetrics (
852+ unique_labels ,
853+ acc_balanced ,
854+ f1_micro , f1_macro ,
855+ prec_micro , prec_macro ,
856+ rec_micro , rec_macro
857+ )
858+
859+
860+ def calculate_validation_metrics (labels_true : np .array , labels_pred : np .array ,
861+ labels_probs : torch .Tensor ) -> ValidationMetrics :
791862 unique_labels = list ({l for l in labels_true })
792863 f1_per_class = f1_score (labels_true , labels_pred , average = None , labels = unique_labels )
793864 f1_micro = f1_score (labels_true , labels_pred , average = "micro" )
0 commit comments