Skip to content

Commit 325e114

Browse files
cicichen01facebook-github-bot
authored andcommitted
Specify Types to enable future function splits (#1253)
Summary: As titled. There are many type error with method split due to vague types. Add type annotation as first step. Differential Revision: D55169019
1 parent 80c7ce5 commit 325e114

File tree

1 file changed

+23
-12
lines changed

1 file changed

+23
-12
lines changed

tests/influence/_utils/common.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -144,38 +144,49 @@ def __init__(self, use_gpu=False) -> None:
144144

145145

146146
class CoefficientNet(nn.Module):
147-
def __init__(self, in_features=1) -> None:
147+
def __init__(self, in_features: int = 1) -> None:
148148
super().__init__()
149149
self.fc1 = nn.Linear(in_features, 1, bias=False)
150150
self.fc1.weight.data.fill_(0.01)
151151

152-
def forward(self, x):
152+
def forward(self, x: Tensor) -> Tensor:
153153
x = self.fc1(x)
154154
return x
155155

156156

157157
class BasicLinearNet(nn.Module):
158-
def __init__(self, in_features, hidden_nodes, out_features) -> None:
158+
def __init__(
159+
self,
160+
in_features: int,
161+
hidden_nodes: int,
162+
out_features: int,
163+
) -> None:
159164
super().__init__()
160165
self.linear1 = nn.Linear(in_features, hidden_nodes)
161166
self.linear2 = nn.Linear(hidden_nodes, out_features)
162167

163-
def forward(self, input):
168+
def forward(self, input: Tensor) -> Tensor:
164169
x = torch.tanh(self.linear1(input))
165170
return torch.tanh(self.linear2(x))
166171

167172

168173
class MultLinearNet(nn.Module):
169-
def __init__(self, in_features, hidden_nodes, out_features, num_inputs) -> None:
174+
def __init__(
175+
self,
176+
in_features: int,
177+
hidden_nodes: int,
178+
out_features: int,
179+
num_inputs: int,
180+
) -> None:
170181
super().__init__()
171182
self.pre = nn.Linear(in_features * num_inputs, in_features)
172183
self.linear1 = nn.Linear(in_features, hidden_nodes)
173184
self.linear2 = nn.Linear(hidden_nodes, out_features)
174185

175-
def forward(self, *inputs):
186+
def forward(self, *inputs: Tensor) -> Tensor:
176187
"""
177-
The signature of inputs is List[torch.Tensor],
178-
where torch.Tensor has the dimensions [num_inputs x in_features].
188+
The signature of inputs is a Tuple of Tensor,
189+
where the Tensor has the dimensions [num_inputs x in_features].
179190
It first concacenates the list and a linear layer to reduce the
180191
dimension.
181192
"""
@@ -193,11 +204,11 @@ class Linear(nn.Module):
193204
those implementations.
194205
"""
195206

196-
def __init__(self, in_features):
207+
def __init__(self, in_features: int) -> None:
197208
super().__init__()
198209
self.linear = nn.Linear(in_features, 1, bias=False)
199210

200-
def forward(self, input):
211+
def forward(self, input: Tensor) -> Tensor:
201212
return self.linear(input)
202213

203214

@@ -206,11 +217,11 @@ class UnpackLinear(nn.Module):
206217
the analogue of `Linear` which unpacks inputs, serving the same purpose.
207218
"""
208219

209-
def __init__(self, in_features, num_inputs) -> None:
220+
def __init__(self, in_features: int, num_inputs: int) -> None:
210221
super().__init__()
211222
self.linear = nn.Linear(in_features * num_inputs, 1, bias=False)
212223

213-
def forward(self, *inputs):
224+
def forward(self, *inputs: Tensor) -> Tensor:
214225
return self.linear(torch.cat(inputs, dim=1))
215226

216227

0 commit comments

Comments
 (0)