Skip to content

Commit 7e0146e

Browse files
d4l3kfacebook-github-bot
authored andcommitted
shader: add SoftDepthShader and HardDepthShader for rendering depth maps (#36)
Summary: X-link: fairinternal/pytorch3d#36 This adds two shaders for rendering depth maps for meshes. This is useful for structure from motion applications that learn depths based off of camera pair disparities. There's two shaders, one hard which just returns the distances and then a second that does a cumsum on the probabilities of the points with a weighted sum. Areas that don't have any z faces are set to the zfar distance. Output from this renderer is `[N, H, W]` since it's just depth no need for channels. I haven't tested this in an ML model yet just in a notebook. hard: ![hardzshader](https://user-images.githubusercontent.com/909104/170190363-ef662c97-0bd2-488c-8675-0557a3c7dd06.png) soft: ![softzshader](https://user-images.githubusercontent.com/909104/170190365-65b08cd7-0c49-4119-803e-d33c1d8c676e.png) Pull Request resolved: #1208 Reviewed By: bottler Differential Revision: D36682194 Pulled By: d4l3k fbshipit-source-id: 5d4e10c6fb0fff5427be4ddd3bd76305a7ccc1e2
1 parent 0e4c53c commit 7e0146e

File tree

2 files changed

+72
-0
lines changed

2 files changed

+72
-0
lines changed

pytorch3d/renderer/mesh/shader.py

+68
Original file line numberDiff line numberDiff line change
@@ -353,3 +353,71 @@ def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tenso
353353
)
354354

355355
return images
356+
357+
358+
class HardDepthShader(ShaderBase):
359+
"""
360+
Renders the Z distances of the closest face for each pixel. If no face is
361+
found it returns the zfar value of the camera.
362+
363+
Output from this shader is [N, H, W, 1] since it's only depth.
364+
365+
To use the default values, simply initialize the shader with the desired
366+
device e.g.
367+
368+
.. code-block::
369+
370+
shader = HardDepthShader(device=torch.device("cuda:0"))
371+
"""
372+
373+
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
374+
cameras = super()._get_cameras(**kwargs)
375+
376+
zfar = kwargs.get("zfar", getattr(cameras, "zfar", 100.0))
377+
mask = fragments.pix_to_face < 0
378+
379+
zbuf = fragments.zbuf[..., 0].clone()
380+
zbuf[mask] = zfar
381+
return zbuf.unsqueeze(3)
382+
383+
384+
class SoftDepthShader(ShaderBase):
385+
"""
386+
Renders the Z distances using an aggregate of the distances of each face
387+
based off of the point distance. If no face is found it returns the zfar
388+
value of the camera.
389+
390+
Output from this shader is [N, H, W, 1] since it's only depth.
391+
392+
To use the default values, simply initialize the shader with the desired
393+
device e.g.
394+
395+
.. code-block::
396+
397+
shader = SoftDepthShader(device=torch.device("cuda:0"))
398+
"""
399+
400+
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
401+
cameras = super()._get_cameras(**kwargs)
402+
403+
N, H, W, K = fragments.pix_to_face.shape
404+
device = fragments.zbuf.device
405+
mask = fragments.pix_to_face >= 0
406+
407+
zfar = kwargs.get("zfar", getattr(cameras, "zfar", 100.0))
408+
409+
# Sigmoid probability map based on the distance of the pixel to the face.
410+
prob_map = torch.sigmoid(-fragments.dists / self.blend_params.sigma) * mask
411+
412+
# append extra face for zfar
413+
dists = torch.cat(
414+
(fragments.zbuf, torch.ones((N, H, W, 1), device=device) * zfar), dim=3
415+
)
416+
probs = torch.cat((prob_map, torch.ones((N, H, W, 1), device=device)), dim=3)
417+
418+
# compute weighting based off of probabilities using cumsum
419+
probs = probs.cumsum(dim=3)
420+
probs = probs.clamp(max=1)
421+
probs = probs.diff(dim=3, prepend=torch.zeros((N, H, W, 1), device=device))
422+
423+
return (probs * dists).sum(dim=3).unsqueeze(3)

tests/test_shader.py

+4
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@
1010
from pytorch3d.renderer.cameras import look_at_view_transform, PerspectiveCameras
1111
from pytorch3d.renderer.mesh.rasterizer import Fragments
1212
from pytorch3d.renderer.mesh.shader import (
13+
HardDepthShader,
1314
HardFlatShader,
1415
HardGouraudShader,
1516
HardPhongShader,
17+
SoftDepthShader,
1618
SoftPhongShader,
1719
SplatterPhongShader,
1820
)
@@ -24,9 +26,11 @@
2426
class TestShader(TestCaseMixin, unittest.TestCase):
2527
def setUp(self):
2628
self.shader_classes = [
29+
HardDepthShader,
2730
HardFlatShader,
2831
HardGouraudShader,
2932
HardPhongShader,
33+
SoftDepthShader,
3034
SoftPhongShader,
3135
SplatterPhongShader,
3236
]

0 commit comments

Comments
 (0)