8
8
import sys
9
9
import warnings
10
10
from collections import namedtuple
11
+ from io import BytesIO
11
12
from typing import Optional , Tuple
12
13
13
14
import numpy as np
@@ -386,11 +387,18 @@ def _read_ply_fixed_size_element_binary(
386
387
np_type = ply_type .np_type
387
388
type_size = ply_type .size
388
389
needed_length = definition .count * len (definition .properties )
389
- needed_bytes = needed_length * type_size
390
- bytes_data = f .read (needed_bytes )
391
- if len (bytes_data ) != needed_bytes :
392
- raise ValueError ("Not enough data for %s." % definition .name )
393
- data = np .frombuffer (bytes_data , dtype = np_type )
390
+ if isinstance (f , BytesIO ):
391
+ # np.fromfile is faster but won't work on a BytesIO
392
+ needed_bytes = needed_length * type_size
393
+ bytes_data = bytearray (needed_bytes )
394
+ n_bytes_read = f .readinto (bytes_data )
395
+ if n_bytes_read != needed_bytes :
396
+ raise ValueError ("Not enough data for %s." % definition .name )
397
+ data = np .frombuffer (bytes_data , dtype = np_type )
398
+ else :
399
+ data = np .fromfile (f , dtype = np_type , count = needed_length )
400
+ if data .shape [0 ] != needed_length :
401
+ raise ValueError ("Not enough data for %s." % definition .name )
394
402
395
403
if (sys .byteorder == "big" ) != big_endian :
396
404
data = data .byteswap ()
@@ -447,6 +455,8 @@ def _try_read_ply_constant_list_binary(
447
455
If every element has the same size, 2D numpy array corresponding to the
448
456
data. The rows are the different values. Otherwise None.
449
457
"""
458
+ if definition .count == 0 :
459
+ return []
450
460
property = definition .properties [0 ]
451
461
endian_str = ">" if big_endian else "<"
452
462
length_format = endian_str + _PLY_TYPES [property .list_size_type ].struct_char
@@ -689,6 +699,7 @@ def _save_ply(
689
699
verts : torch .Tensor ,
690
700
faces : torch .LongTensor ,
691
701
verts_normals : torch .Tensor ,
702
+ ascii : bool ,
692
703
decimal_places : Optional [int ] = None ,
693
704
) -> None :
694
705
"""
@@ -699,52 +710,75 @@ def _save_ply(
699
710
verts: FloatTensor of shape (V, 3) giving vertex coordinates.
700
711
faces: LongTensor of shsape (F, 3) giving faces.
701
712
verts_normals: FloatTensor of shape (V, 3) giving vertex normals.
702
- decimal_places: Number of decimal places for saving.
713
+ ascii: (bool) whether to use the ascii ply format.
714
+ decimal_places: Number of decimal places for saving if ascii=True.
703
715
"""
704
716
assert not len (verts ) or (verts .dim () == 2 and verts .size (1 ) == 3 )
705
717
assert not len (faces ) or (faces .dim () == 2 and faces .size (1 ) == 3 )
706
718
assert not len (verts_normals ) or (
707
719
verts_normals .dim () == 2 and verts_normals .size (1 ) == 3
708
720
)
709
721
710
- print ("ply\n format ascii 1.0" , file = f )
711
- print (f"element vertex { verts .shape [0 ]} " , file = f )
712
- print ("property float x" , file = f )
713
- print ("property float y" , file = f )
714
- print ("property float z" , file = f )
722
+ if ascii :
723
+ f .write (b"ply\n format ascii 1.0\n " )
724
+ elif sys .byteorder == "big" :
725
+ f .write (b"ply\n format binary_big_endian 1.0\n " )
726
+ else :
727
+ f .write (b"ply\n format binary_little_endian 1.0\n " )
728
+ f .write (f"element vertex { verts .shape [0 ]} \n " .encode ("ascii" ))
729
+ f .write (b"property float x\n " )
730
+ f .write (b"property float y\n " )
731
+ f .write (b"property float z\n " )
715
732
if verts_normals .numel () > 0 :
716
- print ( "property float nx" , file = f )
717
- print ( "property float ny" , file = f )
718
- print ( "property float nz" , file = f )
719
- print (f"element face { faces .shape [0 ]} " , file = f )
720
- print ( "property list uchar int vertex_index" , file = f )
721
- print ( "end_header" , file = f )
733
+ f . write ( b "property float nx\n " )
734
+ f . write ( b "property float ny\n " )
735
+ f . write ( b "property float nz\n " )
736
+ f . write (f"element face { faces .shape [0 ]} \n " . encode ( "ascii" ) )
737
+ f . write ( b "property list uchar int vertex_index\n " )
738
+ f . write ( b "end_header\n " )
722
739
723
740
if not (len (verts ) or len (faces )):
724
741
warnings .warn ("Empty 'verts' and 'faces' arguments provided" )
725
742
return
726
743
727
- if decimal_places is None :
728
- float_str = "%f"
744
+ vert_data = torch .cat ((verts , verts_normals ), dim = 1 ).detach ().numpy ()
745
+ if ascii :
746
+ if decimal_places is None :
747
+ float_str = "%f"
748
+ else :
749
+ float_str = "%" + ".%df" % decimal_places
750
+ np .savetxt (f , vert_data , float_str )
729
751
else :
730
- float_str = "%" + ".%df" % decimal_places
731
-
732
- vert_data = torch .cat ((verts , verts_normals ), dim = 1 )
733
- np .savetxt (f , vert_data .detach ().numpy (), float_str )
752
+ assert vert_data .dtype == np .float32
753
+ if isinstance (f , BytesIO ):
754
+ # tofile only works with real files, but is faster than this.
755
+ f .write (vert_data .tobytes ())
756
+ else :
757
+ vert_data .tofile (f )
734
758
735
759
faces_array = faces .detach ().numpy ()
736
760
737
761
_check_faces_indices (faces , max_index = verts .shape [0 ])
738
762
739
763
if len (faces_array ):
740
- np .savetxt (f , faces_array , "3 %d %d %d" )
764
+ if ascii :
765
+ np .savetxt (f , faces_array , "3 %d %d %d" )
766
+ else :
767
+ # rows are 13 bytes: a one-byte 3 followed by three four-byte face indices.
768
+ faces_uints = np .full ((len (faces_array ), 13 ), 3 , dtype = np .uint8 )
769
+ faces_uints [:, 1 :] = faces_array .astype (np .uint32 ).view (np .uint8 )
770
+ if isinstance (f , BytesIO ):
771
+ f .write (faces_uints .tobytes ())
772
+ else :
773
+ faces_uints .tofile (f )
741
774
742
775
743
776
def save_ply (
744
777
f ,
745
778
verts : torch .Tensor ,
746
779
faces : Optional [torch .LongTensor ] = None ,
747
780
verts_normals : Optional [torch .Tensor ] = None ,
781
+ ascii : bool = False ,
748
782
decimal_places : Optional [int ] = None ,
749
783
) -> None :
750
784
"""
@@ -755,7 +789,8 @@ def save_ply(
755
789
verts: FloatTensor of shape (V, 3) giving vertex coordinates.
756
790
faces: LongTensor of shape (F, 3) giving faces.
757
791
verts_normals: FloatTensor of shape (V, 3) giving vertex normals.
758
- decimal_places: Number of decimal places for saving.
792
+ ascii: (bool) whether to use the ascii ply format.
793
+ decimal_places: Number of decimal places for saving if ascii=True.
759
794
"""
760
795
761
796
verts_normals = (
@@ -781,5 +816,5 @@ def save_ply(
781
816
message = "Argument 'verts_normals' should either be empty or of shape (num_verts, 3)."
782
817
raise ValueError (message )
783
818
784
- with _open_file (f , "w " ) as f :
785
- _save_ply (f , verts , faces , verts_normals , decimal_places )
819
+ with _open_file (f , "wb " ) as f :
820
+ _save_ply (f , verts , faces , verts_normals , ascii , decimal_places )
0 commit comments