Skip to content

Commit bee31c4

Browse files
una-dinosauriafacebook-github-bot
una-dinosauria
authored andcommitted
Make some matrix conversion jittable (#898)
Summary: Make sure the functions from `rotation_conversion` are jittable, and add some type hints. Add tests to verify this is the case. Pull Request resolved: #898 Reviewed By: patricklabatut Differential Revision: D31926103 Pulled By: bottler fbshipit-source-id: bff6013c5ca2d452e37e631bd902f0674d5ca091
1 parent 29417d1 commit bee31c4

File tree

2 files changed

+61
-30
lines changed

2 files changed

+61
-30
lines changed

pytorch3d/transforms/rotation_conversions.py

Lines changed: 41 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import functools
87
from typing import Optional
98

109
import torch
@@ -39,7 +38,7 @@
3938
"""
4039

4140

42-
def quaternion_to_matrix(quaternions):
41+
def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
4342
"""
4443
Convert rotations given as quaternions to rotation matrices.
4544
@@ -70,7 +69,7 @@ def quaternion_to_matrix(quaternions):
7069
return o.reshape(quaternions.shape[:-1] + (3, 3))
7170

7271

73-
def _copysign(a, b):
72+
def _copysign(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
7473
"""
7574
Return a tensor where each element has the absolute value taken from the,
7675
corresponding element of a, with sign taken from the corresponding
@@ -114,7 +113,7 @@ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
114113

115114
batch_dim = matrix.shape[:-2]
116115
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
117-
matrix.reshape(*batch_dim, 9), dim=-1
116+
matrix.reshape(batch_dim + (9,)), dim=-1
118117
)
119118

120119
q_abs = _sqrt_positive_part(
@@ -142,17 +141,18 @@ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
142141

143142
# We floor here at 0.1 but the exact level is not important; if q_abs is small,
144143
# the candidate won't be picked.
145-
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(q_abs.new_tensor(0.1)))
144+
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
145+
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
146146

147147
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
148148
# forall i; we pick the best-conditioned one (with the largest denominator)
149149

150150
return quat_candidates[
151151
F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : # pyre-ignore[16]
152-
].reshape(*batch_dim, 4)
152+
].reshape(batch_dim + (4,))
153153

154154

155-
def _axis_angle_rotation(axis: str, angle):
155+
def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor:
156156
"""
157157
Return the rotation matrices for one of the rotations about an axis
158158
of which Euler angles describe, for each value of the angle given.
@@ -172,15 +172,17 @@ def _axis_angle_rotation(axis: str, angle):
172172

173173
if axis == "X":
174174
R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
175-
if axis == "Y":
175+
elif axis == "Y":
176176
R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
177-
if axis == "Z":
177+
elif axis == "Z":
178178
R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
179+
else:
180+
raise ValueError("letter must be either X, Y or Z.")
179181

180182
return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
181183

182184

183-
def euler_angles_to_matrix(euler_angles, convention: str):
185+
def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor:
184186
"""
185187
Convert rotations given as Euler angles in radians to rotation matrices.
186188
@@ -201,13 +203,17 @@ def euler_angles_to_matrix(euler_angles, convention: str):
201203
for letter in convention:
202204
if letter not in ("X", "Y", "Z"):
203205
raise ValueError(f"Invalid letter {letter} in convention string.")
204-
matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1))
205-
return functools.reduce(torch.matmul, matrices)
206+
matrices = [
207+
_axis_angle_rotation(c, e)
208+
for c, e in zip(convention, torch.unbind(euler_angles, -1))
209+
]
210+
# return functools.reduce(torch.matmul, matrices)
211+
return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2])
206212

207213

208214
def _angle_from_tan(
209215
axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
210-
):
216+
) -> torch.Tensor:
211217
"""
212218
Extract the first or third Euler angle from the two members of
213219
the matrix which are positive constant times its sine and cosine.
@@ -238,16 +244,17 @@ def _angle_from_tan(
238244
return torch.atan2(data[..., i2], -data[..., i1])
239245

240246

241-
def _index_from_letter(letter: str):
247+
def _index_from_letter(letter: str) -> int:
242248
if letter == "X":
243249
return 0
244250
if letter == "Y":
245251
return 1
246252
if letter == "Z":
247253
return 2
254+
raise ValueError("letter must be either X, Y or Z.")
248255

249256

250-
def matrix_to_euler_angles(matrix, convention: str):
257+
def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tensor:
251258
"""
252259
Convert rotations given as rotation matrices to Euler angles in radians.
253260
@@ -291,7 +298,7 @@ def matrix_to_euler_angles(matrix, convention: str):
291298

292299
def random_quaternions(
293300
n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
294-
):
301+
) -> torch.Tensor:
295302
"""
296303
Generate random quaternions representing rotations,
297304
i.e. versors with nonnegative real part.
@@ -305,6 +312,8 @@ def random_quaternions(
305312
Returns:
306313
Quaternions as tensor of shape (N, 4).
307314
"""
315+
if isinstance(device, str):
316+
device = torch.device(device)
308317
o = torch.randn((n, 4), dtype=dtype, device=device)
309318
s = (o * o).sum(1)
310319
o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
@@ -313,7 +322,7 @@ def random_quaternions(
313322

314323
def random_rotations(
315324
n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
316-
):
325+
) -> torch.Tensor:
317326
"""
318327
Generate random rotations as 3x3 rotation matrices.
319328
@@ -332,7 +341,7 @@ def random_rotations(
332341

333342
def random_rotation(
334343
dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
335-
):
344+
) -> torch.Tensor:
336345
"""
337346
Generate a single random 3x3 rotation matrix.
338347
@@ -347,7 +356,7 @@ def random_rotation(
347356
return random_rotations(1, dtype, device)[0]
348357

349358

350-
def standardize_quaternion(quaternions):
359+
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
351360
"""
352361
Convert a unit quaternion to a standard form: one in which the real
353362
part is non negative.
@@ -362,7 +371,7 @@ def standardize_quaternion(quaternions):
362371
return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
363372

364373

365-
def quaternion_raw_multiply(a, b):
374+
def quaternion_raw_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
366375
"""
367376
Multiply two quaternions.
368377
Usual torch rules for broadcasting apply.
@@ -383,7 +392,7 @@ def quaternion_raw_multiply(a, b):
383392
return torch.stack((ow, ox, oy, oz), -1)
384393

385394

386-
def quaternion_multiply(a, b):
395+
def quaternion_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
387396
"""
388397
Multiply two quaternions representing rotations, returning the quaternion
389398
representing their composition, i.e. the versor with nonnegative real part.
@@ -400,7 +409,7 @@ def quaternion_multiply(a, b):
400409
return standardize_quaternion(ab)
401410

402411

403-
def quaternion_invert(quaternion):
412+
def quaternion_invert(quaternion: torch.Tensor) -> torch.Tensor:
404413
"""
405414
Given a quaternion representing rotation, get the quaternion representing
406415
its inverse.
@@ -413,10 +422,11 @@ def quaternion_invert(quaternion):
413422
The inverse, a tensor of quaternions of shape (..., 4).
414423
"""
415424

416-
return quaternion * quaternion.new_tensor([1, -1, -1, -1])
425+
scaling = torch.tensor([1, -1, -1, -1], device=quaternion.device)
426+
return quaternion * scaling
417427

418428

419-
def quaternion_apply(quaternion, point):
429+
def quaternion_apply(quaternion: torch.Tensor, point: torch.Tensor) -> torch.Tensor:
420430
"""
421431
Apply the rotation given by a quaternion to a 3D point.
422432
Usual torch rules for broadcasting apply.
@@ -439,7 +449,7 @@ def quaternion_apply(quaternion, point):
439449
return out[..., 1:]
440450

441451

442-
def axis_angle_to_matrix(axis_angle):
452+
def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:
443453
"""
444454
Convert rotations given as axis/angle to rotation matrices.
445455
@@ -455,7 +465,7 @@ def axis_angle_to_matrix(axis_angle):
455465
return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
456466

457467

458-
def matrix_to_axis_angle(matrix):
468+
def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor:
459469
"""
460470
Convert rotations given as rotation matrices to axis/angle.
461471
@@ -471,7 +481,7 @@ def matrix_to_axis_angle(matrix):
471481
return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
472482

473483

474-
def axis_angle_to_quaternion(axis_angle):
484+
def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
475485
"""
476486
Convert rotations given as axis/angle to quaternions.
477487
@@ -485,7 +495,7 @@ def axis_angle_to_quaternion(axis_angle):
485495
quaternions with real part first, as tensor of shape (..., 4).
486496
"""
487497
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
488-
half_angles = 0.5 * angles
498+
half_angles = angles * 0.5
489499
eps = 1e-6
490500
small_angles = angles.abs() < eps
491501
sin_half_angles_over_angles = torch.empty_like(angles)
@@ -503,7 +513,7 @@ def axis_angle_to_quaternion(axis_angle):
503513
return quaternions
504514

505515

506-
def quaternion_to_axis_angle(quaternions):
516+
def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
507517
"""
508518
Convert rotations given as quaternions to axis/angle.
509519
@@ -573,4 +583,5 @@ def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
573583
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
574584
Retrieved from http://arxiv.org/abs/1812.07035
575585
"""
576-
return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6)
586+
batch_dim = matrix.size()[:-2]
587+
return matrix[..., :2, :].clone().reshape(batch_dim + (6,))

tests/test_rotation_conversions.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import itertools
99
import math
1010
import unittest
11+
from distutils.version import LooseVersion
1112
from typing import Optional, Union
1213

1314
import numpy as np
@@ -264,6 +265,25 @@ def test_6d(self):
264265
torch.matmul(r, r.permute(0, 2, 1)), torch.eye(3).expand_as(r), atol=1e-6
265266
)
266267

268+
@unittest.skipIf(LooseVersion(torch.__version__) < "1.9", "recent torchscript only")
269+
def test_scriptable(self):
270+
torch.jit.script(axis_angle_to_matrix)
271+
torch.jit.script(axis_angle_to_quaternion)
272+
torch.jit.script(euler_angles_to_matrix)
273+
torch.jit.script(matrix_to_axis_angle)
274+
torch.jit.script(matrix_to_euler_angles)
275+
torch.jit.script(matrix_to_quaternion)
276+
torch.jit.script(matrix_to_rotation_6d)
277+
torch.jit.script(quaternion_apply)
278+
torch.jit.script(quaternion_multiply)
279+
torch.jit.script(quaternion_to_matrix)
280+
torch.jit.script(quaternion_to_axis_angle)
281+
torch.jit.script(random_quaternions)
282+
torch.jit.script(random_rotation)
283+
torch.jit.script(random_rotations)
284+
torch.jit.script(random_quaternions)
285+
torch.jit.script(rotation_6d_to_matrix)
286+
267287
def _assert_quaternions_close(
268288
self,
269289
input: Union[torch.Tensor, np.ndarray],

0 commit comments

Comments
 (0)