Skip to content

Commit c79c954

Browse files
davnov134facebook-github-bot
authored andcommitted
Rename and move render_flyaround into core implicitron
Summary: Move the flyaround rendering function into core implicitron. The unblocks an example in the facebookresearch/co3d repo. Reviewed By: bottler Differential Revision: D39257801 fbshipit-source-id: 6841a88a43d4aa364dd86ba83ca2d4c3cf0435a4
1 parent 438c194 commit c79c954

File tree

6 files changed

+580
-299
lines changed

6 files changed

+580
-299
lines changed

projects/implicitron_trainer/visualize_reconstruction.py

+38-294
Original file line numberDiff line numberDiff line change
@@ -12,311 +12,60 @@
1212
n_eval_cameras=40 render_size="[64,64]" video_size="[256,256]"
1313
"""
1414

15-
import math
1615
import os
17-
import random
1816
import sys
1917
from typing import Optional, Tuple
2018

2119
import numpy as np
2220
import torch
23-
import torch.nn.functional as Fu
2421
from omegaconf import OmegaConf
25-
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData
26-
from pytorch3d.implicitron.dataset.utils import is_train_frame
27-
from pytorch3d.implicitron.models.base_model import EvaluationMode
22+
from pytorch3d.implicitron.models.visualization import render_flyaround
2823
from pytorch3d.implicitron.tools.configurable import get_default_args
29-
from pytorch3d.implicitron.tools.eval_video_trajectory import (
30-
generate_eval_video_cameras,
31-
)
32-
from pytorch3d.implicitron.tools.video_writer import VideoWriter
33-
from pytorch3d.implicitron.tools.vis_utils import (
34-
get_visdom_connection,
35-
make_depth_image,
36-
)
37-
from tqdm import tqdm
3824

3925
from .experiment import Experiment
4026

4127

42-
def render_sequence(
43-
dataset: DatasetBase,
44-
sequence_name: str,
45-
model: torch.nn.Module,
46-
video_path,
47-
n_eval_cameras=40,
48-
fps=20,
49-
max_angle=2 * math.pi,
50-
trajectory_type="circular_lsq_fit",
51-
trajectory_scale=1.1,
52-
scene_center=(0.0, 0.0, 0.0),
53-
up=(0.0, -1.0, 0.0),
54-
traj_offset=0.0,
55-
n_source_views=9,
56-
viz_env="debug",
57-
visdom_show_preds=False,
58-
visdom_server="http://127.0.0.1",
59-
visdom_port=8097,
60-
num_workers=10,
61-
seed=None,
62-
video_resize=None,
63-
):
64-
if seed is None:
65-
seed = hash(sequence_name)
66-
67-
if visdom_show_preds:
68-
viz = get_visdom_connection(server=visdom_server, port=visdom_port)
69-
else:
70-
viz = None
71-
72-
print(f"Loading all data of sequence '{sequence_name}'.")
73-
seq_idx = list(dataset.sequence_indices_in_order(sequence_name))
74-
train_data = _load_whole_dataset(dataset, seq_idx, num_workers=num_workers)
75-
assert all(train_data.sequence_name[0] == sn for sn in train_data.sequence_name)
76-
sequence_set_name = "train" if is_train_frame(train_data.frame_type)[0] else "test"
77-
print(f"Sequence set = {sequence_set_name}.")
78-
train_cameras = train_data.camera
79-
time = torch.linspace(0, max_angle, n_eval_cameras + 1)[:n_eval_cameras]
80-
test_cameras = generate_eval_video_cameras(
81-
train_cameras,
82-
time=time,
83-
n_eval_cams=n_eval_cameras,
84-
trajectory_type=trajectory_type,
85-
trajectory_scale=trajectory_scale,
86-
scene_center=scene_center,
87-
up=up,
88-
focal_length=None,
89-
principal_point=torch.zeros(n_eval_cameras, 2),
90-
traj_offset_canonical=(0.0, 0.0, traj_offset),
91-
)
92-
93-
# sample the source views reproducibly
94-
with torch.random.fork_rng():
95-
torch.manual_seed(seed)
96-
source_views_i = torch.randperm(len(seq_idx))[:n_source_views]
97-
# add the first dummy view that will get replaced with the target camera
98-
source_views_i = Fu.pad(source_views_i, [1, 0])
99-
source_views = [seq_idx[i] for i in source_views_i.tolist()]
100-
batch = _load_whole_dataset(dataset, source_views, num_workers=num_workers)
101-
assert all(batch.sequence_name[0] == sn for sn in batch.sequence_name)
102-
103-
preds_total = []
104-
for n in tqdm(range(n_eval_cameras), total=n_eval_cameras):
105-
# set the first batch camera to the target camera
106-
for k in ("R", "T", "focal_length", "principal_point"):
107-
getattr(batch.camera, k)[0] = getattr(test_cameras[n], k)
108-
109-
# Move to cuda
110-
net_input = batch.cuda()
111-
with torch.no_grad():
112-
preds = model(**{**net_input, "evaluation_mode": EvaluationMode.EVALUATION})
113-
114-
# make sure we dont overwrite something
115-
assert all(k not in preds for k in net_input.keys())
116-
preds.update(net_input) # merge everything into one big dict
117-
118-
# Render the predictions to images
119-
rendered_pred = images_from_preds(preds)
120-
preds_total.append(rendered_pred)
121-
122-
# show the preds every 5% of the export iterations
123-
if visdom_show_preds and (
124-
n % max(n_eval_cameras // 20, 1) == 0 or n == n_eval_cameras - 1
125-
):
126-
show_predictions(
127-
preds_total,
128-
sequence_name=batch.sequence_name[0],
129-
viz=viz,
130-
viz_env=viz_env,
131-
)
132-
133-
print(f"Exporting videos for sequence {sequence_name} ...")
134-
generate_prediction_videos(
135-
preds_total,
136-
sequence_name=batch.sequence_name[0],
137-
viz=viz,
138-
viz_env=viz_env,
139-
fps=fps,
140-
video_path=video_path,
141-
resize=video_resize,
142-
)
143-
144-
145-
def _load_whole_dataset(dataset, idx, num_workers=10):
146-
load_all_dataloader = torch.utils.data.DataLoader(
147-
torch.utils.data.Subset(dataset, idx),
148-
batch_size=len(idx),
149-
num_workers=num_workers,
150-
shuffle=False,
151-
collate_fn=FrameData.collate,
152-
)
153-
return next(iter(load_all_dataloader))
154-
155-
156-
def images_from_preds(preds):
157-
imout = {}
158-
for k in (
159-
"image_rgb",
160-
"images_render",
161-
"fg_probability",
162-
"masks_render",
163-
"depths_render",
164-
"depth_map",
165-
"_all_source_images",
166-
):
167-
if k == "_all_source_images" and "image_rgb" in preds:
168-
src_ims = preds["image_rgb"][1:].cpu().detach().clone()
169-
v = _stack_images(src_ims, None)[None]
170-
else:
171-
if k not in preds or preds[k] is None:
172-
print(f"cant show {k}")
173-
continue
174-
v = preds[k].cpu().detach().clone()
175-
if k.startswith("depth"):
176-
mask_resize = Fu.interpolate(
177-
preds["masks_render"],
178-
size=preds[k].shape[2:],
179-
mode="nearest",
180-
)
181-
v = make_depth_image(preds[k], mask_resize)
182-
if v.shape[1] == 1:
183-
v = v.repeat(1, 3, 1, 1)
184-
imout[k] = v.detach().cpu()
185-
186-
return imout
187-
188-
189-
def _stack_images(ims, size):
190-
ba = ims.shape[0]
191-
H = int(np.ceil(np.sqrt(ba)))
192-
W = H
193-
n_add = H * W - ba
194-
if n_add > 0:
195-
ims = torch.cat((ims, torch.zeros_like(ims[:1]).repeat(n_add, 1, 1, 1)))
196-
197-
ims = ims.view(H, W, *ims.shape[1:])
198-
cated = torch.cat([torch.cat(list(row), dim=2) for row in ims], dim=1)
199-
if size is not None:
200-
cated = Fu.interpolate(cated[None], size=size, mode="bilinear")[0]
201-
return cated.clamp(0.0, 1.0)
202-
203-
204-
def show_predictions(
205-
preds,
206-
sequence_name,
207-
viz,
208-
viz_env="visualizer",
209-
predicted_keys=(
210-
"images_render",
211-
"masks_render",
212-
"depths_render",
213-
"_all_source_images",
214-
),
215-
n_samples=10,
216-
one_image_width=200,
217-
):
218-
"""Given a list of predictions visualize them into a single image using visdom."""
219-
assert isinstance(preds, list)
220-
221-
pred_all = []
222-
# Randomly choose a subset of the rendered images, sort by ordr in the sequence
223-
n_samples = min(n_samples, len(preds))
224-
pred_idx = sorted(random.sample(list(range(len(preds))), n_samples))
225-
for predi in pred_idx:
226-
# Make the concatentation for the same camera vertically
227-
pred_all.append(
228-
torch.cat(
229-
[
230-
torch.nn.functional.interpolate(
231-
preds[predi][k].cpu(),
232-
scale_factor=one_image_width / preds[predi][k].shape[3],
233-
mode="bilinear",
234-
).clamp(0.0, 1.0)
235-
for k in predicted_keys
236-
],
237-
dim=2,
238-
)
239-
)
240-
# Concatenate the images horizontally
241-
pred_all_cat = torch.cat(pred_all, dim=3)[0]
242-
viz.image(
243-
pred_all_cat,
244-
win="show_predictions",
245-
env=viz_env,
246-
opts={"title": f"pred_{sequence_name}"},
247-
)
248-
249-
250-
def generate_prediction_videos(
251-
preds,
252-
sequence_name,
253-
viz=None,
254-
viz_env="visualizer",
255-
predicted_keys=(
256-
"images_render",
257-
"masks_render",
258-
"depths_render",
259-
"_all_source_images",
260-
),
261-
fps=20,
262-
video_path="/tmp/video",
263-
resize=None,
264-
):
265-
"""Given a list of predictions create and visualize rotating videos of the
266-
objects using visdom.
267-
"""
268-
assert isinstance(preds, list)
269-
270-
# make sure the target video directory exists
271-
os.makedirs(os.path.dirname(video_path), exist_ok=True)
272-
273-
# init a video writer for each predicted key
274-
vws = {}
275-
for k in predicted_keys:
276-
vws[k] = VideoWriter(out_path=f"{video_path}_{sequence_name}_{k}.mp4", fps=fps)
277-
278-
for rendered_pred in tqdm(preds):
279-
for k in predicted_keys:
280-
vws[k].write_frame(
281-
rendered_pred[k][0].clip(0.0, 1.0).detach().cpu().numpy(),
282-
resize=resize,
283-
)
284-
285-
for k in predicted_keys:
286-
vws[k].get_video(quiet=True)
287-
print(f"Generated {vws[k].out_path}.")
288-
if viz is not None:
289-
viz.video(
290-
videofile=vws[k].out_path,
291-
env=viz_env,
292-
win=k, # we reuse the same window otherwise visdom dies
293-
opts={"title": sequence_name + " " + k},
294-
)
295-
296-
297-
def export_scenes(
28+
def visualize_reconstruction(
29829
exp_dir: str = "",
29930
restrict_sequence_name: Optional[str] = None,
30031
output_directory: Optional[str] = None,
30132
render_size: Tuple[int, int] = (512, 512),
30233
video_size: Optional[Tuple[int, int]] = None,
303-
split: str = "train", # train | val | test
34+
split: str = "train",
30435
n_source_views: int = 9,
30536
n_eval_cameras: int = 40,
306-
visdom_server="http://127.0.0.1",
307-
visdom_port=8097,
30837
visdom_show_preds: bool = False,
38+
visdom_server: str = "http://127.0.0.1",
39+
visdom_port: int = 8097,
30940
visdom_env: Optional[str] = None,
310-
gpu_idx: int = 0,
31141
):
42+
"""
43+
Given an `exp_dir` containing a trained Implicitron model, generates videos consisting
44+
of renderes of sequences from the dataset used to train and evaluate the trained
45+
Implicitron model.
46+
47+
Args:
48+
exp_dir: Implicitron experiment directory.
49+
restrict_sequence_name: If set, defines the list of sequences to visualize.
50+
output_directory: If set, defines a custom directory to output visualizations to.
51+
render_size: The size (HxW) of the generated renders.
52+
video_size: The size (HxW) of the output video.
53+
split: The dataset split to use for visualization.
54+
Can be "train" / "val" / "test".
55+
n_source_views: The number of source views added to each rendered batch. These
56+
views are required inputs for models such as NeRFormer / NeRF-WCE.
57+
n_eval_cameras: The number of cameras each fly-around trajectory.
58+
visdom_show_preds: If `True`, outputs visualizations to visdom.
59+
visdom_server: The address of the visdom server.
60+
visdom_port: The port of the visdom server.
61+
visdom_env: If set, defines a custom name for the visdom environment.
62+
"""
63+
31264
# In case an output directory is specified use it. If no output_directory
31365
# is specified create a vis folder inside the experiment directory
31466
if output_directory is None:
31567
output_directory = os.path.join(exp_dir, "vis")
316-
else:
317-
output_directory = output_directory
318-
if not os.path.exists(output_directory):
319-
os.makedirs(output_directory)
68+
os.makedirs(output_directory, exist_ok=True)
32069

32170
# Set the random seeds
32271
torch.manual_seed(0)
@@ -325,7 +74,6 @@ def export_scenes(
32574
# Get the config from the experiment_directory,
32675
# and overwrite relevant fields
32776
config = _get_config_from_experiment_directory(exp_dir)
328-
config.gpu_idx = gpu_idx
32977
config.exp_dir = exp_dir
33078
# important so that the CO3D dataset gets loaded in full
33179
dataset_args = (
@@ -340,10 +88,6 @@ def export_scenes(
34088
if restrict_sequence_name is not None:
34189
dataset_args.restrict_sequence_name = restrict_sequence_name
34290

343-
# Set up the CUDA env for the visualization
344-
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
345-
os.environ["CUDA_VISIBLE_DEVICES"] = str(config.gpu_idx)
346-
34791
# Load the previously trained model
34892
experiment = Experiment(config)
34993
model = experiment.model_factory(force_resume=True)
@@ -360,17 +104,17 @@ def export_scenes(
360104
# iterate over the sequences in the dataset
361105
for sequence_name in dataset.sequence_names():
362106
with torch.no_grad():
363-
render_sequence(
364-
dataset,
365-
sequence_name,
366-
model,
367-
video_path="{}/video".format(output_directory),
107+
render_flyaround(
108+
dataset=dataset,
109+
sequence_name=sequence_name,
110+
model=model,
111+
output_video_path=os.path.join(output_directory, "video"),
368112
n_source_views=n_source_views,
369113
visdom_show_preds=visdom_show_preds,
370-
n_eval_cameras=n_eval_cameras,
114+
n_flyaround_poses=n_eval_cameras,
371115
visdom_server=visdom_server,
372116
visdom_port=visdom_port,
373-
viz_env=f"visualizer_{config.visdom_env}"
117+
visdom_environment=f"visualizer_{config.visdom_env}"
374118
if visdom_env is None
375119
else visdom_env,
376120
video_resize=video_size,
@@ -384,11 +128,11 @@ def _get_config_from_experiment_directory(experiment_directory):
384128

385129

386130
def main(argv):
387-
# automatically parses arguments of export_scenes
388-
cfg = OmegaConf.create(get_default_args(export_scenes))
131+
# automatically parses arguments of visualize_reconstruction
132+
cfg = OmegaConf.create(get_default_args(visualize_reconstruction))
389133
cfg.update(OmegaConf.from_cli())
390134
with torch.no_grad():
391-
export_scenes(**cfg)
135+
visualize_reconstruction(**cfg)
392136

393137

394138
if __name__ == "__main__":

0 commit comments

Comments
 (0)