6
6
7
7
from __future__ import annotations
8
8
9
- import dataclasses
10
-
11
9
from abc import ABC , abstractmethod
12
10
from dataclasses import dataclass , field
13
11
from enum import Enum
@@ -29,7 +27,6 @@ class RenderSamplingMode(Enum):
29
27
FULL_GRID = "full_grid"
30
28
31
29
32
- @dataclasses .dataclass
33
30
class ImplicitronRayBundle :
34
31
"""
35
32
Parametrizes points along projection rays by storing ray `origins`,
@@ -69,53 +66,58 @@ class ImplicitronRayBundle:
69
66
lengths should be equal to the midpoints of bins `(..., num_points_per_ray)`.
70
67
pixel_radii_2d: An optional tensor of shape `(..., 1)`
71
68
base radii of the conical frustums.
69
+
70
+ Raises:
71
+ ValueError: If either bins or lengths are not provided.
72
+ ValueError: If bins is provided and the last dim is inferior or equal to 1.
72
73
"""
73
74
74
- origins : torch .Tensor
75
- directions : torch .Tensor
76
- lengths : torch .Tensor
77
- xys : torch .Tensor
78
- camera_ids : Optional [torch .LongTensor ] = None
79
- camera_counts : Optional [torch .LongTensor ] = None
80
- bins : Optional [torch .Tensor ] = None
81
- pixel_radii_2d : Optional [torch .Tensor ] = None
82
-
83
- @classmethod
84
- def from_bins (
85
- cls ,
75
+ def __init__ (
76
+ self ,
86
77
origins : torch .Tensor ,
87
78
directions : torch .Tensor ,
88
- bins : torch .Tensor ,
79
+ lengths : Optional [ torch .Tensor ] ,
89
80
xys : torch .Tensor ,
90
- ** kwargs ,
91
- ) -> "ImplicitronRayBundle" :
92
- """
93
- Creates a new instance from bins instead of lengths.
94
-
95
- Attributes:
96
- origins: A tensor of shape `(..., 3)` denoting the
97
- origins of the sampling rays in world coords.
98
- directions: A tensor of shape `(..., 3)` containing the direction
99
- vectors of sampling rays in world coords. They don't have to be normalized;
100
- they define unit vectors in the respective 1D coordinate systems; see
101
- documentation for :func:`ray_bundle_to_ray_points` for the conversion formula.
102
- bins: A tensor of shape `(..., num_points_per_ray + 1)`
103
- containing the bins at which the rays are sampled. In this case
104
- lengths is equal to the midpoints of bins `(..., num_points_per_ray)`.
105
- xys: A tensor of shape `(..., 2)`, the xy-locations (`xys`) of the ray pixels
106
- kwargs: Additional arguments passed to the constructor of ImplicitronRayBundle
107
- Returns:
108
- An instance of ImplicitronRayBundle.
109
- """
110
-
111
- if bins .shape [- 1 ] <= 1 :
81
+ camera_ids : Optional [torch .LongTensor ] = None ,
82
+ camera_counts : Optional [torch .LongTensor ] = None ,
83
+ bins : Optional [torch .Tensor ] = None ,
84
+ pixel_radii_2d : Optional [torch .Tensor ] = None ,
85
+ ):
86
+ if bins is not None and bins .shape [- 1 ] <= 1 :
112
87
raise ValueError (
113
88
"The last dim of bins must be at least superior or equal to 2."
114
89
)
115
- # equivalent to: 0.5 * (bins[..., 1:] + bins[..., :-1]) but more efficient
116
- lengths = torch .lerp (bins [..., 1 :], bins [..., :- 1 ], 0.5 )
117
90
118
- return cls (origins , directions , lengths , xys , bins = bins , ** kwargs )
91
+ if bins is None and lengths is None :
92
+ raise ValueError (
93
+ "Please set either bins or lengths to initialize an ImplicitronRayBundle."
94
+ )
95
+
96
+ self .origins = origins
97
+ self .directions = directions
98
+ self ._lengths = lengths if bins is None else None
99
+ self .xys = xys
100
+ self .bins = bins
101
+ self .pixel_radii_2d = pixel_radii_2d
102
+ self .camera_ids = camera_ids
103
+ self .camera_counts = camera_counts
104
+
105
+ @property
106
+ def lengths (self ) -> torch .Tensor :
107
+ if self .bins is not None :
108
+ # equivalent to: 0.5 * (bins[..., 1:] + bins[..., :-1]) but more efficient
109
+ # pyre-ignore
110
+ return torch .lerp (self .bins [..., :- 1 ], self .bins [..., 1 :], 0.5 )
111
+ return self ._lengths
112
+
113
+ @lengths .setter
114
+ def lengths (self , value ):
115
+ if self .bins is not None :
116
+ raise ValueError (
117
+ "If the bins attribute is not None you cannot set the lengths attribute."
118
+ )
119
+ else :
120
+ self ._lengths = value
119
121
120
122
def is_packed (self ) -> bool :
121
123
"""
0 commit comments