Skip to content

Replace all no_grad() instances with inference_mode() in reference scripts #4629

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix="
header = f"Test: {log_suffix}"

num_processed_samples = 0
with torch.no_grad():
with torch.inference_mode():
for image, target in metric_logger.log_every(data_loader, print_freq, header):
image = image.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
Expand Down
2 changes: 1 addition & 1 deletion references/classification/train_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def main(args):
print("Starting training for epoch", epoch)
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq)
lr_scheduler.step()
with torch.no_grad():
with torch.inference_mode():
if epoch >= args.num_observer_update_epochs:
print("Disabling observer for subseq epochs, epoch = ", epoch)
model.apply(torch.quantization.disable_observer)
Expand Down
2 changes: 1 addition & 1 deletion references/classification/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def update_parameters(self, model):

def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
with torch.inference_mode():
maxk = max(topk)
batch_size = target.size(0)
if target.ndim == 2:
Expand Down
2 changes: 1 addition & 1 deletion references/detection/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _get_iou_types(model):
return iou_types


@torch.no_grad()
@torch.inference_mode()
def evaluate(model, data_loader, device):
n_threads = torch.get_num_threads()
# FIXME remove this and make paste_masks_in_image run on the GPU
Expand Down
2 changes: 1 addition & 1 deletion references/detection/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def reduce_dict(input_dict, average=True):
world_size = get_world_size()
if world_size < 2:
return input_dict
with torch.no_grad():
with torch.inference_mode():
names = []
values = []
# sort the keys so that they are consistent across processes
Expand Down
2 changes: 1 addition & 1 deletion references/segmentation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def evaluate(model, data_loader, device, num_classes):
confmat = utils.ConfusionMatrix(num_classes)
metric_logger = utils.MetricLogger(delimiter=" ")
header = "Test:"
with torch.no_grad():
with torch.inference_mode():
for image, target in metric_logger.log_every(data_loader, 100, header):
image, target = image.to(device), target.to(device)
output = model(image)
Expand Down
2 changes: 1 addition & 1 deletion references/segmentation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def update(self, a, b):
n = self.num_classes
if self.mat is None:
self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)
with torch.no_grad():
with torch.inference_mode():
k = (a >= 0) & (a < n)
inds = n * a[k].to(torch.int64) + b[k]
self.mat += torch.bincount(inds, minlength=n ** 2).reshape(n, n)
Expand Down
2 changes: 1 addition & 1 deletion references/similarity/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def find_best_threshold(dists, targets, device):
return best_thresh, accuracy


@torch.no_grad()
@torch.inference_mode()
def evaluate(model, loader, device):
model.eval()
embeds, labels = [], []
Expand Down
2 changes: 1 addition & 1 deletion references/video_classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def evaluate(model, criterion, data_loader, device):
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
header = "Test:"
with torch.no_grad():
with torch.inference_mode():
for video, target in metric_logger.log_every(data_loader, 100, header):
video = video.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
Expand Down
2 changes: 1 addition & 1 deletion references/video_classification/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def log_every(self, iterable, print_freq, header=None):

def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
with torch.inference_mode():
maxk = max(topk)
batch_size = target.size(0)

Expand Down