diff --git a/torchvision/models/segmentation/deeplabv3.py b/torchvision/models/segmentation/deeplabv3.py index f58c5d26a66..a92ddfe3b7a 100644 --- a/torchvision/models/segmentation/deeplabv3.py +++ b/torchvision/models/segmentation/deeplabv3.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, List, Optional +from typing import Any, Optional, Sequence import torch from torch import nn @@ -46,9 +46,9 @@ class DeepLabV3(_SimpleSegmentationModel): class DeepLabHead(nn.Sequential): - def __init__(self, in_channels: int, num_classes: int) -> None: + def __init__(self, in_channels: int, num_classes: int, atrous_rates: Sequence[int] = (12, 24, 36)) -> None: super().__init__( - ASPP(in_channels, [12, 24, 36]), + ASPP(in_channels, atrous_rates), nn.Conv2d(256, 256, 3, padding=1, bias=False), nn.BatchNorm2d(256), nn.ReLU(), @@ -83,7 +83,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class ASPP(nn.Module): - def __init__(self, in_channels: int, atrous_rates: List[int], out_channels: int = 256) -> None: + def __init__(self, in_channels: int, atrous_rates: Sequence[int], out_channels: int = 256) -> None: super().__init__() modules = [] modules.append(