Skip to content

Commit 092400f

Browse files
dhbloofacebook-github-bot
authored andcommitted
allow saving vertex normal in save_obj (#1511)
Summary: Although we can load per-vertex normals in `load_obj`, saving per-vertex normals is not supported in `save_obj`. This patch fixes this by allowing passing per-vertex normal data in `save_obj`: ``` python def save_obj( f: PathOrStr, verts, faces, decimal_places: Optional[int] = None, path_manager: Optional[PathManager] = None, *, verts_normals: Optional[torch.Tensor] = None, faces_normals: Optional[torch.Tensor] = None, verts_uvs: Optional[torch.Tensor] = None, faces_uvs: Optional[torch.Tensor] = None, texture_map: Optional[torch.Tensor] = None, ) -> None: """ Save a mesh to an .obj file. Args: f: File (str or path) to which the mesh should be written. verts: FloatTensor of shape (V, 3) giving vertex coordinates. faces: LongTensor of shape (F, 3) giving faces. decimal_places: Number of decimal places for saving. path_manager: Optional PathManager for interpreting f if it is a str. verts_normals: FloatTensor of shape (V, 3) giving the normal per vertex. faces_normals: LongTensor of shape (F, 3) giving the index into verts_normals for each vertex in the face. verts_uvs: FloatTensor of shape (V, 2) giving the uv coordinate per vertex. faces_uvs: LongTensor of shape (F, 3) giving the index into verts_uvs for each vertex in the face. texture_map: FloatTensor of shape (H, W, 3) representing the texture map for the mesh which will be saved as an image. The values are expected to be in the range [0, 1], """ ``` Pull Request resolved: #1511 Reviewed By: shapovalov Differential Revision: D45086045 Pulled By: bottler fbshipit-source-id: 666efb0d2c302df6cf9f2f6601d83a07856bf32f
1 parent ec87284 commit 092400f

File tree

2 files changed

+262
-32
lines changed

2 files changed

+262
-32
lines changed

pytorch3d/io/obj_io.py

+107-20
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,8 @@ def save_obj(
684684
decimal_places: Optional[int] = None,
685685
path_manager: Optional[PathManager] = None,
686686
*,
687+
normals: Optional[torch.Tensor] = None,
688+
faces_normals_idx: Optional[torch.Tensor] = None,
687689
verts_uvs: Optional[torch.Tensor] = None,
688690
faces_uvs: Optional[torch.Tensor] = None,
689691
texture_map: Optional[torch.Tensor] = None,
@@ -698,6 +700,10 @@ def save_obj(
698700
decimal_places: Number of decimal places for saving.
699701
path_manager: Optional PathManager for interpreting f if
700702
it is a str.
703+
normals: FloatTensor of shape (V, 3) giving normals for faces_normals_idx
704+
to index into.
705+
faces_normals_idx: LongTensor of shape (F, 3) giving the index into
706+
normals for each vertex in the face.
701707
verts_uvs: FloatTensor of shape (V, 2) giving the uv coordinate per vertex.
702708
faces_uvs: LongTensor of shape (F, 3) giving the index into verts_uvs for
703709
each vertex in the face.
@@ -713,6 +719,22 @@ def save_obj(
713719
message = "'faces' should either be empty or of shape (num_faces, 3)."
714720
raise ValueError(message)
715721

722+
if (normals is None) != (faces_normals_idx is None):
723+
message = "'normals' and 'faces_normals_idx' must both be None or neither."
724+
raise ValueError(message)
725+
726+
if faces_normals_idx is not None and (
727+
faces_normals_idx.dim() != 2 or faces_normals_idx.size(1) != 3
728+
):
729+
message = (
730+
"'faces_normals_idx' should either be empty or of shape (num_faces, 3)."
731+
)
732+
raise ValueError(message)
733+
734+
if normals is not None and (normals.dim() != 2 or normals.size(1) != 3):
735+
message = "'normals' should either be empty or of shape (num_verts, 3)."
736+
raise ValueError(message)
737+
716738
if faces_uvs is not None and (faces_uvs.dim() != 2 or faces_uvs.size(1) != 3):
717739
message = "'faces_uvs' should either be empty or of shape (num_faces, 3)."
718740
raise ValueError(message)
@@ -742,9 +764,12 @@ def save_obj(
742764
verts,
743765
faces,
744766
decimal_places,
767+
normals=normals,
768+
faces_normals_idx=faces_normals_idx,
745769
verts_uvs=verts_uvs,
746770
faces_uvs=faces_uvs,
747771
save_texture=save_texture,
772+
save_normals=normals is not None,
748773
)
749774

750775
# Save the .mtl and .png files associated with the texture
@@ -777,9 +802,12 @@ def _save(
777802
faces,
778803
decimal_places: Optional[int] = None,
779804
*,
805+
normals: Optional[torch.Tensor] = None,
806+
faces_normals_idx: Optional[torch.Tensor] = None,
780807
verts_uvs: Optional[torch.Tensor] = None,
781808
faces_uvs: Optional[torch.Tensor] = None,
782809
save_texture: bool = False,
810+
save_normals: bool = False,
783811
) -> None:
784812

785813
if len(verts) and (verts.dim() != 2 or verts.size(1) != 3):
@@ -798,18 +826,26 @@ def _save(
798826

799827
lines = ""
800828

801-
if len(verts):
802-
if decimal_places is None:
803-
float_str = "%f"
804-
else:
805-
float_str = "%" + ".%df" % decimal_places
829+
if decimal_places is None:
830+
float_str = "%f"
831+
else:
832+
float_str = "%" + ".%df" % decimal_places
806833

834+
if len(verts):
807835
V, D = verts.shape
808836
for i in range(V):
809837
vert = [float_str % verts[i, j] for j in range(D)]
810838
lines += "v %s\n" % " ".join(vert)
811839

840+
if save_normals:
841+
assert normals is not None
842+
assert faces_normals_idx is not None
843+
lines += _write_normals(normals, faces_normals_idx, float_str)
844+
812845
if save_texture:
846+
assert faces_uvs is not None
847+
assert verts_uvs is not None
848+
813849
if faces_uvs is not None and (faces_uvs.dim() != 2 or faces_uvs.size(1) != 3):
814850
message = "'faces_uvs' should either be empty or of shape (num_faces, 3)."
815851
raise ValueError(message)
@@ -818,7 +854,6 @@ def _save(
818854
message = "'verts_uvs' should either be empty or of shape (num_verts, 2)."
819855
raise ValueError(message)
820856

821-
# pyre-fixme[16] # undefined attribute cpu
822857
verts_uvs, faces_uvs = verts_uvs.cpu(), faces_uvs.cpu()
823858

824859
# Save verts uvs after verts
@@ -828,25 +863,77 @@ def _save(
828863
uv = [float_str % verts_uvs[i, j] for j in range(uD)]
829864
lines += "vt %s\n" % " ".join(uv)
830865

866+
f.write(lines)
867+
831868
if torch.any(faces >= verts.shape[0]) or torch.any(faces < 0):
832869
warnings.warn("Faces have invalid indices")
833870

834871
if len(faces):
835-
F, P = faces.shape
836-
for i in range(F):
837-
if save_texture:
838-
# Format faces as {verts_idx}/{verts_uvs_idx}
872+
_write_faces(
873+
f,
874+
faces,
875+
faces_uvs if save_texture else None,
876+
faces_normals_idx if save_normals else None,
877+
)
878+
879+
880+
def _write_normals(
881+
normals: torch.Tensor, faces_normals_idx: torch.Tensor, float_str: str
882+
) -> str:
883+
if faces_normals_idx.dim() != 2 or faces_normals_idx.size(1) != 3:
884+
message = (
885+
"'faces_normals_idx' should either be empty or of shape (num_faces, 3)."
886+
)
887+
raise ValueError(message)
888+
889+
if normals.dim() != 2 or normals.size(1) != 3:
890+
message = "'normals' should either be empty or of shape (num_verts, 3)."
891+
raise ValueError(message)
892+
893+
normals, faces_normals_idx = normals.cpu(), faces_normals_idx.cpu()
894+
895+
lines = []
896+
V, D = normals.shape
897+
for i in range(V):
898+
normal = [float_str % normals[i, j] for j in range(D)]
899+
lines.append("vn %s\n" % " ".join(normal))
900+
return "".join(lines)
901+
902+
903+
def _write_faces(
904+
f,
905+
faces: torch.Tensor,
906+
faces_uvs: Optional[torch.Tensor],
907+
faces_normals_idx: Optional[torch.Tensor],
908+
) -> None:
909+
F, P = faces.shape
910+
for i in range(F):
911+
if faces_normals_idx is not None:
912+
if faces_uvs is not None:
913+
# Format faces as {verts_idx}/{verts_uvs_idx}/{verts_normals_idx}
839914
face = [
840-
"%d/%d" % (faces[i, j] + 1, faces_uvs[i, j] + 1) for j in range(P)
915+
"%d/%d/%d"
916+
% (
917+
faces[i, j] + 1,
918+
faces_uvs[i, j] + 1,
919+
faces_normals_idx[i, j] + 1,
920+
)
921+
for j in range(P)
841922
]
842923
else:
843-
face = ["%d" % (faces[i, j] + 1) for j in range(P)]
844-
845-
if i + 1 < F:
846-
lines += "f %s\n" % " ".join(face)
847-
848-
elif i + 1 == F:
849-
# No newline at the end of the file.
850-
lines += "f %s" % " ".join(face)
924+
# Format faces as {verts_idx}//{verts_normals_idx}
925+
face = [
926+
"%d//%d" % (faces[i, j] + 1, faces_normals_idx[i, j] + 1)
927+
for j in range(P)
928+
]
929+
elif faces_uvs is not None:
930+
# Format faces as {verts_idx}/{verts_uvs_idx}
931+
face = ["%d/%d" % (faces[i, j] + 1, faces_uvs[i, j] + 1) for j in range(P)]
932+
else:
933+
face = ["%d" % (faces[i, j] + 1) for j in range(P)]
851934

852-
f.write(lines)
935+
if i + 1 < F:
936+
f.write("f %s\n" % " ".join(face))
937+
else:
938+
# No newline at the end of the file.
939+
f.write("f %s" % " ".join(face))

0 commit comments

Comments
 (0)