|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 | 6 |
|
7 | 7 | from dataclasses import dataclass
|
8 |
| -from typing import NamedTuple, Optional, Tuple, Union |
| 8 | +from typing import Optional, Tuple, Union |
9 | 9 |
|
10 | 10 | import torch
|
11 | 11 | import torch.nn as nn
|
12 | 12 |
|
13 | 13 | from .rasterize_meshes import rasterize_meshes
|
14 | 14 |
|
15 | 15 |
|
16 |
| -# Class to store the outputs of mesh rasterization |
17 |
| -class Fragments(NamedTuple): |
| 16 | +@dataclass(frozen=True) |
| 17 | +class Fragments: |
| 18 | + """ |
| 19 | + Members: |
| 20 | + pix_to_face: |
| 21 | + LongTensor of shape (N, image_size, image_size, faces_per_pixel) giving |
| 22 | + the indices of the nearest faces at each pixel, sorted in ascending |
| 23 | + z-order. Concretely ``pix_to_face[n, y, x, k] = f`` means that |
| 24 | + ``faces_verts[f]`` is the kth closest face (in the z-direction) to pixel |
| 25 | + (y, x). Pixels that are hit by fewer than faces_per_pixel are padded with |
| 26 | + -1. |
| 27 | +
|
| 28 | + zbuf: |
| 29 | + FloatTensor of shape (N, image_size, image_size, faces_per_pixel) giving |
| 30 | + the NDC z-coordinates of the nearest faces at each pixel, sorted in |
| 31 | + ascending z-order. Concretely, if ``pix_to_face[n, y, x, k] = f`` then |
| 32 | + ``zbuf[n, y, x, k] = face_verts[f, 2]``. Pixels hit by fewer than |
| 33 | + faces_per_pixel are padded with -1. |
| 34 | +
|
| 35 | + bary_coords: |
| 36 | + FloatTensor of shape (N, image_size, image_size, faces_per_pixel, 3) |
| 37 | + giving the barycentric coordinates in NDC units of the nearest faces at |
| 38 | + each pixel, sorted in ascending z-order. Concretely, if ``pix_to_face[n, |
| 39 | + y, x, k] = f`` then ``[w0, w1, w2] = barycentric[n, y, x, k]`` gives the |
| 40 | + barycentric coords for pixel (y, x) relative to the face defined by |
| 41 | + ``face_verts[f]``. Pixels hit by fewer than faces_per_pixel are padded |
| 42 | + with -1. |
| 43 | +
|
| 44 | + dists: |
| 45 | + FloatTensor of shape (N, image_size, image_size, faces_per_pixel) giving |
| 46 | + the signed Euclidean distance (in NDC units) in the x/y plane of each |
| 47 | + point closest to the pixel. Concretely if ``pix_to_face[n, y, x, k] = f`` |
| 48 | + then ``pix_dists[n, y, x, k]`` is the squared distance between the pixel |
| 49 | + (y, x) and the face given by vertices ``face_verts[f]``. Pixels hit with |
| 50 | + fewer than ``faces_per_pixel`` are padded with -1. |
| 51 | + """ |
| 52 | + |
18 | 53 | pix_to_face: torch.Tensor
|
19 | 54 | zbuf: torch.Tensor
|
20 | 55 | bary_coords: torch.Tensor
|
21 | 56 | dists: torch.Tensor
|
22 | 57 |
|
| 58 | + def detach(self) -> "Fragments": |
| 59 | + return Fragments( |
| 60 | + pix_to_face=self.pix_to_face, |
| 61 | + zbuf=self.zbuf.detach(), |
| 62 | + bary_coords=self.bary_coords.detach(), |
| 63 | + dists=self.dists.detach(), |
| 64 | + ) |
| 65 | + |
23 | 66 |
|
24 | 67 | @dataclass
|
25 | 68 | class RasterizationSettings:
|
|
0 commit comments