-
Notifications
You must be signed in to change notification settings - Fork 1.4k
/
Copy pathtextures.py
1937 lines (1724 loc) · 78.7 KB
/
textures.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
import itertools
import warnings
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union
import torch
import torch.nn.functional as F
from pytorch3d.ops import interpolate_face_attributes
from pytorch3d.structures.utils import list_to_packed, list_to_padded, padded_to_list
from torch.nn.functional import interpolate
from .utils import pack_unique_rectangles, PackedRectangle, Rectangle
# This file contains classes and helper functions for texturing.
# There are three types of textures: TexturesVertex, TexturesAtlas
# and TexturesUV which inherit from a base textures class TexturesBase.
#
# Each texture class has a method 'sample_textures' to sample a
# value given barycentric coordinates.
#
# All the textures accept either list or padded inputs. The values
# are stored as either per face values (TexturesAtlas, TexturesUV),
# or per face vertex features (TexturesVertex).
def _list_to_padded_wrapper(
x: List[torch.Tensor],
pad_size: Union[list, tuple, None] = None,
pad_value: float = 0.0,
) -> torch.Tensor:
r"""
This is a wrapper function for
pytorch3d.structures.utils.list_to_padded function which only accepts
3-dimensional inputs.
For this use case, the input x is of shape (F, 3, ...) where only F
is different for each element in the list
Transforms a list of N tensors each of shape (Mi, ...) into a single tensor
of shape (N, pad_size, ...), or (N, max(Mi), ...)
if pad_size is None.
Args:
x: list of Tensors
pad_size: int specifying the size of the first dimension
of the padded tensor
pad_value: float value to be used to fill the padded tensor
Returns:
x_padded: tensor consisting of padded input tensors
"""
N = len(x)
dims = x[0].ndim
reshape_dims = x[0].shape[1:]
D = torch.prod(torch.tensor(reshape_dims)).item()
x_reshaped = []
for y in x:
if y.ndim != dims and y.shape[1:] != reshape_dims:
msg = (
"list_to_padded requires tensors to have the same number of dimensions"
)
raise ValueError(msg)
# pyre-fixme[6]: For 2nd param expected `int` but got `Union[bool, float, int]`.
x_reshaped.append(y.reshape(-1, D))
x_padded = list_to_padded(x_reshaped, pad_size=pad_size, pad_value=pad_value)
# pyre-fixme[58]: `+` is not supported for operand types `Tuple[int, int]` and
# `Size`.
return x_padded.reshape((N, -1) + reshape_dims)
def _padded_to_list_wrapper(
x: torch.Tensor, split_size: Union[list, tuple, None] = None
) -> List[torch.Tensor]:
r"""
This is a wrapper function for pytorch3d.structures.utils.padded_to_list
which only accepts 3-dimensional inputs.
For this use case, the input x is of shape (N, F, ...) where F
is the number of faces which is different for each tensor in the batch.
This function transforms a padded tensor of shape (N, M, ...) into a
list of N tensors of shape (Mi, ...) where (Mi) is specified in
split_size(i), or of shape (M,) if split_size is None.
Args:
x: padded Tensor
split_size: list of ints defining the number of items for each tensor
in the output list.
Returns:
x_list: a list of tensors
"""
N, M = x.shape[:2]
reshape_dims = x.shape[2:]
D = torch.prod(torch.tensor(reshape_dims)).item()
# pyre-fixme[6]: For 3rd param expected `int` but got `Union[bool, float, int]`.
x_reshaped = x.reshape(N, M, D)
x_list = padded_to_list(x_reshaped, split_size=split_size)
# pyre-fixme[58]: `+` is not supported for operand types `Tuple[typing.Any]` and
# `Size`.
x_list = [xl.reshape((xl.shape[0],) + reshape_dims) for xl in x_list]
return x_list
def _pad_texture_maps(
images: Union[Tuple[torch.Tensor], List[torch.Tensor]], align_corners: bool
) -> torch.Tensor:
"""
Pad all texture images so they have the same height and width.
Args:
images: list of N tensors of shape (H_i, W_i, C)
align_corners: used for interpolation
Returns:
tex_maps: Tensor of shape (N, max_H, max_W, C)
"""
tex_maps = []
max_H = 0
max_W = 0
for im in images:
h, w, _C = im.shape
if h > max_H:
max_H = h
if w > max_W:
max_W = w
tex_maps.append(im)
max_shape = (max_H, max_W)
for i, image in enumerate(tex_maps):
if image.shape[:2] != max_shape:
image_BCHW = image.permute(2, 0, 1)[None]
new_image_BCHW = interpolate(
image_BCHW,
size=max_shape,
mode="bilinear",
align_corners=align_corners,
)
tex_maps[i] = new_image_BCHW[0].permute(1, 2, 0)
tex_maps = torch.stack(tex_maps, dim=0) # (num_tex_maps, max_H, max_W, C)
return tex_maps
def _pad_texture_multiple_maps(
multiple_texture_maps: Union[Tuple[torch.Tensor], List[torch.Tensor]],
align_corners: bool,
) -> torch.Tensor:
"""
Pad all texture images so they have the same height and width.
Args:
images: list of N tensors of shape (M_i, H_i, W_i, C)
M_i : Number of texture maps:w
align_corners: used for interpolation
Returns:
tex_maps: Tensor of shape (N, max_M, max_H, max_W, C)
"""
tex_maps = []
max_M = 0
max_H = 0
max_W = 0
C = 0
for im in multiple_texture_maps:
m, h, w, C = im.shape
if m > max_M:
max_M = m
if h > max_H:
max_H = h
if w > max_W:
max_W = w
tex_maps.append(im)
max_shape = (max_M, max_H, max_W, C)
max_im_shape = (max_H, max_W)
for i, tms in enumerate(tex_maps):
new_tex_maps = torch.zeros(max_shape)
for j in range(tms.shape[0]):
im = tms[j]
if im.shape[:2] != max_im_shape:
image_BCHW = im.permute(2, 0, 1)[None]
new_image_BCHW = interpolate(
image_BCHW,
size=max_im_shape,
mode="bilinear",
align_corners=align_corners,
)
new_tex_maps[j] = new_image_BCHW[0].permute(1, 2, 0)
else:
new_tex_maps[j] = im
tex_maps[i] = new_tex_maps
tex_maps = torch.stack(tex_maps, dim=0) # (num_tex_maps, max_H, max_W, C)
return tex_maps
# A base class for defining a batch of textures
# with helper methods.
# This is also useful to have so that inside `Meshes`
# we can allow the input textures to be any texture
# type which is an instance of the base class.
class TexturesBase:
def isempty(self):
if self._N is not None and self.valid is not None:
return self._N == 0 or self.valid.eq(False).all()
return False
def to(self, device):
for k in dir(self):
v = getattr(self, k)
if isinstance(v, (list, tuple)) and all(
torch.is_tensor(elem) for elem in v
):
v = [elem.to(device) for elem in v]
setattr(self, k, v)
if torch.is_tensor(v) and v.device != device:
setattr(self, k, v.to(device))
self.device = device
return self
def _extend(self, N: int, props: List[str]) -> Dict[str, Union[torch.Tensor, List]]:
"""
Create a dict with the specified properties
repeated N times per batch element.
Args:
N: number of new copies of each texture
in the batch.
props: a List of strings which refer to either
class attributes or class methods which
return tensors or lists.
Returns:
Dict with the same keys as props. The values are the
extended properties.
"""
if not isinstance(N, int):
raise ValueError("N must be an integer.")
if N <= 0:
raise ValueError("N must be > 0.")
new_props = {}
for p in props:
t = getattr(self, p)
if callable(t):
t = t() # class method
if t is None:
new_props[p] = None
elif isinstance(t, list):
if not all(isinstance(elem, (int, float)) for elem in t):
raise ValueError("Extend only supports lists of scalars")
t = [[ti] * N for ti in t]
new_props[p] = list(itertools.chain(*t))
elif torch.is_tensor(t):
new_props[p] = t.repeat_interleave(N, dim=0)
else:
raise ValueError(
f"Property {p} has unsupported type {type(t)}."
"Only tensors and lists are supported."
)
return new_props
def _getitem(self, index: Union[int, slice], props: List[str]):
"""
Helper function for __getitem__
"""
new_props = {}
if isinstance(index, (int, slice)):
for p in props:
t = getattr(self, p)
if callable(t):
t = t() # class method
new_props[p] = t[index] if t is not None else None
elif isinstance(index, list):
index = torch.tensor(index)
if isinstance(index, torch.Tensor):
if index.dtype == torch.bool:
index = index.nonzero()
index = index.squeeze(1) if index.numel() > 0 else index
index = index.tolist()
for p in props:
t = getattr(self, p)
if callable(t):
t = t() # class method
new_props[p] = [t[i] for i in index] if t is not None else None
return new_props
def sample_textures(self) -> torch.Tensor:
"""
Different texture classes sample textures in different ways
e.g. for vertex textures, the values at each vertex
are interpolated across the face using the barycentric
coordinates.
Each texture class should implement a sample_textures
method to take the `fragments` from rasterization.
Using `fragments.pix_to_face` and `fragments.bary_coords`
this function should return the sampled texture values for
each pixel in the output image.
"""
raise NotImplementedError()
def submeshes(
self,
vertex_ids_list: List[List[torch.LongTensor]],
faces_ids_list: List[List[torch.LongTensor]],
) -> "TexturesBase":
"""
Extract sub-textures used for submeshing.
"""
raise NotImplementedError(f"{self.__class__} does not support submeshes")
def faces_verts_textures_packed(self) -> torch.Tensor:
"""
Returns the texture for each vertex for each face in the mesh.
For N meshes, this function returns sum(Fi)x3xC where Fi is the
number of faces in the i-th mesh and C is the dimensional of
the feature (C = 3 for RGB textures).
You can use the utils function in structures.utils to convert the
packed representation to a list or padded.
"""
raise NotImplementedError()
def clone(self) -> "TexturesBase":
"""
Each texture class should implement a method
to clone all necessary internal tensors.
"""
raise NotImplementedError()
def detach(self) -> "TexturesBase":
"""
Each texture class should implement a method
to detach all necessary internal tensors.
"""
raise NotImplementedError()
def __getitem__(self, index) -> "TexturesBase":
"""
Each texture class should implement a method
to get the texture properties for the
specified elements in the batch.
The TexturesBase._getitem(i) method
can be used as a helper function to retrieve the
class attributes for item i. Then, a new
instance of the child class can be created with
the attributes.
"""
raise NotImplementedError()
def Textures(
maps: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
faces_uvs: Optional[torch.Tensor] = None,
verts_uvs: Optional[torch.Tensor] = None,
verts_rgb: Optional[torch.Tensor] = None,
) -> TexturesBase:
"""
Textures class has been DEPRECATED.
Preserving Textures as a function for backwards compatibility.
Args:
maps: texture map per mesh. This can either be a list of maps
[(H, W, C)] or a padded tensor of shape (N, H, W, C).
faces_uvs: (N, F, 3) tensor giving the index into verts_uvs for each
vertex in the face. Padding value is assumed to be -1.
verts_uvs: (N, V, 2) tensor giving the uv coordinate per vertex.
verts_rgb: (N, V, C) tensor giving the color per vertex. Padding
value is assumed to be -1. (C=3 for RGB.)
Returns:
a Textures class which is an instance of TexturesBase e.g. TexturesUV,
TexturesAtlas, TexturesVertex
"""
warnings.warn(
"""Textures class is deprecated,
use TexturesUV, TexturesAtlas, TexturesVertex instead.
Textures class will be removed in future releases.""",
PendingDeprecationWarning,
)
if faces_uvs is not None and verts_uvs is not None and maps is not None:
return TexturesUV(maps=maps, faces_uvs=faces_uvs, verts_uvs=verts_uvs)
if verts_rgb is not None:
return TexturesVertex(verts_features=verts_rgb)
raise ValueError(
"Textures either requires all three of (faces uvs, verts uvs, maps) or verts rgb"
)
class TexturesAtlas(TexturesBase):
def __init__(self, atlas: Union[torch.Tensor, List[torch.Tensor]]) -> None:
"""
A texture representation where each face has a square texture map.
This is based on the implementation from SoftRasterizer [1].
Args:
atlas: (N, F, R, R, C) tensor giving the per face texture map.
The atlas can be created during obj loading with the
pytorch3d.io.load_obj function - in the input arguments
set `create_texture_atlas=True`. The atlas will be
returned in aux.texture_atlas.
The padded and list representations of the textures are stored
and the packed representations is computed on the fly and
not cached.
[1] Liu et al, 'Soft Rasterizer: A Differentiable Renderer for Image-based
3D Reasoning', ICCV 2019
See also https://github.com/ShichenLiu/SoftRas/issues/21
"""
if isinstance(atlas, (list, tuple)):
correct_format = all(
(
torch.is_tensor(elem)
and elem.ndim == 4
and elem.shape[1] == elem.shape[2]
and elem.shape[1] == atlas[0].shape[1]
)
for elem in atlas
)
if not correct_format:
msg = (
"Expected atlas to be a list of tensors of shape (F, R, R, C) "
"with the same value of R."
)
raise ValueError(msg)
self._atlas_list = atlas
self._atlas_padded = None
self.device = torch.device("cpu")
# These values may be overridden when textures is
# passed into the Meshes constructor. For more details
# refer to the __init__ of Meshes.
self._N = len(atlas)
self._num_faces_per_mesh = [len(a) for a in atlas]
if self._N > 0:
self.device = atlas[0].device
elif torch.is_tensor(atlas):
if atlas.ndim != 5:
msg = "Expected atlas to be of shape (N, F, R, R, C); got %r"
raise ValueError(msg % repr(atlas.ndim))
self._atlas_padded = atlas
self._atlas_list = None
self.device = atlas.device
# These values may be overridden when textures is
# passed into the Meshes constructor. For more details
# refer to the __init__ of Meshes.
self._N = len(atlas)
max_F = atlas.shape[1]
self._num_faces_per_mesh = [max_F] * self._N
else:
raise ValueError("Expected atlas to be a tensor or list")
# The num_faces_per_mesh, N and valid
# are reset inside the Meshes object when textures is
# passed into the Meshes constructor. For more details
# refer to the __init__ of Meshes.
self.valid = torch.ones((self._N,), dtype=torch.bool, device=self.device)
def clone(self) -> "TexturesAtlas":
tex = self.__class__(atlas=self.atlas_padded().clone())
if self._atlas_list is not None:
tex._atlas_list = [atlas.clone() for atlas in self._atlas_list]
num_faces = (
self._num_faces_per_mesh.clone()
if torch.is_tensor(self._num_faces_per_mesh)
else self._num_faces_per_mesh
)
tex.valid = self.valid.clone()
tex._num_faces_per_mesh = num_faces
return tex
def detach(self) -> "TexturesAtlas":
tex = self.__class__(atlas=self.atlas_padded().detach())
if self._atlas_list is not None:
tex._atlas_list = [atlas.detach() for atlas in self._atlas_list]
num_faces = (
self._num_faces_per_mesh.detach()
if torch.is_tensor(self._num_faces_per_mesh)
else self._num_faces_per_mesh
)
tex.valid = self.valid.detach()
tex._num_faces_per_mesh = num_faces
return tex
def __getitem__(self, index) -> "TexturesAtlas":
props = ["atlas_list", "_num_faces_per_mesh"]
new_props = self._getitem(index, props=props)
atlas = new_props["atlas_list"]
if isinstance(atlas, list):
# multiple batch elements
new_tex = self.__class__(atlas=atlas)
elif torch.is_tensor(atlas):
# single element
new_tex = self.__class__(atlas=[atlas])
else:
raise ValueError("Not all values are provided in the correct format")
new_tex._num_faces_per_mesh = new_props["_num_faces_per_mesh"]
return new_tex
def atlas_padded(self) -> torch.Tensor:
if self._atlas_padded is None:
if self.isempty():
self._atlas_padded = torch.zeros(
(self._N, 0, 0, 0, 3), dtype=torch.float32, device=self.device
)
else:
self._atlas_padded = _list_to_padded_wrapper(
self._atlas_list, pad_value=0.0
)
return self._atlas_padded
def atlas_list(self) -> List[torch.Tensor]:
if self._atlas_list is None:
if self.isempty():
self._atlas_padded = [
torch.empty((0, 0, 0, 3), dtype=torch.float32, device=self.device)
] * self._N
self._atlas_list = _padded_to_list_wrapper(
self._atlas_padded, split_size=self._num_faces_per_mesh
)
return self._atlas_list
def atlas_packed(self) -> torch.Tensor:
if self.isempty():
return torch.zeros(
(self._N, 0, 0, 3), dtype=torch.float32, device=self.device
)
atlas_list = self.atlas_list()
return list_to_packed(atlas_list)[0]
def extend(self, N: int) -> "TexturesAtlas":
new_props = self._extend(N, ["atlas_padded", "_num_faces_per_mesh"])
new_tex = self.__class__(atlas=new_props["atlas_padded"])
new_tex._num_faces_per_mesh = new_props["_num_faces_per_mesh"]
return new_tex
# pyre-fixme[14]: `sample_textures` overrides method defined in `TexturesBase`
# inconsistently.
def sample_textures(self, fragments, **kwargs) -> torch.Tensor:
"""
This is similar to a nearest neighbor sampling and involves a
discretization step. The barycentric coordinates from
rasterization are used to find the nearest grid cell in the texture
atlas and the RGB is returned as the color.
This means that this step is differentiable with respect to the RGB
values of the texture atlas but not differentiable with respect to the
barycentric coordinates.
TODO: Add a different sampling mode which interpolates the barycentric
coordinates to sample the texture and will be differentiable w.r.t
the barycentric coordinates.
Args:
fragments:
The outputs of rasterization. From this we use
- pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices
of the faces (in the packed representation) which
overlap each pixel in the image.
- barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying
the barycentric coordinates of each pixel
relative to the faces (in the packed
representation) which overlap the pixel.
Returns:
texels: (N, H, W, K, C)
"""
N, H, W, K = fragments.pix_to_face.shape
atlas_packed = self.atlas_packed()
R = atlas_packed.shape[1]
bary = fragments.bary_coords
pix_to_face = fragments.pix_to_face
bary_w01 = bary[..., :2]
# pyre-fixme[16]: `bool` has no attribute `__getitem__`.
mask = (pix_to_face < 0)[..., None]
bary_w01 = torch.where(mask, torch.zeros_like(bary_w01), bary_w01)
# If barycentric coordinates are > 1.0 (in the case of
# blur_radius > 0.0), wxy might be > R. We need to clamp this
# index to R-1 to index into the texture atlas.
w_xy = (bary_w01 * R).to(torch.int64).clamp(max=R - 1) # (N, H, W, K, 2)
below_diag = (
bary_w01.sum(dim=-1) * R - w_xy.float().sum(dim=-1)
) <= 1.0 # (N, H, W, K)
w_x, w_y = w_xy.unbind(-1)
w_x = torch.where(below_diag, w_x, (R - 1 - w_x))
w_y = torch.where(below_diag, w_y, (R - 1 - w_y))
texels = atlas_packed[pix_to_face, w_y, w_x]
texels = texels * (pix_to_face >= 0)[..., None].float()
return texels
def submeshes(
self,
vertex_ids_list: List[List[torch.LongTensor]],
faces_ids_list: List[List[torch.LongTensor]],
) -> "TexturesAtlas":
"""
Extract a sub-texture for use in a submesh.
If the meshes batch corresponding to this TextureAtlas contains
`n = len(faces_ids_list)` meshes, then self.atlas_list()
will be of length n. After submeshing, we obtain a batch of
`k = sum(len(v) for v in atlas_list` submeshes (see Meshes.submeshes). This
function creates a corresponding TexturesAtlas object with `atlas_list`
of length `k`.
"""
if len(faces_ids_list) != len(self.atlas_list()):
raise IndexError(
"faces_ids_list must be of " "the same length as atlas_list."
)
sub_features = []
for atlas, faces_ids in zip(self.atlas_list(), faces_ids_list):
for faces_ids_submesh in faces_ids:
sub_features.append(atlas[faces_ids_submesh])
return self.__class__(sub_features)
def faces_verts_textures_packed(self) -> torch.Tensor:
"""
Samples texture from each vertex for each face in the mesh.
For N meshes with {Fi} number of faces, it returns a
tensor of shape sum(Fi)x3xC (C = 3 for RGB).
You can use the utils function in structures.utils to convert the
packed representation to a list or padded.
"""
atlas_packed = self.atlas_packed()
# assume each face consists of (v0, v1, v2).
# to sample from the atlas we only need the first two barycentric coordinates.
# for details on how this texture sample works refer to the sample_textures function.
t0 = atlas_packed[:, 0, -1] # corresponding to v0 with bary = (1, 0)
t1 = atlas_packed[:, -1, 0] # corresponding to v1 with bary = (0, 1)
t2 = atlas_packed[:, 0, 0] # corresponding to v2 with bary = (0, 0)
return torch.stack((t0, t1, t2), dim=1)
def join_batch(self, textures: List["TexturesAtlas"]) -> "TexturesAtlas":
"""
Join the list of textures given by `textures` to
self to create a batch of textures. Return a new
TexturesAtlas object with the combined textures.
Args:
textures: List of TexturesAtlas objects
Returns:
new_tex: TexturesAtlas object with the combined
textures from self and the list `textures`.
"""
tex_types_same = all(isinstance(tex, TexturesAtlas) for tex in textures)
if not tex_types_same:
raise ValueError("All textures must be of type TexturesAtlas.")
atlas_list = []
atlas_list += self.atlas_list()
num_faces_per_mesh = self._num_faces_per_mesh.copy()
for tex in textures:
atlas_list += tex.atlas_list()
num_faces_per_mesh += tex._num_faces_per_mesh
new_tex = self.__class__(atlas=atlas_list)
new_tex._num_faces_per_mesh = num_faces_per_mesh
return new_tex
def join_scene(self) -> "TexturesAtlas":
"""
Return a new TexturesAtlas amalgamating the batch.
"""
return self.__class__(atlas=[torch.cat(self.atlas_list())])
def check_shapes(
self, batch_size: int, max_num_verts: int, max_num_faces: int
) -> bool:
"""
Check if the dimensions of the atlas match that of the mesh faces
"""
# (N, F) should be the same
return self.atlas_padded().shape[0:2] == (batch_size, max_num_faces)
class TexturesUV(TexturesBase):
def __init__(
self,
maps: Union[torch.Tensor, List[torch.Tensor]],
faces_uvs: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
verts_uvs: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
*,
maps_ids: Optional[
Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]
] = None,
padding_mode: str = "border",
align_corners: bool = True,
sampling_mode: str = "bilinear",
) -> None:
"""
Textures are represented as a per mesh texture map and uv coordinates for each
vertex in each face. NOTE: this class only supports one texture map per mesh.
Args:
maps: Either (1) a texture map per mesh. This can either be a list of maps
[(H, W, C)] or a padded tensor of shape (N, H, W, C).
For RGB, C = 3. In this case maps_ids must be None.
Or (2) a set of M texture maps per mesh. This can either be a list of sets
[(M, H, W, C)] or a padded tensor of shape (N, M, H, W, C).
For RGB, C = 3. In this case maps_ids must be provided to
identify which is relevant to each face.
faces_uvs: (N, F, 3) LongTensor giving the index into verts_uvs
for each face
verts_uvs: (N, V, 2) tensor giving the uv coordinates per vertex
(a FloatTensor with values between 0 and 1).
maps_ids: Used if there are to be multiple maps per face.
This can be either a list of map_ids [(F,)]
or a long tensor of shape (N, F) giving the id of the texture map
for each face. If maps_ids is present, the maps has an extra dimension M
(so maps_padded is (N, M, H, W, C) and maps_list has elements of
shape (M, H, W, C)).
Specifically, the color
of a vertex V is given by an average of
maps_padded[i, maps_ids[i, f], u, v, :]
over u and v integers adjacent to
_verts_uvs_padded[i, _faces_uvs_padded[i, f, 0], :] .
align_corners: If true, the extreme values 0 and 1 for verts_uvs
indicate the centers of the edge pixels in the maps.
padding_mode: padding mode for outside grid values
("zeros", "border" or "reflection").
sampling_mode: type of interpolation used to sample the texture.
Corresponds to the mode parameter in PyTorch's
grid_sample ("nearest" or "bilinear").
The align_corners and padding_mode arguments correspond to the arguments
of the `grid_sample` torch function. There is an informative illustration of
the two align_corners options at
https://discuss.pytorch.org/t/22663/9 .
An example of how the indexing into the maps, with align_corners=True,
works is as follows.
If maps[i] has shape [1001, 101] and the value of verts_uvs[i][j]
is [0.4, 0.3], then a value of j in faces_uvs[i] means a vertex
whose color is given by maps[i][700, 40]. padding_mode affects what
happens if a value in verts_uvs is less than 0 or greater than 1.
Note that increasing a value in verts_uvs[..., 0] increases an index
in maps, whereas increasing a value in verts_uvs[..., 1] _decreases_
an _earlier_ index in maps.
If align_corners=False, an example would be as follows.
If maps[i] has shape [1000, 100] and the value of verts_uvs[i][j]
is [0.405, 0.2995], then a value of j in faces_uvs[i] means a vertex
whose color is given by maps[i][700, 40].
When align_corners=False, padding_mode even matters for values in
verts_uvs slightly above 0 or slightly below 1. In this case, the
padding_mode matters if the first value is outside the interval
[0.0005, 0.9995] or if the second is outside the interval
[0.005, 0.995].
"""
self.padding_mode = padding_mode
self.align_corners = align_corners
self.sampling_mode = sampling_mode
if isinstance(faces_uvs, (list, tuple)):
for fv in faces_uvs:
if fv.ndim != 2 or fv.shape[-1] != 3:
msg = "Expected faces_uvs to be of shape (F, 3); got %r"
raise ValueError(msg % repr(fv.shape))
self._faces_uvs_list = faces_uvs
self._faces_uvs_padded = None
self.device = torch.device("cpu")
# These values may be overridden when textures is
# passed into the Meshes constructor. For more details
# refer to the __init__ of Meshes.
self._N = len(faces_uvs)
self._num_faces_per_mesh = [len(fv) for fv in faces_uvs]
if self._N > 0:
self.device = faces_uvs[0].device
elif torch.is_tensor(faces_uvs):
if faces_uvs.ndim != 3 or faces_uvs.shape[-1] != 3:
msg = "Expected faces_uvs to be of shape (N, F, 3); got %r"
raise ValueError(msg % repr(faces_uvs.shape))
self._faces_uvs_padded = faces_uvs
self._faces_uvs_list = None
self.device = faces_uvs.device
# These values may be overridden when textures is
# passed into the Meshes constructor. For more details
# refer to the __init__ of Meshes.
self._N = len(faces_uvs)
max_F = faces_uvs.shape[1]
self._num_faces_per_mesh = [max_F] * self._N
else:
raise ValueError("Expected faces_uvs to be a tensor or list")
if isinstance(verts_uvs, (list, tuple)):
for fv in verts_uvs:
if fv.ndim != 2 or fv.shape[-1] != 2:
msg = "Expected verts_uvs to be of shape (V, 2); got %r"
raise ValueError(msg % repr(fv.shape))
self._verts_uvs_list = verts_uvs
self._verts_uvs_padded = None
if len(verts_uvs) != self._N:
raise ValueError(
"verts_uvs and faces_uvs must have the same batch dimension"
)
if not all(v.device == self.device for v in verts_uvs):
raise ValueError("verts_uvs and faces_uvs must be on the same device")
elif torch.is_tensor(verts_uvs):
if (
verts_uvs.ndim != 3
or verts_uvs.shape[-1] != 2
or verts_uvs.shape[0] != self._N
):
msg = "Expected verts_uvs to be of shape (N, V, 2); got %r"
raise ValueError(msg % repr(verts_uvs.shape))
self._verts_uvs_padded = verts_uvs
self._verts_uvs_list = None
if verts_uvs.device != self.device:
raise ValueError("verts_uvs and faces_uvs must be on the same device")
else:
raise ValueError("Expected verts_uvs to be a tensor or list")
self._maps_ids_padded, self._maps_ids_list = self._format_maps_ids(maps_ids)
if isinstance(maps, (list, tuple)):
self._maps_list = maps
else:
self._maps_list = None
self._maps_padded = self._format_maps_padded(maps)
if self._maps_padded.device != self.device:
raise ValueError("maps must be on the same device as verts/faces uvs.")
self.valid = torch.ones((self._N,), dtype=torch.bool, device=self.device)
def _format_maps_ids(
self,
maps_ids: Optional[
Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]
],
) -> Tuple[
Optional[torch.Tensor], Optional[Union[List[torch.Tensor], Tuple[torch.Tensor]]]
]:
if maps_ids is None:
return None, None
elif isinstance(maps_ids, (list, tuple)):
for mid in maps_ids:
if mid.ndim != 1:
msg = "Expected maps_ids to be of shape (F,); got %r"
raise ValueError(msg % repr(mid.shape))
if len(maps_ids) != self._N:
raise ValueError(
"map_ids, faces_uvs and verts_uvs must have the same batch dimension"
)
if not all(mid.device == self.device for mid in maps_ids):
raise ValueError(
"maps_ids and verts/faces uvs must be on the same device"
)
if not all(
mid.shape[0] == nfm
for mid, nfm in zip(maps_ids, self._num_faces_per_mesh)
):
raise ValueError(
"map_ids and faces_uvs must have the same number of faces per mesh"
)
if not all(mid.device == self.device for mid in maps_ids):
raise ValueError(
"maps_ids and verts/faces uvs must be on the same device"
)
if not self._num_faces_per_mesh:
return torch.Tensor(), maps_ids
return list_to_padded(maps_ids, pad_value=0), maps_ids
elif isinstance(maps_ids, torch.Tensor):
if maps_ids.ndim != 2 or maps_ids.shape[0] != self._N:
msg = "Expected maps_ids to be of shape (N, F); got %r"
raise ValueError(msg % repr(maps_ids.shape))
maps_ids_padded = maps_ids
max_F = max(self._num_faces_per_mesh)
if not maps_ids.shape[1] == max_F:
raise ValueError(
"map_ids and faces_uvs must have the same number of faces per mesh"
)
if maps_ids.device != self.device:
raise ValueError(
"maps_ids and verts/faces uvs must be on the same device"
)
return maps_ids_padded, None
raise ValueError("Expected maps_ids to be a tensor or list")
def _format_maps_padded(
self, maps: Union[torch.Tensor, List[torch.Tensor]]
) -> torch.Tensor:
maps_ids_none = self._maps_ids_padded is None
if isinstance(maps, torch.Tensor):
if not maps_ids_none:
if maps.ndim != 5 or maps.shape[0] != self._N:
msg = "Expected maps to be of shape (N, M, H, W, C); got %r"
raise ValueError(msg % repr(maps.shape))
elif maps.ndim != 4 or maps.shape[0] != self._N:
msg = "Expected maps to be of shape (N, H, W, C); got %r"
raise ValueError(msg % repr(maps.shape))
return maps
if isinstance(maps, (list, tuple)):
if len(maps) != self._N:
raise ValueError("Expected one texture map per mesh in the batch.")
if self._N > 0:
ndim = 3 if maps_ids_none else 4
if not all(map.ndim == ndim for map in maps):
raise ValueError("Invalid number of dimensions in texture maps")
if not all(map.shape[-1] == maps[0].shape[-1] for map in maps):
raise ValueError("Inconsistent number of channels in maps")
maps_padded = (
_pad_texture_maps(maps, align_corners=self.align_corners)
if maps_ids_none
else _pad_texture_multiple_maps(
maps, align_corners=self.align_corners
)
)
else:
if maps_ids_none:
maps_padded = torch.empty(
(self._N, 0, 0, 3), dtype=torch.float32, device=self.device
)
else:
maps_padded = torch.empty(
(self._N, 0, 0, 0, 3), dtype=torch.float32, device=self.device
)
return maps_padded
raise ValueError("Expected maps to be a tensor or list of tensors.")
def clone(self) -> "TexturesUV":
tex = self.__class__(
self.maps_padded().clone(),
self.faces_uvs_padded().clone(),
self.verts_uvs_padded().clone(),
maps_ids=(
self._maps_ids_padded.clone()
if self._maps_ids_padded is not None
else None
),
align_corners=self.align_corners,
padding_mode=self.padding_mode,
sampling_mode=self.sampling_mode,
)
if self._maps_list is not None:
tex._maps_list = [m.clone() for m in self._maps_list]
if self._verts_uvs_list is not None:
tex._verts_uvs_list = [v.clone() for v in self._verts_uvs_list]
if self._faces_uvs_list is not None:
tex._faces_uvs_list = [f.clone() for f in self._faces_uvs_list]
if self._maps_ids_list is not None:
tex._maps_ids_list = [f.clone() for f in self._maps_ids_list]
num_faces = (
self._num_faces_per_mesh.clone()
if torch.is_tensor(self._num_faces_per_mesh)
else self._num_faces_per_mesh
)
tex._num_faces_per_mesh = num_faces
tex.valid = self.valid.clone()
return tex
def detach(self) -> "TexturesUV":
tex = self.__class__(
self.maps_padded().detach(),
self.faces_uvs_padded().detach(),
self.verts_uvs_padded().detach(),
maps_ids=(
self._maps_ids_padded.detach()
if self._maps_ids_padded is not None
else None
),
align_corners=self.align_corners,
padding_mode=self.padding_mode,
sampling_mode=self.sampling_mode,
)
if self._maps_list is not None:
tex._maps_list = [m.detach() for m in self._maps_list]
if self._verts_uvs_list is not None:
tex._verts_uvs_list = [v.detach() for v in self._verts_uvs_list]