Skip to content

Commit 2bbca5f

Browse files
bottlerfacebook-github-bot
authored andcommitted
Allow setting verts_normals on Meshes
Summary: Add ability to set the vertex normals when creating a Meshes, so that the pluggable loaders can return them from a file. Reviewed By: nikhilaravi Differential Revision: D27765258 fbshipit-source-id: b5ddaa00de3707f636f94d9f74d1da12ecce0608
1 parent 502f15a commit 2bbca5f

File tree

3 files changed

+68
-5
lines changed

3 files changed

+68
-5
lines changed

pytorch3d/structures/meshes.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,14 @@ class Meshes(object):
207207
"equisized",
208208
]
209209

210-
def __init__(self, verts=None, faces=None, textures=None):
210+
def __init__(
211+
self,
212+
verts=None,
213+
faces=None,
214+
textures=None,
215+
*,
216+
verts_normals=None,
217+
):
211218
"""
212219
Args:
213220
verts:
@@ -229,6 +236,17 @@ def __init__(self, verts=None, faces=None, textures=None):
229236
the same number of faces.
230237
textures: Optional instance of the Textures class with mesh
231238
texture properties.
239+
verts_normals:
240+
Optional. Can be either
241+
242+
- List where each element is a tensor of shape (num_verts, 3)
243+
containing the normals of each vertex.
244+
- Padded float tensor with shape (num_meshes, max_num_verts, 3).
245+
They should be padded with fill value of 0 so they all have
246+
the same number of vertices.
247+
Note that modifying the mesh later, e.g. with offset_verts_,
248+
can cause these normals to be forgotten and normals to be recalculated
249+
based on the new vertex positions.
232250
233251
Refer to comments above for descriptions of List and Padded representations.
234252
"""
@@ -354,8 +372,8 @@ def __init__(self, verts=None, faces=None, textures=None):
354372
self.equisized = True
355373

356374
elif torch.is_tensor(verts) and torch.is_tensor(faces):
357-
if verts.size(2) != 3 and faces.size(2) != 3:
358-
raise ValueError("Verts and Faces tensors have incorrect dimensions.")
375+
if verts.size(2) != 3 or faces.size(2) != 3:
376+
raise ValueError("Verts or Faces tensors have incorrect dimensions.")
359377
self._verts_padded = verts
360378
self._faces_padded = faces.to(torch.int64)
361379
self._N = self._verts_padded.shape[0]
@@ -412,6 +430,36 @@ def __init__(self, verts=None, faces=None, textures=None):
412430
self.textures._N = self._N
413431
self.textures.valid = self.valid
414432

433+
if verts_normals is not None:
434+
self._set_verts_normals(verts_normals)
435+
436+
def _set_verts_normals(self, verts_normals) -> None:
437+
if isinstance(verts_normals, list):
438+
if len(verts_normals) != self._N:
439+
raise ValueError("Invalid verts_normals input")
440+
441+
for item, n_verts in zip(verts_normals, self._num_verts_per_mesh):
442+
if (
443+
not isinstance(item, torch.Tensor)
444+
or item.ndim != 2
445+
or item.shape[1] != 3
446+
or item.shape[0] != n_verts
447+
):
448+
raise ValueError("Invalid verts_normals input")
449+
self._verts_normals_packed = torch.cat(verts_normals, 0)
450+
elif torch.is_tensor(verts_normals):
451+
if (
452+
verts_normals.ndim != 3
453+
or verts_normals.size(2) != 3
454+
or verts_normals.size(0) != self._N
455+
):
456+
raise ValueError("Vertex normals tensor has incorrect dimensions.")
457+
self._verts_normals_packed = struct_utils.padded_to_packed(
458+
verts_normals, split_size=self._num_verts_per_mesh.tolist()
459+
)
460+
else:
461+
raise ValueError("verts_normals must be a list or tensor")
462+
415463
def __len__(self):
416464
return self._N
417465

@@ -1253,6 +1301,7 @@ def split(self, split_sizes: list):
12531301
def offset_verts_(self, vert_offsets_packed):
12541302
"""
12551303
Add an offset to the vertices of this Meshes. In place operation.
1304+
If normals are present they may be recalculated.
12561305
12571306
Args:
12581307
vert_offsets_packed: A Tensor of shape (3,) or the same shape as
@@ -1286,7 +1335,7 @@ def offset_verts_(self, vert_offsets_packed):
12861335
self._verts_padded[i, : verts.shape[0], :] = verts
12871336

12881337
# update face areas and normals and vertex normals
1289-
# only if the original attributes are computed
1338+
# only if the original attributes are present
12901339
if update_normals and any(
12911340
v is not None
12921341
for v in [self._faces_areas_packed, self._faces_normals_packed]

tests/test_mesh_normal_consistency.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def test_no_intersection(self):
223223
Test Mesh Normal Consistency for a mesh known to have no
224224
intersecting faces.
225225
"""
226-
verts = torch.rand(1, 6, 2)
226+
verts = torch.rand(1, 6, 3)
227227
faces = torch.arange(6).reshape(1, 2, 3)
228228
meshes = Meshes(verts=verts, faces=faces)
229229
out = mesh_normal_consistency(meshes)

tests/test_meshes.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,6 +1138,20 @@ def test_compute_normals(self):
11381138
self.assertEqual(meshes.faces_normals_padded().shape[0], 0)
11391139
self.assertEqual(meshes.faces_normals_list(), [])
11401140

1141+
def test_assigned_normals(self):
1142+
verts = torch.rand(2, 6, 3)
1143+
faces = torch.randint(6, size=(2, 4, 3))
1144+
1145+
for verts_normals in [list(verts.unbind(0)), verts]:
1146+
yes_normals = Meshes(
1147+
verts=verts.clone(), faces=faces, verts_normals=verts_normals
1148+
)
1149+
self.assertClose(yes_normals.verts_normals_padded(), verts)
1150+
yes_normals.offset_verts_(torch.FloatTensor([1, 2, 3]))
1151+
self.assertClose(yes_normals.verts_normals_padded(), verts)
1152+
yes_normals.offset_verts_(torch.FloatTensor([1, 2, 3]).expand(12, 3))
1153+
self.assertFalse(torch.allclose(yes_normals.verts_normals_padded(), verts))
1154+
11411155
def test_compute_faces_areas_cpu_cuda(self):
11421156
num_meshes = 10
11431157
max_v = 100

0 commit comments

Comments
 (0)