|
12 | 12 | import torch
|
13 | 13 |
|
14 | 14 |
|
15 |
| -def _get_rotation_to_best_fit_xy( |
16 |
| - points: torch.Tensor, centroid: torch.Tensor |
| 15 | +def get_rotation_to_best_fit_xy( |
| 16 | + points: torch.Tensor, centroid: Optional[torch.Tensor] = None |
17 | 17 | ) -> torch.Tensor:
|
18 | 18 | """
|
19 |
| - Returns a rotation r such that points @ r has a best fit plane |
| 19 | + Returns a rotation R such that `points @ R` has a best fit plane |
20 | 20 | parallel to the xy plane
|
21 | 21 |
|
22 | 22 | Args:
|
23 |
| - points: (N, 3) tensor of points in 3D |
24 |
| - centroid: (3,) their centroid |
| 23 | + points: (*, N, 3) tensor of points in 3D |
| 24 | + centroid: (*, 1, 3), (3,) or scalar: their centroid |
25 | 25 |
|
26 | 26 | Returns:
|
27 |
| - (3,3) tensor rotation matrix |
| 27 | + (*, 3, 3) tensor rotation matrix |
28 | 28 | """
|
29 |
| - points_centered = points - centroid[None] |
30 |
| - return torch.linalg.eigh(points_centered.t() @ points_centered)[1][:, [1, 2, 0]] |
| 29 | + if centroid is None: |
| 30 | + centroid = points.mean(dim=-2, keepdim=True) |
| 31 | + |
| 32 | + points_centered = points - centroid |
| 33 | + _, evec = torch.linalg.eigh(points_centered.transpose(-1, -2) @ points_centered) |
| 34 | + # in general, evec can form either right- or left-handed basis, |
| 35 | + # but we need the former to have a proper rotation (not reflection) |
| 36 | + return torch.cat( |
| 37 | + (evec[..., 1:], torch.cross(evec[..., 1], evec[..., 2])[..., None]), dim=-1 |
| 38 | + ) |
31 | 39 |
|
32 | 40 |
|
33 | 41 | def _signed_area(path: torch.Tensor) -> torch.Tensor:
|
@@ -191,7 +199,7 @@ def fit_circle_in_3d(
|
191 | 199 | Circle3D object
|
192 | 200 | """
|
193 | 201 | centroid = points.mean(0)
|
194 |
| - r = _get_rotation_to_best_fit_xy(points, centroid) |
| 202 | + r = get_rotation_to_best_fit_xy(points, centroid) |
195 | 203 | normal = r[:, 2]
|
196 | 204 | rotated_points = (points - centroid) @ r
|
197 | 205 | result_2d = fit_circle_in_2d(
|
|
0 commit comments