Skip to content

Commit 3d011a9

Browse files
EmGarrfacebook-github-bot
authored andcommitted
Adapt RayPointRefiner and RayMarcher to support bins.
Summary: ## Context Bins are used in mipnerf to allow to manipulate easily intervals. For example, by doing the following, `bins[..., :-1]` you will obtain all the left coordinates of your intervals, while doing `bins[..., 1:]` is equals to the right coordinates of your intervals. We introduce here the support of bins like in MipNerf implementation. ## RayPointRefiner Small changes have been made to modify RayPointRefiner. - If bins is None ``` mids = torch.lerp(ray_bundle.lengths[..., 1:], ray_bundle.lengths[…, :-1], 0.5) z_samples = sample_pdf( mids, # [..., npt] weights[..., 1:-1], # [..., npt - 1] …. ) ``` - If bins is not None In the MipNerf implementation the sampling is done on all the bins. It allows us to use the full weights tensor without slashing it. ``` z_samples = sample_pdf( ray_bundle.bins, # [..., npt + 1] weights, # [..., npt] ... ) ``` ## RayMarcher Add a ray_deltas optional argument. If None, keep the same deltas computation from ray_lengths. Reviewed By: shapovalov Differential Revision: D46389092 fbshipit-source-id: d4f1963310065bd31c1c7fac1adfe11cbeaba606
1 parent 5910d81 commit 3d011a9

File tree

5 files changed

+107
-18
lines changed

5 files changed

+107
-18
lines changed

pytorch3d/implicitron/models/renderer/multipass_ea.py

+4
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,13 @@ def _run_raymarcher(
157157
else 0.0
158158
)
159159

160+
ray_deltas = (
161+
None if ray_bundle.bins is None else torch.diff(ray_bundle.bins, dim=-1)
162+
)
160163
output = self.raymarcher(
161164
*implicit_functions[0](ray_bundle=ray_bundle),
162165
ray_lengths=ray_bundle.lengths,
166+
ray_deltas=ray_deltas,
163167
density_noise_std=density_noise_std,
164168
)
165169
output.prev_stage = prev_stage

pytorch3d/implicitron/models/renderer/ray_point_refiner.py

+22-9
Original file line numberDiff line numberDiff line change
@@ -78,29 +78,42 @@ def forward(
7878
7979
"""
8080

81-
z_vals = input_ray_bundle.lengths
8281
with torch.no_grad():
8382
if self.blurpool_weights:
8483
ray_weights = apply_blurpool_on_weights(ray_weights)
8584

86-
z_vals_mid = torch.lerp(z_vals[..., 1:], z_vals[..., :-1], 0.5)
85+
n_pts_per_ray = self.n_pts_per_ray
86+
ray_weights = ray_weights.view(-1, ray_weights.shape[-1])
87+
if input_ray_bundle.bins is None:
88+
z_vals: torch.Tensor = input_ray_bundle.lengths
89+
ray_weights = ray_weights[..., 1:-1]
90+
bins = torch.lerp(z_vals[..., 1:], z_vals[..., :-1], 0.5)
91+
else:
92+
z_vals = input_ray_bundle.bins
93+
n_pts_per_ray += 1
94+
bins = z_vals
8795
z_samples = sample_pdf(
88-
z_vals_mid.view(-1, z_vals_mid.shape[-1]),
89-
ray_weights.view(-1, ray_weights.shape[-1])[..., 1:-1],
90-
self.n_pts_per_ray,
96+
bins.view(-1, bins.shape[-1]),
97+
ray_weights,
98+
n_pts_per_ray,
9199
det=not self.random_sampling,
92100
eps=self.sample_pdf_eps,
93-
).view(*z_vals.shape[:-1], self.n_pts_per_ray)
101+
).view(*z_vals.shape[:-1], n_pts_per_ray)
102+
94103
if self.add_input_samples:
95104
z_vals = torch.cat((z_vals, z_samples), dim=-1)
96105
else:
97106
z_vals = z_samples
98107
# Resort by depth.
99108
z_vals, _ = torch.sort(z_vals, dim=-1)
100109

101-
new_bundle = ImplicitronRayBundle(**vars(input_ray_bundle))
102-
new_bundle.lengths = z_vals
103-
return new_bundle
110+
kwargs_ray = dict(vars(input_ray_bundle))
111+
if input_ray_bundle.bins is None:
112+
kwargs_ray["lengths"] = z_vals
113+
return ImplicitronRayBundle(**kwargs_ray)
114+
kwargs_ray["bins"] = z_vals
115+
del kwargs_ray["lengths"]
116+
return ImplicitronRayBundle.from_bins(**kwargs_ray)
104117

105118

106119
def apply_blurpool_on_weights(weights) -> torch.Tensor:

pytorch3d/implicitron/models/renderer/raymarcher.py

+15-8
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Any, Callable, Dict, Tuple
7+
from typing import Any, Callable, Dict, Optional, Tuple
88

99
import torch
1010
from pytorch3d.implicitron.models.renderer.base import RendererOutput
@@ -119,6 +119,7 @@ def forward(
119119
rays_features: torch.Tensor,
120120
aux: Dict[str, Any],
121121
ray_lengths: torch.Tensor,
122+
ray_deltas: Optional[torch.Tensor] = None,
122123
density_noise_std: float = 0.0,
123124
**kwargs,
124125
) -> RendererOutput:
@@ -131,6 +132,9 @@ def forward(
131132
aux: a dictionary with extra information.
132133
ray_lengths: Per-ray depth values represented with a tensor
133134
of shape `(..., n_points_per_ray, feature_dim)`.
135+
ray_deltas: Optional differences between consecutive elements along the ray bundle
136+
represented with a tensor of shape `(..., n_points_per_ray)`. If None,
137+
these differences are computed from ray_lengths.
134138
density_noise_std: the magnitude of the noise added to densities.
135139
136140
Returns:
@@ -152,14 +156,17 @@ def forward(
152156
density_1d=True,
153157
)
154158

155-
ray_lengths_diffs = ray_lengths[..., 1:] - ray_lengths[..., :-1]
156-
if self.replicate_last_interval:
157-
last_interval = ray_lengths_diffs[..., -1:]
159+
if ray_deltas is None:
160+
ray_lengths_diffs = torch.diff(ray_lengths, dim=-1)
161+
if self.replicate_last_interval:
162+
last_interval = ray_lengths_diffs[..., -1:]
163+
else:
164+
last_interval = torch.full_like(
165+
ray_lengths[..., :1], self.background_opacity
166+
)
167+
deltas = torch.cat((ray_lengths_diffs, last_interval), dim=-1)
158168
else:
159-
last_interval = torch.full_like(
160-
ray_lengths[..., :1], self.background_opacity
161-
)
162-
deltas = torch.cat((ray_lengths_diffs, last_interval), dim=-1)
169+
deltas = ray_deltas
163170

164171
rays_densities = rays_densities[..., 0]
165172

pytorch3d/renderer/implicit/harmonic_embedding.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __init__(
2424
and the integrated position encoding in
2525
`MIP-NeRF <https://arxiv.org/abs/2103.13415>`_.
2626
27-
During, the inference you can provide the extra argument `diag_cov`.
27+
During the inference you can provide the extra argument `diag_cov`.
2828
2929
If `diag_cov is None`, it converts
3030
rays parametrized with a `ray_bundle` to 3D points by

tests/implicitron/test_ray_point_refiner.py

+65
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,71 @@ def test_simple(self):
7070
(lengths_random[..., 1:] - lengths_random[..., :-1] > 0).all()
7171
)
7272

73+
def test_simple_use_bins(self):
74+
"""
75+
Same spirit than test_simple but use bins in the ImplicitronRayBunle.
76+
It has been duplicated to avoid cognitive overload while reading the
77+
test (lot of if else).
78+
"""
79+
length = 15
80+
n_pts_per_ray = 10
81+
82+
for add_input_samples, use_blurpool in product([False, True], [False, True]):
83+
ray_point_refiner = RayPointRefiner(
84+
n_pts_per_ray=n_pts_per_ray,
85+
random_sampling=False,
86+
add_input_samples=add_input_samples,
87+
)
88+
89+
bundle = ImplicitronRayBundle(
90+
lengths=None,
91+
bins=torch.arange(length + 1, dtype=torch.float32).expand(
92+
3, 25, length + 1
93+
),
94+
origins=None,
95+
directions=None,
96+
xys=None,
97+
camera_ids=None,
98+
camera_counts=None,
99+
)
100+
weights = torch.ones(3, 25, length)
101+
refined = ray_point_refiner(bundle, weights, blurpool_weights=use_blurpool)
102+
103+
self.assertIsNone(refined.directions)
104+
self.assertIsNone(refined.origins)
105+
self.assertIsNone(refined.xys)
106+
expected_bins = torch.linspace(0, length, n_pts_per_ray + 1)
107+
expected_bins = expected_bins.expand(3, 25, n_pts_per_ray + 1)
108+
if add_input_samples:
109+
expected_bins = torch.cat((bundle.bins, expected_bins), dim=-1).sort()[
110+
0
111+
]
112+
full_expected = torch.lerp(
113+
expected_bins[..., :-1], expected_bins[..., 1:], 0.5
114+
)
115+
116+
self.assertClose(refined.lengths, full_expected)
117+
118+
ray_point_refiner_random = RayPointRefiner(
119+
n_pts_per_ray=n_pts_per_ray,
120+
random_sampling=True,
121+
add_input_samples=add_input_samples,
122+
)
123+
124+
refined_random = ray_point_refiner_random(
125+
bundle, weights, blurpool_weights=use_blurpool
126+
)
127+
lengths_random = refined_random.lengths
128+
self.assertEqual(lengths_random.shape, full_expected.shape)
129+
if not add_input_samples:
130+
self.assertGreater(lengths_random.min().item(), 0)
131+
self.assertLess(lengths_random.max().item(), length)
132+
133+
# Check sorted
134+
self.assertTrue(
135+
(lengths_random[..., 1:] - lengths_random[..., :-1] > 0).all()
136+
)
137+
73138
def test_apply_blurpool_on_weights(self):
74139
weights = torch.tensor(
75140
[

0 commit comments

Comments
 (0)