11# pyre-strict
22import time
33import warnings
4- from typing import Any , Callable , Dict , List , Optional
4+ from typing import Any , Callable , cast , Dict , List , Optional , Tuple
55
66import torch
77import torch .nn as nn
@@ -17,6 +17,82 @@ def l2_loss(x1, x2, weights=None) -> torch.Tensor:
1717 return torch .sum ((weights / weights .norm (p = 1 )) * ((x1 - x2 ) ** 2 )) / 2.0
1818
1919
20+ class ConvergenceTracker :
21+ def __init__ (self , patience : int , threshold : float ) -> None :
22+ self .min_avg_loss : Optional [torch .Tensor ] = None
23+ self .convergence_counter : int = 0
24+ self .converged : bool = False
25+
26+ self .threshold = threshold
27+ self .patience = patience
28+
29+ def update (self , average_loss : torch .Tensor ) -> bool :
30+ if self .min_avg_loss is not None :
31+ # if we haven't improved by at least `threshold`
32+ if average_loss > self .min_avg_loss or torch .isclose (
33+ cast (torch .Tensor , self .min_avg_loss ), average_loss , atol = self .threshold
34+ ):
35+ self .convergence_counter += 1
36+ if self .convergence_counter >= self .patience :
37+ self .converged = True
38+ return True
39+ else :
40+ self .convergence_counter = 0
41+ if self .min_avg_loss is None or self .min_avg_loss >= average_loss :
42+ self .min_avg_loss = average_loss .clone ()
43+ return False
44+
45+
46+ class LossWindow :
47+ def __init__ (self , window_size : int ) -> None :
48+ self .loss_window : List [torch .Tensor ] = []
49+ self .window_size = window_size
50+
51+ def append (self , loss : torch .Tensor ) -> None :
52+ if len (self .loss_window ) >= self .window_size :
53+ self .loss_window = self .loss_window [- self .window_size :]
54+ self .loss_window .append (loss )
55+
56+ def average (self ) -> torch .Tensor :
57+ return torch .mean (torch .stack (self .loss_window ))
58+
59+
60+ def _init_linear_model (model : LinearModel , init_scheme : Optional [str ] = None ) -> None :
61+ assert model .linear is not None
62+ if init_scheme is not None :
63+ assert init_scheme in ["xavier" , "zeros" ]
64+
65+ with torch .no_grad ():
66+ if init_scheme == "xavier" :
67+ # pyre-fixme[16]: `Optional` has no attribute `weight`.
68+ torch .nn .init .xavier_uniform_ (model .linear .weight )
69+ else :
70+ model .linear .weight .zero_ ()
71+
72+ # pyre-fixme[16]: `Optional` has no attribute `bias`.
73+ if model .linear .bias is not None :
74+ model .linear .bias .zero_ ()
75+
76+
77+ def _get_point (
78+ datapoint : Tuple [torch .Tensor , ...],
79+ device : Optional [str ] = None ,
80+ ) -> Tuple [torch .Tensor , torch .Tensor , Optional [torch .Tensor ]]:
81+ if len (datapoint ) == 2 :
82+ x , y = datapoint
83+ w = None
84+ else :
85+ x , y , w = datapoint
86+
87+ if device is not None :
88+ x = x .to (device )
89+ y = y .to (device )
90+ if w is not None :
91+ w = w .to (device )
92+
93+ return x , y , w
94+
95+
2096def sgd_train_linear_model (
2197 model : LinearModel ,
2298 dataloader : DataLoader ,
@@ -102,31 +178,16 @@ def sgd_train_linear_model(
102178 This will return the final training loss (averaged with
103179 `running_loss_window`)
104180 """
105- loss_window : List [torch .Tensor ] = []
106- min_avg_loss = None
107- convergence_counter = 0
108- converged = False
109-
110- # pyre-fixme[3]: Return type must be annotated.
111- # pyre-fixme[2]: Parameter must be annotated.
112- def get_point (datapoint ):
113- if len (datapoint ) == 2 :
114- x , y = datapoint
115- w = None
116- else :
117- x , y , w = datapoint
118-
119- if device is not None :
120- x = x .to (device )
121- y = y .to (device )
122- if w is not None :
123- w = w .to (device )
124-
125- return x , y , w
181+ converge_tracker = ConvergenceTracker (patience , threshold )
126182
127183 # get a point and construct the model
128184 data_iter = iter (dataloader )
129- x , y , w = get_point (next (data_iter ))
185+ x , y , w = _get_point (next (data_iter ), device )
186+
187+ if running_loss_window is None :
188+ running_loss_window = x .shape [0 ] * len (dataloader )
189+
190+ loss_window = LossWindow (running_loss_window )
130191
131192 model ._construct_model_params (
132193 in_features = x .shape [1 ],
@@ -135,21 +196,8 @@ def get_point(datapoint):
135196 )
136197 model .train ()
137198
138- assert model .linear is not None
139-
140- if init_scheme is not None :
141- assert init_scheme in ["xavier" , "zeros" ]
142-
143- with torch .no_grad ():
144- if init_scheme == "xavier" :
145- # pyre-fixme[16]: `Optional` has no attribute `weight`.
146- torch .nn .init .xavier_uniform_ (model .linear .weight )
147- else :
148- model .linear .weight .zero_ ()
149-
150- # pyre-fixme[16]: `Optional` has no attribute `bias`.
151- if model .linear .bias is not None :
152- model .linear .bias .zero_ ()
199+ # Initialize linear model weights if applicable
200+ _init_linear_model (model , init_scheme )
153201
154202 with torch .enable_grad ():
155203 optim = torch .optim .SGD (model .parameters (), lr = initial_lr )
@@ -163,9 +211,6 @@ def get_point(datapoint):
163211 i = 0
164212 while epoch < max_epoch :
165213 while True : # for x, y, w in dataloader
166- if running_loss_window is None :
167- running_loss_window = x .shape [0 ] * len (dataloader )
168-
169214 y = y .view (x .shape [0 ], - 1 )
170215 if w is not None :
171216 w = w .view (x .shape [0 ], - 1 )
@@ -176,33 +221,20 @@ def get_point(datapoint):
176221
177222 loss = loss_fn (y , out , w )
178223 if reg_term is not None :
179- reg = torch .norm (model .linear .weight , p = reg_term )
224+ # pyre-fixme[16]: `Optional` has no attribute `weight`.
225+ reg = torch .norm (model .linear .weight , p = reg_term ) # type: ignore
180226 loss += reg .sum () * alpha
181227
182- if len (loss_window ) >= running_loss_window :
183- loss_window = loss_window [1 :]
184228 loss_window .append (loss .clone ().detach ())
185- assert len (loss_window ) <= running_loss_window
186-
187- average_loss = torch .mean (torch .stack (loss_window ))
188- if min_avg_loss is not None :
189- # if we haven't improved by at least `threshold`
190- if average_loss > min_avg_loss or torch .isclose (
191- min_avg_loss , average_loss , atol = threshold
192- ):
193- convergence_counter += 1
194- if convergence_counter >= patience :
195- converged = True
196- break
197- else :
198- convergence_counter = 0
199- if min_avg_loss is None or min_avg_loss >= average_loss :
200- min_avg_loss = average_loss .clone ()
229+ average_loss = loss_window .average ()
230+ if converge_tracker .update (average_loss ):
231+ break # converged
201232
202233 if debug :
203234 print (
204- f"lr={ optim .param_groups [0 ]['lr' ]} , Loss={ loss } ,"
205- + "Aloss={average_loss}, min_avg_loss={min_avg_loss}"
235+ f"lr={ optim .param_groups [0 ]['lr' ]} , Loss={ loss } , "
236+ f"Aloss={ average_loss } , "
237+ f"min_avg_loss={ converge_tracker .min_avg_loss } "
206238 )
207239
208240 loss .backward ()
@@ -215,19 +247,19 @@ def get_point(datapoint):
215247 temp = next (data_iter , None )
216248 if temp is None :
217249 break
218- x , y , w = get_point (temp )
250+ x , y , w = _get_point (temp , device )
219251
220- if converged :
252+ if converge_tracker . converged :
221253 break
222254
223255 epoch += 1
224256 data_iter = iter (dataloader )
225- x , y , w = get_point (next (data_iter ))
257+ x , y , w = _get_point (next (data_iter ), device )
226258
227259 t2 = time .time ()
228260 return {
229261 "train_time" : t2 - t1 ,
230- "train_loss" : torch . mean ( torch . stack ( loss_window ) ).item (),
262+ "train_loss" : loss_window . average ( ).item (),
231263 "train_iter" : i ,
232264 "train_epoch" : epoch ,
233265 }
@@ -303,7 +335,8 @@ def sklearn_train_linear_model(
303335 if not sklearn .__version__ >= "0.23.0" :
304336 warnings .warn (
305337 "Must have sklearn version 0.23.0 or higher to use "
306- "sample_weight in Lasso regression."
338+ "sample_weight in Lasso regression." ,
339+ stacklevel = 1 ,
307340 )
308341
309342 num_batches = 0
@@ -346,7 +379,8 @@ def sklearn_train_linear_model(
346379 warnings .warn (
347380 "Sample weight is not supported for the provided linear model!"
348381 " Trained model without weighting inputs. For Lasso, please"
349- " upgrade sklearn to a version >= 0.23.0."
382+ " upgrade sklearn to a version >= 0.23.0." ,
383+ stacklevel = 1 ,
350384 )
351385
352386 t2 = time .time ()
0 commit comments