Skip to content

Commit 61f8266

Browse files
jdsgomesdatumbox
andauthored
reverting some recently introduced exceptions (#5659)
* reverting some recently introduced exceptions * Update torchvision/ops/poolers.py Co-authored-by: Vasilis Vryniotis <[email protected]> * address PR comments * replace one more assert with torch._assert: * address PR comments * make type checker happy * Fix bug * fix bug * fix for wrong asserts * attempt to make tests pass * Fix test_ops tests * Fix expected exception in tests * fix typo * fix tests and format * fix flake8 * remove one last exception * fix error * remove unused immport * replace fake returns by else Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 1db8795 commit 61f8266

12 files changed

+138
-126
lines changed

test/test_models.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -745,24 +745,24 @@ def test_detection_model_validation(model_fn):
745745
x = [torch.rand(input_shape)]
746746

747747
# validate that targets are present in training
748-
with pytest.raises(ValueError):
748+
with pytest.raises(AssertionError):
749749
model(x)
750750

751751
# validate type
752752
targets = [{"boxes": 0.0}]
753-
with pytest.raises(TypeError):
753+
with pytest.raises(AssertionError):
754754
model(x, targets=targets)
755755

756756
# validate boxes shape
757757
for boxes in (torch.rand((4,)), torch.rand((1, 5))):
758758
targets = [{"boxes": boxes}]
759-
with pytest.raises(ValueError):
759+
with pytest.raises(AssertionError):
760760
model(x, targets=targets)
761761

762762
# validate that no degenerate boxes are present
763763
boxes = torch.tensor([[1, 3, 1, 4], [2, 4, 3, 4]])
764764
targets = [{"boxes": boxes}]
765-
with pytest.raises(ValueError):
765+
with pytest.raises(AssertionError):
766766
model(x, targets=targets)
767767

768768

test/test_models_detection_anchor_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def test_incorrect_anchors(self):
1616
image1 = torch.randn(3, 800, 800)
1717
image_list = ImageList(image1, [(800, 800)])
1818
feature_maps = [torch.randn(1, 50)]
19-
pytest.raises(ValueError, anc, image_list, feature_maps)
19+
pytest.raises(AssertionError, anc, image_list, feature_maps)
2020

2121
def _init_test_anchor_generator(self):
2222
anchor_sizes = ((10,),)

test/test_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,13 +138,13 @@ def test_autocast(self, x_dtype, rois_dtype):
138138

139139
def _helper_boxes_shape(self, func):
140140
# test boxes as Tensor[N, 5]
141-
with pytest.raises(ValueError):
141+
with pytest.raises(AssertionError):
142142
a = torch.linspace(1, 8 * 8, 8 * 8).reshape(1, 1, 8, 8)
143143
boxes = torch.tensor([[0, 0, 3, 3]], dtype=a.dtype)
144144
func(a, boxes, output_size=(2, 2))
145145

146146
# test boxes as List[Tensor[N, 4]]
147-
with pytest.raises(ValueError):
147+
with pytest.raises(AssertionError):
148148
a = torch.linspace(1, 8 * 8, 8 * 8).reshape(1, 1, 8, 8)
149149
boxes = torch.tensor([[0, 0, 3]], dtype=a.dtype)
150150
ops.roi_pool(a, [boxes], output_size=(2, 2))

torchvision/models/detection/_utils.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,14 @@ def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
159159
return targets
160160

161161
def decode(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor:
162-
if not isinstance(boxes, (list, tuple)):
163-
raise TypeError(f"This function expects boxes of type list or tuple, instead got {type(boxes)}")
164-
if not isinstance(rel_codes, torch.Tensor):
165-
raise TypeError(f"This function expects rel_codes of type torch.Tensor, instead got {type(rel_codes)}")
162+
torch._assert(
163+
isinstance(boxes, (list, tuple)),
164+
"This function expects boxes of type list or tuple.",
165+
)
166+
torch._assert(
167+
isinstance(rel_codes, torch.Tensor),
168+
"This function expects rel_codes of type torch.Tensor.",
169+
)
166170
boxes_per_image = [b.size(0) for b in boxes]
167171
concat_boxes = torch.cat(boxes, dim=0)
168172
box_sum = 0
@@ -335,8 +339,7 @@ def __init__(self, high_threshold: float, low_threshold: float, allow_low_qualit
335339
"""
336340
self.BELOW_LOW_THRESHOLD = -1
337341
self.BETWEEN_THRESHOLDS = -2
338-
if low_threshold > high_threshold:
339-
raise ValueError("low_threshold should be <= high_threshold")
342+
torch._assert(low_threshold <= high_threshold, "low_threshold should be <= high_threshold")
340343
self.high_threshold = high_threshold
341344
self.low_threshold = low_threshold
342345
self.allow_low_quality_matches = allow_low_quality_matches
@@ -375,8 +378,9 @@ def __call__(self, match_quality_matrix: Tensor) -> Tensor:
375378

376379
if self.allow_low_quality_matches:
377380
if all_matches is None:
378-
raise ValueError("all_matches should not be None")
379-
self.set_low_quality_matches_(matches, all_matches, match_quality_matrix)
381+
torch._assert(False, "all_matches should not be None")
382+
else:
383+
self.set_low_quality_matches_(matches, all_matches, match_quality_matrix)
380384

381385
return matches
382386

torchvision/models/detection/anchor_utils.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -84,17 +84,14 @@ def num_anchors_per_location(self):
8484
def grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]]) -> List[Tensor]:
8585
anchors = []
8686
cell_anchors = self.cell_anchors
87-
88-
if cell_anchors is None:
89-
ValueError("cell_anchors should not be None")
90-
91-
if not (len(grid_sizes) == len(strides) == len(cell_anchors)):
92-
raise ValueError(
93-
"Anchors should be Tuple[Tuple[int]] because each feature "
94-
"map could potentially have different sizes and aspect ratios. "
95-
"There needs to be a match between the number of "
96-
"feature maps passed and the number of sizes / aspect ratios specified."
97-
)
87+
torch._assert(cell_anchors is not None, "cell_anchors should not be None")
88+
torch._assert(
89+
len(grid_sizes) == len(strides) == len(cell_anchors),
90+
"Anchors should be Tuple[Tuple[int]] because each feature "
91+
"map could potentially have different sizes and aspect ratios. "
92+
"There needs to be a match between the number of "
93+
"feature maps passed and the number of sizes / aspect ratios specified.",
94+
)
9895

9996
for size, stride, base_anchors in zip(grid_sizes, strides, cell_anchors):
10097
grid_height, grid_width = size

torchvision/models/detection/faster_rcnn.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Any, Optional, Union
22

3+
import torch
34
import torch.nn.functional as F
45
from torch import nn
56
from torchvision.ops import MultiScaleRoIAlign
@@ -313,10 +314,10 @@ def __init__(self, in_channels, num_classes):
313314

314315
def forward(self, x):
315316
if x.dim() == 4:
316-
if list(x.shape[2:]) != [1, 1]:
317-
raise ValueError(
318-
f"x has the wrong shape, expecting the last two dimensions to be [1,1] instead of {list(x.shape[2:])}"
319-
)
317+
torch._assert(
318+
list(x.shape[2:]) == [1, 1],
319+
f"x has the wrong shape, expecting the last two dimensions to be [1,1] instead of {list(x.shape[2:])}",
320+
)
320321
x = x.flatten(start_dim=1)
321322
scores = self.cls_score(x)
322323
bbox_deltas = self.bbox_pred(x)

torchvision/models/detection/fcos.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -565,23 +565,25 @@ def forward(
565565
like `scores`, `labels` and `mask` (for Mask R-CNN models).
566566
"""
567567
if self.training:
568+
568569
if targets is None:
569-
raise ValueError("In training mode, targets should be passed")
570-
for target in targets:
571-
boxes = target["boxes"]
572-
if isinstance(boxes, torch.Tensor):
573-
if len(boxes.shape) != 2 or boxes.shape[-1] != 4:
574-
raise ValueError(f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.")
575-
else:
576-
raise TypeError(f"Expected target boxes to be of type Tensor, got {type(boxes)}.")
570+
torch._assert(False, "targets should not be none when in training mode")
571+
else:
572+
for target in targets:
573+
boxes = target["boxes"]
574+
torch._assert(isinstance(boxes, torch.Tensor), "Expected target boxes to be of type Tensor.")
575+
torch._assert(
576+
len(boxes.shape) == 2 and boxes.shape[-1] == 4,
577+
f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.",
578+
)
577579

578580
original_image_sizes: List[Tuple[int, int]] = []
579581
for img in images:
580582
val = img.shape[-2:]
581-
if len(val) != 2:
582-
raise ValueError(
583-
f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}"
584-
)
583+
torch._assert(
584+
len(val) == 2,
585+
f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}",
586+
)
585587
original_image_sizes.append((val[0], val[1]))
586588

587589
# transform the input
@@ -596,9 +598,9 @@ def forward(
596598
# print the first degenerate box
597599
bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
598600
degen_bb: List[float] = boxes[bb_idx].tolist()
599-
raise ValueError(
600-
"All bounding boxes should have positive height and width."
601-
f" Found invalid box {degen_bb} for target at index {target_idx}."
601+
torch._assert(
602+
False,
603+
f"All bounding boxes should have positive height and width. Found invalid box {degen_bb} for target at index {target_idx}.",
602604
)
603605

604606
# get the features from the backbone
@@ -619,11 +621,11 @@ def forward(
619621
losses = {}
620622
detections: List[Dict[str, Tensor]] = []
621623
if self.training:
622-
# compute the losses
623624
if targets is None:
624-
raise ValueError("targets should not be none when in training mode")
625-
626-
losses = self.compute_loss(targets, head_outputs, anchors, num_anchors_per_level)
625+
torch._assert(False, "targets should not be none when in training mode")
626+
else:
627+
# compute the losses
628+
losses = self.compute_loss(targets, head_outputs, anchors, num_anchors_per_level)
627629
else:
628630
# split outputs per level
629631
split_head_outputs: Dict[str, List[Tensor]] = {}

torchvision/models/detection/generalized_rcnn.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -59,23 +59,25 @@ def forward(self, images, targets=None):
5959
"""
6060
if self.training:
6161
if targets is None:
62-
raise ValueError("In training mode, targets should be passed")
63-
64-
for target in targets:
65-
boxes = target["boxes"]
66-
if isinstance(boxes, torch.Tensor):
67-
if len(boxes.shape) != 2 or boxes.shape[-1] != 4:
68-
raise ValueError(f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.")
69-
else:
70-
raise TypeError(f"Expected target boxes to be of type Tensor, got {type(boxes)}.")
62+
torch._assert(False, "targets should not be none when in training mode")
63+
else:
64+
for target in targets:
65+
boxes = target["boxes"]
66+
if isinstance(boxes, torch.Tensor):
67+
torch._assert(
68+
len(boxes.shape) == 2 and boxes.shape[-1] == 4,
69+
f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.",
70+
)
71+
else:
72+
torch._assert(False, f"Expected target boxes to be of type Tensor, got {type(boxes)}.")
7173

7274
original_image_sizes: List[Tuple[int, int]] = []
7375
for img in images:
7476
val = img.shape[-2:]
75-
if len(val) != 2:
76-
raise ValueError(
77-
f"Expecting the last two dimensions of the input tensor to be H and W, instead got {img.shape[-2:]}"
78-
)
77+
torch._assert(
78+
len(val) == 2,
79+
f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}",
80+
)
7981
original_image_sizes.append((val[0], val[1]))
8082

8183
images, targets = self.transform(images, targets)
@@ -90,9 +92,10 @@ def forward(self, images, targets=None):
9092
# print the first degenerate box
9193
bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
9294
degen_bb: List[float] = boxes[bb_idx].tolist()
93-
raise ValueError(
95+
torch._assert(
96+
False,
9497
"All bounding boxes should have positive height and width."
95-
f" Found invalid box {degen_bb} for target at index {target_idx}."
98+
f" Found invalid box {degen_bb} for target at index {target_idx}.",
9699
)
97100

98101
features = self.backbone(images.tensors)

torchvision/models/detection/retinanet.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -494,28 +494,26 @@ def forward(self, images, targets=None):
494494
like `scores`, `labels` and `mask` (for Mask R-CNN models).
495495
496496
"""
497-
if self.training and targets is None:
498-
raise ValueError("In training mode, targets should be passed")
499-
500497
if self.training:
501498
if targets is None:
502-
raise ValueError("In training mode, targets should be passed")
503-
for target in targets:
504-
boxes = target["boxes"]
505-
if isinstance(boxes, torch.Tensor):
506-
if len(boxes.shape) != 2 or boxes.shape[-1] != 4:
507-
raise ValueError(f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.")
508-
else:
509-
raise TypeError(f"Expected target boxes to be of type Tensor, got {type(boxes)}.")
499+
torch._assert(False, "targets should not be none when in training mode")
500+
else:
501+
for target in targets:
502+
boxes = target["boxes"]
503+
torch._assert(isinstance(boxes, torch.Tensor), "Expected target boxes to be of type Tensor.")
504+
torch._assert(
505+
len(boxes.shape) == 2 and boxes.shape[-1] == 4,
506+
"Expected target boxes to be a tensor of shape [N, 4].",
507+
)
510508

511509
# get the original image sizes
512510
original_image_sizes: List[Tuple[int, int]] = []
513511
for img in images:
514512
val = img.shape[-2:]
515-
if len(val) != 2:
516-
raise ValueError(
517-
f"Expecting the two last elements of the input tensors to be H and W instead got {img.shape[-2:]}"
518-
)
513+
torch._assert(
514+
len(val) == 2,
515+
f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}",
516+
)
519517
original_image_sizes.append((val[0], val[1]))
520518

521519
# transform the input
@@ -531,9 +529,10 @@ def forward(self, images, targets=None):
531529
# print the first degenerate box
532530
bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
533531
degen_bb: List[float] = boxes[bb_idx].tolist()
534-
raise ValueError(
532+
torch._assert(
533+
False,
535534
"All bounding boxes should have positive height and width."
536-
f" Found invalid box {degen_bb} for target at index {target_idx}."
535+
f" Found invalid box {degen_bb} for target at index {target_idx}.",
537536
)
538537

539538
# get the features from the backbone
@@ -554,9 +553,10 @@ def forward(self, images, targets=None):
554553
detections: List[Dict[str, Tensor]] = []
555554
if self.training:
556555
if targets is None:
557-
raise ValueError("In training mode, targets should be passed")
558-
# compute the losses
559-
losses = self.compute_loss(targets, head_outputs, anchors)
556+
torch._assert(False, "targets should not be none when in training mode")
557+
else:
558+
# compute the losses
559+
losses = self.compute_loss(targets, head_outputs, anchors)
560560
else:
561561
# recover level sizes
562562
num_anchors_per_level = [x.size(2) * x.size(3) for x in features]

0 commit comments

Comments
 (0)