|
1 | 1 | from functools import partial
|
2 |
| -from typing import Any, List, Optional |
| 2 | +from typing import Any, Optional, Sequence |
3 | 3 |
|
4 | 4 | import torch
|
5 | 5 | from torch import nn
|
@@ -46,9 +46,9 @@ class DeepLabV3(_SimpleSegmentationModel):
|
46 | 46 |
|
47 | 47 |
|
48 | 48 | class DeepLabHead(nn.Sequential):
|
49 |
| - def __init__(self, in_channels: int, num_classes: int) -> None: |
| 49 | + def __init__(self, in_channels: int, num_classes: int, atrous_rates: Sequence[int] = (12, 24, 36)) -> None: |
50 | 50 | super().__init__(
|
51 |
| - ASPP(in_channels, [12, 24, 36]), |
| 51 | + ASPP(in_channels, atrous_rates), |
52 | 52 | nn.Conv2d(256, 256, 3, padding=1, bias=False),
|
53 | 53 | nn.BatchNorm2d(256),
|
54 | 54 | nn.ReLU(),
|
@@ -83,7 +83,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
|
83 | 83 |
|
84 | 84 |
|
85 | 85 | class ASPP(nn.Module):
|
86 |
| - def __init__(self, in_channels: int, atrous_rates: List[int], out_channels: int = 256) -> None: |
| 86 | + def __init__(self, in_channels: int, atrous_rates: Sequence[int], out_channels: int = 256) -> None: |
87 | 87 | super().__init__()
|
88 | 88 | modules = []
|
89 | 89 | modules.append(
|
|
0 commit comments