18
18
from pytorch3d .implicitron .dataset .utils import is_known_frame , is_train_frame
19
19
from pytorch3d .implicitron .models .base_model import ImplicitronRender
20
20
from pytorch3d .implicitron .tools import vis_utils
21
- from pytorch3d .implicitron .tools .camera_utils import volumetric_camera_overlaps
22
21
from pytorch3d .implicitron .tools .image_utils import mask_background
23
22
from pytorch3d .implicitron .tools .metric_utils import calc_psnr , eval_depth , iou , rgb_l1
24
23
from pytorch3d .implicitron .tools .point_cloud_utils import get_rgbd_point_cloud
25
24
from pytorch3d .implicitron .tools .vis_utils import make_depth_image
26
- from pytorch3d .renderer .camera_utils import join_cameras_as_batch
27
- from pytorch3d .renderer .cameras import CamerasBase , PerspectiveCameras
25
+ from pytorch3d .renderer .cameras import PerspectiveCameras
28
26
from pytorch3d .vis .plotly_vis import plot_scene
29
27
from tabulate import tabulate
30
28
@@ -149,7 +147,6 @@ def eval_batch(
149
147
visualize : bool = False ,
150
148
visualize_visdom_env : str = "eval_debug" ,
151
149
break_after_visualising : bool = True ,
152
- source_cameras : Optional [CamerasBase ] = None ,
153
150
) -> Dict [str , Any ]:
154
151
"""
155
152
Produce performance metrics for a single batch of new-view synthesis
@@ -171,8 +168,6 @@ def eval_batch(
171
168
ground truth.
172
169
lpips_model: A pre-trained model for evaluating the LPIPS metric.
173
170
visualize: If True, visualizes the results to Visdom.
174
- source_cameras: A list of all training cameras for evaluating the
175
- difficulty of the target views.
176
171
177
172
Returns:
178
173
results: A dictionary holding evaluation metrics.
@@ -365,16 +360,7 @@ def eval_batch(
365
360
# convert all metrics to floats
366
361
results = {k : float (v ) for k , v in results .items ()}
367
362
368
- if source_cameras is None :
369
- # pyre-fixme[16]: Optional has no attribute __getitem__
370
- source_cameras = frame_data .camera [torch .where (is_known )[0 ]]
371
-
372
363
results ["meta" ] = {
373
- # calculate the camera difficulties and add to results
374
- "camera_difficulty" : calculate_camera_difficulties (
375
- frame_data .camera [0 ],
376
- source_cameras ,
377
- )[0 ].item (),
378
364
# store the size of the batch (corresponds to n_src_views+1)
379
365
"batch_size" : int (is_known .numel ()),
380
366
# store the type of the target frame
@@ -406,33 +392,6 @@ def average_per_batch_results(
406
392
}
407
393
408
394
409
- def calculate_camera_difficulties (
410
- cameras_target : CamerasBase ,
411
- cameras_source : CamerasBase ,
412
- ) -> torch .Tensor :
413
- """
414
- Calculate the difficulties of the target cameras, given a set of known
415
- cameras `cameras_source`.
416
-
417
- Returns:
418
- a tensor of shape (len(cameras_target),)
419
- """
420
- ious = [
421
- volumetric_camera_overlaps (
422
- join_cameras_as_batch (
423
- # pyre-fixme[6]: Expected `CamerasBase` for 1st param but got
424
- # `Optional[pytorch3d.renderer.utils.TensorProperties]`.
425
- [cameras_target [cami ], cameras_source .to (cameras_target .device )]
426
- )
427
- )[0 , :]
428
- for cami in range (cameras_target .R .shape [0 ])
429
- ]
430
- camera_difficulties = torch .stack (
431
- [_reduce_camera_iou_overlap (iou [1 :]) for iou in ious ]
432
- )
433
- return camera_difficulties
434
-
435
-
436
395
def _reduce_camera_iou_overlap (ious : torch .Tensor , topk : int = 2 ) -> torch .Tensor :
437
396
"""
438
397
Calculate the final camera difficulty by computing the average of the
@@ -458,8 +417,7 @@ def _get_camera_difficulty_bin_edges(camera_difficulty_bin_breaks: Tuple[float,
458
417
def summarize_nvs_eval_results (
459
418
per_batch_eval_results : List [Dict [str , Any ]],
460
419
is_multisequence : bool ,
461
- camera_difficulty_bin_breaks : Tuple [float , float ],
462
- ):
420
+ ) -> Tuple [Dict [str , Any ], Dict [str , Any ]]:
463
421
"""
464
422
Compile the per-batch evaluation results `per_batch_eval_results` into
465
423
a set of aggregate metrics. The produced metrics depend on is_multisequence.
@@ -482,19 +440,12 @@ def summarize_nvs_eval_results(
482
440
batch_sizes = torch .tensor (
483
441
[r ["meta" ]["batch_size" ] for r in per_batch_eval_results ]
484
442
).long ()
485
- camera_difficulty = torch .tensor (
486
- [r ["meta" ]["camera_difficulty" ] for r in per_batch_eval_results ]
487
- ).float ()
443
+
488
444
is_train = is_train_frame ([r ["meta" ]["frame_type" ] for r in per_batch_eval_results ])
489
445
490
446
# init the result database dict
491
447
results = []
492
448
493
- diff_bin_edges , diff_bin_names = _get_camera_difficulty_bin_edges (
494
- camera_difficulty_bin_breaks
495
- )
496
- n_diff_edges = diff_bin_edges .numel ()
497
-
498
449
# add per set averages
499
450
for SET in eval_sets :
500
451
if SET is None :
@@ -504,26 +455,17 @@ def summarize_nvs_eval_results(
504
455
ok_set = is_train == int (SET == "train" )
505
456
set_name = SET
506
457
507
- # eval each difficulty bin, including a full average result (diff_bin=None)
508
- for diff_bin in [None , * list (range (n_diff_edges - 1 ))]:
509
- if diff_bin is None :
510
- # average over all results
511
- in_bin = ok_set
512
- diff_bin_name = "all"
513
- else :
514
- b1 , b2 = diff_bin_edges [diff_bin : (diff_bin + 2 )]
515
- in_bin = ok_set & (camera_difficulty > b1 ) & (camera_difficulty <= b2 )
516
- diff_bin_name = diff_bin_names [diff_bin ]
517
- bin_results = average_per_batch_results (
518
- per_batch_eval_results , idx = torch .where (in_bin )[0 ]
519
- )
520
- results .append (
521
- {
522
- "subset" : set_name ,
523
- "subsubset" : f"diff={ diff_bin_name } " ,
524
- "metrics" : bin_results ,
525
- }
526
- )
458
+ # average over all results
459
+ bin_results = average_per_batch_results (
460
+ per_batch_eval_results , idx = torch .where (ok_set )[0 ]
461
+ )
462
+ results .append (
463
+ {
464
+ "subset" : set_name ,
465
+ "subsubset" : "diff=all" ,
466
+ "metrics" : bin_results ,
467
+ }
468
+ )
527
469
528
470
if is_multisequence :
529
471
# split based on n_src_views
@@ -552,7 +494,7 @@ def _get_flat_nvs_metric_key(result, metric_name) -> str:
552
494
return metric_key
553
495
554
496
555
- def flatten_nvs_results (results ):
497
+ def flatten_nvs_results (results ) -> Dict [ str , Any ] :
556
498
"""
557
499
Takes input `results` list of dicts of the form::
558
500
@@ -571,7 +513,6 @@ def flatten_nvs_results(results):
571
513
'subset=train/test/...|subsubset=src=1/src=2/...': nvs_eval_metrics,
572
514
...
573
515
}
574
-
575
516
"""
576
517
results_flat = {}
577
518
for result in results :
0 commit comments