14
14
import sys
15
15
import warnings
16
16
from collections import namedtuple
17
+ from dataclasses import asdict , dataclass
17
18
from io import BytesIO , TextIOBase
18
- from typing import List , Optional , Tuple , cast
19
+ from typing import List , Optional , Tuple
19
20
20
21
import numpy as np
21
22
import torch
@@ -137,6 +138,7 @@ def __init__(self, f) -> None:
137
138
self.ascii: (bool) Whether in ascii format
138
139
self.big_endian: (bool) (if not ascii) whether big endian
139
140
self.obj_info: (List[str]) arbitrary extra data
141
+ self.comments: (List[str]) comments
140
142
141
143
Args:
142
144
f: file-like object.
@@ -145,7 +147,8 @@ def __init__(self, f) -> None:
145
147
raise ValueError ("Invalid file header." )
146
148
seen_format = False
147
149
self .elements : List [_PlyElementType ] = []
148
- self .obj_info = []
150
+ self .comments : List [str ] = []
151
+ self .obj_info : List [str ] = []
149
152
while True :
150
153
line = f .readline ()
151
154
if isinstance (line , bytes ):
@@ -176,6 +179,9 @@ def __init__(self, f) -> None:
176
179
continue
177
180
if line .startswith ("format" ):
178
181
raise ValueError ("Invalid format line." )
182
+ if line .startswith ("comment " ):
183
+ self .comments .append (line [8 :])
184
+ continue
179
185
if line .startswith ("comment" ) or len (line ) == 0 :
180
186
continue
181
187
if line .startswith ("element" ):
@@ -781,9 +787,28 @@ def _load_ply_raw(f, path_manager: PathManager) -> Tuple[_PlyHeader, dict]:
781
787
return header , elements
782
788
783
789
790
+ @dataclass (frozen = True )
791
+ class _VertsColumnIndices :
792
+ """
793
+ Contains the relevant layout of the verts section of file being read.
794
+ Members
795
+ point_idxs: List[int] of 3 point columns.
796
+ color_idxs: List[int] of 3 color columns if they are present,
797
+ otherwise None.
798
+ color_scale: value to scale colors by.
799
+ normal_idxs: List[int] of 3 normals columns if they are present,
800
+ otherwise None.
801
+ """
802
+
803
+ point_idxs : List [int ]
804
+ color_idxs : Optional [List [int ]]
805
+ color_scale : float
806
+ normal_idxs : Optional [List [int ]]
807
+
808
+
784
809
def _get_verts_column_indices (
785
810
vertex_head : _PlyElementType ,
786
- ) -> Tuple [ List [ int ], Optional [ List [ int ]], float , Optional [ List [ int ]]] :
811
+ ) -> _VertsColumnIndices :
787
812
"""
788
813
Get the columns of verts, verts_colors, and verts_normals in the vertex
789
814
element of a parsed ply file, together with a color scale factor.
@@ -809,12 +834,7 @@ def _get_verts_column_indices(
809
834
vertex_head: as returned from load_ply_raw.
810
835
811
836
Returns:
812
- point_idxs: List[int] of 3 point columns.
813
- color_idxs: List[int] of 3 color columns if they are present,
814
- otherwise None.
815
- color_scale: value to scale colors by.
816
- normal_idxs: List[int] of 3 normals columns if they are present,
817
- otherwise None.
837
+ _VertsColumnIndices object
818
838
"""
819
839
point_idxs : List [Optional [int ]] = [None , None , None ]
820
840
color_idxs : List [Optional [int ]] = [None , None , None ]
@@ -839,29 +859,38 @@ def _get_verts_column_indices(
839
859
for idx in color_idxs
840
860
):
841
861
color_scale = 1.0 / 255
842
- return (
843
- point_idxs ,
844
- # pyre-fixme[22]: The cast is redundant.
845
- None if None in color_idxs else cast (List [int ], color_idxs ),
846
- color_scale ,
847
- # pyre-fixme[22]: The cast is redundant.
848
- None if None in normal_idxs else cast (List [int ], normal_idxs ),
862
+ return _VertsColumnIndices (
863
+ point_idxs = point_idxs ,
864
+ color_idxs = None if None in color_idxs else color_idxs ,
865
+ color_scale = color_scale ,
866
+ normal_idxs = None if None in normal_idxs else normal_idxs ,
849
867
)
850
868
851
869
852
- def _get_verts (
853
- header : _PlyHeader , elements : dict
854
- ) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [torch .Tensor ]]:
870
+ @dataclass (frozen = True )
871
+ class _VertsData :
872
+ """
873
+ Contains the data of the verts section of file being read.
874
+ Members:
875
+ verts: FloatTensor of shape (V, 3).
876
+ verts_colors: None or FloatTensor of shape (V, 3).
877
+ verts_normals: None or FloatTensor of shape (V, 3).
878
+ """
879
+
880
+ verts : torch .Tensor
881
+ verts_colors : Optional [torch .Tensor ] = None
882
+ verts_normals : Optional [torch .Tensor ] = None
883
+
884
+
885
+ def _get_verts (header : _PlyHeader , elements : dict ) -> _VertsData :
855
886
"""
856
887
Get the vertex locations, colors and normals from a parsed ply file.
857
888
858
889
Args:
859
890
header, elements: as returned from load_ply_raw.
860
891
861
892
Returns:
862
- verts: FloatTensor of shape (V, 3).
863
- vertex_colors: None or FloatTensor of shape (V, 3).
864
- vertex_normals: None or FloatTensor of shape (V, 3).
893
+ _VertsData object
865
894
"""
866
895
867
896
vertex = elements .get ("vertex" , None )
@@ -870,16 +899,17 @@ def _get_verts(
870
899
if not isinstance (vertex , list ):
871
900
raise ValueError ("Invalid vertices in file." )
872
901
vertex_head = next (head for head in header .elements if head .name == "vertex" )
873
- point_idxs , color_idxs , color_scale , normal_idxs = _get_verts_column_indices (
874
- vertex_head
875
- )
902
+
903
+ column_idxs = _get_verts_column_indices (vertex_head )
876
904
877
905
# Case of no vertices
878
906
if vertex_head .count == 0 :
879
907
verts = torch .zeros ((0 , 3 ), dtype = torch .float32 )
880
- if color_idxs is None :
881
- return verts , None , None
882
- return verts , torch .zeros ((0 , 3 ), dtype = torch .float32 ), None
908
+ if column_idxs .color_idxs is None :
909
+ return _VertsData (verts = verts )
910
+ return _VertsData (
911
+ verts = verts , verts_colors = torch .zeros ((0 , 3 ), dtype = torch .float32 )
912
+ )
883
913
884
914
# Simple case where the only data is the vertices themselves
885
915
if (
@@ -888,22 +918,22 @@ def _get_verts(
888
918
and vertex [0 ].ndim == 2
889
919
and vertex [0 ].shape [1 ] == 3
890
920
):
891
- return _make_tensor (vertex [0 ], cols = 3 , dtype = torch .float32 ), None , None
921
+ return _VertsData ( verts = _make_tensor (vertex [0 ], cols = 3 , dtype = torch .float32 ))
892
922
893
923
vertex_colors = None
894
924
vertex_normals = None
895
925
896
926
if len (vertex ) == 1 :
897
927
# This is the case where the whole vertex element has one type,
898
928
# so it was read as a single array and we can index straight into it.
899
- verts = torch .tensor (vertex [0 ][:, point_idxs ], dtype = torch .float32 )
900
- if color_idxs is not None :
901
- vertex_colors = color_scale * torch .tensor (
902
- vertex [0 ][:, color_idxs ], dtype = torch .float32
929
+ verts = torch .tensor (vertex [0 ][:, column_idxs . point_idxs ], dtype = torch .float32 )
930
+ if column_idxs . color_idxs is not None :
931
+ vertex_colors = column_idxs . color_scale * torch .tensor (
932
+ vertex [0 ][:, column_idxs . color_idxs ], dtype = torch .float32
903
933
)
904
- if normal_idxs is not None :
934
+ if column_idxs . normal_idxs is not None :
905
935
vertex_normals = torch .tensor (
906
- vertex [0 ][:, normal_idxs ], dtype = torch .float32
936
+ vertex [0 ][:, column_idxs . normal_idxs ], dtype = torch .float32
907
937
)
908
938
else :
909
939
# The vertex element is heterogeneous. It was read as several arrays,
@@ -918,7 +948,7 @@ def _get_verts(
918
948
]
919
949
verts = torch .empty (size = (vertex_head .count , 3 ), dtype = torch .float32 )
920
950
for axis in range (3 ):
921
- partnum , col = prop_to_partnum_col [point_idxs [axis ]]
951
+ partnum , col = prop_to_partnum_col [column_idxs . point_idxs [axis ]]
922
952
verts .numpy ()[:, axis ] = vertex [partnum ][:, col ]
923
953
# Note that in the previous line, we made the assignment
924
954
# as numpy arrays by casting verts. If we took the (more
@@ -928,30 +958,49 @@ def _get_verts(
928
958
# if not vertex[partnum].flags["C_CONTIGUOUS"]:
929
959
# vertex[partnum] = np.ascontiguousarray(vertex[partnum])
930
960
# verts[:, axis] = torch.tensor((vertex[partnum][:, col]))
931
- if color_idxs is not None :
961
+ if column_idxs . color_idxs is not None :
932
962
vertex_colors = torch .empty (
933
963
size = (vertex_head .count , 3 ), dtype = torch .float32
934
964
)
935
965
for color in range (3 ):
936
- partnum , col = prop_to_partnum_col [color_idxs [color ]]
966
+ partnum , col = prop_to_partnum_col [column_idxs . color_idxs [color ]]
937
967
vertex_colors .numpy ()[:, color ] = vertex [partnum ][:, col ]
938
- vertex_colors *= color_scale
939
- if normal_idxs is not None :
968
+ vertex_colors *= column_idxs . color_scale
969
+ if column_idxs . normal_idxs is not None :
940
970
vertex_normals = torch .empty (
941
971
size = (vertex_head .count , 3 ), dtype = torch .float32
942
972
)
943
973
for axis in range (3 ):
944
- partnum , col = prop_to_partnum_col [normal_idxs [axis ]]
974
+ partnum , col = prop_to_partnum_col [column_idxs . normal_idxs [axis ]]
945
975
vertex_normals .numpy ()[:, axis ] = vertex [partnum ][:, col ]
946
976
947
- return verts , vertex_colors , vertex_normals
977
+ return _VertsData (
978
+ verts = verts ,
979
+ verts_colors = vertex_colors ,
980
+ verts_normals = vertex_normals ,
981
+ )
982
+
983
+
984
+ @dataclass (frozen = True )
985
+ class _PlyData :
986
+ """
987
+ Contains the data from a PLY file which has been read.
988
+ Members:
989
+ header: _PlyHeader of file metadata from the header
990
+ verts: FloatTensor of shape (V, 3).
991
+ faces: None or LongTensor of vertex indices, shape (F, 3).
992
+ verts_colors: None or FloatTensor of shape (V, 3).
993
+ verts_normals: None or FloatTensor of shape (V, 3).
994
+ """
995
+
996
+ header : _PlyHeader
997
+ verts : torch .Tensor
998
+ faces : Optional [torch .Tensor ]
999
+ verts_colors : Optional [torch .Tensor ]
1000
+ verts_normals : Optional [torch .Tensor ]
948
1001
949
1002
950
- def _load_ply (
951
- f , * , path_manager : PathManager
952
- ) -> Tuple [
953
- torch .Tensor , Optional [torch .Tensor ], Optional [torch .Tensor ], Optional [torch .Tensor ]
954
- ]:
1003
+ def _load_ply (f , * , path_manager : PathManager ) -> _PlyData :
955
1004
"""
956
1005
Load the data from a .ply file.
957
1006
@@ -964,14 +1013,11 @@ def _load_ply(
964
1013
path_manager: PathManager for loading if f is a str.
965
1014
966
1015
Returns:
967
- verts: FloatTensor of shape (V, 3).
968
- faces: None or LongTensor of vertex indices, shape (F, 3).
969
- vertex_colors: None or FloatTensor of shape (V, 3).
970
- vertex_normals: None or FloatTensor of shape (V, 3).
1016
+ _PlyData object
971
1017
"""
972
1018
header , elements = _load_ply_raw (f , path_manager = path_manager )
973
1019
974
- verts , vertex_colors , vertex_normals = _get_verts (header , elements )
1020
+ verts_data = _get_verts (header , elements )
975
1021
976
1022
face = elements .get ("face" , None )
977
1023
if face is not None :
@@ -1007,9 +1053,9 @@ def _load_ply(
1007
1053
faces = torch .tensor (face_list , dtype = torch .int64 )
1008
1054
1009
1055
if faces is not None :
1010
- _check_faces_indices (faces , max_index = verts .shape [0 ])
1056
+ _check_faces_indices (faces , max_index = verts_data . verts .shape [0 ])
1011
1057
1012
- return verts , faces , vertex_colors , vertex_normals
1058
+ return _PlyData ( ** asdict ( verts_data ) , faces = faces , header = header )
1013
1059
1014
1060
1015
1061
def load_ply (
@@ -1064,11 +1110,12 @@ def load_ply(
1064
1110
1065
1111
if path_manager is None :
1066
1112
path_manager = PathManager ()
1067
- verts , faces , _ , _ = _load_ply (f , path_manager = path_manager )
1113
+ data = _load_ply (f , path_manager = path_manager )
1114
+ faces = data .faces
1068
1115
if faces is None :
1069
1116
faces = torch .zeros (0 , 3 , dtype = torch .int64 )
1070
1117
1071
- return verts , faces
1118
+ return data . verts , faces
1072
1119
1073
1120
1074
1121
def _write_ply_header (
@@ -1305,20 +1352,20 @@ def read(
1305
1352
if not endswith (path , self .known_suffixes ):
1306
1353
return None
1307
1354
1308
- verts , faces , verts_colors , verts_normals = _load_ply (
1309
- f = path , path_manager = path_manager
1310
- )
1355
+ data = _load_ply (f = path , path_manager = path_manager )
1356
+ faces = data .faces
1311
1357
if faces is None :
1312
1358
faces = torch .zeros (0 , 3 , dtype = torch .int64 )
1313
1359
1314
1360
texture = None
1315
- if include_textures and verts_colors is not None :
1316
- texture = TexturesVertex ([verts_colors .to (device )])
1361
+ if include_textures and data . verts_colors is not None :
1362
+ texture = TexturesVertex ([data . verts_colors .to (device )])
1317
1363
1318
- if verts_normals is not None :
1319
- verts_normals = [verts_normals ]
1364
+ verts_normals = None
1365
+ if data .verts_normals is not None :
1366
+ verts_normals = [data .verts_normals .to (device )]
1320
1367
mesh = Meshes (
1321
- verts = [verts .to (device )],
1368
+ verts = [data . verts .to (device )],
1322
1369
faces = [faces .to (device )],
1323
1370
textures = texture ,
1324
1371
verts_normals = verts_normals ,
@@ -1392,14 +1439,17 @@ def read(
1392
1439
if not endswith (path , self .known_suffixes ):
1393
1440
return None
1394
1441
1395
- verts , faces , features , normals = _load_ply (f = path , path_manager = path_manager )
1396
- verts = verts .to (device )
1397
- if features is not None :
1398
- features = [features .to (device )]
1399
- if normals is not None :
1400
- normals = [normals .to (device )]
1442
+ data = _load_ply (f = path , path_manager = path_manager )
1443
+ features = None
1444
+ if data .verts_colors is not None :
1445
+ features = [data .verts_colors .to (device )]
1446
+ normals = None
1447
+ if data .verts_normals is not None :
1448
+ normals = [data .verts_normals .to (device )]
1401
1449
1402
- pointcloud = Pointclouds (points = [verts ], features = features , normals = normals )
1450
+ pointcloud = Pointclouds (
1451
+ points = [data .verts .to (device )], features = features , normals = normals
1452
+ )
1403
1453
return pointcloud
1404
1454
1405
1455
def save (
0 commit comments