8
8
from typing import Tuple
9
9
10
10
import torch
11
+ from pytorch3d .transforms import rotation_conversions
11
12
12
13
from ..transforms import acos_linear_extrapolation
13
14
@@ -160,19 +161,10 @@ def _so3_exp_map(
160
161
nrms = (log_rot * log_rot ).sum (1 )
161
162
# phis ... rotation angles
162
163
rot_angles = torch .clamp (nrms , eps ).sqrt ()
163
- # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
164
- rot_angles_inv = 1.0 / rot_angles
165
- fac1 = rot_angles_inv * rot_angles .sin ()
166
- fac2 = rot_angles_inv * rot_angles_inv * (1.0 - rot_angles .cos ())
167
164
skews = hat (log_rot )
168
165
skews_square = torch .bmm (skews , skews )
169
166
170
- R = (
171
- fac1 [:, None , None ] * skews
172
- # pyre-fixme[16]: `float` has no attribute `__getitem__`.
173
- + fac2 [:, None , None ] * skews_square
174
- + torch .eye (3 , dtype = log_rot .dtype , device = log_rot .device )[None ]
175
- )
167
+ R = rotation_conversions .axis_angle_to_matrix (log_rot )
176
168
177
169
return R , rot_angles , skews , skews_square
178
170
@@ -183,49 +175,23 @@ def so3_log_map(
183
175
"""
184
176
Convert a batch of 3x3 rotation matrices `R`
185
177
to a batch of 3-dimensional matrix logarithms of rotation matrices
186
- The conversion has a singularity around `(R=I)` which is handled
187
- by clamping controlled with the `eps` and `cos_bound` arguments.
178
+ The conversion has a singularity around `(R=I)`.
188
179
189
180
Args:
190
181
R: batch of rotation matrices of shape `(minibatch, 3, 3)`.
191
- eps: A float constant handling the conversion singularity.
192
- cos_bound: Clamps the cosine of the rotation angle to
193
- [-1 + cos_bound, 1 - cos_bound] to avoid non-finite outputs/gradients
194
- of the `acos` call when computing `so3_rotation_angle`.
195
- Note that the non-finite outputs/gradients are returned when
196
- the rotation angle is close to 0 or π.
182
+ eps: (unused, for backward compatibility)
183
+ cos_bound: (unused, for backward compatibility)
197
184
198
185
Returns:
199
186
Batch of logarithms of input rotation matrices
200
187
of shape `(minibatch, 3)`.
201
-
202
- Raises:
203
- ValueError if `R` is of incorrect shape.
204
- ValueError if `R` has an unexpected trace.
205
188
"""
206
189
207
190
N , dim1 , dim2 = R .shape
208
191
if dim1 != 3 or dim2 != 3 :
209
192
raise ValueError ("Input has to be a batch of 3x3 Tensors." )
210
193
211
- phi = so3_rotation_angle (R , cos_bound = cos_bound , eps = eps )
212
-
213
- phi_sin = torch .sin (phi )
214
-
215
- # We want to avoid a tiny denominator of phi_factor = phi / (2.0 * phi_sin).
216
- # Hence, for phi_sin.abs() <= 0.5 * eps, we approximate phi_factor with
217
- # 2nd order Taylor expansion: phi_factor = 0.5 + (1.0 / 12) * phi**2
218
- phi_factor = torch .empty_like (phi )
219
- ok_denom = phi_sin .abs () > (0.5 * eps )
220
- # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
221
- phi_factor [~ ok_denom ] = 0.5 + (phi [~ ok_denom ] ** 2 ) * (1.0 / 12 )
222
- phi_factor [ok_denom ] = phi [ok_denom ] / (2.0 * phi_sin [ok_denom ])
223
-
224
- log_rot_hat = phi_factor [:, None , None ] * (R - R .permute (0 , 2 , 1 ))
225
-
226
- log_rot = hat_inv (log_rot_hat )
227
-
228
- return log_rot
194
+ return rotation_conversions .matrix_to_axis_angle (R )
229
195
230
196
231
197
def hat_inv (h : torch .Tensor ) -> torch .Tensor :
0 commit comments