diff --git a/captum/influence/_core/arnoldi_influence_function.py b/captum/influence/_core/arnoldi_influence_function.py index f02db89c7d..953f5a8bd1 100644 --- a/captum/influence/_core/arnoldi_influence_function.py +++ b/captum/influence/_core/arnoldi_influence_function.py @@ -157,7 +157,7 @@ def _parameter_distill( k: Optional[int], hessian_reg: float, hessian_inverse_tol: float, -): +) -> Tuple[Tensor, List[Tuple[Tensor, ...]]]: """ This takes the output of `_parameter_arnoldi`, and extracts the top-k eigenvalues / eigenvectors of the matrix that `_parameter_arnoldi` found the Krylov subspace diff --git a/captum/influence/_core/influence_function.py b/captum/influence/_core/influence_function.py index 0115ac1def..dbc709090b 100644 --- a/captum/influence/_core/influence_function.py +++ b/captum/influence/_core/influence_function.py @@ -596,7 +596,7 @@ def _get_dataset_embeddings_intermediate_quantities_influence_function( batch_embeddings_fn: Callable, inputs_dataset: DataLoader, aggregate: bool, -): +) -> Tensor: """ given `batch_embeddings_fn`, which produces the embeddings for a given batch, returns either the embeddings for an entire dataset (if `aggregate` is false), diff --git a/captum/influence/_core/similarity_influence.py b/captum/influence/_core/similarity_influence.py index 3c4046ca1d..d1d0d2b6f8 100644 --- a/captum/influence/_core/similarity_influence.py +++ b/captum/influence/_core/similarity_influence.py @@ -18,7 +18,7 @@ """ -def euclidean_distance(test, train) -> Tensor: +def euclidean_distance(test: Tensor, train: Tensor) -> Tensor: r""" Calculates the pairwise euclidean distance for batches of feature vectors. Tensors test and train have shape (batch_size_1, *), and (batch_size_2, *). @@ -31,7 +31,7 @@ def euclidean_distance(test, train) -> Tensor: return similarity -def cosine_similarity(test, train, replace_nan=0) -> Tensor: +def cosine_similarity(test: Tensor, train: Tensor, replace_nan: int = 0) -> Tensor: r""" Calculates the pairwise cosine similarity for batches of feature vectors. Tensors test and train have shape (batch_size_1, *), and (batch_size_2, *). diff --git a/captum/influence/_core/tracincp.py b/captum/influence/_core/tracincp.py index fb26aa32c8..e383773b1b 100644 --- a/captum/influence/_core/tracincp.py +++ b/captum/influence/_core/tracincp.py @@ -140,16 +140,9 @@ def __init__( Default: None """ - self.model = model + self.model: Module = model - if isinstance(checkpoints, str): - self.checkpoints = AV.sort_files(glob.glob(join(checkpoints, "*"))) - elif isinstance(checkpoints, List) and isinstance(checkpoints[0], str): - self.checkpoints = AV.sort_files(checkpoints) - else: - self.checkpoints = list(checkpoints) # cast to avoid mypy error - if isinstance(self.checkpoints, List): - assert len(self.checkpoints) > 0, "No checkpoints saved!" + self.checkpoints = checkpoints # type: ignore self.checkpoints_load_func = checkpoints_load_func self.loss_fn = loss_fn @@ -181,6 +174,24 @@ def __init__( "percentage completion of the computation, nor any time estimates." ) + @property + def checkpoints(self) -> List[str]: + return self._checkpoints + + @checkpoints.setter + def checkpoints(self, checkpoints: Union[str, List[str], Iterator]) -> None: + if isinstance(checkpoints, str): + self._checkpoints = AV.sort_files(glob.glob(join(checkpoints, "*"))) + elif isinstance(checkpoints, List) and isinstance(checkpoints[0], str): + self._checkpoints = AV.sort_files(checkpoints) + else: + self._checkpoints = list(checkpoints) # cast to avoid mypy error + + if len(self._checkpoints) <= 0: + raise ValueError( + f"Invalid checkpoints provided for TracIn class: {checkpoints}!" + ) + @abstractmethod def self_influence( self, diff --git a/captum/influence/_core/tracincp_fast_rand_proj.py b/captum/influence/_core/tracincp_fast_rand_proj.py index ccc3bf061f..32b44506ba 100644 --- a/captum/influence/_core/tracincp_fast_rand_proj.py +++ b/captum/influence/_core/tracincp_fast_rand_proj.py @@ -82,7 +82,7 @@ class TracInCPFast(TracInCPBase): def __init__( self, model: Module, - final_fc_layer: Module, + final_fc_layer: Union[Module, str], train_dataset: Union[Dataset, DataLoader], checkpoints: Union[str, List[str], Iterator], checkpoints_load_func: Callable = _load_flexible_state_dict, @@ -183,7 +183,7 @@ def __init__( self.vectorize = vectorize # TODO: restore prior state - self.final_fc_layer = final_fc_layer + self.final_fc_layer = final_fc_layer # type: ignore for param in self.final_fc_layer.parameters(): param.requires_grad = True @@ -720,7 +720,7 @@ def _basic_computation_tracincp_fast( targets: Tensor, loss_fn: Optional[Union[Module, Callable]] = None, reduction_type: Optional[str] = None, -): +) -> Tuple[Tensor, Tensor]: """ For instances of TracInCPFast and children classes, computation of influence scores or self influence scores repeatedly calls this function for different checkpoints @@ -1363,7 +1363,7 @@ def _set_projections_tracincp_fast_rand_proj( def _process_src_intermediate_quantities_tracincp_fast_rand_proj( self, src_intermediate_quantities: torch.Tensor, - ): + ) -> None: """ Assumes `self._get_intermediate_quantities_tracin_fast_rand_proj` returns vector representations for each example, and that influence between a diff --git a/captum/influence/_utils/common.py b/captum/influence/_utils/common.py index c214ecbdf1..0df8fd00d9 100644 --- a/captum/influence/_utils/common.py +++ b/captum/influence/_utils/common.py @@ -4,6 +4,7 @@ from typing import ( Any, Callable, + Dict, Iterable, List, NamedTuple, @@ -613,7 +614,7 @@ def _influence_batch_intermediate_quantities_influence_function( influence_inst: "IntermediateQuantitiesInfluenceFunction", test_batch: Tuple[Any, ...], train_batch: Tuple[Any, ...], -): +) -> Tensor: """ computes influence of a test batch on a train batch, for implementations of `IntermediateQuantitiesInfluenceFunction` @@ -628,7 +629,7 @@ def _influence_helper_intermediate_quantities_influence_function( influence_inst: "IntermediateQuantitiesInfluenceFunction", inputs_dataset: Union[Tuple[Any, ...], DataLoader], show_progress: bool, -): +) -> Tensor: """ Helper function that computes influence scores for implementations of `NaiveInfluenceFunction` which implement the `compute_intermediate_quantities` @@ -666,7 +667,7 @@ def _self_influence_helper_intermediate_quantities_influence_function( influence_inst: "IntermediateQuantitiesInfluenceFunction", inputs_dataset: Optional[Union[Tuple[Any, ...], DataLoader]], show_progress: bool, -): +) -> Tensor: """ Helper function that computes self-influence scores for implementations of `NaiveInfluenceFunction` which implement the `compute_intermediate_quantities` @@ -983,14 +984,14 @@ def _compute_batch_loss_influence_function_base( raise Exception -def _set_attr(obj, names, val): +def _set_attr(obj, names, val) -> None: if len(names) == 1: setattr(obj, names[0], val) else: _set_attr(getattr(obj, names[0]), names[1:], val) -def _del_attr(obj, names): +def _del_attr(obj, names) -> None: if len(names) == 1: delattr(obj, names[0]) else: @@ -1006,7 +1007,7 @@ def _model_make_functional(model, param_names, params): return params -def _model_reinsert_params(model, param_names, params, register=False): +def _model_reinsert_params(model, param_names, params, register: bool = False) -> None: for param_name, param in zip(param_names, params): _set_attr( model, @@ -1024,7 +1025,7 @@ def _custom_functional_call(model, d, features): return out -def _functional_call(model, d, features): +def _functional_call(model: Module, d: Dict[str, Tensor], features): """ Makes a call to `model.forward`, which is treated as a function of the parameters in `d`, a dict from parameter name to parameter, instead of as a function of diff --git a/tests/influence/_core/test_tracin_validation.py b/tests/influence/_core/test_tracin_validation.py index f24e56d7e1..682bff408d 100644 --- a/tests/influence/_core/test_tracin_validation.py +++ b/tests/influence/_core/test_tracin_validation.py @@ -36,7 +36,7 @@ class TestTracinValidator(BaseTest): ) def test_tracin_require_inputs_dataset( self, - reduction, + reduction: str, tracin_constructor: Callable, ) -> None: """ @@ -64,6 +64,10 @@ def test_tracin_require_inputs_dataset( tracin.influence(None, k=None) def test_tracincp_fast_rand_proj_inputs(self) -> None: + """ + This test verifies that TracInCPFast should be initialized + with a valid `final_fc_layer`. + """ with tempfile.TemporaryDirectory() as tmpdir: ( net, @@ -83,3 +87,34 @@ def test_tracincp_fast_rand_proj_inputs(self) -> None: loss_fn=nn.MSELoss(), batch_size=1, ) + + @parameterized.expand( + param_list, + name_func=build_test_name_func(), + ) + def test_tracincp_input_checkpoints( + self, reduction: str, tracin_constructor: Callable + ) -> None: + """ + This test verifies that tracinCP and tracinCPFast + class should be initialized with valid `checkpoints`. + """ + with tempfile.TemporaryDirectory() as invalid_tmpdir: + with tempfile.TemporaryDirectory() as tmpdir: + ( + net, + train_dataset, + test_samples, + test_labels, + ) = get_random_model_and_data(tmpdir, unpack_inputs=False) + + with self.assertRaisesRegex( + ValueError, "Invalid checkpoints provided for TracIn class: " + ): + tracin_constructor( + net, + train_dataset, + invalid_tmpdir, + loss_fn=nn.MSELoss(), + batch_size=1, + )