1212import torch
1313import torch .nn as nn
1414from sklearn .metrics import f1_score
15- from torch import Tensor
15+ from torch import BoolTensor , Tensor
1616from torch .nn .functional import softmax
1717from torch .utils .data import DataLoader
1818from torch .utils .tensorboard import SummaryWriter
@@ -577,22 +577,20 @@ def np_one_hot(targets: np.ndarray, n_classes: int = None) -> np.ndarray:
577577 return np .eye (n_classes )[targets ]
578578
579579
580- def masked_std (
581- x : torch .Tensor , mask : torch .BoolTensor , dim : int = 0 , eps : float = 1e-12
582- ) -> torch .Tensor :
580+ def masked_std (x : Tensor , mask : BoolTensor , dim : int = 0 , eps : float = 1e-12 ) -> Tensor :
583581 """Compute the standard deviation of a tensor, ignoring masked values.
584582
585583 Args:
586- x (torch. Tensor): Tensor to compute standard deviation of.
587- mask (torch. BoolTensor): Same shape as x with True where x is valid and False
584+ x (Tensor): Tensor to compute standard deviation of.
585+ mask (BoolTensor): Same shape as x with True where x is valid and False
588586 where x should be masked. Mask should not be all False in any column of
589587 dimension dim to avoid NaNs.
590588 dim (int, optional): Dimension to take std of. Defaults to 0.
591589 eps (float, optional): Small positive number to ensure std is differentiable.
592590 Defaults to 1e-12.
593591
594592 Returns:
595- torch. Tensor: Same shape as x, except dimension dim reduced.
593+ Tensor: Same shape as x, except dimension dim reduced.
596594 """
597595 mean = masked_mean (x , mask , dim = dim )
598596 squared_diff = (x - mean .unsqueeze (dim = dim )) ** 2
@@ -601,18 +599,42 @@ def masked_std(
601599 return std
602600
603601
604- def masked_mean (x : torch . Tensor , mask : torch . BoolTensor , dim : int = 0 ) -> torch . Tensor :
602+ def masked_mean (x : Tensor , mask : BoolTensor , dim : int = 0 ) -> Tensor :
605603 """Compute the mean of a tensor, ignoring masked values.
606604
607605 Args:
608- x (torch. Tensor): Tensor to compute standard deviation of.
609- mask (torch. BoolTensor): Same shape as x with True where x is valid and False
606+ x (Tensor): Tensor to compute mean of.
607+ mask (BoolTensor): Same shape as x with True where x is valid and False
610608 where x should be masked. Mask should not be all False in any column of
611609 dimension dim to avoid NaNs from zero division.
612610 dim (int, optional): Dimension to take mean of. Defaults to 0.
613611
614612 Returns:
615- torch. Tensor: Same shape as x, except dimension dim reduced.
613+ Tensor: Same shape as x, except dimension dim reduced.
616614 """
617- x_nan = x .masked_fill (~ mask , float ("nan" ))
615+ # for safety, we could add this assert but might impact performance
616+ # assert (
617+ # mask.sum(dim=dim).ne(0).all()
618+ # ), "mask should not be all False in any column, causes zero division"
619+ x_nan = x .float ().masked_fill (~ mask , float ("nan" ))
618620 return x_nan .nanmean (dim = dim )
621+
622+
623+ def masked_max (x : Tensor , mask : BoolTensor , dim : int = 0 ) -> Tensor :
624+ """Compute the max of a tensor along dimension dim, ignoring values at indices where
625+ mask is False. See masked_mean docstring for Args details.
626+ """
627+ # replace padded values with +/-inf to make sure min()/max() ignore them
628+ x_inf = x .float ().masked_fill (~ mask , float ("-inf" ))
629+ # 1st ret val = max, 2nd ret val = max indices
630+ x_max , _ = x_inf .max (dim = dim )
631+ return x_max
632+
633+
634+ def masked_min (x : Tensor , mask : BoolTensor , dim : int = 0 ) -> Tensor :
635+ """Compute the min of a tensor along dimension dim, ignoring values at indices where
636+ mask is False. See masked_mean docstring for Args details.
637+ """
638+ x_inf = x .float ().masked_fill (~ mask , float ("inf" ))
639+ x_min , _ = x_inf .min (dim = dim )
640+ return x_min
0 commit comments