Skip to content

Commit e27a216

Browse files
Carl Hvarfnermeta-codesync[bot]
authored andcommitted
PositiveIndexKernel (#3047)
Summary: Pull Request resolved: #3047 PositiveIndexKernel - a MultiTaskGP kernel that enforces positive correlation. Should probably be upstreamed into GPyTorch at some point. Also introduces priors on diagonal and off-diagonals separately, so that priors can be set on task correlation in a more intuititve fashion. Reviewed By: Balandat Differential Revision: D84878629 fbshipit-source-id: e0cceb10ea9d148f16ae122f8d1c603fc562a280
1 parent 1625a23 commit e27a216

File tree

3 files changed

+374
-0
lines changed

3 files changed

+374
-0
lines changed
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from gpytorch.constraints import GreaterThan, Interval, Positive
9+
from gpytorch.kernels import IndexKernel, Kernel
10+
from gpytorch.priors import Prior
11+
12+
13+
class PositiveIndexKernel(IndexKernel):
14+
r"""
15+
A kernel for discrete indices with strictly positive correlations. This is
16+
enforced by a positivity constraint on the decomposed covariance matrix.
17+
18+
Similar to IndexKernel but ensures all off-diagonal correlations are positive
19+
by using a Cholesky-like parameterization with positive elements.
20+
21+
.. math::
22+
k(i, j) = \frac{(LL^T)_{i,j}}{(LL^T)_{t,t}}
23+
24+
where L is a lower triangular matrix with positive elements and t is the
25+
target_task_index.
26+
"""
27+
28+
def __init__(
29+
self,
30+
num_tasks: int,
31+
rank: int = 1,
32+
task_prior: Prior | None = None,
33+
diag_prior: Prior | None = None,
34+
normalize_covar_matrix: bool = False,
35+
var_constraint: Interval | None = None,
36+
target_task_index: int = 0,
37+
unit_scale_for_target: bool = True,
38+
**kwargs,
39+
):
40+
r"""A kernel for discrete indices with strictly positive correlations.
41+
42+
Args:
43+
num_tasks (int): Total number of indices.
44+
rank (int): Rank of the covariance matrix parameterization.
45+
task_prior (Prior, optional): Prior for the covariance matrix.
46+
diag_prior (Prior, optional): Prior for the diagonal elements.
47+
normalize_covar_matrix (bool): Whether to normalize the covariance matrix.
48+
target_task_index (int): Index of the task whose diagonal element should be
49+
normalized to 1. Defaults to 0 (first task).
50+
unit_scale_for_target (bool): Whether to ensure the target task's has unit
51+
outputscale.
52+
**kwargs: Additional arguments passed to IndexKernel.
53+
"""
54+
if rank > num_tasks:
55+
raise RuntimeError(
56+
"Cannot create a task covariance matrix larger than the number of tasks"
57+
)
58+
if not (0 <= target_task_index < num_tasks):
59+
raise ValueError(
60+
f"target_task_index must be between 0 and {num_tasks - 1}, "
61+
f"got {target_task_index}"
62+
)
63+
Kernel.__init__(self, **kwargs)
64+
65+
if var_constraint is None:
66+
var_constraint = Positive()
67+
68+
self.register_parameter(
69+
name="raw_var",
70+
parameter=torch.nn.Parameter(torch.randn(*self.batch_shape, num_tasks)),
71+
)
72+
self.register_constraint("raw_var", var_constraint)
73+
# delete covar factor from parameters
74+
self.normalize_covar_matrix = normalize_covar_matrix
75+
self.num_tasks = num_tasks
76+
self.target_task_index = target_task_index
77+
self.register_parameter(
78+
name="raw_covar_factor",
79+
parameter=torch.nn.Parameter(
80+
torch.rand(*self.batch_shape, num_tasks, rank)
81+
),
82+
)
83+
self.unit_scale_for_target = unit_scale_for_target
84+
if task_prior is not None:
85+
if not isinstance(task_prior, Prior):
86+
raise TypeError(
87+
f"Expected gpytorch.priors.Prior but got "
88+
f"{type(task_prior).__name__}"
89+
)
90+
self.register_prior(
91+
"IndexKernelPrior", task_prior, lambda m: m._lower_triangle_corr
92+
)
93+
if diag_prior is not None:
94+
self.register_prior("ScalePrior", diag_prior, lambda m: m._diagonal)
95+
96+
self.register_constraint("raw_covar_factor", GreaterThan(0.0))
97+
98+
def _covar_factor_params(self, m):
99+
return m.covar_factor
100+
101+
def _covar_factor_closure(self, m, v):
102+
m._set_covar_factor(v)
103+
104+
@property
105+
def covar_factor(self):
106+
return self.raw_covar_factor_constraint.transform(self.raw_covar_factor)
107+
108+
@covar_factor.setter
109+
def covar_factor(self, value):
110+
self._set_covar_factor(value)
111+
112+
def _set_covar_factor(self, value):
113+
# This must be a tensor
114+
self.initialize(
115+
raw_covar_factor=self.raw_covar_factor_constraint.inverse_transform(value)
116+
)
117+
118+
@property
119+
def _lower_triangle_corr(self):
120+
lower_row, lower_col = torch.tril_indices(
121+
self.num_tasks, self.num_tasks, offset=-1
122+
)
123+
covar = self.covar_matrix
124+
norm_factor = covar.diagonal(dim1=-1, dim2=-2).sqrt()
125+
corr = covar / (norm_factor.unsqueeze(-1) * norm_factor.unsqueeze(-2))
126+
low_tri = corr[..., lower_row, lower_col]
127+
128+
return low_tri
129+
130+
@property
131+
def _diagonal(self):
132+
return torch.diagonal(self.covar_matrix, dim1=-2, dim2=-1)
133+
134+
def _eval_covar_matrix(self):
135+
cf = self.covar_factor
136+
covar = cf @ cf.transpose(-1, -2) + self.var * torch.eye(
137+
self.num_tasks, dtype=cf.dtype, device=cf.device
138+
)
139+
# Normalize by the target task's diagonal element
140+
if self.unit_scale_for_target:
141+
norm_factor = covar[..., self.target_task_index, self.target_task_index]
142+
covar = covar / norm_factor.unsqueeze(-1).unsqueeze(-1)
143+
return covar
144+
145+
@property
146+
def covar_matrix(self):
147+
return self._eval_covar_matrix()

sphinx/source/models.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,9 @@ Kernels
146146
.. automodule:: botorch.models.kernels.orthogonal_additive_kernel
147147
.. autoclass:: OrthogonalAdditiveKernel
148148

149+
.. automodule:: botorch.models.kernels.positive_index
150+
.. autoclass:: PositiveIndexKernel
151+
149152
Likelihoods
150153
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
151154
.. automodule:: botorch.models.likelihoods.pairwise
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from botorch.models.kernels.positive_index import PositiveIndexKernel
9+
from botorch.utils.testing import BotorchTestCase
10+
from gpytorch.priors import NormalPrior
11+
12+
13+
class TestPositiveIndexKernel(BotorchTestCase):
14+
def test_positive_index_kernel(self):
15+
for dtype in (torch.float32, torch.float64):
16+
# Test initialization
17+
with self.subTest("basic_initialization", dtype=dtype):
18+
num_tasks = 4
19+
rank = 2
20+
kernel = PositiveIndexKernel(num_tasks=num_tasks, rank=rank).to(
21+
dtype=dtype
22+
)
23+
24+
self.assertEqual(kernel.num_tasks, num_tasks)
25+
self.assertEqual(kernel.raw_covar_factor.shape, (num_tasks, rank))
26+
self.assertEqual(kernel.normalize_covar_matrix, False)
27+
28+
# Test initialization with batch shape
29+
with self.subTest("initialization_with_batch_shape", dtype=dtype):
30+
num_tasks = 3
31+
rank = 2
32+
batch_shape = torch.Size([2])
33+
kernel = PositiveIndexKernel(
34+
num_tasks=num_tasks, rank=rank, batch_shape=batch_shape
35+
).to(dtype=dtype)
36+
37+
self.assertEqual(kernel.raw_covar_factor.shape, (2, num_tasks, rank))
38+
39+
# Test rank validation
40+
with self.subTest("rank_validation", dtype=dtype):
41+
num_tasks = 3
42+
rank = 5
43+
with self.assertRaises(RuntimeError):
44+
PositiveIndexKernel(num_tasks=num_tasks, rank=rank)
45+
46+
# Test target_task_index validation
47+
with self.subTest("target_task_index_validation", dtype=dtype):
48+
num_tasks = 4
49+
# Test invalid negative index
50+
with self.assertRaises(ValueError):
51+
PositiveIndexKernel(
52+
num_tasks=num_tasks, rank=2, target_task_index=-1
53+
)
54+
# Test invalid index >= num_tasks
55+
with self.assertRaises(ValueError):
56+
PositiveIndexKernel(
57+
num_tasks=num_tasks, rank=2, target_task_index=4
58+
)
59+
# Test valid indices (should not raise)
60+
PositiveIndexKernel(num_tasks=num_tasks, rank=2, target_task_index=0)
61+
PositiveIndexKernel(num_tasks=num_tasks, rank=2, target_task_index=3)
62+
63+
# Test covar_factor constraint
64+
with self.subTest("positive_correlations", dtype=dtype):
65+
kernel = PositiveIndexKernel(num_tasks=5, rank=3).to(dtype=dtype)
66+
covar_factor = kernel.covar_factor
67+
68+
# All elements should be positive
69+
self.assertTrue((covar_factor > 0).all())
70+
71+
self.assertTrue((kernel.covar_matrix >= 0).all())
72+
73+
# Test covariance matrix normalization (default target_task_index=0)
74+
with self.subTest("covar_matrix_normalization_default", dtype=dtype):
75+
kernel = PositiveIndexKernel(num_tasks=4, rank=2).to(dtype=dtype)
76+
covar = kernel.covar_matrix
77+
78+
# First diagonal element should be 1.0 (normalized by default)
79+
self.assertAllClose(
80+
covar[0, 0], torch.tensor(1.0, dtype=dtype), atol=1e-4
81+
)
82+
83+
# Test covariance matrix normalization with custom target_task_index
84+
with self.subTest("covar_matrix_normalization_custom_target", dtype=dtype):
85+
kernel = PositiveIndexKernel(
86+
num_tasks=4, rank=2, target_task_index=2
87+
).to(dtype=dtype)
88+
covar = kernel.covar_matrix
89+
90+
# Third diagonal element should be 1.0 (target_task_index=2)
91+
self.assertAllClose(
92+
covar[2, 2], torch.tensor(1.0, dtype=dtype), atol=1e-4
93+
)
94+
95+
# Test forward pass shape
96+
with self.subTest("forward", dtype=dtype):
97+
num_tasks = 4
98+
kernel = PositiveIndexKernel(num_tasks=num_tasks, rank=2).to(
99+
dtype=dtype
100+
)
101+
102+
i1 = torch.tensor([[0, 1], [2, 3]], dtype=torch.long)
103+
i2 = torch.tensor([[1, 2]], dtype=torch.long)
104+
105+
result = kernel(i1, i2)
106+
self.assertEqual(result.shape, torch.Size([2, 1]))
107+
num_tasks = 3
108+
kernel = PositiveIndexKernel(num_tasks=num_tasks, rank=1).to(
109+
dtype=dtype
110+
)
111+
112+
kernel.initialize(
113+
raw_covar_factor=torch.ones(num_tasks, 1, dtype=dtype)
114+
)
115+
i1 = torch.tensor([[0]], dtype=torch.long)
116+
i2 = torch.tensor([[1]], dtype=torch.long)
117+
118+
result = kernel(i1, i2).to_dense()
119+
covar_matrix = kernel.covar_matrix
120+
expected = covar_matrix[0, 1]
121+
122+
self.assertAllClose(result.squeeze(), expected)
123+
124+
# Test with priors
125+
with self.subTest("with_priors", dtype=dtype):
126+
num_tasks = 4
127+
task_prior = NormalPrior(0, 1)
128+
diag_prior = NormalPrior(1, 0.1)
129+
130+
kernel = PositiveIndexKernel(
131+
num_tasks=num_tasks,
132+
rank=2,
133+
task_prior=task_prior,
134+
diag_prior=diag_prior,
135+
initialize_to_mode=False,
136+
).to(dtype=dtype)
137+
prior_names = [p[0] for p in kernel.named_priors()]
138+
self.assertIn("IndexKernelPrior", prior_names)
139+
self.assertIn("ScalePrior", prior_names)
140+
141+
# Test batch forward
142+
with self.subTest("batch_forward", dtype=dtype):
143+
num_tasks = 3
144+
batch_shape = torch.Size([2])
145+
kernel = PositiveIndexKernel(
146+
num_tasks=num_tasks, rank=2, batch_shape=batch_shape
147+
).to(dtype=dtype)
148+
149+
i1 = torch.tensor([[[0], [1]]], dtype=torch.long)
150+
i2 = torch.tensor([[[1], [2]]], dtype=torch.long)
151+
152+
result = kernel(i1, i2)
153+
154+
# Check that batch dimensions are preserved
155+
self.assertEqual(result.shape[0], 2)
156+
157+
# Test diagonal property (default target_task_index=0)
158+
with self.subTest("diagonal", dtype=dtype):
159+
kernel = PositiveIndexKernel(num_tasks=4, rank=2).to(dtype=dtype)
160+
diag = kernel._diagonal
161+
162+
self.assertEqual(diag.shape, torch.Size([4]))
163+
# First diagonal element should be 1.0 (default target_task_index=0)
164+
self.assertAllClose(diag[0], torch.tensor(1.0, dtype=dtype), atol=1e-4)
165+
166+
# Test diagonal property with custom target_task_index
167+
kernel = PositiveIndexKernel(
168+
num_tasks=4, rank=2, target_task_index=1
169+
).to(dtype=dtype)
170+
diag = kernel._diagonal
171+
172+
self.assertEqual(diag.shape, torch.Size([4]))
173+
# Second diagonal element should be 1.0 (target_task_index=1)
174+
self.assertAllClose(diag[1], torch.tensor(1.0, dtype=dtype), atol=1e-4)
175+
176+
# Test lower triangle property
177+
with self.subTest("lower_triangle", dtype=dtype):
178+
num_tasks = 5
179+
kernel = PositiveIndexKernel(num_tasks=num_tasks, rank=2).to(
180+
dtype=dtype
181+
)
182+
lower_tri = kernel._lower_triangle_corr
183+
184+
# Number of lower triangular elements (excluding diagonal)
185+
expected_size = num_tasks * (num_tasks - 1) // 2
186+
self.assertEqual(lower_tri.shape[-1], expected_size)
187+
self.assertTrue((lower_tri >= 0).all())
188+
189+
# Test invalid prior type
190+
with self.subTest("invalid_prior_type", dtype=dtype):
191+
with self.assertRaises(TypeError):
192+
PositiveIndexKernel(num_tasks=4, rank=2, task_prior="not_a_prior")
193+
194+
# Test covariance matrix properties
195+
with self.subTest("covar_matrix", dtype=dtype):
196+
kernel = PositiveIndexKernel(num_tasks=5, rank=4).to(dtype=dtype)
197+
covar = kernel.covar_matrix
198+
199+
# Should be square
200+
self.assertEqual(covar.shape[-2], covar.shape[-1])
201+
202+
# Should be positive definite (all eigenvalues > 0)
203+
eigvals = torch.linalg.eigvalsh(covar)
204+
self.assertTrue((eigvals > 0).all())
205+
206+
# Should be symmetric
207+
self.assertAllClose(covar, covar.T, atol=1e-5)
208+
209+
# Test covar_factor setter and getter
210+
with self.subTest("covar_factor", dtype=dtype):
211+
kernel = PositiveIndexKernel(num_tasks=3, rank=2).to(dtype=dtype)
212+
new_covar_factor = torch.ones(3, 2, dtype=dtype) * 2.0
213+
kernel.covar_factor = new_covar_factor
214+
self.assertAllClose(kernel.covar_factor, new_covar_factor, atol=1e-5)
215+
216+
kernel = PositiveIndexKernel(num_tasks=3, rank=2).to(dtype=dtype)
217+
params = kernel._covar_factor_params(kernel)
218+
self.assertEqual(params.shape, torch.Size([3, 2]))
219+
self.assertTrue((params > 0).all())
220+
221+
kernel = PositiveIndexKernel(num_tasks=3, rank=2).to(dtype=dtype)
222+
new_value = torch.ones(3, 2, dtype=dtype) * 3.0
223+
kernel._covar_factor_closure(kernel, new_value)
224+
self.assertAllClose(kernel.covar_factor, new_value, atol=1e-5)

0 commit comments

Comments
 (0)