Skip to content

Commit 67dce75

Browse files
committed
feat: parameterize atrous_rates for deeplabv3 in DeepLabHead.
1 parent 16967f0 commit 67dce75

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

torchvision/models/segmentation/deeplabv3.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from functools import partial
2-
from typing import Any, List, Optional
2+
from typing import Any, Optional, Sequence
33

44
import torch
55
from torch import nn
@@ -46,9 +46,9 @@ class DeepLabV3(_SimpleSegmentationModel):
4646

4747

4848
class DeepLabHead(nn.Sequential):
49-
def __init__(self, in_channels: int, num_classes: int) -> None:
49+
def __init__(self, in_channels: int, num_classes: int, atrous_rates: Sequence[int] = (12, 24, 36)) -> None:
5050
super().__init__(
51-
ASPP(in_channels, [12, 24, 36]),
51+
ASPP(in_channels, atrous_rates),
5252
nn.Conv2d(256, 256, 3, padding=1, bias=False),
5353
nn.BatchNorm2d(256),
5454
nn.ReLU(),
@@ -83,7 +83,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
8383

8484

8585
class ASPP(nn.Module):
86-
def __init__(self, in_channels: int, atrous_rates: List[int], out_channels: int = 256) -> None:
86+
def __init__(self, in_channels: int, atrous_rates: Sequence[int], out_channels: int = 256) -> None:
8787
super().__init__()
8888
modules = []
8989
modules.append(

0 commit comments

Comments
 (0)