Skip to content

Commit ec1921c

Browse files
Carl Hvarfnerfacebook-github-bot
authored andcommitted
PositiveIndexKernel (meta-pytorch#3047)
Summary: PositiveIndexKernel - a MultiTaskGP kernel that enforces positive correlation. Should probably be upstreamed into GPyTorch at some point. Also introduces priors on diagonal and off-diagonals separately, so that priors can be set on task correlation in a more intuititve fashion. Differential Revision: D84878629
1 parent 09502f9 commit ec1921c

File tree

3 files changed

+354
-0
lines changed

3 files changed

+354
-0
lines changed
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Optional
8+
9+
import torch
10+
from gpytorch.constraints import GreaterThan
11+
from gpytorch.kernels import IndexKernel
12+
from gpytorch.priors import Prior
13+
14+
15+
class PositiveIndexKernel(IndexKernel):
16+
r"""
17+
A kernel for discrete indices with strictly positive correlations.
18+
19+
Similar to IndexKernel but ensures all off-diagonal correlations are positive
20+
by using a Cholesky-like parameterization with positive elements.
21+
22+
.. math::
23+
k(i, j) = \frac{(LL^T)_{i,j}}{(LL^T)_{t,t}}
24+
25+
where L is a lower triangular matrix with positive elements and t is the
26+
target_task_index.
27+
"""
28+
29+
def __init__(
30+
self,
31+
num_tasks: int,
32+
rank: Optional[int] = 1,
33+
task_prior: Optional[Prior] = None,
34+
diag_prior: Optional[Prior] = None,
35+
normalize_covar_matrix: bool = False,
36+
target_task_index: int = 0,
37+
unit_scale_for_target: bool = True,
38+
**kwargs,
39+
):
40+
r"""A kernel for discrete indices with strictly positive correlations.
41+
42+
Args:
43+
num_tasks (int): Total number of indices.
44+
rank (int): Rank of the covariance matrix parameterization.
45+
task_prior (Prior, optional): Prior for the covariance matrix.
46+
diag_prior (Prior, optional): Prior for the diagonal elements.
47+
normalize_covar_matrix (bool): Whether to normalize the covariance matrix.
48+
target_task_index (int): Index of the task whose diagonal element should be
49+
normalized to 1. Defaults to 0 (first task).
50+
unit_scale_for_target (bool): Whether to ensure the target task's has unit
51+
outputscale.
52+
**kwargs: Additional arguments passed to IndexKernel.
53+
"""
54+
if rank > num_tasks:
55+
raise RuntimeError(
56+
"Cannot create a task covariance matrix larger than the number of tasks"
57+
)
58+
if not (0 <= target_task_index < num_tasks):
59+
raise ValueError(
60+
f"target_task_index must be between 0 and {num_tasks - 1}, "
61+
f"got {target_task_index}"
62+
)
63+
super().__init__(
64+
num_tasks=num_tasks,
65+
rank=rank,
66+
var_constraint=None,
67+
**kwargs,
68+
)
69+
delattr(self, "covar_factor")
70+
# delete covar factor from parameters
71+
self.normalize_covar_matrix = normalize_covar_matrix
72+
self.num_tasks = num_tasks
73+
self.target_task_index = target_task_index
74+
self.register_parameter(
75+
name="raw_covar_factor",
76+
parameter=torch.nn.Parameter(
77+
torch.rand(*self.batch_shape, num_tasks, rank)
78+
),
79+
)
80+
self.unit_scale_for_target = unit_scale_for_target
81+
if task_prior is not None:
82+
if not isinstance(task_prior, Prior):
83+
raise TypeError(
84+
f"Expected gpytorch.priors.Prior but got "
85+
f"{type(task_prior).__name__}"
86+
)
87+
self.register_prior(
88+
"IndexKernelPrior", task_prior, lambda m: m._lower_triangle
89+
)
90+
if diag_prior is not None:
91+
self.register_prior("ScalePrior", diag_prior, lambda m: m._diagonal)
92+
93+
self.register_constraint("raw_covar_factor", GreaterThan(0.0))
94+
95+
def _covar_factor_params(self, m):
96+
return m.covar_factor
97+
98+
def _covar_factor_closure(self, m, v):
99+
m._set_covar_factor(v)
100+
101+
@property
102+
def covar_factor(self):
103+
return self.raw_covar_factor_constraint.transform(self.raw_covar_factor)
104+
105+
@covar_factor.setter
106+
def covar_factor(self, value):
107+
self._set_covar_factor(value)
108+
109+
def _set_covar_factor(self, value):
110+
# This must be a tensor
111+
self.initialize(
112+
raw_covar_factor=self.raw_covar_factor_constraint.inverse_transform(value)
113+
)
114+
115+
@property
116+
def _lower_triangle(self):
117+
lower_row, lower_col = torch.tril_indices(
118+
self.num_tasks, self.num_tasks, offset=-1
119+
)
120+
covar = self.covar_matrix
121+
norm_factor = covar.diagonal(dim1=-1, dim2=-2).sqrt()
122+
corr = covar / (norm_factor.unsqueeze(-1) * norm_factor.unsqueeze(-2))
123+
low_tri = corr[..., lower_row, lower_col]
124+
125+
return low_tri
126+
127+
@property
128+
def _diagonal(self):
129+
return torch.diagonal(self.covar_matrix, dim1=-2, dim2=-1)
130+
131+
def _eval_covar_matrix(self):
132+
cf = self.covar_factor
133+
covar = cf @ cf.transpose(-1, -2) + self.var * torch.eye(
134+
self.num_tasks, dtype=cf.dtype, device=cf.device
135+
)
136+
# Normalize by the target task's diagonal element
137+
if self.unit_scale_for_target:
138+
norm_factor = covar[..., self.target_task_index, self.target_task_index]
139+
covar = covar / norm_factor.unsqueeze(-1).unsqueeze(-1)
140+
return covar
141+
142+
@property
143+
def covar_matrix(self):
144+
return self._eval_covar_matrix()

sphinx/source/models.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,9 @@ Kernels
146146
.. automodule:: botorch.models.kernels.orthogonal_additive_kernel
147147
.. autoclass:: OrthogonalAdditiveKernel
148148

149+
.. automodule:: botorch.models.kernels.positive_index
150+
.. autoclass:: PositiveIndexKernel
151+
149152
Likelihoods
150153
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
151154
.. automodule:: botorch.models.likelihoods.pairwise
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from botorch.models.kernels.positive_index import PositiveIndexKernel
9+
from botorch.utils.testing import BotorchTestCase
10+
from gpytorch.priors import NormalPrior
11+
12+
13+
class TestPositiveIndexKernel(BotorchTestCase):
14+
def test_positive_index_kernel(self):
15+
# Test initialization
16+
with self.subTest("basic_initialization"):
17+
num_tasks = 4
18+
rank = 2
19+
kernel = PositiveIndexKernel(num_tasks=num_tasks, rank=rank)
20+
21+
self.assertEqual(kernel.num_tasks, num_tasks)
22+
self.assertEqual(kernel.raw_covar_factor.shape, (num_tasks, rank))
23+
self.assertEqual(kernel.normalize_covar_matrix, False)
24+
25+
# Test initialization with batch shape
26+
with self.subTest("initialization_with_batch_shape"):
27+
num_tasks = 3
28+
rank = 2
29+
batch_shape = torch.Size([2])
30+
kernel = PositiveIndexKernel(
31+
num_tasks=num_tasks, rank=rank, batch_shape=batch_shape
32+
)
33+
34+
self.assertEqual(kernel.raw_covar_factor.shape, (2, num_tasks, rank))
35+
36+
# Test rank validation
37+
with self.subTest("rank_validation"):
38+
num_tasks = 3
39+
rank = 5
40+
with self.assertRaises(RuntimeError):
41+
PositiveIndexKernel(num_tasks=num_tasks, rank=rank)
42+
43+
# Test target_task_index validation
44+
with self.subTest("target_task_index_validation"):
45+
num_tasks = 4
46+
# Test invalid negative index
47+
with self.assertRaises(ValueError):
48+
PositiveIndexKernel(num_tasks=num_tasks, rank=2, target_task_index=-1)
49+
# Test invalid index >= num_tasks
50+
with self.assertRaises(ValueError):
51+
PositiveIndexKernel(num_tasks=num_tasks, rank=2, target_task_index=4)
52+
# Test valid indices (should not raise)
53+
PositiveIndexKernel(num_tasks=num_tasks, rank=2, target_task_index=0)
54+
PositiveIndexKernel(num_tasks=num_tasks, rank=2, target_task_index=3)
55+
56+
# Test covar_factor constraint
57+
with self.subTest("positive_correlations"):
58+
kernel = PositiveIndexKernel(num_tasks=5, rank=3)
59+
covar_factor = kernel.covar_factor
60+
61+
# All elements should be positive
62+
self.assertTrue((covar_factor > 0).all())
63+
64+
self.assertTrue((kernel.covar_matrix >= 0).all())
65+
66+
# Test covariance matrix normalization (default target_task_index=0)
67+
with self.subTest("covar_matrix_normalization_default"):
68+
kernel = PositiveIndexKernel(num_tasks=4, rank=2)
69+
covar = kernel.covar_matrix
70+
71+
# First diagonal element should be 1.0 (normalized by default)
72+
self.assertAllClose(covar[0, 0], torch.tensor(1.0), atol=1e-4)
73+
74+
# Test covariance matrix normalization with custom target_task_index
75+
with self.subTest("covar_matrix_normalization_custom_target"):
76+
kernel = PositiveIndexKernel(num_tasks=4, rank=2, target_task_index=2)
77+
covar = kernel.covar_matrix
78+
79+
# Third diagonal element should be 1.0 (target_task_index=2)
80+
self.assertAllClose(covar[2, 2], torch.tensor(1.0), atol=1e-4)
81+
82+
# Other diagonal elements should not be 1.0
83+
self.assertNotEqual(covar[0, 0].item(), 1.0)
84+
85+
# Test forward pass shape
86+
with self.subTest("forward"):
87+
num_tasks = 4
88+
kernel = PositiveIndexKernel(num_tasks=num_tasks, rank=2)
89+
kernel.eval()
90+
91+
i1 = torch.tensor([[0, 1], [2, 3]], dtype=torch.long)
92+
i2 = torch.tensor([[1, 2]], dtype=torch.long)
93+
94+
result = kernel(i1, i2)
95+
self.assertEqual(result.shape, torch.Size([2, 1]))
96+
num_tasks = 3
97+
kernel = PositiveIndexKernel(num_tasks=num_tasks, rank=1)
98+
kernel.eval()
99+
100+
kernel.initialize(raw_covar_factor=torch.ones(num_tasks, 1))
101+
i1 = torch.tensor([[0]], dtype=torch.long)
102+
i2 = torch.tensor([[1]], dtype=torch.long)
103+
104+
result = kernel(i1, i2).to_dense()
105+
covar_matrix = kernel.covar_matrix
106+
expected = covar_matrix[0, 1]
107+
108+
self.assertAllClose(result.squeeze(), expected)
109+
110+
# Test with priors
111+
with self.subTest("with_priors"):
112+
num_tasks = 4
113+
task_prior = NormalPrior(0, 1)
114+
diag_prior = NormalPrior(1, 0.1)
115+
116+
kernel = PositiveIndexKernel(
117+
num_tasks=num_tasks,
118+
rank=2,
119+
task_prior=task_prior,
120+
diag_prior=diag_prior,
121+
initialize_to_mode=False,
122+
)
123+
prior_names = [p[0] for p in kernel.named_priors()]
124+
self.assertIn("IndexKernelPrior", prior_names)
125+
self.assertIn("ScalePrior", prior_names)
126+
127+
# Test batch forward
128+
with self.subTest("batch_forward"):
129+
num_tasks = 3
130+
batch_shape = torch.Size([2])
131+
kernel = PositiveIndexKernel(
132+
num_tasks=num_tasks, rank=2, batch_shape=batch_shape
133+
)
134+
kernel.eval()
135+
136+
i1 = torch.tensor([[[0], [1]]], dtype=torch.long)
137+
i2 = torch.tensor([[[1], [2]]], dtype=torch.long)
138+
139+
result = kernel(i1, i2)
140+
141+
# Check that batch dimensions are preserved
142+
self.assertEqual(result.shape[0], 2)
143+
144+
# Test diagonal property (default target_task_index=0)
145+
with self.subTest("diagonal"):
146+
kernel = PositiveIndexKernel(num_tasks=4, rank=2)
147+
diag = kernel._diagonal
148+
149+
self.assertEqual(diag.shape, torch.Size([4]))
150+
# First diagonal element should be 1.0 (default target_task_index=0)
151+
self.assertAllClose(diag[0], torch.tensor(1.0), atol=1e-4)
152+
153+
# Test diagonal property with custom target_task_index
154+
kernel = PositiveIndexKernel(num_tasks=4, rank=2, target_task_index=1)
155+
diag = kernel._diagonal
156+
157+
self.assertEqual(diag.shape, torch.Size([4]))
158+
# Second diagonal element should be 1.0 (target_task_index=1)
159+
self.assertAllClose(diag[1], torch.tensor(1.0), atol=1e-4)
160+
161+
# Test lower triangle property
162+
with self.subTest("lower_triangle"):
163+
num_tasks = 5
164+
kernel = PositiveIndexKernel(num_tasks=num_tasks, rank=2)
165+
lower_tri = kernel._lower_triangle
166+
167+
# Number of lower triangular elements (excluding diagonal)
168+
expected_size = num_tasks * (num_tasks - 1) // 2
169+
self.assertEqual(lower_tri.shape[-1], expected_size)
170+
self.assertTrue((lower_tri >= 0).all())
171+
172+
# Test invalid prior type
173+
with self.subTest("invalid_prior_type"):
174+
with self.assertRaises(TypeError):
175+
PositiveIndexKernel(num_tasks=4, rank=2, task_prior="not_a_prior")
176+
177+
# Test covariance matrix properties
178+
with self.subTest("covar_matrix"):
179+
kernel = PositiveIndexKernel(num_tasks=5, rank=4)
180+
covar = kernel.covar_matrix
181+
182+
# Should be square
183+
self.assertEqual(covar.shape[-2], covar.shape[-1])
184+
185+
# Should be positive definite (all eigenvalues > 0)
186+
eigvals = torch.linalg.eigvalsh(covar)
187+
self.assertTrue((eigvals > 0).all())
188+
189+
# Should be symmetric
190+
self.assertAllClose(covar, covar.T, atol=1e-5)
191+
192+
# Test covar_factor setter and getter
193+
with self.subTest("covar_factor"):
194+
kernel = PositiveIndexKernel(num_tasks=3, rank=2)
195+
new_covar_factor = torch.ones(3, 2) * 2.0
196+
kernel.covar_factor = new_covar_factor
197+
self.assertAllClose(kernel.covar_factor, new_covar_factor, atol=1e-5)
198+
199+
kernel = PositiveIndexKernel(num_tasks=3, rank=2)
200+
params = kernel._covar_factor_params(kernel)
201+
self.assertEqual(params.shape, torch.Size([3, 2]))
202+
self.assertTrue((params > 0).all())
203+
204+
kernel = PositiveIndexKernel(num_tasks=3, rank=2)
205+
new_value = torch.ones(3, 2) * 3.0
206+
kernel._covar_factor_closure(kernel, new_value)
207+
self.assertAllClose(kernel.covar_factor, new_value, atol=1e-5)

0 commit comments

Comments
 (0)