Skip to content

Commit 292acc7

Browse files
Update so3 operations for numerical stability
Summary: Replace implementations of `so3_exp_map` and `so3_log_map` in so3.py with existing more-stable implementations. Reviewed By: bottler Differential Revision: D52513319 fbshipit-source-id: fbfc039643fef284d8baa11bab61651964077afe
1 parent 3621a36 commit 292acc7

File tree

2 files changed

+6
-54
lines changed

2 files changed

+6
-54
lines changed

pytorch3d/transforms/so3.py

Lines changed: 6 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Tuple
99

1010
import torch
11+
from pytorch3d.transforms import rotation_conversions
1112

1213
from ..transforms import acos_linear_extrapolation
1314

@@ -160,19 +161,10 @@ def _so3_exp_map(
160161
nrms = (log_rot * log_rot).sum(1)
161162
# phis ... rotation angles
162163
rot_angles = torch.clamp(nrms, eps).sqrt()
163-
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
164-
rot_angles_inv = 1.0 / rot_angles
165-
fac1 = rot_angles_inv * rot_angles.sin()
166-
fac2 = rot_angles_inv * rot_angles_inv * (1.0 - rot_angles.cos())
167164
skews = hat(log_rot)
168165
skews_square = torch.bmm(skews, skews)
169166

170-
R = (
171-
fac1[:, None, None] * skews
172-
# pyre-fixme[16]: `float` has no attribute `__getitem__`.
173-
+ fac2[:, None, None] * skews_square
174-
+ torch.eye(3, dtype=log_rot.dtype, device=log_rot.device)[None]
175-
)
167+
R = rotation_conversions.axis_angle_to_matrix(log_rot)
176168

177169
return R, rot_angles, skews, skews_square
178170

@@ -183,49 +175,23 @@ def so3_log_map(
183175
"""
184176
Convert a batch of 3x3 rotation matrices `R`
185177
to a batch of 3-dimensional matrix logarithms of rotation matrices
186-
The conversion has a singularity around `(R=I)` which is handled
187-
by clamping controlled with the `eps` and `cos_bound` arguments.
178+
The conversion has a singularity around `(R=I)`.
188179
189180
Args:
190181
R: batch of rotation matrices of shape `(minibatch, 3, 3)`.
191-
eps: A float constant handling the conversion singularity.
192-
cos_bound: Clamps the cosine of the rotation angle to
193-
[-1 + cos_bound, 1 - cos_bound] to avoid non-finite outputs/gradients
194-
of the `acos` call when computing `so3_rotation_angle`.
195-
Note that the non-finite outputs/gradients are returned when
196-
the rotation angle is close to 0 or π.
182+
eps: (unused, for backward compatibility)
183+
cos_bound: (unused, for backward compatibility)
197184
198185
Returns:
199186
Batch of logarithms of input rotation matrices
200187
of shape `(minibatch, 3)`.
201-
202-
Raises:
203-
ValueError if `R` is of incorrect shape.
204-
ValueError if `R` has an unexpected trace.
205188
"""
206189

207190
N, dim1, dim2 = R.shape
208191
if dim1 != 3 or dim2 != 3:
209192
raise ValueError("Input has to be a batch of 3x3 Tensors.")
210193

211-
phi = so3_rotation_angle(R, cos_bound=cos_bound, eps=eps)
212-
213-
phi_sin = torch.sin(phi)
214-
215-
# We want to avoid a tiny denominator of phi_factor = phi / (2.0 * phi_sin).
216-
# Hence, for phi_sin.abs() <= 0.5 * eps, we approximate phi_factor with
217-
# 2nd order Taylor expansion: phi_factor = 0.5 + (1.0 / 12) * phi**2
218-
phi_factor = torch.empty_like(phi)
219-
ok_denom = phi_sin.abs() > (0.5 * eps)
220-
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
221-
phi_factor[~ok_denom] = 0.5 + (phi[~ok_denom] ** 2) * (1.0 / 12)
222-
phi_factor[ok_denom] = phi[ok_denom] / (2.0 * phi_sin[ok_denom])
223-
224-
log_rot_hat = phi_factor[:, None, None] * (R - R.permute(0, 2, 1))
225-
226-
log_rot = hat_inv(log_rot_hat)
227-
228-
return log_rot
194+
return rotation_conversions.matrix_to_axis_angle(R)
229195

230196

231197
def hat_inv(h: torch.Tensor) -> torch.Tensor:

tests/test_so3.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -97,20 +97,6 @@ def test_bad_so3_input_value_err(self):
9797
so3_log_map(rot)
9898
self.assertTrue("Input has to be a batch of 3x3 Tensors." in str(err.exception))
9999

100-
# trace of rot definitely bigger than 3 or smaller than -1
101-
rot = torch.cat(
102-
(
103-
torch.rand(size=[5, 3, 3], device=device) + 4.0,
104-
torch.rand(size=[5, 3, 3], device=device) - 3.0,
105-
)
106-
)
107-
with self.assertRaises(ValueError) as err:
108-
so3_log_map(rot)
109-
self.assertTrue(
110-
"A matrix has trace outside valid range [-1-eps,3+eps]."
111-
in str(err.exception)
112-
)
113-
114100
def test_so3_exp_singularity(self, batch_size: int = 100):
115101
"""
116102
Tests whether the `so3_exp_map` is robust to the input vectors

0 commit comments

Comments
 (0)