From 3b2949feaacecc3bf465e2acea9104cc348587e0 Mon Sep 17 00:00:00 2001 From: Sid Kattoju Date: Tue, 2 Jul 2024 10:48:03 -0400 Subject: [PATCH 1/2] Adding a token getter to get service account tokens --- internal/authentication/tokengetter.go | 111 ++++++++++++++++++++ internal/authentication/tokengetter_test.go | 88 ++++++++++++++++ 2 files changed, 199 insertions(+) create mode 100644 internal/authentication/tokengetter.go create mode 100644 internal/authentication/tokengetter_test.go diff --git a/internal/authentication/tokengetter.go b/internal/authentication/tokengetter.go new file mode 100644 index 000000000..aed477560 --- /dev/null +++ b/internal/authentication/tokengetter.go @@ -0,0 +1,111 @@ +package authentication + +import ( + "context" + "sync" + "time" + + authenticationv1 "k8s.io/api/authentication/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + corev1 "k8s.io/client-go/kubernetes/typed/core/v1" + "k8s.io/utils/ptr" +) + +type TokenGetter struct { + client corev1.ServiceAccountsGetter + expirationDuration time.Duration + removeAfterExpiredDuration time.Duration + tokens map[types.NamespacedName]*authenticationv1.TokenRequestStatus + mu sync.RWMutex +} + +type TokenGetterOption func(*TokenGetter) + +const ( + RotationThresholdPercentage = 10 + DefaultExpirationDuration = 5 * time.Minute + DefaultRemoveAfterExpiredDuration = 90 * time.Minute +) + +// Returns a token getter that can fetch tokens given a service account. +// The token getter also caches tokens which helps reduce the number of requests to the API Server. +// In case a cached token is expiring a fresh token is created. +func NewTokenGetter(client corev1.ServiceAccountsGetter, options ...TokenGetterOption) *TokenGetter { + tokenGetter := &TokenGetter{ + client: client, + expirationDuration: DefaultExpirationDuration, + removeAfterExpiredDuration: DefaultRemoveAfterExpiredDuration, + tokens: map[types.NamespacedName]*authenticationv1.TokenRequestStatus{}, + } + + for _, opt := range options { + opt(tokenGetter) + } + + return tokenGetter +} + +func WithExpirationDuration(expirationDuration time.Duration) TokenGetterOption { + return func(tg *TokenGetter) { + tg.expirationDuration = expirationDuration + } +} + +func WithRemoveAfterExpiredDuration(removeAfterExpiredDuration time.Duration) TokenGetterOption { + return func(tg *TokenGetter) { + tg.removeAfterExpiredDuration = removeAfterExpiredDuration + } +} + +// Get returns a token from the cache if available and not expiring, otherwise creates a new token +func (t *TokenGetter) Get(ctx context.Context, key types.NamespacedName) (string, error) { + t.mu.RLock() + token, ok := t.tokens[key] + t.mu.RUnlock() + + expireTime := time.Time{} + if ok { + expireTime = token.ExpirationTimestamp.Time + } + + // Create a new token if the cached token expires within DurationPercentage of expirationDuration from now + rotationThresholdAfterNow := metav1.Now().Add(t.expirationDuration * (RotationThresholdPercentage / 100)) + if expireTime.Before(rotationThresholdAfterNow) { + var err error + token, err = t.getToken(ctx, key) + if err != nil { + return "", err + } + t.mu.Lock() + t.tokens[key] = token + t.mu.Unlock() + } + + // Delete tokens that have been expired for more than ExpiredDuration + t.reapExpiredTokens(t.removeAfterExpiredDuration) + + return token.Token, nil +} + +func (t *TokenGetter) getToken(ctx context.Context, key types.NamespacedName) (*authenticationv1.TokenRequestStatus, error) { + req, err := t.client.ServiceAccounts(key.Namespace).CreateToken(ctx, + key.Name, + &authenticationv1.TokenRequest{ + Spec: authenticationv1.TokenRequestSpec{ExpirationSeconds: ptr.To[int64](int64(t.expirationDuration))}, + }, metav1.CreateOptions{}) + if err != nil { + return nil, err + } + return &req.Status, nil +} + +func (t *TokenGetter) reapExpiredTokens(expiredDuration time.Duration) { + t.mu.Lock() + defer t.mu.Unlock() + for key, token := range t.tokens { + if metav1.Now().Sub(token.ExpirationTimestamp.Time) > expiredDuration { + delete(t.tokens, key) + } + } +} diff --git a/internal/authentication/tokengetter_test.go b/internal/authentication/tokengetter_test.go new file mode 100644 index 000000000..5b246c36a --- /dev/null +++ b/internal/authentication/tokengetter_test.go @@ -0,0 +1,88 @@ +package authentication + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" + authenticationv1 "k8s.io/api/authentication/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/kubernetes/fake" + ctest "k8s.io/client-go/testing" +) + +func TestTokenGetterGet(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + fakeClient.PrependReactor("create", "serviceaccounts/token", + func(action ctest.Action) (bool, runtime.Object, error) { + act, ok := action.(ctest.CreateActionImpl) + if !ok { + return false, nil, nil + } + tokenRequest := act.GetObject().(*authenticationv1.TokenRequest) + var err error + if act.Name == "test-service-account-1" { + tokenRequest.Status = authenticationv1.TokenRequestStatus{ + Token: "test-token-1", + ExpirationTimestamp: metav1.NewTime(metav1.Now().Add(DefaultExpirationDuration)), + } + } + if act.Name == "test-service-account-2" { + tokenRequest.Status = authenticationv1.TokenRequestStatus{ + Token: "test-token-2", + ExpirationTimestamp: metav1.NewTime(metav1.Now().Add(1 * time.Second)), + } + } + if act.Name == "test-service-account-3" { + tokenRequest.Status = authenticationv1.TokenRequestStatus{ + Token: "test-token-3", + ExpirationTimestamp: metav1.NewTime(metav1.Now().Add(-DefaultRemoveAfterExpiredDuration)), + } + } + if act.Name == "test-service-account-4" { + tokenRequest = nil + err = fmt.Errorf("error when fetching token") + } + return true, tokenRequest, err + }) + + tg := NewTokenGetter(fakeClient.CoreV1(), + WithExpirationDuration(DefaultExpirationDuration), + WithRemoveAfterExpiredDuration(DefaultRemoveAfterExpiredDuration)) + + tests := []struct { + testName string + serviceAccountName string + namespace string + want string + errorMsg string + }{ + {"Testing getting token with fake client", "test-service-account-1", + "test-namespace-1", "test-token-1", "failed to get token"}, + {"Testing getting token from cache", "test-service-account-1", + "test-namespace-1", "test-token-1", "failed to get token"}, + {"Testing getting short lived token from fake client", "test-service-account-2", + "test-namespace-2", "test-token-2", "failed to get token"}, + {"Testing getting expired token from cache", "test-service-account-2", + "test-namespace-2", "test-token-2", "failed to refresh token"}, + {"Testing token that expired 90 minutes ago", "test-service-account-3", + "test-namespace-3", "test-token-3", "failed to get token"}, + {"Testing error when getting token from fake client", "test-service-account-4", + "test-namespace-4", "error when fetching token", "error when fetching token"}, + } + + for _, tc := range tests { + got, err := tg.Get(context.Background(), types.NamespacedName{Namespace: tc.namespace, Name: tc.serviceAccountName}) + if err != nil { + t.Logf("%s: expected: %v, got: %v", tc.testName, tc.want, err) + assert.EqualError(t, err, tc.errorMsg) + } else { + t.Logf("%s: expected: %v, got: %v", tc.testName, tc.want, got) + assert.Equal(t, tc.want, got, tc.errorMsg) + } + } +} From 00c7f845e4ede366a1cd7838a923ba702083471a Mon Sep 17 00:00:00 2001 From: Sid Kattoju Date: Fri, 5 Jul 2024 14:07:20 -0400 Subject: [PATCH 2/2] Fixes and Improvements for TokenGetter --- internal/authentication/tokengetter.go | 41 ++++++++------------- internal/authentication/tokengetter_test.go | 9 ++--- 2 files changed, 20 insertions(+), 30 deletions(-) diff --git a/internal/authentication/tokengetter.go b/internal/authentication/tokengetter.go index aed477560..585fc65e6 100644 --- a/internal/authentication/tokengetter.go +++ b/internal/authentication/tokengetter.go @@ -13,19 +13,17 @@ import ( ) type TokenGetter struct { - client corev1.ServiceAccountsGetter - expirationDuration time.Duration - removeAfterExpiredDuration time.Duration - tokens map[types.NamespacedName]*authenticationv1.TokenRequestStatus - mu sync.RWMutex + client corev1.ServiceAccountsGetter + expirationDuration time.Duration + tokens map[types.NamespacedName]*authenticationv1.TokenRequestStatus + mu sync.RWMutex } type TokenGetterOption func(*TokenGetter) const ( - RotationThresholdPercentage = 10 - DefaultExpirationDuration = 5 * time.Minute - DefaultRemoveAfterExpiredDuration = 90 * time.Minute + rotationThresholdFraction = 0.1 + DefaultExpirationDuration = 5 * time.Minute ) // Returns a token getter that can fetch tokens given a service account. @@ -33,10 +31,9 @@ const ( // In case a cached token is expiring a fresh token is created. func NewTokenGetter(client corev1.ServiceAccountsGetter, options ...TokenGetterOption) *TokenGetter { tokenGetter := &TokenGetter{ - client: client, - expirationDuration: DefaultExpirationDuration, - removeAfterExpiredDuration: DefaultRemoveAfterExpiredDuration, - tokens: map[types.NamespacedName]*authenticationv1.TokenRequestStatus{}, + client: client, + expirationDuration: DefaultExpirationDuration, + tokens: map[types.NamespacedName]*authenticationv1.TokenRequestStatus{}, } for _, opt := range options { @@ -52,12 +49,6 @@ func WithExpirationDuration(expirationDuration time.Duration) TokenGetterOption } } -func WithRemoveAfterExpiredDuration(removeAfterExpiredDuration time.Duration) TokenGetterOption { - return func(tg *TokenGetter) { - tg.removeAfterExpiredDuration = removeAfterExpiredDuration - } -} - // Get returns a token from the cache if available and not expiring, otherwise creates a new token func (t *TokenGetter) Get(ctx context.Context, key types.NamespacedName) (string, error) { t.mu.RLock() @@ -69,8 +60,8 @@ func (t *TokenGetter) Get(ctx context.Context, key types.NamespacedName) (string expireTime = token.ExpirationTimestamp.Time } - // Create a new token if the cached token expires within DurationPercentage of expirationDuration from now - rotationThresholdAfterNow := metav1.Now().Add(t.expirationDuration * (RotationThresholdPercentage / 100)) + // Create a new token if the cached token expires within rotationThresholdFraction of expirationDuration from now + rotationThresholdAfterNow := metav1.Now().Add(time.Duration(float64(t.expirationDuration) * (rotationThresholdFraction))) if expireTime.Before(rotationThresholdAfterNow) { var err error token, err = t.getToken(ctx, key) @@ -82,8 +73,8 @@ func (t *TokenGetter) Get(ctx context.Context, key types.NamespacedName) (string t.mu.Unlock() } - // Delete tokens that have been expired for more than ExpiredDuration - t.reapExpiredTokens(t.removeAfterExpiredDuration) + // Delete tokens that have expired + t.reapExpiredTokens() return token.Token, nil } @@ -92,7 +83,7 @@ func (t *TokenGetter) getToken(ctx context.Context, key types.NamespacedName) (* req, err := t.client.ServiceAccounts(key.Namespace).CreateToken(ctx, key.Name, &authenticationv1.TokenRequest{ - Spec: authenticationv1.TokenRequestSpec{ExpirationSeconds: ptr.To[int64](int64(t.expirationDuration))}, + Spec: authenticationv1.TokenRequestSpec{ExpirationSeconds: ptr.To[int64](int64(t.expirationDuration / time.Second))}, }, metav1.CreateOptions{}) if err != nil { return nil, err @@ -100,11 +91,11 @@ func (t *TokenGetter) getToken(ctx context.Context, key types.NamespacedName) (* return &req.Status, nil } -func (t *TokenGetter) reapExpiredTokens(expiredDuration time.Duration) { +func (t *TokenGetter) reapExpiredTokens() { t.mu.Lock() defer t.mu.Unlock() for key, token := range t.tokens { - if metav1.Now().Sub(token.ExpirationTimestamp.Time) > expiredDuration { + if metav1.Now().Sub(token.ExpirationTimestamp.Time) > 0 { delete(t.tokens, key) } } diff --git a/internal/authentication/tokengetter_test.go b/internal/authentication/tokengetter_test.go index 5b246c36a..b9553cac3 100644 --- a/internal/authentication/tokengetter_test.go +++ b/internal/authentication/tokengetter_test.go @@ -40,7 +40,7 @@ func TestTokenGetterGet(t *testing.T) { if act.Name == "test-service-account-3" { tokenRequest.Status = authenticationv1.TokenRequestStatus{ Token: "test-token-3", - ExpirationTimestamp: metav1.NewTime(metav1.Now().Add(-DefaultRemoveAfterExpiredDuration)), + ExpirationTimestamp: metav1.NewTime(metav1.Now().Add(-10 * time.Second)), } } if act.Name == "test-service-account-4" { @@ -51,8 +51,7 @@ func TestTokenGetterGet(t *testing.T) { }) tg := NewTokenGetter(fakeClient.CoreV1(), - WithExpirationDuration(DefaultExpirationDuration), - WithRemoveAfterExpiredDuration(DefaultRemoveAfterExpiredDuration)) + WithExpirationDuration(DefaultExpirationDuration)) tests := []struct { testName string @@ -67,9 +66,9 @@ func TestTokenGetterGet(t *testing.T) { "test-namespace-1", "test-token-1", "failed to get token"}, {"Testing getting short lived token from fake client", "test-service-account-2", "test-namespace-2", "test-token-2", "failed to get token"}, - {"Testing getting expired token from cache", "test-service-account-2", + {"Testing getting nearly expired token from cache", "test-service-account-2", "test-namespace-2", "test-token-2", "failed to refresh token"}, - {"Testing token that expired 90 minutes ago", "test-service-account-3", + {"Testing token that expired 10 seconds ago", "test-service-account-3", "test-namespace-3", "test-token-3", "failed to get token"}, {"Testing error when getting token from fake client", "test-service-account-4", "test-namespace-4", "error when fetching token", "error when fetching token"},