|
36 | 36 | import numpy as np
|
37 | 37 | import torch
|
38 | 38 | from common_testing import TestCaseMixin
|
| 39 | +from pytorch3d.renderer.camera_utils import join_cameras_as_batch |
39 | 40 | from pytorch3d.renderer.cameras import (
|
40 | 41 | CamerasBase,
|
41 | 42 | FoVOrthographicCameras,
|
@@ -688,6 +689,99 @@ def test_clone(self, batch_size: int = 10):
|
688 | 689 | else:
|
689 | 690 | self.assertTrue(val == val_clone)
|
690 | 691 |
|
| 692 | + def test_join_cameras_as_batch_errors(self): |
| 693 | + cam0 = PerspectiveCameras(device="cuda:0") |
| 694 | + cam1 = OrthographicCameras(device="cuda:0") |
| 695 | + |
| 696 | + # Cameras not of the same type |
| 697 | + with self.assertRaisesRegex(ValueError, "same type"): |
| 698 | + join_cameras_as_batch([cam0, cam1]) |
| 699 | + |
| 700 | + cam2 = OrthographicCameras(device="cpu") |
| 701 | + # Cameras not on the same device |
| 702 | + with self.assertRaisesRegex(ValueError, "same device"): |
| 703 | + join_cameras_as_batch([cam1, cam2]) |
| 704 | + |
| 705 | + cam3 = OrthographicCameras(in_ndc=False, device="cuda:0") |
| 706 | + # Different coordinate systems -- all should be in ndc or in screen |
| 707 | + with self.assertRaisesRegex( |
| 708 | + ValueError, "Attribute _in_ndc is not constant across inputs" |
| 709 | + ): |
| 710 | + join_cameras_as_batch([cam1, cam3]) |
| 711 | + |
| 712 | + def join_cameras_as_batch_fov(self, camera_cls): |
| 713 | + R0 = torch.randn((6, 3, 3)) |
| 714 | + R1 = torch.randn((3, 3, 3)) |
| 715 | + cam0 = camera_cls(znear=10.0, zfar=100.0, R=R0, device="cuda:0") |
| 716 | + cam1 = camera_cls(znear=10.0, zfar=200.0, R=R1, device="cuda:0") |
| 717 | + |
| 718 | + cam_batch = join_cameras_as_batch([cam0, cam1]) |
| 719 | + |
| 720 | + self.assertEqual(cam_batch._N, cam0._N + cam1._N) |
| 721 | + self.assertEqual(cam_batch.device, cam0.device) |
| 722 | + self.assertClose(cam_batch.R, torch.cat((R0, R1), dim=0).to(device="cuda:0")) |
| 723 | + |
| 724 | + def join_cameras_as_batch(self, camera_cls): |
| 725 | + R0 = torch.randn((6, 3, 3)) |
| 726 | + R1 = torch.randn((3, 3, 3)) |
| 727 | + p0 = torch.randn((6, 2, 1)) |
| 728 | + p1 = torch.randn((3, 2, 1)) |
| 729 | + f0 = 5.0 |
| 730 | + f1 = torch.randn(3, 2) |
| 731 | + f2 = torch.randn(3, 1) |
| 732 | + cam0 = camera_cls( |
| 733 | + R=R0, |
| 734 | + focal_length=f0, |
| 735 | + principal_point=p0, |
| 736 | + ) |
| 737 | + cam1 = camera_cls( |
| 738 | + R=R1, |
| 739 | + focal_length=f0, |
| 740 | + principal_point=p1, |
| 741 | + ) |
| 742 | + cam2 = camera_cls( |
| 743 | + R=R1, |
| 744 | + focal_length=f1, |
| 745 | + principal_point=p1, |
| 746 | + ) |
| 747 | + cam3 = camera_cls( |
| 748 | + R=R1, |
| 749 | + focal_length=f2, |
| 750 | + principal_point=p1, |
| 751 | + ) |
| 752 | + cam_batch = join_cameras_as_batch([cam0, cam1]) |
| 753 | + |
| 754 | + self.assertEqual(cam_batch._N, cam0._N + cam1._N) |
| 755 | + self.assertEqual(cam_batch.device, cam0.device) |
| 756 | + self.assertClose(cam_batch.R, torch.cat((R0, R1), dim=0)) |
| 757 | + self.assertClose(cam_batch.principal_point, torch.cat((p0, p1), dim=0)) |
| 758 | + self.assertEqual(cam_batch._in_ndc, cam0._in_ndc) |
| 759 | + |
| 760 | + # Test one broadcasted value and one fixed value |
| 761 | + # Focal length as (N,) in one camera and (N, 2) in the other |
| 762 | + cam_batch = join_cameras_as_batch([cam0, cam2]) |
| 763 | + self.assertEqual(cam_batch._N, cam0._N + cam2._N) |
| 764 | + self.assertClose(cam_batch.R, torch.cat((R0, R1), dim=0)) |
| 765 | + self.assertClose( |
| 766 | + cam_batch.focal_length, |
| 767 | + torch.cat([torch.tensor([[f0, f0]]).expand(6, -1), f1], dim=0), |
| 768 | + ) |
| 769 | + |
| 770 | + # Focal length as (N, 1) in one camera and (N, 2) in the other |
| 771 | + cam_batch = join_cameras_as_batch([cam2, cam3]) |
| 772 | + self.assertClose( |
| 773 | + cam_batch.focal_length, |
| 774 | + torch.cat([f1, f2.expand(-1, 2)], dim=0), |
| 775 | + ) |
| 776 | + |
| 777 | + def test_join_batch_perspective(self): |
| 778 | + self.join_cameras_as_batch_fov(FoVPerspectiveCameras) |
| 779 | + self.join_cameras_as_batch(PerspectiveCameras) |
| 780 | + |
| 781 | + def test_join_batch_orthographic(self): |
| 782 | + self.join_cameras_as_batch_fov(FoVOrthographicCameras) |
| 783 | + self.join_cameras_as_batch(OrthographicCameras) |
| 784 | + |
691 | 785 |
|
692 | 786 | ############################################################
|
693 | 787 | # FoVPerspective Camera #
|
@@ -1055,7 +1149,7 @@ def test_getitem(self):
|
1055 | 1149 | index = torch.tensor([1, 3, 5], dtype=torch.int64)
|
1056 | 1150 | c135 = cam[index]
|
1057 | 1151 | self.assertEqual(len(c135), 3)
|
1058 |
| - self.assertClose(c135.focal_length, torch.tensor([5.0] * 3)) |
| 1152 | + self.assertClose(c135.focal_length, torch.tensor([[5.0, 5.0]] * 3)) |
1059 | 1153 | self.assertClose(c135.R, R_matrix[[1, 3, 5], ...])
|
1060 | 1154 | self.assertClose(c135.principal_point, principal_point[[1, 3, 5], ...])
|
1061 | 1155 |
|
@@ -1131,7 +1225,7 @@ def test_getitem(self):
|
1131 | 1225 | index = torch.tensor([1, 3, 5], dtype=torch.int64)
|
1132 | 1226 | c135 = cam[index]
|
1133 | 1227 | self.assertEqual(len(c135), 3)
|
1134 |
| - self.assertClose(c135.focal_length, torch.tensor([5.0] * 3)) |
| 1228 | + self.assertClose(c135.focal_length, torch.tensor([[5.0, 5.0]] * 3)) |
1135 | 1229 | self.assertClose(c135.R, R_matrix[[1, 3, 5], ...])
|
1136 | 1230 | self.assertClose(c135.principal_point, principal_point[[1, 3, 5], ...])
|
1137 | 1231 |
|
|
0 commit comments