Skip to content

Commit 967a099

Browse files
bottlerfacebook-github-bot
authored andcommitted
Use dataclasses inside ply_io.
Summary: Refactor ply_io to make it easier to add new features. Mostly taken from the starting code I attached to #904. Reviewed By: patricklabatut Differential Revision: D34375978 fbshipit-source-id: ec017d31f07c6f71ba6d97a0623bb10be1e81212
1 parent feb5d36 commit 967a099

File tree

1 file changed

+121
-71
lines changed

1 file changed

+121
-71
lines changed

pytorch3d/io/ply_io.py

+121-71
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414
import sys
1515
import warnings
1616
from collections import namedtuple
17+
from dataclasses import asdict, dataclass
1718
from io import BytesIO, TextIOBase
18-
from typing import List, Optional, Tuple, cast
19+
from typing import List, Optional, Tuple
1920

2021
import numpy as np
2122
import torch
@@ -137,6 +138,7 @@ def __init__(self, f) -> None:
137138
self.ascii: (bool) Whether in ascii format
138139
self.big_endian: (bool) (if not ascii) whether big endian
139140
self.obj_info: (List[str]) arbitrary extra data
141+
self.comments: (List[str]) comments
140142
141143
Args:
142144
f: file-like object.
@@ -145,7 +147,8 @@ def __init__(self, f) -> None:
145147
raise ValueError("Invalid file header.")
146148
seen_format = False
147149
self.elements: List[_PlyElementType] = []
148-
self.obj_info = []
150+
self.comments: List[str] = []
151+
self.obj_info: List[str] = []
149152
while True:
150153
line = f.readline()
151154
if isinstance(line, bytes):
@@ -176,6 +179,9 @@ def __init__(self, f) -> None:
176179
continue
177180
if line.startswith("format"):
178181
raise ValueError("Invalid format line.")
182+
if line.startswith("comment "):
183+
self.comments.append(line[8:])
184+
continue
179185
if line.startswith("comment") or len(line) == 0:
180186
continue
181187
if line.startswith("element"):
@@ -781,9 +787,28 @@ def _load_ply_raw(f, path_manager: PathManager) -> Tuple[_PlyHeader, dict]:
781787
return header, elements
782788

783789

790+
@dataclass(frozen=True)
791+
class _VertsColumnIndices:
792+
"""
793+
Contains the relevant layout of the verts section of file being read.
794+
Members
795+
point_idxs: List[int] of 3 point columns.
796+
color_idxs: List[int] of 3 color columns if they are present,
797+
otherwise None.
798+
color_scale: value to scale colors by.
799+
normal_idxs: List[int] of 3 normals columns if they are present,
800+
otherwise None.
801+
"""
802+
803+
point_idxs: List[int]
804+
color_idxs: Optional[List[int]]
805+
color_scale: float
806+
normal_idxs: Optional[List[int]]
807+
808+
784809
def _get_verts_column_indices(
785810
vertex_head: _PlyElementType,
786-
) -> Tuple[List[int], Optional[List[int]], float, Optional[List[int]]]:
811+
) -> _VertsColumnIndices:
787812
"""
788813
Get the columns of verts, verts_colors, and verts_normals in the vertex
789814
element of a parsed ply file, together with a color scale factor.
@@ -809,12 +834,7 @@ def _get_verts_column_indices(
809834
vertex_head: as returned from load_ply_raw.
810835
811836
Returns:
812-
point_idxs: List[int] of 3 point columns.
813-
color_idxs: List[int] of 3 color columns if they are present,
814-
otherwise None.
815-
color_scale: value to scale colors by.
816-
normal_idxs: List[int] of 3 normals columns if they are present,
817-
otherwise None.
837+
_VertsColumnIndices object
818838
"""
819839
point_idxs: List[Optional[int]] = [None, None, None]
820840
color_idxs: List[Optional[int]] = [None, None, None]
@@ -839,29 +859,38 @@ def _get_verts_column_indices(
839859
for idx in color_idxs
840860
):
841861
color_scale = 1.0 / 255
842-
return (
843-
point_idxs,
844-
# pyre-fixme[22]: The cast is redundant.
845-
None if None in color_idxs else cast(List[int], color_idxs),
846-
color_scale,
847-
# pyre-fixme[22]: The cast is redundant.
848-
None if None in normal_idxs else cast(List[int], normal_idxs),
862+
return _VertsColumnIndices(
863+
point_idxs=point_idxs,
864+
color_idxs=None if None in color_idxs else color_idxs,
865+
color_scale=color_scale,
866+
normal_idxs=None if None in normal_idxs else normal_idxs,
849867
)
850868

851869

852-
def _get_verts(
853-
header: _PlyHeader, elements: dict
854-
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
870+
@dataclass(frozen=True)
871+
class _VertsData:
872+
"""
873+
Contains the data of the verts section of file being read.
874+
Members:
875+
verts: FloatTensor of shape (V, 3).
876+
verts_colors: None or FloatTensor of shape (V, 3).
877+
verts_normals: None or FloatTensor of shape (V, 3).
878+
"""
879+
880+
verts: torch.Tensor
881+
verts_colors: Optional[torch.Tensor] = None
882+
verts_normals: Optional[torch.Tensor] = None
883+
884+
885+
def _get_verts(header: _PlyHeader, elements: dict) -> _VertsData:
855886
"""
856887
Get the vertex locations, colors and normals from a parsed ply file.
857888
858889
Args:
859890
header, elements: as returned from load_ply_raw.
860891
861892
Returns:
862-
verts: FloatTensor of shape (V, 3).
863-
vertex_colors: None or FloatTensor of shape (V, 3).
864-
vertex_normals: None or FloatTensor of shape (V, 3).
893+
_VertsData object
865894
"""
866895

867896
vertex = elements.get("vertex", None)
@@ -870,16 +899,17 @@ def _get_verts(
870899
if not isinstance(vertex, list):
871900
raise ValueError("Invalid vertices in file.")
872901
vertex_head = next(head for head in header.elements if head.name == "vertex")
873-
point_idxs, color_idxs, color_scale, normal_idxs = _get_verts_column_indices(
874-
vertex_head
875-
)
902+
903+
column_idxs = _get_verts_column_indices(vertex_head)
876904

877905
# Case of no vertices
878906
if vertex_head.count == 0:
879907
verts = torch.zeros((0, 3), dtype=torch.float32)
880-
if color_idxs is None:
881-
return verts, None, None
882-
return verts, torch.zeros((0, 3), dtype=torch.float32), None
908+
if column_idxs.color_idxs is None:
909+
return _VertsData(verts=verts)
910+
return _VertsData(
911+
verts=verts, verts_colors=torch.zeros((0, 3), dtype=torch.float32)
912+
)
883913

884914
# Simple case where the only data is the vertices themselves
885915
if (
@@ -888,22 +918,22 @@ def _get_verts(
888918
and vertex[0].ndim == 2
889919
and vertex[0].shape[1] == 3
890920
):
891-
return _make_tensor(vertex[0], cols=3, dtype=torch.float32), None, None
921+
return _VertsData(verts=_make_tensor(vertex[0], cols=3, dtype=torch.float32))
892922

893923
vertex_colors = None
894924
vertex_normals = None
895925

896926
if len(vertex) == 1:
897927
# This is the case where the whole vertex element has one type,
898928
# so it was read as a single array and we can index straight into it.
899-
verts = torch.tensor(vertex[0][:, point_idxs], dtype=torch.float32)
900-
if color_idxs is not None:
901-
vertex_colors = color_scale * torch.tensor(
902-
vertex[0][:, color_idxs], dtype=torch.float32
929+
verts = torch.tensor(vertex[0][:, column_idxs.point_idxs], dtype=torch.float32)
930+
if column_idxs.color_idxs is not None:
931+
vertex_colors = column_idxs.color_scale * torch.tensor(
932+
vertex[0][:, column_idxs.color_idxs], dtype=torch.float32
903933
)
904-
if normal_idxs is not None:
934+
if column_idxs.normal_idxs is not None:
905935
vertex_normals = torch.tensor(
906-
vertex[0][:, normal_idxs], dtype=torch.float32
936+
vertex[0][:, column_idxs.normal_idxs], dtype=torch.float32
907937
)
908938
else:
909939
# The vertex element is heterogeneous. It was read as several arrays,
@@ -918,7 +948,7 @@ def _get_verts(
918948
]
919949
verts = torch.empty(size=(vertex_head.count, 3), dtype=torch.float32)
920950
for axis in range(3):
921-
partnum, col = prop_to_partnum_col[point_idxs[axis]]
951+
partnum, col = prop_to_partnum_col[column_idxs.point_idxs[axis]]
922952
verts.numpy()[:, axis] = vertex[partnum][:, col]
923953
# Note that in the previous line, we made the assignment
924954
# as numpy arrays by casting verts. If we took the (more
@@ -928,30 +958,49 @@ def _get_verts(
928958
# if not vertex[partnum].flags["C_CONTIGUOUS"]:
929959
# vertex[partnum] = np.ascontiguousarray(vertex[partnum])
930960
# verts[:, axis] = torch.tensor((vertex[partnum][:, col]))
931-
if color_idxs is not None:
961+
if column_idxs.color_idxs is not None:
932962
vertex_colors = torch.empty(
933963
size=(vertex_head.count, 3), dtype=torch.float32
934964
)
935965
for color in range(3):
936-
partnum, col = prop_to_partnum_col[color_idxs[color]]
966+
partnum, col = prop_to_partnum_col[column_idxs.color_idxs[color]]
937967
vertex_colors.numpy()[:, color] = vertex[partnum][:, col]
938-
vertex_colors *= color_scale
939-
if normal_idxs is not None:
968+
vertex_colors *= column_idxs.color_scale
969+
if column_idxs.normal_idxs is not None:
940970
vertex_normals = torch.empty(
941971
size=(vertex_head.count, 3), dtype=torch.float32
942972
)
943973
for axis in range(3):
944-
partnum, col = prop_to_partnum_col[normal_idxs[axis]]
974+
partnum, col = prop_to_partnum_col[column_idxs.normal_idxs[axis]]
945975
vertex_normals.numpy()[:, axis] = vertex[partnum][:, col]
946976

947-
return verts, vertex_colors, vertex_normals
977+
return _VertsData(
978+
verts=verts,
979+
verts_colors=vertex_colors,
980+
verts_normals=vertex_normals,
981+
)
982+
983+
984+
@dataclass(frozen=True)
985+
class _PlyData:
986+
"""
987+
Contains the data from a PLY file which has been read.
988+
Members:
989+
header: _PlyHeader of file metadata from the header
990+
verts: FloatTensor of shape (V, 3).
991+
faces: None or LongTensor of vertex indices, shape (F, 3).
992+
verts_colors: None or FloatTensor of shape (V, 3).
993+
verts_normals: None or FloatTensor of shape (V, 3).
994+
"""
995+
996+
header: _PlyHeader
997+
verts: torch.Tensor
998+
faces: Optional[torch.Tensor]
999+
verts_colors: Optional[torch.Tensor]
1000+
verts_normals: Optional[torch.Tensor]
9481001

9491002

950-
def _load_ply(
951-
f, *, path_manager: PathManager
952-
) -> Tuple[
953-
torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]
954-
]:
1003+
def _load_ply(f, *, path_manager: PathManager) -> _PlyData:
9551004
"""
9561005
Load the data from a .ply file.
9571006
@@ -964,14 +1013,11 @@ def _load_ply(
9641013
path_manager: PathManager for loading if f is a str.
9651014
9661015
Returns:
967-
verts: FloatTensor of shape (V, 3).
968-
faces: None or LongTensor of vertex indices, shape (F, 3).
969-
vertex_colors: None or FloatTensor of shape (V, 3).
970-
vertex_normals: None or FloatTensor of shape (V, 3).
1016+
_PlyData object
9711017
"""
9721018
header, elements = _load_ply_raw(f, path_manager=path_manager)
9731019

974-
verts, vertex_colors, vertex_normals = _get_verts(header, elements)
1020+
verts_data = _get_verts(header, elements)
9751021

9761022
face = elements.get("face", None)
9771023
if face is not None:
@@ -1007,9 +1053,9 @@ def _load_ply(
10071053
faces = torch.tensor(face_list, dtype=torch.int64)
10081054

10091055
if faces is not None:
1010-
_check_faces_indices(faces, max_index=verts.shape[0])
1056+
_check_faces_indices(faces, max_index=verts_data.verts.shape[0])
10111057

1012-
return verts, faces, vertex_colors, vertex_normals
1058+
return _PlyData(**asdict(verts_data), faces=faces, header=header)
10131059

10141060

10151061
def load_ply(
@@ -1064,11 +1110,12 @@ def load_ply(
10641110

10651111
if path_manager is None:
10661112
path_manager = PathManager()
1067-
verts, faces, _, _ = _load_ply(f, path_manager=path_manager)
1113+
data = _load_ply(f, path_manager=path_manager)
1114+
faces = data.faces
10681115
if faces is None:
10691116
faces = torch.zeros(0, 3, dtype=torch.int64)
10701117

1071-
return verts, faces
1118+
return data.verts, faces
10721119

10731120

10741121
def _write_ply_header(
@@ -1305,20 +1352,20 @@ def read(
13051352
if not endswith(path, self.known_suffixes):
13061353
return None
13071354

1308-
verts, faces, verts_colors, verts_normals = _load_ply(
1309-
f=path, path_manager=path_manager
1310-
)
1355+
data = _load_ply(f=path, path_manager=path_manager)
1356+
faces = data.faces
13111357
if faces is None:
13121358
faces = torch.zeros(0, 3, dtype=torch.int64)
13131359

13141360
texture = None
1315-
if include_textures and verts_colors is not None:
1316-
texture = TexturesVertex([verts_colors.to(device)])
1361+
if include_textures and data.verts_colors is not None:
1362+
texture = TexturesVertex([data.verts_colors.to(device)])
13171363

1318-
if verts_normals is not None:
1319-
verts_normals = [verts_normals]
1364+
verts_normals = None
1365+
if data.verts_normals is not None:
1366+
verts_normals = [data.verts_normals.to(device)]
13201367
mesh = Meshes(
1321-
verts=[verts.to(device)],
1368+
verts=[data.verts.to(device)],
13221369
faces=[faces.to(device)],
13231370
textures=texture,
13241371
verts_normals=verts_normals,
@@ -1392,14 +1439,17 @@ def read(
13921439
if not endswith(path, self.known_suffixes):
13931440
return None
13941441

1395-
verts, faces, features, normals = _load_ply(f=path, path_manager=path_manager)
1396-
verts = verts.to(device)
1397-
if features is not None:
1398-
features = [features.to(device)]
1399-
if normals is not None:
1400-
normals = [normals.to(device)]
1442+
data = _load_ply(f=path, path_manager=path_manager)
1443+
features = None
1444+
if data.verts_colors is not None:
1445+
features = [data.verts_colors.to(device)]
1446+
normals = None
1447+
if data.verts_normals is not None:
1448+
normals = [data.verts_normals.to(device)]
14011449

1402-
pointcloud = Pointclouds(points=[verts], features=features, normals=normals)
1450+
pointcloud = Pointclouds(
1451+
points=[data.verts.to(device)], features=features, normals=normals
1452+
)
14031453
return pointcloud
14041454

14051455
def save(

0 commit comments

Comments
 (0)