diff --git a/captum/attr/_core/lime.py b/captum/attr/_core/lime.py index b9ecabec1c..9ea81f3de4 100644 --- a/captum/attr/_core/lime.py +++ b/captum/attr/_core/lime.py @@ -727,7 +727,7 @@ def __init__( interpretable_model (optional, Model): Model object to train interpretable model. - This argument is optional and defaults to SkLearnLasso(alpha=1.0), + This argument is optional and defaults to SkLearnLasso(alpha=0.01), which is a wrapper around the Lasso linear model in SkLearn. This requires having sklearn version >= 0.23 available. @@ -805,7 +805,7 @@ def __init__( """ if interpretable_model is None: - interpretable_model = SkLearnLasso(alpha=1.0) + interpretable_model = SkLearnLasso(alpha=0.01) if similarity_func is None: similarity_func = get_exp_kernel_similarity_function() diff --git a/tests/attr/test_lime.py b/tests/attr/test_lime.py index 593b8f5cc8..ab570e82bb 100644 --- a/tests/attr/test_lime.py +++ b/tests/attr/test_lime.py @@ -464,6 +464,7 @@ def _lime_test_assert( lime = Lime( model, similarity_func=get_exp_kernel_similarity_function("cosine", 10.0), + interpretable_model=SkLearnLasso(alpha=1.0), ) with self.assertWarns(DeprecationWarning): attributions = lime.attribute(