4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
- import functools
8
7
from typing import Optional
9
8
10
9
import torch
39
38
"""
40
39
41
40
42
- def quaternion_to_matrix (quaternions ) :
41
+ def quaternion_to_matrix (quaternions : torch . Tensor ) -> torch . Tensor :
43
42
"""
44
43
Convert rotations given as quaternions to rotation matrices.
45
44
@@ -70,7 +69,7 @@ def quaternion_to_matrix(quaternions):
70
69
return o .reshape (quaternions .shape [:- 1 ] + (3 , 3 ))
71
70
72
71
73
- def _copysign (a , b ) :
72
+ def _copysign (a : torch . Tensor , b : torch . Tensor ) -> torch . Tensor :
74
73
"""
75
74
Return a tensor where each element has the absolute value taken from the,
76
75
corresponding element of a, with sign taken from the corresponding
@@ -114,7 +113,7 @@ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
114
113
115
114
batch_dim = matrix .shape [:- 2 ]
116
115
m00 , m01 , m02 , m10 , m11 , m12 , m20 , m21 , m22 = torch .unbind (
117
- matrix .reshape (* batch_dim , 9 ), dim = - 1
116
+ matrix .reshape (batch_dim + ( 9 ,) ), dim = - 1
118
117
)
119
118
120
119
q_abs = _sqrt_positive_part (
@@ -142,17 +141,18 @@ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
142
141
143
142
# We floor here at 0.1 but the exact level is not important; if q_abs is small,
144
143
# the candidate won't be picked.
145
- quat_candidates = quat_by_rijk / (2.0 * q_abs [..., None ].max (q_abs .new_tensor (0.1 )))
144
+ flr = torch .tensor (0.1 ).to (dtype = q_abs .dtype , device = q_abs .device )
145
+ quat_candidates = quat_by_rijk / (2.0 * q_abs [..., None ].max (flr ))
146
146
147
147
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
148
148
# forall i; we pick the best-conditioned one (with the largest denominator)
149
149
150
150
return quat_candidates [
151
151
F .one_hot (q_abs .argmax (dim = - 1 ), num_classes = 4 ) > 0.5 , : # pyre-ignore[16]
152
- ].reshape (* batch_dim , 4 )
152
+ ].reshape (batch_dim + ( 4 ,) )
153
153
154
154
155
- def _axis_angle_rotation (axis : str , angle ) :
155
+ def _axis_angle_rotation (axis : str , angle : torch . Tensor ) -> torch . Tensor :
156
156
"""
157
157
Return the rotation matrices for one of the rotations about an axis
158
158
of which Euler angles describe, for each value of the angle given.
@@ -172,15 +172,17 @@ def _axis_angle_rotation(axis: str, angle):
172
172
173
173
if axis == "X" :
174
174
R_flat = (one , zero , zero , zero , cos , - sin , zero , sin , cos )
175
- if axis == "Y" :
175
+ elif axis == "Y" :
176
176
R_flat = (cos , zero , sin , zero , one , zero , - sin , zero , cos )
177
- if axis == "Z" :
177
+ elif axis == "Z" :
178
178
R_flat = (cos , - sin , zero , sin , cos , zero , zero , zero , one )
179
+ else :
180
+ raise ValueError ("letter must be either X, Y or Z." )
179
181
180
182
return torch .stack (R_flat , - 1 ).reshape (angle .shape + (3 , 3 ))
181
183
182
184
183
- def euler_angles_to_matrix (euler_angles , convention : str ):
185
+ def euler_angles_to_matrix (euler_angles : torch . Tensor , convention : str ) -> torch . Tensor :
184
186
"""
185
187
Convert rotations given as Euler angles in radians to rotation matrices.
186
188
@@ -201,13 +203,17 @@ def euler_angles_to_matrix(euler_angles, convention: str):
201
203
for letter in convention :
202
204
if letter not in ("X" , "Y" , "Z" ):
203
205
raise ValueError (f"Invalid letter { letter } in convention string." )
204
- matrices = map (_axis_angle_rotation , convention , torch .unbind (euler_angles , - 1 ))
205
- return functools .reduce (torch .matmul , matrices )
206
+ matrices = [
207
+ _axis_angle_rotation (c , e )
208
+ for c , e in zip (convention , torch .unbind (euler_angles , - 1 ))
209
+ ]
210
+ # return functools.reduce(torch.matmul, matrices)
211
+ return torch .matmul (torch .matmul (matrices [0 ], matrices [1 ]), matrices [2 ])
206
212
207
213
208
214
def _angle_from_tan (
209
215
axis : str , other_axis : str , data , horizontal : bool , tait_bryan : bool
210
- ):
216
+ ) -> torch . Tensor :
211
217
"""
212
218
Extract the first or third Euler angle from the two members of
213
219
the matrix which are positive constant times its sine and cosine.
@@ -238,16 +244,17 @@ def _angle_from_tan(
238
244
return torch .atan2 (data [..., i2 ], - data [..., i1 ])
239
245
240
246
241
- def _index_from_letter (letter : str ):
247
+ def _index_from_letter (letter : str ) -> int :
242
248
if letter == "X" :
243
249
return 0
244
250
if letter == "Y" :
245
251
return 1
246
252
if letter == "Z" :
247
253
return 2
254
+ raise ValueError ("letter must be either X, Y or Z." )
248
255
249
256
250
- def matrix_to_euler_angles (matrix , convention : str ):
257
+ def matrix_to_euler_angles (matrix : torch . Tensor , convention : str ) -> torch . Tensor :
251
258
"""
252
259
Convert rotations given as rotation matrices to Euler angles in radians.
253
260
@@ -291,7 +298,7 @@ def matrix_to_euler_angles(matrix, convention: str):
291
298
292
299
def random_quaternions (
293
300
n : int , dtype : Optional [torch .dtype ] = None , device : Optional [Device ] = None
294
- ):
301
+ ) -> torch . Tensor :
295
302
"""
296
303
Generate random quaternions representing rotations,
297
304
i.e. versors with nonnegative real part.
@@ -305,6 +312,8 @@ def random_quaternions(
305
312
Returns:
306
313
Quaternions as tensor of shape (N, 4).
307
314
"""
315
+ if isinstance (device , str ):
316
+ device = torch .device (device )
308
317
o = torch .randn ((n , 4 ), dtype = dtype , device = device )
309
318
s = (o * o ).sum (1 )
310
319
o = o / _copysign (torch .sqrt (s ), o [:, 0 ])[:, None ]
@@ -313,7 +322,7 @@ def random_quaternions(
313
322
314
323
def random_rotations (
315
324
n : int , dtype : Optional [torch .dtype ] = None , device : Optional [Device ] = None
316
- ):
325
+ ) -> torch . Tensor :
317
326
"""
318
327
Generate random rotations as 3x3 rotation matrices.
319
328
@@ -332,7 +341,7 @@ def random_rotations(
332
341
333
342
def random_rotation (
334
343
dtype : Optional [torch .dtype ] = None , device : Optional [Device ] = None
335
- ):
344
+ ) -> torch . Tensor :
336
345
"""
337
346
Generate a single random 3x3 rotation matrix.
338
347
@@ -347,7 +356,7 @@ def random_rotation(
347
356
return random_rotations (1 , dtype , device )[0 ]
348
357
349
358
350
- def standardize_quaternion (quaternions ) :
359
+ def standardize_quaternion (quaternions : torch . Tensor ) -> torch . Tensor :
351
360
"""
352
361
Convert a unit quaternion to a standard form: one in which the real
353
362
part is non negative.
@@ -362,7 +371,7 @@ def standardize_quaternion(quaternions):
362
371
return torch .where (quaternions [..., 0 :1 ] < 0 , - quaternions , quaternions )
363
372
364
373
365
- def quaternion_raw_multiply (a , b ) :
374
+ def quaternion_raw_multiply (a : torch . Tensor , b : torch . Tensor ) -> torch . Tensor :
366
375
"""
367
376
Multiply two quaternions.
368
377
Usual torch rules for broadcasting apply.
@@ -383,7 +392,7 @@ def quaternion_raw_multiply(a, b):
383
392
return torch .stack ((ow , ox , oy , oz ), - 1 )
384
393
385
394
386
- def quaternion_multiply (a , b ) :
395
+ def quaternion_multiply (a : torch . Tensor , b : torch . Tensor ) -> torch . Tensor :
387
396
"""
388
397
Multiply two quaternions representing rotations, returning the quaternion
389
398
representing their composition, i.e. the versor with nonnegative real part.
@@ -400,7 +409,7 @@ def quaternion_multiply(a, b):
400
409
return standardize_quaternion (ab )
401
410
402
411
403
- def quaternion_invert (quaternion ) :
412
+ def quaternion_invert (quaternion : torch . Tensor ) -> torch . Tensor :
404
413
"""
405
414
Given a quaternion representing rotation, get the quaternion representing
406
415
its inverse.
@@ -413,10 +422,11 @@ def quaternion_invert(quaternion):
413
422
The inverse, a tensor of quaternions of shape (..., 4).
414
423
"""
415
424
416
- return quaternion * quaternion .new_tensor ([1 , - 1 , - 1 , - 1 ])
425
+ scaling = torch .tensor ([1 , - 1 , - 1 , - 1 ], device = quaternion .device )
426
+ return quaternion * scaling
417
427
418
428
419
- def quaternion_apply (quaternion , point ) :
429
+ def quaternion_apply (quaternion : torch . Tensor , point : torch . Tensor ) -> torch . Tensor :
420
430
"""
421
431
Apply the rotation given by a quaternion to a 3D point.
422
432
Usual torch rules for broadcasting apply.
@@ -439,7 +449,7 @@ def quaternion_apply(quaternion, point):
439
449
return out [..., 1 :]
440
450
441
451
442
- def axis_angle_to_matrix (axis_angle ) :
452
+ def axis_angle_to_matrix (axis_angle : torch . Tensor ) -> torch . Tensor :
443
453
"""
444
454
Convert rotations given as axis/angle to rotation matrices.
445
455
@@ -455,7 +465,7 @@ def axis_angle_to_matrix(axis_angle):
455
465
return quaternion_to_matrix (axis_angle_to_quaternion (axis_angle ))
456
466
457
467
458
- def matrix_to_axis_angle (matrix ) :
468
+ def matrix_to_axis_angle (matrix : torch . Tensor ) -> torch . Tensor :
459
469
"""
460
470
Convert rotations given as rotation matrices to axis/angle.
461
471
@@ -471,7 +481,7 @@ def matrix_to_axis_angle(matrix):
471
481
return quaternion_to_axis_angle (matrix_to_quaternion (matrix ))
472
482
473
483
474
- def axis_angle_to_quaternion (axis_angle ) :
484
+ def axis_angle_to_quaternion (axis_angle : torch . Tensor ) -> torch . Tensor :
475
485
"""
476
486
Convert rotations given as axis/angle to quaternions.
477
487
@@ -485,7 +495,7 @@ def axis_angle_to_quaternion(axis_angle):
485
495
quaternions with real part first, as tensor of shape (..., 4).
486
496
"""
487
497
angles = torch .norm (axis_angle , p = 2 , dim = - 1 , keepdim = True )
488
- half_angles = 0.5 * angles
498
+ half_angles = angles * 0.5
489
499
eps = 1e-6
490
500
small_angles = angles .abs () < eps
491
501
sin_half_angles_over_angles = torch .empty_like (angles )
@@ -503,7 +513,7 @@ def axis_angle_to_quaternion(axis_angle):
503
513
return quaternions
504
514
505
515
506
- def quaternion_to_axis_angle (quaternions ) :
516
+ def quaternion_to_axis_angle (quaternions : torch . Tensor ) -> torch . Tensor :
507
517
"""
508
518
Convert rotations given as quaternions to axis/angle.
509
519
@@ -573,4 +583,5 @@ def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
573
583
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
574
584
Retrieved from http://arxiv.org/abs/1812.07035
575
585
"""
576
- return matrix [..., :2 , :].clone ().reshape (* matrix .size ()[:- 2 ], 6 )
586
+ batch_dim = matrix .size ()[:- 2 ]
587
+ return matrix [..., :2 , :].clone ().reshape (batch_dim + (6 ,))
0 commit comments