Skip to content

Commit 197f1d6

Browse files
bottlerfacebook-github-bot
authored andcommitted
save_ply binary
Summary: Make save_ply save to binary instead of ascii. An option makes the previous functionality available. save_ply's API accepts a stream, but this is undocumented; that stream must now be a binary stream not a text stream. Avoiding warnings about making tensors from immutable numpy arrays. Possible performance improvement when reading binary files. Fix reading zero-length binary lists. Reviewed By: nikhilaravi Differential Revision: D22333118 fbshipit-source-id: b423dfd3da46e047bead200255f47a7707306811
1 parent ebe2693 commit 197f1d6

File tree

2 files changed

+102
-44
lines changed

2 files changed

+102
-44
lines changed

pytorch3d/io/ply_io.py

Lines changed: 62 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import sys
99
import warnings
1010
from collections import namedtuple
11+
from io import BytesIO
1112
from typing import Optional, Tuple
1213

1314
import numpy as np
@@ -386,11 +387,18 @@ def _read_ply_fixed_size_element_binary(
386387
np_type = ply_type.np_type
387388
type_size = ply_type.size
388389
needed_length = definition.count * len(definition.properties)
389-
needed_bytes = needed_length * type_size
390-
bytes_data = f.read(needed_bytes)
391-
if len(bytes_data) != needed_bytes:
392-
raise ValueError("Not enough data for %s." % definition.name)
393-
data = np.frombuffer(bytes_data, dtype=np_type)
390+
if isinstance(f, BytesIO):
391+
# np.fromfile is faster but won't work on a BytesIO
392+
needed_bytes = needed_length * type_size
393+
bytes_data = bytearray(needed_bytes)
394+
n_bytes_read = f.readinto(bytes_data)
395+
if n_bytes_read != needed_bytes:
396+
raise ValueError("Not enough data for %s." % definition.name)
397+
data = np.frombuffer(bytes_data, dtype=np_type)
398+
else:
399+
data = np.fromfile(f, dtype=np_type, count=needed_length)
400+
if data.shape[0] != needed_length:
401+
raise ValueError("Not enough data for %s." % definition.name)
394402

395403
if (sys.byteorder == "big") != big_endian:
396404
data = data.byteswap()
@@ -447,6 +455,8 @@ def _try_read_ply_constant_list_binary(
447455
If every element has the same size, 2D numpy array corresponding to the
448456
data. The rows are the different values. Otherwise None.
449457
"""
458+
if definition.count == 0:
459+
return []
450460
property = definition.properties[0]
451461
endian_str = ">" if big_endian else "<"
452462
length_format = endian_str + _PLY_TYPES[property.list_size_type].struct_char
@@ -689,6 +699,7 @@ def _save_ply(
689699
verts: torch.Tensor,
690700
faces: torch.LongTensor,
691701
verts_normals: torch.Tensor,
702+
ascii: bool,
692703
decimal_places: Optional[int] = None,
693704
) -> None:
694705
"""
@@ -699,52 +710,75 @@ def _save_ply(
699710
verts: FloatTensor of shape (V, 3) giving vertex coordinates.
700711
faces: LongTensor of shsape (F, 3) giving faces.
701712
verts_normals: FloatTensor of shape (V, 3) giving vertex normals.
702-
decimal_places: Number of decimal places for saving.
713+
ascii: (bool) whether to use the ascii ply format.
714+
decimal_places: Number of decimal places for saving if ascii=True.
703715
"""
704716
assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3)
705717
assert not len(faces) or (faces.dim() == 2 and faces.size(1) == 3)
706718
assert not len(verts_normals) or (
707719
verts_normals.dim() == 2 and verts_normals.size(1) == 3
708720
)
709721

710-
print("ply\nformat ascii 1.0", file=f)
711-
print(f"element vertex {verts.shape[0]}", file=f)
712-
print("property float x", file=f)
713-
print("property float y", file=f)
714-
print("property float z", file=f)
722+
if ascii:
723+
f.write(b"ply\nformat ascii 1.0\n")
724+
elif sys.byteorder == "big":
725+
f.write(b"ply\nformat binary_big_endian 1.0\n")
726+
else:
727+
f.write(b"ply\nformat binary_little_endian 1.0\n")
728+
f.write(f"element vertex {verts.shape[0]}\n".encode("ascii"))
729+
f.write(b"property float x\n")
730+
f.write(b"property float y\n")
731+
f.write(b"property float z\n")
715732
if verts_normals.numel() > 0:
716-
print("property float nx", file=f)
717-
print("property float ny", file=f)
718-
print("property float nz", file=f)
719-
print(f"element face {faces.shape[0]}", file=f)
720-
print("property list uchar int vertex_index", file=f)
721-
print("end_header", file=f)
733+
f.write(b"property float nx\n")
734+
f.write(b"property float ny\n")
735+
f.write(b"property float nz\n")
736+
f.write(f"element face {faces.shape[0]}\n".encode("ascii"))
737+
f.write(b"property list uchar int vertex_index\n")
738+
f.write(b"end_header\n")
722739

723740
if not (len(verts) or len(faces)):
724741
warnings.warn("Empty 'verts' and 'faces' arguments provided")
725742
return
726743

727-
if decimal_places is None:
728-
float_str = "%f"
744+
vert_data = torch.cat((verts, verts_normals), dim=1).detach().numpy()
745+
if ascii:
746+
if decimal_places is None:
747+
float_str = "%f"
748+
else:
749+
float_str = "%" + ".%df" % decimal_places
750+
np.savetxt(f, vert_data, float_str)
729751
else:
730-
float_str = "%" + ".%df" % decimal_places
731-
732-
vert_data = torch.cat((verts, verts_normals), dim=1)
733-
np.savetxt(f, vert_data.detach().numpy(), float_str)
752+
assert vert_data.dtype == np.float32
753+
if isinstance(f, BytesIO):
754+
# tofile only works with real files, but is faster than this.
755+
f.write(vert_data.tobytes())
756+
else:
757+
vert_data.tofile(f)
734758

735759
faces_array = faces.detach().numpy()
736760

737761
_check_faces_indices(faces, max_index=verts.shape[0])
738762

739763
if len(faces_array):
740-
np.savetxt(f, faces_array, "3 %d %d %d")
764+
if ascii:
765+
np.savetxt(f, faces_array, "3 %d %d %d")
766+
else:
767+
# rows are 13 bytes: a one-byte 3 followed by three four-byte face indices.
768+
faces_uints = np.full((len(faces_array), 13), 3, dtype=np.uint8)
769+
faces_uints[:, 1:] = faces_array.astype(np.uint32).view(np.uint8)
770+
if isinstance(f, BytesIO):
771+
f.write(faces_uints.tobytes())
772+
else:
773+
faces_uints.tofile(f)
741774

742775

743776
def save_ply(
744777
f,
745778
verts: torch.Tensor,
746779
faces: Optional[torch.LongTensor] = None,
747780
verts_normals: Optional[torch.Tensor] = None,
781+
ascii: bool = False,
748782
decimal_places: Optional[int] = None,
749783
) -> None:
750784
"""
@@ -755,7 +789,8 @@ def save_ply(
755789
verts: FloatTensor of shape (V, 3) giving vertex coordinates.
756790
faces: LongTensor of shape (F, 3) giving faces.
757791
verts_normals: FloatTensor of shape (V, 3) giving vertex normals.
758-
decimal_places: Number of decimal places for saving.
792+
ascii: (bool) whether to use the ascii ply format.
793+
decimal_places: Number of decimal places for saving if ascii=True.
759794
"""
760795

761796
verts_normals = (
@@ -781,5 +816,5 @@ def save_ply(
781816
message = "Argument 'verts_normals' should either be empty or of shape (num_verts, 3)."
782817
raise ValueError(message)
783818

784-
with _open_file(f, "w") as f:
785-
_save_ply(f, verts, faces, verts_normals, decimal_places)
819+
with _open_file(f, "wb") as f:
820+
_save_ply(f, verts, faces, verts_normals, ascii, decimal_places)

tests/test_ply_io.py

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import struct
44
import unittest
55
from io import BytesIO, StringIO
6+
from tempfile import TemporaryFile
67

78
import torch
89
from common_testing import TestCaseMixin
@@ -144,7 +145,7 @@ def test_save_ply_invalid_shapes(self):
144145
with self.assertRaises(ValueError) as error:
145146
verts = torch.FloatTensor([[0.1, 0.2, 0.3, 0.4]]) # (V, 4)
146147
faces = torch.LongTensor([[0, 1, 2]])
147-
save_ply(StringIO(), verts, faces)
148+
save_ply(BytesIO(), verts, faces)
148149
expected_message = (
149150
"Argument 'verts' should either be empty or of shape (num_verts, 3)."
150151
)
@@ -154,7 +155,7 @@ def test_save_ply_invalid_shapes(self):
154155
with self.assertRaises(ValueError) as error:
155156
verts = torch.FloatTensor([[0.1, 0.2, 0.3]])
156157
faces = torch.LongTensor([[0, 1, 2, 3]]) # (F, 4)
157-
save_ply(StringIO(), verts, faces)
158+
save_ply(BytesIO(), verts, faces)
158159
expected_message = (
159160
"Argument 'faces' should either be empty or of shape (num_faces, 3)."
160161
)
@@ -165,14 +166,14 @@ def test_save_ply_invalid_indices(self):
165166
verts = torch.FloatTensor([[0.1, 0.2, 0.3]])
166167
faces = torch.LongTensor([[0, 1, 2]])
167168
with self.assertWarnsRegex(UserWarning, message_regex):
168-
save_ply(StringIO(), verts, faces)
169+
save_ply(BytesIO(), verts, faces)
169170

170171
faces = torch.LongTensor([[-1, 0, 1]])
171172
with self.assertWarnsRegex(UserWarning, message_regex):
172-
save_ply(StringIO(), verts, faces)
173+
save_ply(BytesIO(), verts, faces)
173174

174175
def _test_save_load(self, verts, faces):
175-
f = StringIO()
176+
f = BytesIO()
176177
save_ply(f, verts, faces)
177178
f.seek(0)
178179
# raise Exception(f.getvalue())
@@ -193,7 +194,7 @@ def test_normals_save(self):
193194
normals = torch.tensor(
194195
[[0, 1, 0], [1, 0, 0], [0, 0, 1], [1, 0, 0]], dtype=torch.float32
195196
)
196-
file = StringIO()
197+
file = BytesIO()
197198
save_ply(file, verts=verts, faces=faces, verts_normals=normals)
198199
file.close()
199200

@@ -237,15 +238,31 @@ def test_empty_save_load(self):
237238

238239
def test_simple_save(self):
239240
verts = torch.tensor(
240-
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=torch.float32
241+
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0], [1, 2, 0]], dtype=torch.float32
241242
)
242243
faces = torch.tensor([[0, 1, 2], [0, 3, 4]])
243-
file = StringIO()
244-
save_ply(file, verts=verts, faces=faces)
245-
file.seek(0)
246-
verts2, faces2 = load_ply(file)
247-
self.assertClose(verts, verts2)
248-
self.assertClose(faces, faces2)
244+
for filetype in BytesIO, TemporaryFile:
245+
lengths = {}
246+
for ascii in [True, False]:
247+
file = filetype()
248+
save_ply(file, verts=verts, faces=faces, ascii=ascii)
249+
lengths[ascii] = file.tell()
250+
251+
file.seek(0)
252+
verts2, faces2 = load_ply(file)
253+
self.assertClose(verts, verts2)
254+
self.assertClose(faces, faces2)
255+
256+
file.seek(0)
257+
if ascii:
258+
file.read().decode("ascii")
259+
else:
260+
with self.assertRaises(UnicodeDecodeError):
261+
file.read().decode("ascii")
262+
263+
if filetype is TemporaryFile:
264+
file.close()
265+
self.assertLess(lengths[False], lengths[True], "ascii should be longer")
249266

250267
def test_load_simple_binary(self):
251268
for big_endian in [True, False]:
@@ -488,15 +505,21 @@ def test_bad_ply_syntax(self):
488505

489506
@staticmethod
490507
def _bm_save_ply(verts: torch.Tensor, faces: torch.Tensor, decimal_places: int):
491-
return lambda: save_ply(StringIO(), verts, faces, decimal_places=decimal_places)
508+
return lambda: save_ply(
509+
BytesIO(),
510+
verts=verts,
511+
faces=faces,
512+
ascii=True,
513+
decimal_places=decimal_places,
514+
)
492515

493516
@staticmethod
494517
def _bm_load_ply(verts: torch.Tensor, faces: torch.Tensor, decimal_places: int):
495-
f = StringIO()
496-
save_ply(f, verts, faces, decimal_places)
518+
f = BytesIO()
519+
save_ply(f, verts=verts, faces=faces, ascii=True, decimal_places=decimal_places)
497520
s = f.getvalue()
498521
# Recreate stream so it's unaffected by how it was created.
499-
return lambda: load_ply(StringIO(s))
522+
return lambda: load_ply(BytesIO(s))
500523

501524
@staticmethod
502525
def bm_save_simple_ply_with_init(V: int, F: int):

0 commit comments

Comments
 (0)