-
Notifications
You must be signed in to change notification settings - Fork 52
/
Copy pathbaseline.py
43 lines (35 loc) · 1.78 KB
/
baseline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from typing import *
import click
import torch
class MGEBaselineInterface:
"""
Abstract class for model wrapper to uniformize the interface of loading and inference across different models.
"""
device: torch.device
@click.command()
@staticmethod
def load(*args, **kwargs) -> "MGEBaselineInterface":
"""
Customized static method to create an instance of the model wrapper from command line arguments. Decorated by `click.command()`
"""
raise NotImplementedError(f"{type(self).__name__} has not implemented the load method.")
def infer(self, image: torch.FloatTensor, intrinsics: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
"""
### Parameters
`image`: [B, 3, H, W] or [3, H, W], RGB values in range [0, 1]
`intrinsics`: [B, 3, 3] or [3, 3], camera intrinsics. Optional.
### Returns
A dictionary containing:
- `points_*`. point map output in OpenCV identity camera space.
Supported suffixes: `metric`, `scale_invariant`, `affine_invariant`.
- `depth_*`. depth map output
Supported suffixes: `metric` (in meters), `scale_invariant`, `affine_invariant`.
- `disparity_affine_invariant`. affine disparity map output
"""
raise NotImplementedError(f"{type(self).__name__} has not implemented the infer method.")
def infer_for_evaluation(self, image: torch.FloatTensor, intrinsics: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
"""
If the model has a special evaluation mode, override this method to provide the evaluation mode inference.
By default, this method simply calls `infer()`.
"""
return self.infer(image, intrinsics)