@@ -144,38 +144,49 @@ def __init__(self, use_gpu=False) -> None:
144144
145145
146146class CoefficientNet (nn .Module ):
147- def __init__ (self , in_features = 1 ) -> None :
147+ def __init__ (self , in_features : int = 1 ) -> None :
148148 super ().__init__ ()
149149 self .fc1 = nn .Linear (in_features , 1 , bias = False )
150150 self .fc1 .weight .data .fill_ (0.01 )
151151
152- def forward (self , x ) :
152+ def forward (self , x : Tensor ) -> Tensor :
153153 x = self .fc1 (x )
154154 return x
155155
156156
157157class BasicLinearNet (nn .Module ):
158- def __init__ (self , in_features , hidden_nodes , out_features ) -> None :
158+ def __init__ (
159+ self ,
160+ in_features : int ,
161+ hidden_nodes : int ,
162+ out_features : int ,
163+ ) -> None :
159164 super ().__init__ ()
160165 self .linear1 = nn .Linear (in_features , hidden_nodes )
161166 self .linear2 = nn .Linear (hidden_nodes , out_features )
162167
163- def forward (self , input ) :
168+ def forward (self , input : Tensor ) -> Tensor :
164169 x = torch .tanh (self .linear1 (input ))
165170 return torch .tanh (self .linear2 (x ))
166171
167172
168173class MultLinearNet (nn .Module ):
169- def __init__ (self , in_features , hidden_nodes , out_features , num_inputs ) -> None :
174+ def __init__ (
175+ self ,
176+ in_features : int ,
177+ hidden_nodes : int ,
178+ out_features : int ,
179+ num_inputs : int ,
180+ ) -> None :
170181 super ().__init__ ()
171182 self .pre = nn .Linear (in_features * num_inputs , in_features )
172183 self .linear1 = nn .Linear (in_features , hidden_nodes )
173184 self .linear2 = nn .Linear (hidden_nodes , out_features )
174185
175- def forward (self , * inputs ) :
186+ def forward (self , * inputs : Tensor ) -> Tensor :
176187 """
177- The signature of inputs is List[torch. Tensor] ,
178- where torch. Tensor has the dimensions [num_inputs x in_features].
188+ The signature of inputs is a Tuple of Tensor,
189+ where the Tensor has the dimensions [num_inputs x in_features].
179190 It first concacenates the list and a linear layer to reduce the
180191 dimension.
181192 """
@@ -193,11 +204,11 @@ class Linear(nn.Module):
193204 those implementations.
194205 """
195206
196- def __init__ (self , in_features ) :
207+ def __init__ (self , in_features : int ) -> None :
197208 super ().__init__ ()
198209 self .linear = nn .Linear (in_features , 1 , bias = False )
199210
200- def forward (self , input ) :
211+ def forward (self , input : Tensor ) -> Tensor :
201212 return self .linear (input )
202213
203214
@@ -206,11 +217,11 @@ class UnpackLinear(nn.Module):
206217 the analogue of `Linear` which unpacks inputs, serving the same purpose.
207218 """
208219
209- def __init__ (self , in_features , num_inputs ) -> None :
220+ def __init__ (self , in_features : int , num_inputs : int ) -> None :
210221 super ().__init__ ()
211222 self .linear = nn .Linear (in_features * num_inputs , 1 , bias = False )
212223
213- def forward (self , * inputs ) :
224+ def forward (self , * inputs : Tensor ) -> Tensor :
214225 return self .linear (torch .cat (inputs , dim = 1 ))
215226
216227
0 commit comments