Skip to content

Commit d425f00

Browse files
authored
Start doc revamp for detection models (#5876)
* Start doc revamp for detection models * Minor cleanup * Use list of tuples for metrics
1 parent cc53cd0 commit d425f00

File tree

4 files changed

+76
-19
lines changed

4 files changed

+76
-19
lines changed

docs/source/conf.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import os
2424
import textwrap
25+
from copy import copy
2526
from pathlib import Path
2627

2728
import pytorch_sphinx_theme
@@ -330,7 +331,7 @@ def inject_weight_metadata(app, what, name, obj, options, lines):
330331
# the `meta` dict contains another embedded `metrics` dict. To
331332
# simplify the table generation below, we create the
332333
# `meta_with_metrics` dict, where the metrics dict has been "flattened"
333-
meta = field.meta
334+
meta = copy(field.meta)
334335
metrics = meta.pop("metrics", {})
335336
meta_with_metrics = dict(meta, **metrics)
336337

@@ -346,17 +347,18 @@ def inject_weight_metadata(app, what, name, obj, options, lines):
346347
lines.append("")
347348

348349

349-
def generate_classification_table():
350-
351-
weight_enums = [getattr(M, name) for name in dir(M) if name.endswith("_Weights")]
350+
def generate_weights_table(module, table_name, metrics):
351+
weight_enums = [getattr(module, name) for name in dir(module) if name.endswith("_Weights")]
352352
weights = [w for weight_enum in weight_enums for w in weight_enum]
353353

354-
column_names = ("**Weight**", "**Acc@1**", "**Acc@5**", "**Params**", "**Recipe**")
354+
metrics_keys, metrics_names = zip(*metrics)
355+
column_names = ["Weight"] + list(metrics_names) + ["Params", "Recipe"]
356+
column_names = [f"**{name}**" for name in column_names] # Add bold
357+
355358
content = [
356359
(
357360
f":class:`{w} <{type(w).__name__}>`",
358-
w.meta["metrics"]["acc@1"],
359-
w.meta["metrics"]["acc@5"],
361+
*(w.meta["metrics"][metric] for metric in metrics_keys),
360362
f"{w.meta['num_params']/1e6:.1f}M",
361363
f"`link <{w.meta['recipe']}>`__",
362364
)
@@ -366,13 +368,14 @@ def generate_classification_table():
366368

367369
generated_dir = Path("generated")
368370
generated_dir.mkdir(exist_ok=True)
369-
with open(generated_dir / "classification_table.rst", "w+") as table_file:
371+
with open(generated_dir / f"{table_name}_table.rst", "w+") as table_file:
370372
table_file.write(".. table::\n")
371-
table_file.write(" :widths: 100 10 10 20 10\n\n")
373+
table_file.write(f" :widths: 100 {'20 ' * len(metrics_names)} 20 10\n\n")
372374
table_file.write(f"{textwrap.indent(table, ' ' * 4)}\n\n")
373375

374376

375-
generate_classification_table()
377+
generate_weights_table(module=M, table_name="classification", metrics=[("acc@1", "Acc@1"), ("acc@5", "Acc@5")])
378+
generate_weights_table(module=M.detection, table_name="detection", metrics=[("box_map", "Box MAP")])
376379

377380

378381
def setup(app):

docs/source/models/retinanet.rst

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
RetinaNet
2+
=========
3+
4+
.. currentmodule:: torchvision.models.detection
5+
6+
The RetinaNet model is based on the `Focal Loss for Dense Object Detection
7+
<https://arxiv.org/abs/1708.02002>`__ paper.
8+
9+
Model builders
10+
--------------
11+
12+
The following model builders can be used to instantiate a RetinaNet model, with or
13+
without pre-trained weights. All the model buidlers internally rely on the
14+
``torchvision.models.detection.retinanet.RetinaNet`` base class. Please refer to the `source code
15+
<https://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py>`_ for
16+
more details about this class.
17+
18+
.. autosummary::
19+
:toctree: generated/
20+
:template: function.rst
21+
22+
retinanet_resnet50_fpn
23+
retinanet_resnet50_fpn_v2

docs/source/models_new.rst

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,4 +59,19 @@ Accuracies are reported on ImageNet
5959
Object Detection, Instance Segmentation and Person Keypoint Detection
6060
=====================================================================
6161

62-
TODO: Something similar to classification models: list of models + table of weights
62+
.. currentmodule:: torchvision.models.detection
63+
64+
The following detection models are available, with or without pre-trained
65+
weights:
66+
67+
.. toctree::
68+
:maxdepth: 1
69+
70+
models/retinanet
71+
72+
Table of all available detection weights
73+
----------------------------------------
74+
75+
Box MAPs are reported on COCO
76+
77+
.. include:: generated/detection_table.rst

torchvision/models/detection/retinanet.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -727,7 +727,7 @@ def retinanet_resnet50_fpn(
727727
"""
728728
Constructs a RetinaNet model with a ResNet-50-FPN backbone.
729729
730-
Reference: `"Focal Loss for Dense Object Detection" <https://arxiv.org/abs/1708.02002>`_.
730+
Reference: `Focal Loss for Dense Object Detection <https://arxiv.org/abs/1708.02002>`_.
731731
732732
The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
733733
image, and should be in ``0-1`` range. Different images can have different sizes.
@@ -763,13 +763,21 @@ def retinanet_resnet50_fpn(
763763
>>> predictions = model(x)
764764
765765
Args:
766-
weights (RetinaNet_ResNet50_FPN_Weights, optional): The pretrained weights for the model
767-
progress (bool): If True, displays a progress bar of the download to stderr
766+
weights (:class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights`, optional): The
767+
pretrained weights to use. See
768+
:class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights`
769+
below for more details, and possible values. By default, no
770+
pre-trained weights are used.
771+
progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
768772
num_classes (int, optional): number of output classes of the model (including the background)
769-
weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone
773+
weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for
774+
the backbone.
770775
trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
771776
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
772777
passed (the default) this value is set to 3.
778+
779+
.. autoclass:: torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights
780+
:members:
773781
"""
774782
weights = RetinaNet_ResNet50_FPN_Weights.verify(weights)
775783
weights_backbone = ResNet50_Weights.verify(weights_backbone)
@@ -811,19 +819,27 @@ def retinanet_resnet50_fpn_v2(
811819
"""
812820
Constructs an improved RetinaNet model with a ResNet-50-FPN backbone.
813821
814-
Reference: `"Bridging the Gap Between Anchor-based and Anchor-free Detection via Adaptive Training Sample Selection"
822+
Reference: `Bridging the Gap Between Anchor-based and Anchor-free Detection via Adaptive Training Sample Selection
815823
<https://arxiv.org/abs/1912.02424>`_.
816824
817825
:func:`~torchvision.models.detection.retinanet_resnet50_fpn` for more details.
818826
819827
Args:
820-
weights (RetinaNet_ResNet50_FPN_V2_Weights, optional): The pretrained weights for the model
821-
progress (bool): If True, displays a progress bar of the download to stderr
828+
weights (:class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights`, optional): The
829+
pretrained weights to use. See
830+
:class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights`
831+
below for more details, and possible values. By default, no
832+
pre-trained weights are used.
833+
progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
822834
num_classes (int, optional): number of output classes of the model (including the background)
823-
weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone
835+
weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for
836+
the backbone.
824837
trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
825838
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
826839
passed (the default) this value is set to 3.
840+
841+
.. autoclass:: torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights
842+
:members:
827843
"""
828844
weights = RetinaNet_ResNet50_FPN_V2_Weights.verify(weights)
829845
weights_backbone = ResNet50_Weights.verify(weights_backbone)

0 commit comments

Comments
 (0)