Skip to content

Commit e08c9e3

Browse files
authored
Replaced all 'no_grad()' instances with 'inference_mode()' (#4629)
1 parent fba4f42 commit e08c9e3

File tree

10 files changed

+10
-10
lines changed

10 files changed

+10
-10
lines changed

references/classification/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix="
5757
header = f"Test: {log_suffix}"
5858

5959
num_processed_samples = 0
60-
with torch.no_grad():
60+
with torch.inference_mode():
6161
for image, target in metric_logger.log_every(data_loader, print_freq, header):
6262
image = image.to(device, non_blocking=True)
6363
target = target.to(device, non_blocking=True)

references/classification/train_quantization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def main(args):
112112
print("Starting training for epoch", epoch)
113113
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq)
114114
lr_scheduler.step()
115-
with torch.no_grad():
115+
with torch.inference_mode():
116116
if epoch >= args.num_observer_update_epochs:
117117
print("Disabling observer for subseq epochs, epoch = ", epoch)
118118
model.apply(torch.quantization.disable_observer)

references/classification/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def update_parameters(self, model):
181181

182182
def accuracy(output, target, topk=(1,)):
183183
"""Computes the accuracy over the k top predictions for the specified values of k"""
184-
with torch.no_grad():
184+
with torch.inference_mode():
185185
maxk = max(topk)
186186
batch_size = target.size(0)
187187
if target.ndim == 2:

references/detection/engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def _get_iou_types(model):
6868
return iou_types
6969

7070

71-
@torch.no_grad()
71+
@torch.inference_mode()
7272
def evaluate(model, data_loader, device):
7373
n_threads = torch.get_num_threads()
7474
# FIXME remove this and make paste_masks_in_image run on the GPU

references/detection/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def reduce_dict(input_dict, average=True):
9595
world_size = get_world_size()
9696
if world_size < 2:
9797
return input_dict
98-
with torch.no_grad():
98+
with torch.inference_mode():
9999
names = []
100100
values = []
101101
# sort the keys so that they are consistent across processes

references/segmentation/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def evaluate(model, data_loader, device, num_classes):
4949
confmat = utils.ConfusionMatrix(num_classes)
5050
metric_logger = utils.MetricLogger(delimiter=" ")
5151
header = "Test:"
52-
with torch.no_grad():
52+
with torch.inference_mode():
5353
for image, target in metric_logger.log_every(data_loader, 100, header):
5454
image, target = image.to(device), target.to(device)
5555
output = model(image)

references/segmentation/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def update(self, a, b):
7676
n = self.num_classes
7777
if self.mat is None:
7878
self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)
79-
with torch.no_grad():
79+
with torch.inference_mode():
8080
k = (a >= 0) & (a < n)
8181
inds = n * a[k].to(torch.int64) + b[k]
8282
self.mat += torch.bincount(inds, minlength=n ** 2).reshape(n, n)

references/similarity/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def find_best_threshold(dists, targets, device):
5151
return best_thresh, accuracy
5252

5353

54-
@torch.no_grad()
54+
@torch.inference_mode()
5555
def evaluate(model, loader, device):
5656
model.eval()
5757
embeds, labels = [], []

references/video_classification/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def evaluate(model, criterion, data_loader, device):
5252
model.eval()
5353
metric_logger = utils.MetricLogger(delimiter=" ")
5454
header = "Test:"
55-
with torch.no_grad():
55+
with torch.inference_mode():
5656
for video, target in metric_logger.log_every(data_loader, 100, header):
5757
video = video.to(device, non_blocking=True)
5858
target = target.to(device, non_blocking=True)

references/video_classification/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def log_every(self, iterable, print_freq, header=None):
159159

160160
def accuracy(output, target, topk=(1,)):
161161
"""Computes the accuracy over the k top predictions for the specified values of k"""
162-
with torch.no_grad():
162+
with torch.inference_mode():
163163
maxk = max(topk)
164164
batch_size = target.size(0)
165165

0 commit comments

Comments
 (0)