4
4
import warnings
5
5
from os import path
6
6
from pathlib import Path
7
- from typing import Dict
7
+ from typing import Dict , List , Optional
8
8
9
+ import numpy as np
10
+ import torch
11
+ from PIL import Image
9
12
from pytorch3d .datasets .shapenet_base import ShapeNetBase
10
13
from pytorch3d .io import load_obj
11
14
@@ -21,18 +24,29 @@ class R2N2(ShapeNetBase):
21
24
voxelized models.
22
25
"""
23
26
24
- def __init__ (self , split , shapenet_dir , r2n2_dir , splits_file ):
27
+ def __init__ (
28
+ self ,
29
+ split : str ,
30
+ shapenet_dir ,
31
+ r2n2_dir ,
32
+ splits_file ,
33
+ return_all_views : bool = True ,
34
+ ):
25
35
"""
26
36
Store each object's synset id and models id the given directories.
37
+
27
38
Args:
28
39
split (str): One of (train, val, test).
29
40
shapenet_dir (path): Path to ShapeNet core v1.
30
41
r2n2_dir (path): Path to the R2N2 dataset.
31
42
splits_file (path): File containing the train/val/test splits.
43
+ return_all_views (bool): Indicator of whether or not to return all 24 views. If set
44
+ to False, one of the 24 views would be randomly selected and returned.
32
45
"""
33
46
super ().__init__ ()
34
47
self .shapenet_dir = shapenet_dir
35
48
self .r2n2_dir = r2n2_dir
49
+ self .return_all_views = return_all_views
36
50
# Examine if split is valid.
37
51
if split not in ["train" , "val" , "test" ]:
38
52
raise ValueError ("split has to be one of (train, val, test)." )
@@ -48,6 +62,16 @@ def __init__(self, split, shapenet_dir, r2n2_dir, splits_file):
48
62
with open (splits_file ) as splits :
49
63
split_dict = json .load (splits )[split ]
50
64
65
+ self .return_images = True
66
+ # Check if the folder containing R2N2 renderings is included in r2n2_dir.
67
+ if not path .isdir (path .join (r2n2_dir , "ShapeNetRendering" )):
68
+ self .return_images = False
69
+ msg = (
70
+ "ShapeNetRendering not found in %s. R2N2 renderings will "
71
+ "be skipped when returning models."
72
+ ) % (r2n2_dir )
73
+ warnings .warn (msg )
74
+
51
75
synset_set = set ()
52
76
for synset in split_dict .keys ():
53
77
# Examine if the given synset is present in the ShapeNetCore dataset
@@ -95,12 +119,15 @@ def __init__(self, split, shapenet_dir, r2n2_dir, splits_file):
95
119
) % (shapenet_dir , ", " .join (synset_not_present ))
96
120
warnings .warn (msg )
97
121
98
- def __getitem__ (self , idx : int ) -> Dict :
122
+ def __getitem__ (self , model_idx , view_idxs : Optional [ List [ int ]] = None ) -> Dict :
99
123
"""
100
124
Read a model by the given index.
101
125
102
126
Args:
103
- idx: The idx of the model to be retrieved in the dataset.
127
+ model_idx: The idx of the model to be retrieved in the dataset.
128
+ view_idx: List of indices of the view to be returned. Each index needs to be
129
+ between 0 and 23, inclusive. If an invalid index is supplied, view_idx will be
130
+ ignored and views will be sampled according to self.return_all_views.
104
131
105
132
Returns:
106
133
dictionary with following keys:
@@ -109,12 +136,51 @@ def __getitem__(self, idx: int) -> Dict:
109
136
- synset_id (str): synset id.
110
137
- model_id (str): model id.
111
138
- label (str): synset label.
139
+ - images: FloatTensor of shape (V, H, W, C), where V is number of views
140
+ returned. Returns a batch of the renderings of the models from the R2N2 dataset.
112
141
"""
113
- model = self ._get_item_ids (idx )
142
+ if type (model_idx ) is tuple :
143
+ model_idx , view_idxs = model_idx
144
+ model = self ._get_item_ids (model_idx )
114
145
model_path = path .join (
115
146
self .shapenet_dir , model ["synset_id" ], model ["model_id" ], "model.obj"
116
147
)
117
148
model ["verts" ], faces , _ = load_obj (model_path )
118
149
model ["faces" ] = faces .verts_idx
119
150
model ["label" ] = self .synset_dict [model ["synset_id" ]]
151
+
152
+ model ["images" ] = None
153
+ # Retrieve R2N2's renderings if required.
154
+ if self .return_images :
155
+ ranges = (
156
+ range (24 ) if self .return_all_views else torch .randint (24 , (1 ,)).tolist ()
157
+ )
158
+ if view_idxs is not None and any (idx < 0 or idx > 23 for idx in view_idxs ):
159
+ msg = (
160
+ "One of the indicies in view_idxs is out of range. "
161
+ "Index needs to be between 0 and 23, inclusive. "
162
+ "Now sampling according to self.return_all_views."
163
+ )
164
+ warnings .warn (msg )
165
+ elif view_idxs is not None :
166
+ ranges = view_idxs
167
+
168
+ rendering_path = path .join (
169
+ self .r2n2_dir ,
170
+ "ShapeNetRendering" ,
171
+ model ["synset_id" ],
172
+ model ["model_id" ],
173
+ "rendering" ,
174
+ )
175
+
176
+ images = []
177
+ for i in ranges :
178
+ # Read image.
179
+ image_path = path .join (rendering_path , "%02d.png" % i )
180
+ raw_img = Image .open (image_path )
181
+ image = torch .from_numpy (np .array (raw_img ) / 255.0 )[..., :3 ]
182
+ images .append (image .to (dtype = torch .float32 ))
183
+
184
+ model ["images" ] = torch .stack (images )
185
+
120
186
return model
0 commit comments