diff --git a/pytorch3d/io/obj_io.py b/pytorch3d/io/obj_io.py index be9b108ef..c17450a80 100644 --- a/pytorch3d/io/obj_io.py +++ b/pytorch3d/io/obj_io.py @@ -684,6 +684,8 @@ def save_obj( decimal_places: Optional[int] = None, path_manager: Optional[PathManager] = None, *, + normals: Optional[torch.Tensor] = None, + faces_normals_idx: Optional[torch.Tensor] = None, verts_uvs: Optional[torch.Tensor] = None, faces_uvs: Optional[torch.Tensor] = None, texture_map: Optional[torch.Tensor] = None, @@ -698,6 +700,9 @@ def save_obj( decimal_places: Number of decimal places for saving. path_manager: Optional PathManager for interpreting f if it is a str. + normals: FloatTensor of shape (V, 3) giving the normal per vertex. + faces_normals_idx: LongTensor of shape (F, 3) giving the index into + normals for each vertex in the face. verts_uvs: FloatTensor of shape (V, 2) giving the uv coordinate per vertex. faces_uvs: LongTensor of shape (F, 3) giving the index into verts_uvs for each vertex in the face. @@ -712,6 +717,15 @@ def save_obj( if len(faces) and (faces.dim() != 2 or faces.size(1) != 3): message = "'faces' should either be empty or of shape (num_faces, 3)." raise ValueError(message) + + if faces_normals_idx is not None and \ + (faces_normals_idx.dim() != 2 or faces_normals_idx.size(1) != 3): + message = "'faces_normals_idx' should either be empty or of shape (num_faces, 3)." + raise ValueError(message) + + if normals is not None and (normals.dim() != 2 or normals.size(1) != 3): + message = "'normals' should either be empty or of shape (num_verts, 3)." + raise ValueError(message) if faces_uvs is not None and (faces_uvs.dim() != 2 or faces_uvs.size(1) != 3): message = "'faces_uvs' should either be empty or of shape (num_faces, 3)." @@ -728,6 +742,7 @@ def save_obj( if path_manager is None: path_manager = PathManager() + save_normals = all([n is not None for n in [normals, faces_normals_idx]]) save_texture = all([t is not None for t in [faces_uvs, verts_uvs, texture_map]]) output_path = Path(f) @@ -742,9 +757,12 @@ def save_obj( verts, faces, decimal_places, + normals=normals, + faces_normals_idx=faces_normals_idx, verts_uvs=verts_uvs, faces_uvs=faces_uvs, save_texture=save_texture, + save_normals=save_normals, ) # Save the .mtl and .png files associated with the texture @@ -777,9 +795,12 @@ def _save( faces, decimal_places: Optional[int] = None, *, + normals: Optional[torch.Tensor] = None, + faces_normals_idx: Optional[torch.Tensor] = None, verts_uvs: Optional[torch.Tensor] = None, faces_uvs: Optional[torch.Tensor] = None, save_texture: bool = False, + save_normals: bool = False, ) -> None: if len(verts) and (verts.dim() != 2 or verts.size(1) != 3): @@ -809,6 +830,25 @@ def _save( vert = [float_str % verts[i, j] for j in range(D)] lines += "v %s\n" % " ".join(vert) + if save_normals: + if faces_normals_idx is not None and (faces_normals_idx.dim() != 2 or faces_normals_idx.size(1) != 3): + message = "'faces_normals_idx' should either be empty or of shape (num_faces, 3)." + raise ValueError(message) + + if normals is not None and (normals.dim() != 2 or normals.size(1) != 3): + message = "'normals' should either be empty or of shape (num_verts, 3)." + raise ValueError(message) + + # pyre-fixme[16] # undefined attribute cpu + normals, faces_normals_idx = normals.cpu(), faces_normals_idx.cpu() + + # Save verts normals after verts + if len(normals): + V, D = normals.shape + for i in range(V): + normal = [float_str % normals[i, j] for j in range(D)] + lines += "vn %s\n" % " ".join(normal) + if save_texture: if faces_uvs is not None and (faces_uvs.dim() != 2 or faces_uvs.size(1) != 3): message = "'faces_uvs' should either be empty or of shape (num_faces, 3)." @@ -834,7 +874,22 @@ def _save( if len(faces): F, P = faces.shape for i in range(F): - if save_texture: + if save_texture and save_normals: + # Format faces as {verts_idx}/{verts_uvs_idx}/{verts_normals_idx} + face = [ + "%d/%d/%d" % ( + faces[i, j] + 1, + faces_uvs[i, j] + 1, + faces_normals_idx[i, j] + 1, + ) + for j in range(P) + ] + elif save_normals: + # Format faces as {verts_idx}//{verts_normals_idx} + face = [ + "%d//%d" % (faces[i, j] + 1, faces_normals_idx[i, j] + 1) for j in range(P) + ] + elif save_texture: # Format faces as {verts_idx}/{verts_uvs_idx} face = [ "%d/%d" % (faces[i, j] + 1, faces_uvs[i, j] + 1) for j in range(P) diff --git a/tests/test_io_obj.py b/tests/test_io_obj.py index 6b67932d9..a5f2d122a 100644 --- a/tests/test_io_obj.py +++ b/tests/test_io_obj.py @@ -895,6 +895,59 @@ def check_item(x, y): with self.assertRaisesRegex(ValueError, "same type of texture"): join_meshes_as_batch([mesh_atlas, mesh_rgb, mesh_atlas]) + def test_save_obj_with_normal(self): + verts = torch.tensor( + [[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]], + dtype=torch.float32, + ) + faces = torch.tensor( + [[0, 2, 1], [0, 1, 2], [3, 2, 1], [3, 1, 0]], dtype=torch.int64 + ) + normals = torch.tensor( + [[0.02, 0.5, 0.73], [0.3, 0.03, 0.361], [0.32, 0.12, 0.47], [0.36, 0.17, 0.9], + [0.40, 0.7, 0.19], [1.0, 0.00, 0.000], [0.00, 1.00, 0.00], [0.00, 0.00, 1.0]], + dtype=torch.float32, + ) + faces_normals_idx = torch.tensor( + [[0, 1, 2], [2, 3, 4], [4, 5, 6], [6, 7, 0]], dtype=torch.int64 + ) + + with TemporaryDirectory() as temp_dir: + obj_file = os.path.join(temp_dir, "mesh.obj") + save_obj( + obj_file, + verts, + faces, + decimal_places=2, + normals=normals, + faces_normals_idx=faces_normals_idx, + ) + + expected_obj_file = "\n".join( + [ + "v 0.01 0.20 0.30", + "v 0.20 0.03 0.41", + "v 0.30 0.40 0.05", + "v 0.60 0.70 0.80", + "vn 0.02 0.50 0.73", + "vn 0.30 0.03 0.36", + "vn 0.32 0.12 0.47", + "vn 0.36 0.17 0.90", + "vn 0.40 0.70 0.19", + "vn 1.00 0.00 0.00", + "vn 0.00 1.00 0.00", + "vn 0.00 0.00 1.00", + "f 1//1 3//2 2//3", + "f 1//3 2//4 3//5", + "f 4//5 3//6 2//7", + "f 4//7 2//8 1//1", + ] + ) + + # Check the obj file is saved correctly + actual_file = open(obj_file, "r") + self.assertEqual(actual_file.read(), expected_obj_file) + def test_save_obj_with_texture(self): verts = torch.tensor( [[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]], @@ -962,6 +1015,84 @@ def test_save_obj_with_texture(self): texture_image = load_rgb_image("mesh.png", temp_dir) self.assertClose(texture_image, texture_map) + def test_save_obj_with_normal_and_texture(self): + verts = torch.tensor( + [[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]], + dtype=torch.float32, + ) + faces = torch.tensor( + [[0, 2, 1], [0, 1, 2], [3, 2, 1], [3, 1, 0]], dtype=torch.int64 + ) + normals = torch.tensor( + [[0.02, 0.5, 0.73], [0.3, 0.03, 0.361], [0.32, 0.12, 0.47], [0.36, 0.17, 0.9]], + dtype=torch.float32, + ) + faces_normals_idx = faces + verts_uvs = torch.tensor( + [[0.02, 0.5], [0.3, 0.03], [0.32, 0.12], [0.36, 0.17]], + dtype=torch.float32, + ) + faces_uvs = faces + texture_map = torch.randint(size=(2, 2, 3), high=255) / 255.0 + + with TemporaryDirectory() as temp_dir: + obj_file = os.path.join(temp_dir, "mesh.obj") + save_obj( + obj_file, + verts, + faces, + decimal_places=2, + normals=normals, + faces_normals_idx=faces_normals_idx, + verts_uvs=verts_uvs, + faces_uvs=faces_uvs, + texture_map=texture_map, + ) + + expected_obj_file = "\n".join( + [ + "", + "mtllib mesh.mtl", + "usemtl mesh", + "", + "v 0.01 0.20 0.30", + "v 0.20 0.03 0.41", + "v 0.30 0.40 0.05", + "v 0.60 0.70 0.80", + "vn 0.02 0.50 0.73", + "vn 0.30 0.03 0.36", + "vn 0.32 0.12 0.47", + "vn 0.36 0.17 0.90", + "vt 0.02 0.50", + "vt 0.30 0.03", + "vt 0.32 0.12", + "vt 0.36 0.17", + "f 1/1/1 3/3/3 2/2/2", + "f 1/1/1 2/2/2 3/3/3", + "f 4/4/4 3/3/3 2/2/2", + "f 4/4/4 2/2/2 1/1/1", + ] + ) + expected_mtl_file = "\n".join(["newmtl mesh", "map_Kd mesh.png", ""]) + + # Check there are only 3 files in the temp dir + tempfiles = ["mesh.obj", "mesh.png", "mesh.mtl"] + tempfiles_dir = os.listdir(temp_dir) + self.assertEqual(Counter(tempfiles), Counter(tempfiles_dir)) + + # Check the obj file is saved correctly + actual_file = open(obj_file, "r") + self.assertEqual(actual_file.read(), expected_obj_file) + + # Check the mtl file is saved correctly + mtl_file_name = os.path.join(temp_dir, "mesh.mtl") + mtl_file = open(mtl_file_name, "r") + self.assertEqual(mtl_file.read(), expected_mtl_file) + + # Check the texture image file is saved correctly + texture_image = load_rgb_image("mesh.png", temp_dir) + self.assertClose(texture_image, texture_map) + def test_save_obj_with_texture_errors(self): verts = torch.tensor( [[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]],