Skip to content

Commit dc08c30

Browse files
megluyagaofacebook-github-bot
authored andcommitted
Return R2N2 renderings
Summary: R2N2 returns R2N2's own renderings of ShapeNetCore models. Reviewed By: nikhilaravi Differential Revision: D22266988 fbshipit-source-id: 36e67bd06c6459773e6e5f654259166b579be36a
1 parent 5636eb6 commit dc08c30

File tree

3 files changed

+91
-8
lines changed

3 files changed

+91
-8
lines changed

pytorch3d/datasets/r2n2/r2n2.py

Lines changed: 71 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@
44
import warnings
55
from os import path
66
from pathlib import Path
7-
from typing import Dict
7+
from typing import Dict, List, Optional
88

9+
import numpy as np
10+
import torch
11+
from PIL import Image
912
from pytorch3d.datasets.shapenet_base import ShapeNetBase
1013
from pytorch3d.io import load_obj
1114

@@ -21,18 +24,29 @@ class R2N2(ShapeNetBase):
2124
voxelized models.
2225
"""
2326

24-
def __init__(self, split, shapenet_dir, r2n2_dir, splits_file):
27+
def __init__(
28+
self,
29+
split: str,
30+
shapenet_dir,
31+
r2n2_dir,
32+
splits_file,
33+
return_all_views: bool = True,
34+
):
2535
"""
2636
Store each object's synset id and models id the given directories.
37+
2738
Args:
2839
split (str): One of (train, val, test).
2940
shapenet_dir (path): Path to ShapeNet core v1.
3041
r2n2_dir (path): Path to the R2N2 dataset.
3142
splits_file (path): File containing the train/val/test splits.
43+
return_all_views (bool): Indicator of whether or not to return all 24 views. If set
44+
to False, one of the 24 views would be randomly selected and returned.
3245
"""
3346
super().__init__()
3447
self.shapenet_dir = shapenet_dir
3548
self.r2n2_dir = r2n2_dir
49+
self.return_all_views = return_all_views
3650
# Examine if split is valid.
3751
if split not in ["train", "val", "test"]:
3852
raise ValueError("split has to be one of (train, val, test).")
@@ -48,6 +62,16 @@ def __init__(self, split, shapenet_dir, r2n2_dir, splits_file):
4862
with open(splits_file) as splits:
4963
split_dict = json.load(splits)[split]
5064

65+
self.return_images = True
66+
# Check if the folder containing R2N2 renderings is included in r2n2_dir.
67+
if not path.isdir(path.join(r2n2_dir, "ShapeNetRendering")):
68+
self.return_images = False
69+
msg = (
70+
"ShapeNetRendering not found in %s. R2N2 renderings will "
71+
"be skipped when returning models."
72+
) % (r2n2_dir)
73+
warnings.warn(msg)
74+
5175
synset_set = set()
5276
for synset in split_dict.keys():
5377
# Examine if the given synset is present in the ShapeNetCore dataset
@@ -95,12 +119,15 @@ def __init__(self, split, shapenet_dir, r2n2_dir, splits_file):
95119
) % (shapenet_dir, ", ".join(synset_not_present))
96120
warnings.warn(msg)
97121

98-
def __getitem__(self, idx: int) -> Dict:
122+
def __getitem__(self, model_idx, view_idxs: Optional[List[int]] = None) -> Dict:
99123
"""
100124
Read a model by the given index.
101125
102126
Args:
103-
idx: The idx of the model to be retrieved in the dataset.
127+
model_idx: The idx of the model to be retrieved in the dataset.
128+
view_idx: List of indices of the view to be returned. Each index needs to be
129+
between 0 and 23, inclusive. If an invalid index is supplied, view_idx will be
130+
ignored and views will be sampled according to self.return_all_views.
104131
105132
Returns:
106133
dictionary with following keys:
@@ -109,12 +136,51 @@ def __getitem__(self, idx: int) -> Dict:
109136
- synset_id (str): synset id.
110137
- model_id (str): model id.
111138
- label (str): synset label.
139+
- images: FloatTensor of shape (V, H, W, C), where V is number of views
140+
returned. Returns a batch of the renderings of the models from the R2N2 dataset.
112141
"""
113-
model = self._get_item_ids(idx)
142+
if type(model_idx) is tuple:
143+
model_idx, view_idxs = model_idx
144+
model = self._get_item_ids(model_idx)
114145
model_path = path.join(
115146
self.shapenet_dir, model["synset_id"], model["model_id"], "model.obj"
116147
)
117148
model["verts"], faces, _ = load_obj(model_path)
118149
model["faces"] = faces.verts_idx
119150
model["label"] = self.synset_dict[model["synset_id"]]
151+
152+
model["images"] = None
153+
# Retrieve R2N2's renderings if required.
154+
if self.return_images:
155+
ranges = (
156+
range(24) if self.return_all_views else torch.randint(24, (1,)).tolist()
157+
)
158+
if view_idxs is not None and any(idx < 0 or idx > 23 for idx in view_idxs):
159+
msg = (
160+
"One of the indicies in view_idxs is out of range. "
161+
"Index needs to be between 0 and 23, inclusive. "
162+
"Now sampling according to self.return_all_views."
163+
)
164+
warnings.warn(msg)
165+
elif view_idxs is not None:
166+
ranges = view_idxs
167+
168+
rendering_path = path.join(
169+
self.r2n2_dir,
170+
"ShapeNetRendering",
171+
model["synset_id"],
172+
model["model_id"],
173+
"rendering",
174+
)
175+
176+
images = []
177+
for i in ranges:
178+
# Read image.
179+
image_path = path.join(rendering_path, "%02d.png" % i)
180+
raw_img = Image.open(image_path)
181+
image = torch.from_numpy(np.array(raw_img) / 255.0)[..., :3]
182+
images.append(image.to(dtype=torch.float32))
183+
184+
model["images"] = torch.stack(images)
185+
120186
return model

pytorch3d/datasets/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from typing import Dict, List
44

5+
import torch
56
from pytorch3d.structures import Meshes
67

78

@@ -32,4 +33,10 @@ def collate_batched_meshes(batch: List[Dict]):
3233
verts=collated_dict["verts"], faces=collated_dict["faces"]
3334
)
3435

36+
# If collate_batched_meshes receives R2N2 items, stack the batches of
37+
# views of each model into a new batch of shape (N, V, H, W, 3) where
38+
# V is the number of views.
39+
if "images" in collated_dict:
40+
collated_dict["images"] = torch.stack(collated_dict["images"])
41+
3542
return collated_dict

tests/test_r2n2.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ def setUp(self):
5656

5757
def test_load_R2N2(self):
5858
"""
59-
Test the loaded train split of R2N2 return items of the correct shapes and types.
59+
Test the loaded train split of R2N2 return items of the correct shapes and types. Also
60+
check the first image returned is correct.
6061
"""
6162
# Load dataset in the train split.
6263
r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
@@ -68,8 +69,9 @@ def test_load_R2N2(self):
6869
self.assertEqual(len(r2n2_dataset), sum(model_nums))
6970

7071
# Randomly retrieve an object from the dataset.
71-
rand_obj = r2n2_dataset[torch.randint(len(r2n2_dataset), (1,))]
72-
# Check that data type and shape of the item returned by __getitem__ are correct.
72+
rand_idx = torch.randint(len(r2n2_dataset), (1,))
73+
rand_obj = r2n2_dataset[rand_idx]
74+
# Check that verts and faces returned by __getitem__ have the correct shapes and types.
7375
verts, faces = rand_obj["verts"], rand_obj["faces"]
7476
self.assertTrue(verts.dtype == torch.float32)
7577
self.assertTrue(faces.dtype == torch.int64)
@@ -78,6 +80,13 @@ def test_load_R2N2(self):
7880
self.assertEqual(faces.ndim, 2)
7981
self.assertEqual(faces.shape[-1], 3)
8082

83+
# Check that image batch returned by __getitem__ has the correct shape.
84+
self.assertEqual(rand_obj["images"].shape[0], 24)
85+
self.assertEqual(rand_obj["images"].shape[1], 137)
86+
self.assertEqual(rand_obj["images"].shape[2], 137)
87+
self.assertEqual(rand_obj["images"].shape[-1], 3)
88+
self.assertEqual(r2n2_dataset[rand_idx, [21]]["images"].shape[0], 1)
89+
8190
def test_collate_models(self):
8291
"""
8392
Test collate_batched_meshes returns items of the correct shapes and types.
@@ -118,6 +127,7 @@ def test_collate_models(self):
118127
self.assertEqual(len(object_batch["label"]), batch_size)
119128
self.assertEqual(object_batch["mesh"].verts_padded().shape[0], batch_size)
120129
self.assertEqual(object_batch["mesh"].faces_padded().shape[0], batch_size)
130+
self.assertEqual(object_batch["images"].shape[0], batch_size)
121131

122132
def test_catch_render_arg_errors(self):
123133
"""

0 commit comments

Comments
 (0)