Skip to content

Commit d281f8e

Browse files
shapovalovfacebook-github-bot
authored andcommitted
Filtering outlier input cameras in trajectory estimation
Summary: Useful for visualising colmap output where some frames are not correctly registered. Reviewed By: bottler Differential Revision: D38743191 fbshipit-source-id: e823df2997870dc41d76784e112d4349f904d311
1 parent b7c826b commit d281f8e

File tree

2 files changed

+48
-1
lines changed

2 files changed

+48
-1
lines changed

pytorch3d/implicitron/tools/eval_video_trajectory.py

+27
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,21 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import logging
78
import math
89
from typing import Optional, Tuple
910

1011
import torch
1112
from pytorch3d.common.compat import eigh
13+
from pytorch3d.implicitron.tools import utils
1214
from pytorch3d.implicitron.tools.circle_fitting import fit_circle_in_3d
1315
from pytorch3d.renderer import look_at_view_transform, PerspectiveCameras
1416
from pytorch3d.transforms import Scale
1517

1618

19+
logger = logging.getLogger(__name__)
20+
21+
1722
def generate_eval_video_cameras(
1823
train_cameras,
1924
n_eval_cams: int = 100,
@@ -27,6 +32,7 @@ def generate_eval_video_cameras(
2732
infer_up_as_plane_normal: bool = True,
2833
traj_offset: Optional[Tuple[float, float, float]] = None,
2934
traj_offset_canonical: Optional[Tuple[float, float, float]] = None,
35+
remove_outliers_rate: float = 0.0,
3036
) -> PerspectiveCameras:
3137
"""
3238
Generate a camera trajectory rendering a scene from multiple viewpoints.
@@ -50,9 +56,16 @@ def generate_eval_video_cameras(
5056
Active for the `trajectory_type="circular"`.
5157
scene_center: The center of the scene in world coordinates which all
5258
the cameras from the generated trajectory look at.
59+
remove_outliers_rate: the number between 0 and 1; if > 0,
60+
some outlier train_cameras will be removed from trajectory estimation;
61+
the filtering is based on camera center coordinates; top and
62+
bottom `remove_outliers_rate` cameras on each dimension are removed.
5363
Returns:
5464
Dictionary of camera instances which can be used as the test dataset
5565
"""
66+
if remove_outliers_rate > 0.0:
67+
train_cameras = _remove_outlier_cameras(train_cameras, remove_outliers_rate)
68+
5669
if trajectory_type in ("figure_eight", "trefoil_knot", "figure_eight_knot"):
5770
cam_centers = train_cameras.get_camera_center()
5871
# get the nearest camera center to the mean of centers
@@ -167,6 +180,20 @@ def generate_eval_video_cameras(
167180
return test_cameras
168181

169182

183+
def _remove_outlier_cameras(
184+
cameras: PerspectiveCameras, outlier_rate: float
185+
) -> PerspectiveCameras:
186+
keep_indices = utils.get_inlier_indicators(
187+
cameras.get_camera_center(), dim=0, outlier_rate=outlier_rate
188+
)
189+
clean_cameras = cameras[keep_indices]
190+
logger.info(
191+
"Filtered outlier cameras when estimating the trajectory: "
192+
f"{len(cameras)}{len(clean_cameras)}"
193+
)
194+
return clean_cameras
195+
196+
170197
def _disambiguate_normal(normal, up):
171198
up_t = torch.tensor(up).to(normal)
172199
flip = (up_t * normal).sum().sign()

pytorch3d/implicitron/tools/utils.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import dataclasses
1010
import time
1111
from contextlib import contextmanager
12-
from typing import Any, Callable, Dict
12+
from typing import Any, Callable, Dict, Iterable, Iterator
1313

1414
import torch
1515

@@ -157,6 +157,26 @@ def cat_dataclass(batch, tensor_collator: Callable):
157157
return type(elem)(**collated)
158158

159159

160+
def recursive_visitor(it: Iterable[Any]) -> Iterator[Any]:
161+
for x in it:
162+
if isinstance(x, Iterable) and not isinstance(x, (str, bytes)):
163+
yield from recursive_visitor(x)
164+
else:
165+
yield x
166+
167+
168+
def get_inlier_indicators(
169+
tensor: torch.Tensor, dim: int, outlier_rate: float
170+
) -> torch.Tensor:
171+
remove_elements = int(min(outlier_rate, 1.0) * tensor.shape[dim] / 2)
172+
hi = torch.topk(tensor, remove_elements, dim=dim).indices.tolist()
173+
lo = torch.topk(-tensor, remove_elements, dim=dim).indices.tolist()
174+
remove_indices = set(recursive_visitor([hi, lo]))
175+
keep_indices = tensor.new_ones(tensor.shape[dim : dim + 1], dtype=torch.bool)
176+
keep_indices[list(remove_indices)] = False
177+
return keep_indices
178+
179+
160180
class Timer:
161181
"""
162182
A simple class for timing execution.

0 commit comments

Comments
 (0)