Skip to content

Commit 395e6cd

Browse files
craymichaelfacebook-github-bot
authored andcommitted
Reduce complexity of 'sgd_train_linear_model' (#1374)
Summary: Reduce complexity of 'sgd_train_linear_model' Differential Revision: D64432524
1 parent 94141d6 commit 395e6cd

File tree

1 file changed

+102
-68
lines changed
  • captum/_utils/models/linear_model

1 file changed

+102
-68
lines changed

captum/_utils/models/linear_model/train.py

Lines changed: 102 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# pyre-strict
22
import time
33
import warnings
4-
from typing import Any, Callable, Dict, List, Optional
4+
from typing import Any, Callable, cast, Dict, List, Optional, Tuple
55

66
import torch
77
import torch.nn as nn
@@ -17,6 +17,82 @@ def l2_loss(x1, x2, weights=None) -> torch.Tensor:
1717
return torch.sum((weights / weights.norm(p=1)) * ((x1 - x2) ** 2)) / 2.0
1818

1919

20+
class ConvergenceTracker:
21+
def __init__(self, patience: int, threshold: float) -> None:
22+
self.min_avg_loss: Optional[torch.Tensor] = None
23+
self.convergence_counter: int = 0
24+
self.converged: bool = False
25+
26+
self.threshold = threshold
27+
self.patience = patience
28+
29+
def update(self, average_loss: torch.Tensor) -> bool:
30+
if self.min_avg_loss is not None:
31+
# if we haven't improved by at least `threshold`
32+
if average_loss > self.min_avg_loss or torch.isclose(
33+
cast(torch.Tensor, self.min_avg_loss), average_loss, atol=self.threshold
34+
):
35+
self.convergence_counter += 1
36+
if self.convergence_counter >= self.patience:
37+
self.converged = True
38+
return True
39+
else:
40+
self.convergence_counter = 0
41+
if self.min_avg_loss is None or self.min_avg_loss >= average_loss:
42+
self.min_avg_loss = average_loss.clone()
43+
return False
44+
45+
46+
class LossWindow:
47+
def __init__(self, window_size: int) -> None:
48+
self.loss_window: List[torch.Tensor] = []
49+
self.window_size = window_size
50+
51+
def append(self, loss: torch.Tensor) -> None:
52+
if len(self.loss_window) >= self.window_size:
53+
self.loss_window = self.loss_window[-self.window_size :]
54+
self.loss_window.append(loss)
55+
56+
def average(self) -> torch.Tensor:
57+
return torch.mean(torch.stack(self.loss_window))
58+
59+
60+
def _init_linear_model(model: LinearModel, init_scheme: Optional[str] = None) -> None:
61+
assert model.linear is not None
62+
if init_scheme is not None:
63+
assert init_scheme in ["xavier", "zeros"]
64+
65+
with torch.no_grad():
66+
if init_scheme == "xavier":
67+
# pyre-fixme[16]: `Optional` has no attribute `weight`.
68+
torch.nn.init.xavier_uniform_(model.linear.weight)
69+
else:
70+
model.linear.weight.zero_()
71+
72+
# pyre-fixme[16]: `Optional` has no attribute `bias`.
73+
if model.linear.bias is not None:
74+
model.linear.bias.zero_()
75+
76+
77+
def _get_point(
78+
datapoint: Tuple[torch.Tensor, ...],
79+
device: Optional[str] = None,
80+
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
81+
if len(datapoint) == 2:
82+
x, y = datapoint
83+
w = None
84+
else:
85+
x, y, w = datapoint
86+
87+
if device is not None:
88+
x = x.to(device)
89+
y = y.to(device)
90+
if w is not None:
91+
w = w.to(device)
92+
93+
return x, y, w
94+
95+
2096
def sgd_train_linear_model(
2197
model: LinearModel,
2298
dataloader: DataLoader,
@@ -102,31 +178,16 @@ def sgd_train_linear_model(
102178
This will return the final training loss (averaged with
103179
`running_loss_window`)
104180
"""
105-
loss_window: List[torch.Tensor] = []
106-
min_avg_loss = None
107-
convergence_counter = 0
108-
converged = False
109-
110-
# pyre-fixme[3]: Return type must be annotated.
111-
# pyre-fixme[2]: Parameter must be annotated.
112-
def get_point(datapoint):
113-
if len(datapoint) == 2:
114-
x, y = datapoint
115-
w = None
116-
else:
117-
x, y, w = datapoint
118-
119-
if device is not None:
120-
x = x.to(device)
121-
y = y.to(device)
122-
if w is not None:
123-
w = w.to(device)
124-
125-
return x, y, w
181+
converge_tracker = ConvergenceTracker(patience, threshold)
126182

127183
# get a point and construct the model
128184
data_iter = iter(dataloader)
129-
x, y, w = get_point(next(data_iter))
185+
x, y, w = _get_point(next(data_iter), device)
186+
187+
if running_loss_window is None:
188+
running_loss_window = x.shape[0] * len(dataloader)
189+
190+
loss_window = LossWindow(running_loss_window)
130191

131192
model._construct_model_params(
132193
in_features=x.shape[1],
@@ -135,21 +196,8 @@ def get_point(datapoint):
135196
)
136197
model.train()
137198

138-
assert model.linear is not None
139-
140-
if init_scheme is not None:
141-
assert init_scheme in ["xavier", "zeros"]
142-
143-
with torch.no_grad():
144-
if init_scheme == "xavier":
145-
# pyre-fixme[16]: `Optional` has no attribute `weight`.
146-
torch.nn.init.xavier_uniform_(model.linear.weight)
147-
else:
148-
model.linear.weight.zero_()
149-
150-
# pyre-fixme[16]: `Optional` has no attribute `bias`.
151-
if model.linear.bias is not None:
152-
model.linear.bias.zero_()
199+
# Initialize linear model weights if applicable
200+
_init_linear_model(model, init_scheme)
153201

154202
with torch.enable_grad():
155203
optim = torch.optim.SGD(model.parameters(), lr=initial_lr)
@@ -163,9 +211,6 @@ def get_point(datapoint):
163211
i = 0
164212
while epoch < max_epoch:
165213
while True: # for x, y, w in dataloader
166-
if running_loss_window is None:
167-
running_loss_window = x.shape[0] * len(dataloader)
168-
169214
y = y.view(x.shape[0], -1)
170215
if w is not None:
171216
w = w.view(x.shape[0], -1)
@@ -176,33 +221,20 @@ def get_point(datapoint):
176221

177222
loss = loss_fn(y, out, w)
178223
if reg_term is not None:
179-
reg = torch.norm(model.linear.weight, p=reg_term)
224+
# pyre-fixme[16]: `Optional` has no attribute `weight`.
225+
reg = torch.norm(model.linear.weight, p=reg_term) # type: ignore
180226
loss += reg.sum() * alpha
181227

182-
if len(loss_window) >= running_loss_window:
183-
loss_window = loss_window[1:]
184228
loss_window.append(loss.clone().detach())
185-
assert len(loss_window) <= running_loss_window
186-
187-
average_loss = torch.mean(torch.stack(loss_window))
188-
if min_avg_loss is not None:
189-
# if we haven't improved by at least `threshold`
190-
if average_loss > min_avg_loss or torch.isclose(
191-
min_avg_loss, average_loss, atol=threshold
192-
):
193-
convergence_counter += 1
194-
if convergence_counter >= patience:
195-
converged = True
196-
break
197-
else:
198-
convergence_counter = 0
199-
if min_avg_loss is None or min_avg_loss >= average_loss:
200-
min_avg_loss = average_loss.clone()
229+
average_loss = loss_window.average()
230+
if converge_tracker.update(average_loss):
231+
break # converged
201232

202233
if debug:
203234
print(
204-
f"lr={optim.param_groups[0]['lr']}, Loss={loss},"
205-
+ "Aloss={average_loss}, min_avg_loss={min_avg_loss}"
235+
f"lr={optim.param_groups[0]['lr']}, Loss={loss}, "
236+
f"Aloss={average_loss}, "
237+
f"min_avg_loss={converge_tracker.min_avg_loss}"
206238
)
207239

208240
loss.backward()
@@ -215,19 +247,19 @@ def get_point(datapoint):
215247
temp = next(data_iter, None)
216248
if temp is None:
217249
break
218-
x, y, w = get_point(temp)
250+
x, y, w = _get_point(temp, device)
219251

220-
if converged:
252+
if converge_tracker.converged:
221253
break
222254

223255
epoch += 1
224256
data_iter = iter(dataloader)
225-
x, y, w = get_point(next(data_iter))
257+
x, y, w = _get_point(next(data_iter), device)
226258

227259
t2 = time.time()
228260
return {
229261
"train_time": t2 - t1,
230-
"train_loss": torch.mean(torch.stack(loss_window)).item(),
262+
"train_loss": loss_window.average().item(),
231263
"train_iter": i,
232264
"train_epoch": epoch,
233265
}
@@ -303,7 +335,8 @@ def sklearn_train_linear_model(
303335
if not sklearn.__version__ >= "0.23.0":
304336
warnings.warn(
305337
"Must have sklearn version 0.23.0 or higher to use "
306-
"sample_weight in Lasso regression."
338+
"sample_weight in Lasso regression.",
339+
stacklevel=1,
307340
)
308341

309342
num_batches = 0
@@ -346,7 +379,8 @@ def sklearn_train_linear_model(
346379
warnings.warn(
347380
"Sample weight is not supported for the provided linear model!"
348381
" Trained model without weighting inputs. For Lasso, please"
349-
" upgrade sklearn to a version >= 0.23.0."
382+
" upgrade sklearn to a version >= 0.23.0.",
383+
stacklevel=1,
350384
)
351385

352386
t2 = time.time()

0 commit comments

Comments
 (0)