@@ -590,7 +590,23 @@ def rotate(
590
590
def pad_image_tensor (
591
591
img : torch .Tensor ,
592
592
padding : Union [int , List [int ]],
593
- fill : Optional [Union [int , float ]] = 0 ,
593
+ fill : Optional [Union [int , float , List [float ]]] = None ,
594
+ padding_mode : str = "constant" ,
595
+ ) -> torch .Tensor :
596
+ if fill is None :
597
+ # This JIT workaround
598
+ return _pad_with_scalar_fill (img , padding , fill = None , padding_mode = padding_mode )
599
+ elif isinstance (fill , (int , float )) or len (fill ) == 1 :
600
+ fill_number = fill [0 ] if isinstance (fill , list ) else fill
601
+ return _pad_with_scalar_fill (img , padding , fill = fill_number , padding_mode = padding_mode )
602
+ else :
603
+ return _pad_with_vector_fill (img , padding , fill = fill , padding_mode = padding_mode )
604
+
605
+
606
+ def _pad_with_scalar_fill (
607
+ img : torch .Tensor ,
608
+ padding : Union [int , List [int ]],
609
+ fill : Optional [Union [int , float ]] = None ,
594
610
padding_mode : str = "constant" ,
595
611
) -> torch .Tensor :
596
612
num_channels , height , width = img .shape [- 3 :]
@@ -613,13 +629,13 @@ def pad_image_tensor(
613
629
def _pad_with_vector_fill (
614
630
img : torch .Tensor ,
615
631
padding : Union [int , List [int ]],
616
- fill : Sequence [float ] = [ 0.0 ],
632
+ fill : List [float ],
617
633
padding_mode : str = "constant" ,
618
634
) -> torch .Tensor :
619
635
if padding_mode != "constant" :
620
636
raise ValueError (f"Padding mode '{ padding_mode } ' is not supported if fill is not scalar" )
621
637
622
- output = pad_image_tensor (img , padding , fill = 0 , padding_mode = "constant" )
638
+ output = _pad_with_scalar_fill (img , padding , fill = 0 , padding_mode = "constant" )
623
639
left , right , top , bottom = _parse_pad_padding (padding )
624
640
fill = torch .tensor (fill , dtype = img .dtype , device = img .device ).view (- 1 , 1 , 1 )
625
641
@@ -638,8 +654,14 @@ def pad_mask(
638
654
mask : torch .Tensor ,
639
655
padding : Union [int , List [int ]],
640
656
padding_mode : str = "constant" ,
641
- fill : Optional [Union [int , float ]] = 0 ,
657
+ fill : Optional [Union [int , float , List [ float ]]] = None ,
642
658
) -> torch .Tensor :
659
+ if fill is None :
660
+ fill = 0
661
+
662
+ if isinstance (fill , list ):
663
+ raise ValueError ("Non-scalar fill value is not supported" )
664
+
643
665
if mask .ndim < 3 :
644
666
mask = mask .unsqueeze (0 )
645
667
needs_squeeze = True
@@ -692,10 +714,11 @@ def pad(
692
714
if not isinstance (padding , int ):
693
715
padding = list (padding )
694
716
695
- # TODO: PyTorch's pad supports only scalars on fill. So we need to overwrite the colour
696
- if isinstance (fill , (int , float )) or fill is None :
697
- return pad_image_tensor (inpt , padding , fill = fill , padding_mode = padding_mode )
698
- return _pad_with_vector_fill (inpt , padding , fill = fill , padding_mode = padding_mode )
717
+ # This cast does Sequence -> List and is required to make mypy happy
718
+ if not (fill is None or isinstance (fill , (int , float ))):
719
+ fill = list (fill )
720
+
721
+ return pad_image_tensor (inpt , padding , fill = fill , padding_mode = padding_mode )
699
722
700
723
701
724
crop_image_tensor = _FT .crop
0 commit comments