12
12
n_eval_cameras=40 render_size="[64,64]" video_size="[256,256]"
13
13
"""
14
14
15
- import math
16
15
import os
17
- import random
18
16
import sys
19
17
from typing import Optional , Tuple
20
18
21
19
import numpy as np
22
20
import torch
23
- import torch .nn .functional as Fu
24
21
from omegaconf import OmegaConf
25
- from pytorch3d .implicitron .dataset .dataset_base import DatasetBase , FrameData
26
- from pytorch3d .implicitron .dataset .utils import is_train_frame
27
- from pytorch3d .implicitron .models .base_model import EvaluationMode
22
+ from pytorch3d .implicitron .models .visualization import render_flyaround
28
23
from pytorch3d .implicitron .tools .configurable import get_default_args
29
- from pytorch3d .implicitron .tools .eval_video_trajectory import (
30
- generate_eval_video_cameras ,
31
- )
32
- from pytorch3d .implicitron .tools .video_writer import VideoWriter
33
- from pytorch3d .implicitron .tools .vis_utils import (
34
- get_visdom_connection ,
35
- make_depth_image ,
36
- )
37
- from tqdm import tqdm
38
24
39
25
from .experiment import Experiment
40
26
41
27
42
- def render_sequence (
43
- dataset : DatasetBase ,
44
- sequence_name : str ,
45
- model : torch .nn .Module ,
46
- video_path ,
47
- n_eval_cameras = 40 ,
48
- fps = 20 ,
49
- max_angle = 2 * math .pi ,
50
- trajectory_type = "circular_lsq_fit" ,
51
- trajectory_scale = 1.1 ,
52
- scene_center = (0.0 , 0.0 , 0.0 ),
53
- up = (0.0 , - 1.0 , 0.0 ),
54
- traj_offset = 0.0 ,
55
- n_source_views = 9 ,
56
- viz_env = "debug" ,
57
- visdom_show_preds = False ,
58
- visdom_server = "http://127.0.0.1" ,
59
- visdom_port = 8097 ,
60
- num_workers = 10 ,
61
- seed = None ,
62
- video_resize = None ,
63
- ):
64
- if seed is None :
65
- seed = hash (sequence_name )
66
-
67
- if visdom_show_preds :
68
- viz = get_visdom_connection (server = visdom_server , port = visdom_port )
69
- else :
70
- viz = None
71
-
72
- print (f"Loading all data of sequence '{ sequence_name } '." )
73
- seq_idx = list (dataset .sequence_indices_in_order (sequence_name ))
74
- train_data = _load_whole_dataset (dataset , seq_idx , num_workers = num_workers )
75
- assert all (train_data .sequence_name [0 ] == sn for sn in train_data .sequence_name )
76
- sequence_set_name = "train" if is_train_frame (train_data .frame_type )[0 ] else "test"
77
- print (f"Sequence set = { sequence_set_name } ." )
78
- train_cameras = train_data .camera
79
- time = torch .linspace (0 , max_angle , n_eval_cameras + 1 )[:n_eval_cameras ]
80
- test_cameras = generate_eval_video_cameras (
81
- train_cameras ,
82
- time = time ,
83
- n_eval_cams = n_eval_cameras ,
84
- trajectory_type = trajectory_type ,
85
- trajectory_scale = trajectory_scale ,
86
- scene_center = scene_center ,
87
- up = up ,
88
- focal_length = None ,
89
- principal_point = torch .zeros (n_eval_cameras , 2 ),
90
- traj_offset_canonical = (0.0 , 0.0 , traj_offset ),
91
- )
92
-
93
- # sample the source views reproducibly
94
- with torch .random .fork_rng ():
95
- torch .manual_seed (seed )
96
- source_views_i = torch .randperm (len (seq_idx ))[:n_source_views ]
97
- # add the first dummy view that will get replaced with the target camera
98
- source_views_i = Fu .pad (source_views_i , [1 , 0 ])
99
- source_views = [seq_idx [i ] for i in source_views_i .tolist ()]
100
- batch = _load_whole_dataset (dataset , source_views , num_workers = num_workers )
101
- assert all (batch .sequence_name [0 ] == sn for sn in batch .sequence_name )
102
-
103
- preds_total = []
104
- for n in tqdm (range (n_eval_cameras ), total = n_eval_cameras ):
105
- # set the first batch camera to the target camera
106
- for k in ("R" , "T" , "focal_length" , "principal_point" ):
107
- getattr (batch .camera , k )[0 ] = getattr (test_cameras [n ], k )
108
-
109
- # Move to cuda
110
- net_input = batch .cuda ()
111
- with torch .no_grad ():
112
- preds = model (** {** net_input , "evaluation_mode" : EvaluationMode .EVALUATION })
113
-
114
- # make sure we dont overwrite something
115
- assert all (k not in preds for k in net_input .keys ())
116
- preds .update (net_input ) # merge everything into one big dict
117
-
118
- # Render the predictions to images
119
- rendered_pred = images_from_preds (preds )
120
- preds_total .append (rendered_pred )
121
-
122
- # show the preds every 5% of the export iterations
123
- if visdom_show_preds and (
124
- n % max (n_eval_cameras // 20 , 1 ) == 0 or n == n_eval_cameras - 1
125
- ):
126
- show_predictions (
127
- preds_total ,
128
- sequence_name = batch .sequence_name [0 ],
129
- viz = viz ,
130
- viz_env = viz_env ,
131
- )
132
-
133
- print (f"Exporting videos for sequence { sequence_name } ..." )
134
- generate_prediction_videos (
135
- preds_total ,
136
- sequence_name = batch .sequence_name [0 ],
137
- viz = viz ,
138
- viz_env = viz_env ,
139
- fps = fps ,
140
- video_path = video_path ,
141
- resize = video_resize ,
142
- )
143
-
144
-
145
- def _load_whole_dataset (dataset , idx , num_workers = 10 ):
146
- load_all_dataloader = torch .utils .data .DataLoader (
147
- torch .utils .data .Subset (dataset , idx ),
148
- batch_size = len (idx ),
149
- num_workers = num_workers ,
150
- shuffle = False ,
151
- collate_fn = FrameData .collate ,
152
- )
153
- return next (iter (load_all_dataloader ))
154
-
155
-
156
- def images_from_preds (preds ):
157
- imout = {}
158
- for k in (
159
- "image_rgb" ,
160
- "images_render" ,
161
- "fg_probability" ,
162
- "masks_render" ,
163
- "depths_render" ,
164
- "depth_map" ,
165
- "_all_source_images" ,
166
- ):
167
- if k == "_all_source_images" and "image_rgb" in preds :
168
- src_ims = preds ["image_rgb" ][1 :].cpu ().detach ().clone ()
169
- v = _stack_images (src_ims , None )[None ]
170
- else :
171
- if k not in preds or preds [k ] is None :
172
- print (f"cant show { k } " )
173
- continue
174
- v = preds [k ].cpu ().detach ().clone ()
175
- if k .startswith ("depth" ):
176
- mask_resize = Fu .interpolate (
177
- preds ["masks_render" ],
178
- size = preds [k ].shape [2 :],
179
- mode = "nearest" ,
180
- )
181
- v = make_depth_image (preds [k ], mask_resize )
182
- if v .shape [1 ] == 1 :
183
- v = v .repeat (1 , 3 , 1 , 1 )
184
- imout [k ] = v .detach ().cpu ()
185
-
186
- return imout
187
-
188
-
189
- def _stack_images (ims , size ):
190
- ba = ims .shape [0 ]
191
- H = int (np .ceil (np .sqrt (ba )))
192
- W = H
193
- n_add = H * W - ba
194
- if n_add > 0 :
195
- ims = torch .cat ((ims , torch .zeros_like (ims [:1 ]).repeat (n_add , 1 , 1 , 1 )))
196
-
197
- ims = ims .view (H , W , * ims .shape [1 :])
198
- cated = torch .cat ([torch .cat (list (row ), dim = 2 ) for row in ims ], dim = 1 )
199
- if size is not None :
200
- cated = Fu .interpolate (cated [None ], size = size , mode = "bilinear" )[0 ]
201
- return cated .clamp (0.0 , 1.0 )
202
-
203
-
204
- def show_predictions (
205
- preds ,
206
- sequence_name ,
207
- viz ,
208
- viz_env = "visualizer" ,
209
- predicted_keys = (
210
- "images_render" ,
211
- "masks_render" ,
212
- "depths_render" ,
213
- "_all_source_images" ,
214
- ),
215
- n_samples = 10 ,
216
- one_image_width = 200 ,
217
- ):
218
- """Given a list of predictions visualize them into a single image using visdom."""
219
- assert isinstance (preds , list )
220
-
221
- pred_all = []
222
- # Randomly choose a subset of the rendered images, sort by ordr in the sequence
223
- n_samples = min (n_samples , len (preds ))
224
- pred_idx = sorted (random .sample (list (range (len (preds ))), n_samples ))
225
- for predi in pred_idx :
226
- # Make the concatentation for the same camera vertically
227
- pred_all .append (
228
- torch .cat (
229
- [
230
- torch .nn .functional .interpolate (
231
- preds [predi ][k ].cpu (),
232
- scale_factor = one_image_width / preds [predi ][k ].shape [3 ],
233
- mode = "bilinear" ,
234
- ).clamp (0.0 , 1.0 )
235
- for k in predicted_keys
236
- ],
237
- dim = 2 ,
238
- )
239
- )
240
- # Concatenate the images horizontally
241
- pred_all_cat = torch .cat (pred_all , dim = 3 )[0 ]
242
- viz .image (
243
- pred_all_cat ,
244
- win = "show_predictions" ,
245
- env = viz_env ,
246
- opts = {"title" : f"pred_{ sequence_name } " },
247
- )
248
-
249
-
250
- def generate_prediction_videos (
251
- preds ,
252
- sequence_name ,
253
- viz = None ,
254
- viz_env = "visualizer" ,
255
- predicted_keys = (
256
- "images_render" ,
257
- "masks_render" ,
258
- "depths_render" ,
259
- "_all_source_images" ,
260
- ),
261
- fps = 20 ,
262
- video_path = "/tmp/video" ,
263
- resize = None ,
264
- ):
265
- """Given a list of predictions create and visualize rotating videos of the
266
- objects using visdom.
267
- """
268
- assert isinstance (preds , list )
269
-
270
- # make sure the target video directory exists
271
- os .makedirs (os .path .dirname (video_path ), exist_ok = True )
272
-
273
- # init a video writer for each predicted key
274
- vws = {}
275
- for k in predicted_keys :
276
- vws [k ] = VideoWriter (out_path = f"{ video_path } _{ sequence_name } _{ k } .mp4" , fps = fps )
277
-
278
- for rendered_pred in tqdm (preds ):
279
- for k in predicted_keys :
280
- vws [k ].write_frame (
281
- rendered_pred [k ][0 ].clip (0.0 , 1.0 ).detach ().cpu ().numpy (),
282
- resize = resize ,
283
- )
284
-
285
- for k in predicted_keys :
286
- vws [k ].get_video (quiet = True )
287
- print (f"Generated { vws [k ].out_path } ." )
288
- if viz is not None :
289
- viz .video (
290
- videofile = vws [k ].out_path ,
291
- env = viz_env ,
292
- win = k , # we reuse the same window otherwise visdom dies
293
- opts = {"title" : sequence_name + " " + k },
294
- )
295
-
296
-
297
- def export_scenes (
28
+ def visualize_reconstruction (
298
29
exp_dir : str = "" ,
299
30
restrict_sequence_name : Optional [str ] = None ,
300
31
output_directory : Optional [str ] = None ,
301
32
render_size : Tuple [int , int ] = (512 , 512 ),
302
33
video_size : Optional [Tuple [int , int ]] = None ,
303
- split : str = "train" , # train | val | test
34
+ split : str = "train" ,
304
35
n_source_views : int = 9 ,
305
36
n_eval_cameras : int = 40 ,
306
- visdom_server = "http://127.0.0.1" ,
307
- visdom_port = 8097 ,
308
37
visdom_show_preds : bool = False ,
38
+ visdom_server : str = "http://127.0.0.1" ,
39
+ visdom_port : int = 8097 ,
309
40
visdom_env : Optional [str ] = None ,
310
- gpu_idx : int = 0 ,
311
41
):
42
+ """
43
+ Given an `exp_dir` containing a trained Implicitron model, generates videos consisting
44
+ of renderes of sequences from the dataset used to train and evaluate the trained
45
+ Implicitron model.
46
+
47
+ Args:
48
+ exp_dir: Implicitron experiment directory.
49
+ restrict_sequence_name: If set, defines the list of sequences to visualize.
50
+ output_directory: If set, defines a custom directory to output visualizations to.
51
+ render_size: The size (HxW) of the generated renders.
52
+ video_size: The size (HxW) of the output video.
53
+ split: The dataset split to use for visualization.
54
+ Can be "train" / "val" / "test".
55
+ n_source_views: The number of source views added to each rendered batch. These
56
+ views are required inputs for models such as NeRFormer / NeRF-WCE.
57
+ n_eval_cameras: The number of cameras each fly-around trajectory.
58
+ visdom_show_preds: If `True`, outputs visualizations to visdom.
59
+ visdom_server: The address of the visdom server.
60
+ visdom_port: The port of the visdom server.
61
+ visdom_env: If set, defines a custom name for the visdom environment.
62
+ """
63
+
312
64
# In case an output directory is specified use it. If no output_directory
313
65
# is specified create a vis folder inside the experiment directory
314
66
if output_directory is None :
315
67
output_directory = os .path .join (exp_dir , "vis" )
316
- else :
317
- output_directory = output_directory
318
- if not os .path .exists (output_directory ):
319
- os .makedirs (output_directory )
68
+ os .makedirs (output_directory , exist_ok = True )
320
69
321
70
# Set the random seeds
322
71
torch .manual_seed (0 )
@@ -325,7 +74,6 @@ def export_scenes(
325
74
# Get the config from the experiment_directory,
326
75
# and overwrite relevant fields
327
76
config = _get_config_from_experiment_directory (exp_dir )
328
- config .gpu_idx = gpu_idx
329
77
config .exp_dir = exp_dir
330
78
# important so that the CO3D dataset gets loaded in full
331
79
dataset_args = (
@@ -340,10 +88,6 @@ def export_scenes(
340
88
if restrict_sequence_name is not None :
341
89
dataset_args .restrict_sequence_name = restrict_sequence_name
342
90
343
- # Set up the CUDA env for the visualization
344
- os .environ ["CUDA_DEVICE_ORDER" ] = "PCI_BUS_ID"
345
- os .environ ["CUDA_VISIBLE_DEVICES" ] = str (config .gpu_idx )
346
-
347
91
# Load the previously trained model
348
92
experiment = Experiment (config )
349
93
model = experiment .model_factory (force_resume = True )
@@ -360,17 +104,17 @@ def export_scenes(
360
104
# iterate over the sequences in the dataset
361
105
for sequence_name in dataset .sequence_names ():
362
106
with torch .no_grad ():
363
- render_sequence (
364
- dataset ,
365
- sequence_name ,
366
- model ,
367
- video_path = "{}/video" . format (output_directory ),
107
+ render_flyaround (
108
+ dataset = dataset ,
109
+ sequence_name = sequence_name ,
110
+ model = model ,
111
+ output_video_path = os . path . join (output_directory , "video" ),
368
112
n_source_views = n_source_views ,
369
113
visdom_show_preds = visdom_show_preds ,
370
- n_eval_cameras = n_eval_cameras ,
114
+ n_flyaround_poses = n_eval_cameras ,
371
115
visdom_server = visdom_server ,
372
116
visdom_port = visdom_port ,
373
- viz_env = f"visualizer_{ config .visdom_env } "
117
+ visdom_environment = f"visualizer_{ config .visdom_env } "
374
118
if visdom_env is None
375
119
else visdom_env ,
376
120
video_resize = video_size ,
@@ -384,11 +128,11 @@ def _get_config_from_experiment_directory(experiment_directory):
384
128
385
129
386
130
def main (argv ):
387
- # automatically parses arguments of export_scenes
388
- cfg = OmegaConf .create (get_default_args (export_scenes ))
131
+ # automatically parses arguments of visualize_reconstruction
132
+ cfg = OmegaConf .create (get_default_args (visualize_reconstruction ))
389
133
cfg .update (OmegaConf .from_cli ())
390
134
with torch .no_grad ():
391
- export_scenes (** cfg )
135
+ visualize_reconstruction (** cfg )
392
136
393
137
394
138
if __name__ == "__main__" :
0 commit comments