Skip to content

Commit d7d740a

Browse files
Nikita Smetaninfacebook-github-bot
Nikita Smetanin
authored andcommitted
Symmetric eigen 3x3 implementation + benchmark & tests
Summary: Symmetric eigenvalues 3x3 implementation from https://github.com/fairinternal/denseposeslim/blob/roman_c3dpo/tools/functions.py#L612 based on https://en.wikipedia.org/wiki/Eigenvalue_algorithm#3.C3.973_matrices and https://www.geometrictools.com/Documentation/RobustEigenSymmetric3x3.pdf Benchmarks show significant outperformance of symeig3x3 in comparison with torch implementations (torch.symeig and torch.linalg.eigh) on GPU (P100), especially for large batches: 70-280ns per sample vs 3400ns per sample for torch_linalg_eigh_1048576_cpu It's worth mentioning that torch.linalg.eigh is still comparably fast for batches up to 8192 on CPU. Some tests are still failing as the error thresholds need to be adjusted appropriately. Reviewed By: patricklabatut Differential Revision: D29915453 fbshipit-source-id: 7c1b062da631c57c4e22a42dd0027ea5e205f1b5
1 parent 9585a58 commit d7d740a

File tree

5 files changed

+680
-0
lines changed

5 files changed

+680
-0
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .utils import _safe_det_3x3
8+
from .symeig3x3 import symeig3x3
Lines changed: 316 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,316 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import math
8+
from typing import Tuple, Optional
9+
10+
import torch
11+
import torch.nn.functional as F
12+
from torch import nn
13+
14+
15+
class _SymEig3x3(nn.Module):
16+
"""
17+
Optimized implementation of eigenvalues and eigenvectors computation for symmetric 3x3
18+
matrices.
19+
20+
Please see https://en.wikipedia.org/wiki/Eigenvalue_algorithm#3.C3.973_matrices
21+
and https://www.geometrictools.com/Documentation/RobustEigenSymmetric3x3.pdf
22+
"""
23+
24+
def __init__(self, eps: Optional[float] = None) -> None:
25+
"""
26+
Args:
27+
eps: epsilon to specify, if None then use torch.float eps
28+
"""
29+
super().__init__()
30+
31+
self.register_buffer("_identity", torch.eye(3))
32+
self.register_buffer("_rotation_2d", torch.tensor([[0.0, -1.0], [1.0, 0.0]]))
33+
self.register_buffer(
34+
"_rotations_3d", self._create_rotation_matrices(self._rotation_2d)
35+
)
36+
37+
self._eps = eps or torch.finfo(torch.float).eps
38+
39+
@staticmethod
40+
def _create_rotation_matrices(rotation_2d) -> torch.Tensor:
41+
"""
42+
Compute rotations for later use in U V computation
43+
44+
Args:
45+
rotation_2d: a π/2 rotation matrix.
46+
47+
Returns:
48+
a (3, 3, 3) tensor containing 3 rotation matrices around each of the coordinate axes
49+
by π/2
50+
"""
51+
52+
rotations_3d = torch.zeros((3, 3, 3))
53+
rotation_axes = set(range(3))
54+
for rotation_axis in rotation_axes:
55+
rest = list(rotation_axes - {rotation_axis})
56+
rotations_3d[rotation_axis][rest[0], rest] = rotation_2d[0]
57+
rotations_3d[rotation_axis][rest[1], rest] = rotation_2d[1]
58+
59+
return rotations_3d
60+
61+
def forward(
62+
self, inputs: torch.Tensor, eigenvectors: bool = True
63+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
64+
"""
65+
Compute eigenvalues and (optionally) eigenvectors
66+
67+
Args:
68+
inputs: symmetric matrices with shape of (..., 3, 3)
69+
eigenvectors: whether should we compute only eigenvalues or eigenvectors as well
70+
71+
Returns:
72+
Either a tuple of (eigenvalues, eigenvectors) or eigenvalues only, depending on
73+
given params. Eigenvalues are of shape (..., 3) and eigenvectors (..., 3, 3)
74+
"""
75+
if inputs.shape[-2:] != (3, 3):
76+
raise ValueError("Only inputs of shape (..., 3, 3) are supported.")
77+
78+
inputs_diag = inputs.diagonal(dim1=-2, dim2=-1) # pyre-ignore[16]
79+
inputs_trace = inputs_diag.sum(-1)
80+
q = inputs_trace / 3.0
81+
82+
# Calculate squared sum of elements outside the main diagonal / 2
83+
p1 = ((inputs ** 2).sum(dim=(-1, -2)) - (inputs_diag ** 2).sum(-1)) / 2
84+
p2 = ((inputs_diag - q[..., None]) ** 2).sum(dim=-1) + 2.0 * p1.clamp(self._eps)
85+
86+
p = torch.sqrt(p2 / 6.0)
87+
B = (inputs - q[..., None, None] * self._identity) / p[..., None, None]
88+
89+
r = torch.det(B) / 2.0
90+
# Keep r within (-1.0, 1.0) boundaries with a margin to prevent exploding gradients.
91+
r = r.clamp(-1.0 + self._eps, 1.0 - self._eps)
92+
93+
phi = torch.acos(r) / 3.0
94+
eig1 = q + 2 * p * torch.cos(phi)
95+
eig2 = q + 2 * p * torch.cos(phi + 2 * math.pi / 3)
96+
eig3 = 3 * q - eig1 - eig2
97+
# eigenvals[..., i] is the i-th eigenvalue of the input, α0 ≤ α1 ≤ α2.
98+
eigenvals = torch.stack((eig2, eig3, eig1), dim=-1)
99+
100+
# Soft dispatch between the degenerate case (diagonal A) and general.
101+
# diag_soft_cond -> 1.0 when p1 < 6 * eps and diag_soft_cond -> 0.0 otherwise.
102+
# We use 6 * eps to take into account the error accumulated during the p1 summation
103+
diag_soft_cond = torch.exp(-((p1 / (6 * self._eps)) ** 2)).detach()[..., None]
104+
105+
# Eigenvalues are the ordered elements of main diagonal in the degenerate case
106+
diag_eigenvals, _ = torch.sort(inputs_diag, dim=-1)
107+
eigenvals = diag_soft_cond * diag_eigenvals + (1.0 - diag_soft_cond) * eigenvals
108+
109+
if eigenvectors:
110+
eigenvecs = self._construct_eigenvecs_set(inputs, eigenvals)
111+
else:
112+
eigenvecs = None
113+
114+
return eigenvals, eigenvecs
115+
116+
def _construct_eigenvecs_set(
117+
self, inputs: torch.Tensor, eigenvals: torch.Tensor
118+
) -> torch.Tensor:
119+
"""
120+
Construct orthonormal set of eigenvectors by given inputs and pre-computed eigenvalues
121+
122+
Args:
123+
inputs: tensor of symmetric matrices of shape (..., 3, 3)
124+
eigenvals: tensor of pre-computed eigenvalues of of shape (..., 3, 3)
125+
126+
Returns:
127+
Tuple of three eigenvector tensors of shape (..., 3, 3), composing an orthonormal
128+
set
129+
"""
130+
eigenvecs_tuple_for_01 = self._construct_eigenvecs(
131+
inputs, eigenvals[..., 0], eigenvals[..., 1]
132+
)
133+
eigenvecs_for_01 = torch.stack(eigenvecs_tuple_for_01, dim=-1)
134+
135+
eigenvecs_tuple_for_21 = self._construct_eigenvecs(
136+
inputs, eigenvals[..., 2], eigenvals[..., 1]
137+
)
138+
eigenvecs_for_21 = torch.stack(eigenvecs_tuple_for_21[::-1], dim=-1)
139+
140+
# The result will be smooth here even if both parts of comparison
141+
# are close, because eigenvecs_01 and eigenvecs_21 would be mostly equal as well
142+
eigenvecs_cond = (
143+
eigenvals[..., 1] - eigenvals[..., 0]
144+
> eigenvals[..., 2] - eigenvals[..., 1]
145+
).detach()
146+
eigenvecs = torch.where(
147+
eigenvecs_cond[..., None, None], eigenvecs_for_01, eigenvecs_for_21
148+
)
149+
150+
return eigenvecs
151+
152+
def _construct_eigenvecs(
153+
self, inputs: torch.Tensor, alpha0: torch.Tensor, alpha1: torch.Tensor
154+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
155+
"""
156+
Construct an orthonormal set of eigenvectors by given pair of eigenvalues.
157+
158+
Args:
159+
inputs: tensor of symmetric matrices of shape (..., 3, 3)
160+
alpha0: first eigenvalues of shape (..., 3)
161+
alpha1: second eigenvalues of shape (..., 3)
162+
163+
Returns:
164+
Tuple of three eigenvector tensors of shape (..., 3, 3), composing an orthonormal
165+
set
166+
"""
167+
168+
# Find the eigenvector corresponding to alpha0, its eigenvalue is distinct
169+
ev0 = self._get_ev0(inputs - alpha0[..., None, None] * self._identity)
170+
u, v = self._get_uv(ev0)
171+
ev1 = self._get_ev1(inputs - alpha1[..., None, None] * self._identity, u, v)
172+
# Third eigenvector is computed as the cross-product of the other two
173+
ev2 = torch.cross(ev0, ev1, dim=-1)
174+
175+
return ev0, ev1, ev2
176+
177+
def _get_ev0(self, char_poly: torch.Tensor) -> torch.Tensor:
178+
"""
179+
Construct the first normalized eigenvector given a characteristic polynomial
180+
181+
Args:
182+
char_poly: a characteristic polynomials of the input matrices of shape (..., 3, 3)
183+
184+
Returns:
185+
Tensor of first eigenvectors of shape (..., 3)
186+
"""
187+
188+
r01 = torch.cross(char_poly[..., 0, :], char_poly[..., 1, :], dim=-1)
189+
r12 = torch.cross(char_poly[..., 1, :], char_poly[..., 2, :], dim=-1)
190+
r02 = torch.cross(char_poly[..., 0, :], char_poly[..., 2, :], dim=-1)
191+
192+
cross_products = torch.stack((r01, r12, r02), dim=-2)
193+
# Regularize it with + or -eps depending on the sign of the first vector
194+
cross_products += self._eps * self._sign_without_zero(
195+
cross_products[..., :1, :]
196+
)
197+
198+
norms_sq = (cross_products ** 2).sum(dim=-1)
199+
max_norms_index = norms_sq.argmax(dim=-1) # pyre-ignore[16]
200+
201+
# Pick only the cross-product with highest squared norm for each input
202+
max_cross_products = self._gather_by_index(
203+
cross_products, max_norms_index[..., None, None], -2
204+
)
205+
# Pick corresponding squared norms for each cross-product
206+
max_norms_sq = self._gather_by_index(norms_sq, max_norms_index[..., None], -1)
207+
208+
# Normalize cross-product vectors by thier norms
209+
return max_cross_products / torch.sqrt(max_norms_sq[..., None])
210+
211+
def _gather_by_index(
212+
self, source: torch.Tensor, index: torch.Tensor, dim: int
213+
) -> torch.Tensor:
214+
"""
215+
Selects elements from the given source tensor by provided index tensor.
216+
Number of dimensions should be the same for source and index tensors.
217+
218+
Args:
219+
source: input tensor to gather from
220+
index: index tensor with indices to gather from source
221+
dim: dimension to gather across
222+
223+
Returns:
224+
Tensor of shape same as the source with exception of specified dimension.
225+
"""
226+
227+
index_shape = list(source.shape)
228+
index_shape[dim] = 1
229+
230+
return source.gather(dim, index.expand(index_shape)).squeeze( # pyre-ignore[16]
231+
dim
232+
)
233+
234+
def _get_uv(self, w: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
235+
"""
236+
Computes unit-length vectors U and V such that {U, V, W} is a right-handed
237+
orthonormal set.
238+
239+
Args:
240+
w: eigenvector tensor of shape (..., 3)
241+
242+
Returns:
243+
Tuple of U and V unit-length vector tensors of shape (..., 3)
244+
"""
245+
246+
min_idx = w.abs().argmin(dim=-1) # pyre-ignore[16]
247+
rotation_2d = self._rotations_3d[min_idx].to(w)
248+
249+
u = F.normalize((rotation_2d @ w[..., None])[..., 0], dim=-1)
250+
v = torch.cross(w, u, dim=-1)
251+
return u, v
252+
253+
def _get_ev1(
254+
self, char_poly: torch.Tensor, u: torch.Tensor, v: torch.Tensor
255+
) -> torch.Tensor:
256+
"""
257+
Computes the second normalized eigenvector given a characteristic polynomial
258+
and U and V vectors
259+
260+
Args:
261+
char_poly: a characteristic polynomials of the input matrices of shape (..., 3, 3)
262+
u: unit-length vectors from _get_uv method
263+
v: unit-length vectors from _get_uv method
264+
265+
Returns:
266+
desc
267+
"""
268+
269+
j = torch.stack((u, v), dim=-1)
270+
m = j.transpose(-1, -2) @ char_poly @ j
271+
272+
# If angle between those vectors is acute, take their sum = m[..., 0, :] + m[..., 1, :],
273+
# otherwise take the difference = m[..., 0, :] - m[..., 1, :]
274+
# m is in theory of rank 1 (or 0), so it snaps only when one of the rows is close to 0
275+
is_acute_sign = self._sign_without_zero(
276+
(m[..., 0, :] * m[..., 1, :]).sum(dim=-1)
277+
).detach()
278+
279+
rowspace = m[..., 0, :] + is_acute_sign[..., None] * m[..., 1, :]
280+
# rowspace will be near zero for second-order eigenvalues
281+
# this regularization guarantees abs(rowspace[0]) >= eps in a smooth'ish way
282+
rowspace += self._eps * self._sign_without_zero(rowspace[..., :1])
283+
284+
return (
285+
j
286+
@ F.normalize(rowspace @ self._rotation_2d.to(rowspace), dim=-1)[..., None]
287+
)[..., 0]
288+
289+
@staticmethod
290+
def _sign_without_zero(tensor):
291+
"""
292+
Args:
293+
tensor: an arbitrary shaped tensor
294+
295+
Returns:
296+
Tensor of the same shape as an input, but with 1.0 if tensor > 0.0 and -1.0
297+
otherwise
298+
"""
299+
return 2.0 * (tensor > 0.0).to(tensor.dtype) - 1.0
300+
301+
302+
def symeig3x3(
303+
inputs: torch.Tensor, eigenvectors: bool = True
304+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
305+
"""
306+
Compute eigenvalues and (optionally) eigenvectors
307+
308+
Args:
309+
inputs: symmetric matrices with shape of (..., 3, 3)
310+
eigenvectors: whether should we compute only eigenvalues or eigenvectors as well
311+
312+
Returns:
313+
Either a tuple of (eigenvalues, eigenvectors) or eigenvalues only, depending on
314+
given params. Eigenvalues are of shape (..., 3) and eigenvectors (..., 3, 3)
315+
"""
316+
return _SymEig3x3().to(inputs.device)(inputs, eigenvectors=eigenvectors)
File renamed without changes.

0 commit comments

Comments
 (0)