|
7 | 7 | from pathlib import Path
|
8 | 8 | import torch
|
9 | 9 |
|
10 |
| -from pytorch3d.io import load_obj, save_obj |
| 10 | +from pytorch3d.io import load_obj, load_objs_as_meshes, save_obj |
| 11 | +from pytorch3d.structures import Meshes, Textures, join_meshes |
11 | 12 |
|
| 13 | +from common_testing import TestCaseMixin |
12 | 14 |
|
13 |
| -class TestMeshObjIO(unittest.TestCase): |
| 15 | + |
| 16 | +class TestMeshObjIO(TestCaseMixin, unittest.TestCase): |
14 | 17 | def test_load_obj_simple(self):
|
15 | 18 | obj_file = "\n".join(
|
16 | 19 | [
|
@@ -517,6 +520,88 @@ def test_load_obj_missing_mtl_noload(self):
|
517 | 520 | self.assertTrue(aux.material_colors is None)
|
518 | 521 | self.assertTrue(aux.texture_images is None)
|
519 | 522 |
|
| 523 | + def test_join_meshes(self): |
| 524 | + """ |
| 525 | + Test that join_meshes and load_objs_as_meshes are consistent with single |
| 526 | + meshes. |
| 527 | + """ |
| 528 | + |
| 529 | + def check_triple(mesh, mesh3): |
| 530 | + """ |
| 531 | + Verify that mesh3 is three copies of mesh. |
| 532 | + """ |
| 533 | + |
| 534 | + def check_item(x, y): |
| 535 | + self.assertEqual(x is None, y is None) |
| 536 | + if x is not None: |
| 537 | + self.assertClose(torch.cat([x, x, x]), y) |
| 538 | + |
| 539 | + check_item(mesh.verts_padded(), mesh3.verts_padded()) |
| 540 | + check_item(mesh.faces_padded(), mesh3.faces_padded()) |
| 541 | + if mesh.textures is not None: |
| 542 | + check_item( |
| 543 | + mesh.textures.maps_padded(), mesh3.textures.maps_padded() |
| 544 | + ) |
| 545 | + check_item( |
| 546 | + mesh.textures.faces_uvs_padded(), |
| 547 | + mesh3.textures.faces_uvs_padded(), |
| 548 | + ) |
| 549 | + check_item( |
| 550 | + mesh.textures.verts_uvs_padded(), |
| 551 | + mesh3.textures.verts_uvs_padded(), |
| 552 | + ) |
| 553 | + check_item( |
| 554 | + mesh.textures.verts_rgb_padded(), |
| 555 | + mesh3.textures.verts_rgb_padded(), |
| 556 | + ) |
| 557 | + |
| 558 | + DATA_DIR = ( |
| 559 | + Path(__file__).resolve().parent.parent / "docs/tutorials/data" |
| 560 | + ) |
| 561 | + obj_filename = DATA_DIR / "cow_mesh/cow.obj" |
| 562 | + |
| 563 | + mesh = load_objs_as_meshes([obj_filename]) |
| 564 | + mesh3 = load_objs_as_meshes([obj_filename, obj_filename, obj_filename]) |
| 565 | + check_triple(mesh, mesh3) |
| 566 | + self.assertTupleEqual( |
| 567 | + mesh.textures.maps_padded().shape, (1, 1024, 1024, 3) |
| 568 | + ) |
| 569 | + |
| 570 | + mesh_notex = load_objs_as_meshes([obj_filename], load_textures=False) |
| 571 | + mesh3_notex = load_objs_as_meshes( |
| 572 | + [obj_filename, obj_filename, obj_filename], load_textures=False |
| 573 | + ) |
| 574 | + check_triple(mesh_notex, mesh3_notex) |
| 575 | + self.assertIsNone(mesh_notex.textures) |
| 576 | + |
| 577 | + verts = torch.randn((4, 3), dtype=torch.float32) |
| 578 | + faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64) |
| 579 | + vert_tex = torch.tensor( |
| 580 | + [[0, 1, 0], [0, 1, 1], [1, 1, 0], [1, 1, 1]], dtype=torch.float32 |
| 581 | + ) |
| 582 | + tex = Textures(verts_rgb=vert_tex[None, :]) |
| 583 | + mesh_rgb = Meshes(verts=[verts], faces=[faces], textures=tex) |
| 584 | + mesh_rgb3 = join_meshes([mesh_rgb, mesh_rgb, mesh_rgb]) |
| 585 | + check_triple(mesh_rgb, mesh_rgb3) |
| 586 | + |
| 587 | + teapot_obj = DATA_DIR / "teapot.obj" |
| 588 | + mesh_teapot = load_objs_as_meshes([teapot_obj]) |
| 589 | + teapot_verts, teapot_faces = mesh_teapot.get_mesh_verts_faces(0) |
| 590 | + mix_mesh = load_objs_as_meshes( |
| 591 | + [obj_filename, teapot_obj], load_textures=False |
| 592 | + ) |
| 593 | + self.assertEqual(len(mix_mesh), 2) |
| 594 | + self.assertClose(mix_mesh.verts_list()[0], mesh.verts_list()[0]) |
| 595 | + self.assertClose(mix_mesh.faces_list()[0], mesh.faces_list()[0]) |
| 596 | + self.assertClose(mix_mesh.verts_list()[1], teapot_verts) |
| 597 | + self.assertClose(mix_mesh.faces_list()[1], teapot_faces) |
| 598 | + |
| 599 | + cow3_tea = join_meshes([mesh3, mesh_teapot], include_textures=False) |
| 600 | + self.assertEqual(len(cow3_tea), 4) |
| 601 | + check_triple(mesh_notex, cow3_tea[:3]) |
| 602 | + self.assertClose(cow3_tea.verts_list()[3], mesh_teapot.verts_list()[0]) |
| 603 | + self.assertClose(cow3_tea.faces_list()[3], mesh_teapot.faces_list()[0]) |
| 604 | + |
520 | 605 | @staticmethod
|
521 | 606 | def save_obj_with_init(V: int, F: int):
|
522 | 607 | verts_list = torch.tensor(V * [[0.11, 0.22, 0.33]]).view(-1, 3)
|
|
0 commit comments