Skip to content

Commit f593bfd

Browse files
patricklabatutfacebook-github-bot
authored andcommitted
More type annotations
Summary: More type annotations: device, shaders, pluggable I/O, stats in NeRF project, cameras, textures, etc... Reviewed By: nikhilaravi Differential Revision: D29327396 fbshipit-source-id: cdf0ceaaa010e22423088752688c8dd81f1acc3c
1 parent 542e2e7 commit f593bfd

File tree

15 files changed

+196
-148
lines changed

15 files changed

+196
-148
lines changed

projects/nerf/nerf/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def download_data(
136136
dataset_names: Optional[List[str]] = None,
137137
data_root: str = DEFAULT_DATA_ROOT,
138138
url_root: str = DEFAULT_URL_ROOT,
139-
):
139+
) -> None:
140140
"""
141141
Downloads the relevant dataset files.
142142

projects/nerf/nerf/stats.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def __init__(self) -> None:
2929
self.history = []
3030
self.reset()
3131

32-
def reset(self):
32+
def reset(self) -> None:
3333
"""
3434
Reset the running average meter.
3535
"""
@@ -38,7 +38,7 @@ def reset(self):
3838
self.sum = 0
3939
self.count = 0
4040

41-
def update(self, val: float, n: int = 1, epoch: int = 0):
41+
def update(self, val: float, n: int = 1, epoch: int = 0) -> None:
4242
"""
4343
Updates the average meter with a value `val`.
4444
@@ -123,7 +123,7 @@ def __init__(
123123
self.plot_file = plot_file
124124
self.hard_reset(epoch=epoch)
125125

126-
def reset(self):
126+
def reset(self) -> None:
127127
"""
128128
Called before an epoch to clear current epoch buffers.
129129
"""
@@ -138,7 +138,7 @@ def reset(self):
138138
# Set a new timestamp.
139139
self._epoch_start = time.time()
140140

141-
def hard_reset(self, epoch: int = -1):
141+
def hard_reset(self, epoch: int = -1) -> None:
142142
"""
143143
Erases all logged data.
144144
"""
@@ -149,7 +149,7 @@ def hard_reset(self, epoch: int = -1):
149149
self.stats = {}
150150
self.reset()
151151

152-
def new_epoch(self):
152+
def new_epoch(self) -> None:
153153
"""
154154
Initializes a new epoch.
155155
"""
@@ -166,7 +166,7 @@ def _gather_value(self, val):
166166
val = float(val.sum())
167167
return val
168168

169-
def update(self, preds: dict, stat_set: str = "train"):
169+
def update(self, preds: dict, stat_set: str = "train") -> None:
170170
"""
171171
Update the internal logs with metrics of a training step.
172172
@@ -211,7 +211,7 @@ def update(self, preds: dict, stat_set: str = "train"):
211211
if val is not None:
212212
self.stats[stat_set][stat].update(val, epoch=epoch, n=1)
213213

214-
def print(self, max_it: Optional[int] = None, stat_set: str = "train"):
214+
def print(self, max_it: Optional[int] = None, stat_set: str = "train") -> None:
215215
"""
216216
Print the current values of all stored stats.
217217
@@ -247,7 +247,7 @@ def plot_stats(
247247
viz: Visdom = None,
248248
visdom_env: Optional[str] = None,
249249
plot_file: Optional[str] = None,
250-
):
250+
) -> None:
251251
"""
252252
Plot the line charts of the history of the stats.
253253

pytorch3d/io/experimental_gltf_io.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,13 @@
4242
from collections import deque
4343
from enum import IntEnum
4444
from io import BytesIO
45-
from pathlib import Path
4645
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union, cast
4746

4847
import numpy as np
4948
import torch
5049
from iopath.common.file_io import PathManager
5150
from PIL import Image
52-
from pytorch3d.io.utils import _open_file
51+
from pytorch3d.io.utils import PathOrStr, _open_file
5352
from pytorch3d.renderer.mesh import TexturesBase, TexturesUV, TexturesVertex
5453
from pytorch3d.structures import Meshes, join_meshes_as_scene
5554
from pytorch3d.transforms import Transform3d, quaternion_to_matrix
@@ -498,7 +497,7 @@ def load(self, include_textures: bool) -> List[Tuple[Optional[str], Meshes]]:
498497

499498

500499
def load_meshes(
501-
path: Union[str, Path],
500+
path: PathOrStr,
502501
path_manager: PathManager,
503502
include_textures: bool = True,
504503
) -> List[Tuple[Optional[str], Meshes]]:
@@ -544,7 +543,7 @@ def __init__(self) -> None:
544543

545544
def read(
546545
self,
547-
path: Union[str, Path],
546+
path: PathOrStr,
548547
include_textures: bool,
549548
device,
550549
path_manager: PathManager,
@@ -566,7 +565,7 @@ def read(
566565
def save(
567566
self,
568567
data: Meshes,
569-
path: Union[str, Path],
568+
path: PathOrStr,
570569
path_manager: PathManager,
571570
binary: Optional[bool],
572571
**kwargs,

pytorch3d/io/obj_io.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,12 @@
1818
from PIL import Image
1919
from pytorch3d.common.types import Device
2020
from pytorch3d.io.mtl_io import load_mtl, make_mesh_texture_atlas
21-
from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file
21+
from pytorch3d.io.utils import (
22+
PathOrStr,
23+
_check_faces_indices,
24+
_make_tensor,
25+
_open_file,
26+
)
2227
from pytorch3d.renderer import TexturesAtlas, TexturesUV
2328
from pytorch3d.structures import Meshes, join_meshes_as_batch
2429

@@ -213,7 +218,7 @@ def load_obj(
213218
None.
214219
"""
215220
data_dir = "./"
216-
if isinstance(f, (str, bytes, os.PathLike)):
221+
if isinstance(f, (str, bytes, Path)):
217222
data_dir = os.path.dirname(f)
218223
if path_manager is None:
219224
path_manager = PathManager()
@@ -297,7 +302,7 @@ def __init__(self) -> None:
297302

298303
def read(
299304
self,
300-
path: Union[str, Path],
305+
path: PathOrStr,
301306
include_textures: bool,
302307
device: Device,
303308
path_manager: PathManager,
@@ -322,7 +327,7 @@ def read(
322327
def save(
323328
self,
324329
data: Meshes,
325-
path: Union[str, Path],
330+
path: PathOrStr,
326331
path_manager: PathManager,
327332
binary: Optional[bool],
328333
decimal_places: Optional[int] = None,
@@ -650,7 +655,7 @@ def _load_obj(
650655

651656

652657
def save_obj(
653-
f: Union[str, os.PathLike],
658+
f: PathOrStr,
654659
verts,
655660
faces,
656661
decimal_places: Optional[int] = None,

pytorch3d/io/off_io.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,12 @@
1313
http://www.geomview.org/docs/html/OFF.html .
1414
"""
1515
import warnings
16-
from pathlib import Path
1716
from typing import Optional, Tuple, Union, cast
1817

1918
import numpy as np
2019
import torch
2120
from iopath.common.file_io import PathManager
22-
from pytorch3d.io.utils import _check_faces_indices, _open_file
21+
from pytorch3d.io.utils import PathOrStr, _check_faces_indices, _open_file
2322
from pytorch3d.renderer import TexturesAtlas, TexturesVertex
2423
from pytorch3d.structures import Meshes
2524

@@ -424,7 +423,7 @@ def __init__(self) -> None:
424423

425424
def read(
426425
self,
427-
path: Union[str, Path],
426+
path: PathOrStr,
428427
include_textures: bool,
429428
device,
430429
path_manager: PathManager,
@@ -460,7 +459,7 @@ def read(
460459
def save(
461460
self,
462461
data: Meshes,
463-
path: Union[str, Path],
462+
path: PathOrStr,
464463
path_manager: PathManager,
465464
binary: Optional[bool],
466465
decimal_places: Optional[int] = None,

pytorch3d/io/pluggable_formats.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
# LICENSE file in the root directory of this source tree.
66

77

8-
from pathlib import Path
8+
import pathlib
99
from typing import Optional, Tuple, Union
1010

1111
from iopath.common.file_io import PathManager
12+
from pytorch3d.common.types import Device
13+
from pytorch3d.io.utils import PathOrStr
1214
from pytorch3d.structures import Meshes, Pointclouds
1315

1416

@@ -20,14 +22,14 @@
2022
"""
2123

2224

23-
def endswith(path, suffixes: Tuple[str, ...]) -> bool:
25+
def endswith(path: PathOrStr, suffixes: Tuple[str, ...]) -> bool:
2426
"""
2527
Returns whether the path ends with one of the given suffixes.
2628
If `path` is not actually a path, returns True. This is useful
2729
for allowing interpreters to bypass inappropriate paths, but
2830
always accepting streams.
2931
"""
30-
if isinstance(path, Path):
32+
if isinstance(path, pathlib.Path):
3133
return path.suffix.lower() in suffixes
3234
if isinstance(path, str):
3335
return path.lower().endswith(suffixes)
@@ -42,9 +44,9 @@ class MeshFormatInterpreter:
4244

4345
def read(
4446
self,
45-
path: Union[str, Path],
47+
path: PathOrStr,
4648
include_textures: bool,
47-
device,
49+
device: Device,
4850
path_manager: PathManager,
4951
**kwargs,
5052
) -> Optional[Meshes]:
@@ -68,7 +70,7 @@ def read(
6870
def save(
6971
self,
7072
data: Meshes,
71-
path: Union[str, Path],
73+
path: PathOrStr,
7274
path_manager: PathManager,
7375
binary: Optional[bool],
7476
**kwargs,
@@ -96,7 +98,7 @@ class PointcloudFormatInterpreter:
9698
"""
9799

98100
def read(
99-
self, path: Union[str, Path], device, path_manager: PathManager, **kwargs
101+
self, path: PathOrStr, device: Device, path_manager: PathManager, **kwargs
100102
) -> Optional[Pointclouds]:
101103
"""
102104
Read the data from the specified file and return it as
@@ -117,7 +119,7 @@ def read(
117119
def save(
118120
self,
119121
data: Pointclouds,
120-
path: Union[str, Path],
122+
path: PathOrStr,
121123
path_manager: PathManager,
122124
binary: Optional[bool],
123125
**kwargs,

pytorch3d/io/ply_io.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,17 @@
1515
import warnings
1616
from collections import namedtuple
1717
from io import BytesIO, TextIOBase
18-
from pathlib import Path
1918
from typing import List, Optional, Tuple, Union, cast
2019

2120
import numpy as np
2221
import torch
2322
from iopath.common.file_io import PathManager
24-
from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file
23+
from pytorch3d.io.utils import (
24+
PathOrStr,
25+
_check_faces_indices,
26+
_make_tensor,
27+
_open_file,
28+
)
2529
from pytorch3d.renderer import TexturesVertex
2630
from pytorch3d.structures import Meshes, Pointclouds
2731

@@ -1237,7 +1241,7 @@ def __init__(self) -> None:
12371241

12381242
def read(
12391243
self,
1240-
path: Union[str, Path],
1244+
path: PathOrStr,
12411245
include_textures: bool,
12421246
device,
12431247
path_manager: PathManager,
@@ -1269,7 +1273,7 @@ def read(
12691273
def save(
12701274
self,
12711275
data: Meshes,
1272-
path: Union[str, Path],
1276+
path: PathOrStr,
12731277
path_manager: PathManager,
12741278
binary: Optional[bool],
12751279
decimal_places: Optional[int] = None,
@@ -1318,7 +1322,7 @@ def __init__(self) -> None:
13181322

13191323
def read(
13201324
self,
1321-
path: Union[str, Path],
1325+
path: PathOrStr,
13221326
device,
13231327
path_manager: PathManager,
13241328
**kwargs,
@@ -1339,7 +1343,7 @@ def read(
13391343
def save(
13401344
self,
13411345
data: Pointclouds,
1342-
path: Union[str, Path],
1346+
path: PathOrStr,
13431347
path_manager: PathManager,
13441348
binary: Optional[bool],
13451349
decimal_places: Optional[int] = None,

pytorch3d/io/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import contextlib
88
import pathlib
99
import warnings
10-
from typing import IO, ContextManager, Optional
10+
from typing import IO, ContextManager, Optional, Union
1111

1212
import numpy as np
1313
import torch
@@ -25,6 +25,9 @@ def nullcontext(x):
2525
yield x
2626

2727

28+
PathOrStr = Union[pathlib.Path, str]
29+
30+
2831
def _open_file(f, path_manager: PathManager, mode="r") -> ContextManager[IO]:
2932
if isinstance(f, str):
3033
f = path_manager.open(f, mode)

0 commit comments

Comments
 (0)