Skip to content

Commit c21ba14

Browse files
Krzysztof Chalupkafacebook-github-bot
Krzysztof Chalupka
authored andcommitted
Add Fragments.detach()
Summary: Add a capability to detach all detachable tensors in Fragments. Reviewed By: bottler Differential Revision: D35918133 fbshipit-source-id: 03b5d4491a3a6791b0a7bc9119f26c1a7aa43196
1 parent d737a05 commit c21ba14

File tree

1 file changed

+46
-3
lines changed

1 file changed

+46
-3
lines changed

pytorch3d/renderer/mesh/rasterizer.py

+46-3
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,64 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from dataclasses import dataclass
8-
from typing import NamedTuple, Optional, Tuple, Union
8+
from typing import Optional, Tuple, Union
99

1010
import torch
1111
import torch.nn as nn
1212

1313
from .rasterize_meshes import rasterize_meshes
1414

1515

16-
# Class to store the outputs of mesh rasterization
17-
class Fragments(NamedTuple):
16+
@dataclass(frozen=True)
17+
class Fragments:
18+
"""
19+
Members:
20+
pix_to_face:
21+
LongTensor of shape (N, image_size, image_size, faces_per_pixel) giving
22+
the indices of the nearest faces at each pixel, sorted in ascending
23+
z-order. Concretely ``pix_to_face[n, y, x, k] = f`` means that
24+
``faces_verts[f]`` is the kth closest face (in the z-direction) to pixel
25+
(y, x). Pixels that are hit by fewer than faces_per_pixel are padded with
26+
-1.
27+
28+
zbuf:
29+
FloatTensor of shape (N, image_size, image_size, faces_per_pixel) giving
30+
the NDC z-coordinates of the nearest faces at each pixel, sorted in
31+
ascending z-order. Concretely, if ``pix_to_face[n, y, x, k] = f`` then
32+
``zbuf[n, y, x, k] = face_verts[f, 2]``. Pixels hit by fewer than
33+
faces_per_pixel are padded with -1.
34+
35+
bary_coords:
36+
FloatTensor of shape (N, image_size, image_size, faces_per_pixel, 3)
37+
giving the barycentric coordinates in NDC units of the nearest faces at
38+
each pixel, sorted in ascending z-order. Concretely, if ``pix_to_face[n,
39+
y, x, k] = f`` then ``[w0, w1, w2] = barycentric[n, y, x, k]`` gives the
40+
barycentric coords for pixel (y, x) relative to the face defined by
41+
``face_verts[f]``. Pixels hit by fewer than faces_per_pixel are padded
42+
with -1.
43+
44+
dists:
45+
FloatTensor of shape (N, image_size, image_size, faces_per_pixel) giving
46+
the signed Euclidean distance (in NDC units) in the x/y plane of each
47+
point closest to the pixel. Concretely if ``pix_to_face[n, y, x, k] = f``
48+
then ``pix_dists[n, y, x, k]`` is the squared distance between the pixel
49+
(y, x) and the face given by vertices ``face_verts[f]``. Pixels hit with
50+
fewer than ``faces_per_pixel`` are padded with -1.
51+
"""
52+
1853
pix_to_face: torch.Tensor
1954
zbuf: torch.Tensor
2055
bary_coords: torch.Tensor
2156
dists: torch.Tensor
2257

58+
def detach(self) -> "Fragments":
59+
return Fragments(
60+
pix_to_face=self.pix_to_face,
61+
zbuf=self.zbuf.detach(),
62+
bary_coords=self.bary_coords.detach(),
63+
dists=self.dists.detach(),
64+
)
65+
2366

2467
@dataclass
2568
class RasterizationSettings:

0 commit comments

Comments
 (0)