1515
1616import torch
1717from botorch .acquisition .acquisition import AcquisitionFunction
18- from botorch .acquisition .analytic import AnalyticAcquisitionFunction
1918from botorch .acquisition .objective import GenericMCObjective
20- from botorch .exceptions import UnsupportedError
19+ from botorch .acquisition . wrapper import AbstractAcquisitionFunctionWrapper
2120from torch import Tensor
2221
2322
@@ -139,7 +138,7 @@ def forward(self, X: Tensor) -> Tensor:
139138 return regularization_term
140139
141140
142- class PenalizedAcquisitionFunction (AcquisitionFunction ):
141+ class PenalizedAcquisitionFunction (AbstractAcquisitionFunctionWrapper ):
143142 r"""Single-outcome acquisition function regularized by the given penalty.
144143
145144 The usage is similar to:
@@ -161,29 +160,16 @@ def __init__(
161160 penalty_func: The regularization function.
162161 regularization_parameter: Regularization parameter used in optimization.
163162 """
164- super () .__init__ (model = raw_acqf .model )
165- self . raw_acqf = raw_acqf
163+ AcquisitionFunction .__init__ (self , model = raw_acqf .model )
164+ AbstractAcquisitionFunctionWrapper . __init__ ( self , acq_function = raw_acqf )
166165 self .penalty_func = penalty_func
167166 self .regularization_parameter = regularization_parameter
168167
169168 def forward (self , X : Tensor ) -> Tensor :
170- raw_value = self .raw_acqf (X = X )
169+ raw_value = self .acq_func (X = X )
171170 penalty_term = self .penalty_func (X )
172171 return raw_value - self .regularization_parameter * penalty_term
173172
174- @property
175- def X_pending (self ) -> Optional [Tensor ]:
176- return self .raw_acqf .X_pending
177-
178- def set_X_pending (self , X_pending : Optional [Tensor ] = None ) -> None :
179- if not isinstance (self .raw_acqf , AnalyticAcquisitionFunction ):
180- self .raw_acqf .set_X_pending (X_pending = X_pending )
181- else :
182- raise UnsupportedError (
183- "The raw acquisition function is Analytic and does not account "
184- "for X_pending yet."
185- )
186-
187173
188174def group_lasso_regularizer (X : Tensor , groups : List [List [int ]]) -> Tensor :
189175 r"""Computes the group lasso regularization function for the given point.
0 commit comments