File tree Expand file tree Collapse file tree 1 file changed +0
-5
lines changed Expand file tree Collapse file tree 1 file changed +0
-5
lines changed Original file line number Diff line number Diff line change @@ -44,35 +44,30 @@ def __init__(
44
44
self .memory_efficient = memory_efficient
45
45
46
46
def bn_function (self , inputs : List [Tensor ]) -> Tensor :
47
- # type: (List[Tensor]) -> Tensor
48
47
concated_features = torch .cat (inputs , 1 )
49
48
bottleneck_output = self .conv1 (self .relu1 (self .norm1 (concated_features ))) # noqa: T484
50
49
return bottleneck_output
51
50
52
51
# todo: rewrite when torchscript supports any
53
52
def any_requires_grad (self , input : List [Tensor ]) -> Tensor :
54
- # type: (List[Tensor]) -> bool
55
53
for tensor in input :
56
54
if tensor .requires_grad :
57
55
return True
58
56
return False
59
57
60
58
@torch .jit .unused # noqa: T484
61
59
def call_checkpoint_bottleneck (self , input : List [Tensor ]) -> Tensor :
62
- # type: (List[Tensor]) -> Tensor
63
60
def closure (* inputs ):
64
61
return self .bn_function (inputs )
65
62
66
63
return cp .checkpoint (closure , * input )
67
64
68
65
@torch .jit ._overload_method # noqa: F811
69
66
def forward (self , input : List [Tensor ]) -> Tensor :
70
- # type: (List[Tensor]) -> (Tensor)
71
67
pass
72
68
73
69
@torch .jit ._overload_method # noqa: F811
74
70
def forward (self , input : Tensor ) -> Tensor :
75
- # type: (Tensor) -> (Tensor)
76
71
pass
77
72
78
73
# torchscript does not yet support *args, so we overload method
You can’t perform that action at this time.
0 commit comments