Skip to content

Fixes and Improvements for TokenGetter #1013

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions internal/authentication/tokengetter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package authentication

import (
"context"
"sync"
"time"

authenticationv1 "k8s.io/api/authentication/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types"
corev1 "k8s.io/client-go/kubernetes/typed/core/v1"
"k8s.io/utils/ptr"
)

type TokenGetter struct {
client corev1.ServiceAccountsGetter
expirationDuration time.Duration
tokens map[types.NamespacedName]*authenticationv1.TokenRequestStatus
mu sync.RWMutex
}

type TokenGetterOption func(*TokenGetter)

const (
rotationThresholdFraction = 0.1
DefaultExpirationDuration = 5 * time.Minute
)

// Returns a token getter that can fetch tokens given a service account.
// The token getter also caches tokens which helps reduce the number of requests to the API Server.
// In case a cached token is expiring a fresh token is created.
func NewTokenGetter(client corev1.ServiceAccountsGetter, options ...TokenGetterOption) *TokenGetter {
tokenGetter := &TokenGetter{
client: client,
expirationDuration: DefaultExpirationDuration,
tokens: map[types.NamespacedName]*authenticationv1.TokenRequestStatus{},
}

for _, opt := range options {
opt(tokenGetter)
}

return tokenGetter
}

func WithExpirationDuration(expirationDuration time.Duration) TokenGetterOption {
return func(tg *TokenGetter) {
tg.expirationDuration = expirationDuration
}
}

// Get returns a token from the cache if available and not expiring, otherwise creates a new token
func (t *TokenGetter) Get(ctx context.Context, key types.NamespacedName) (string, error) {
t.mu.RLock()
token, ok := t.tokens[key]
t.mu.RUnlock()

expireTime := time.Time{}
if ok {
expireTime = token.ExpirationTimestamp.Time
}

// Create a new token if the cached token expires within rotationThresholdFraction of expirationDuration from now
rotationThresholdAfterNow := metav1.Now().Add(time.Duration(float64(t.expirationDuration) * (rotationThresholdFraction)))
if expireTime.Before(rotationThresholdAfterNow) {
var err error
token, err = t.getToken(ctx, key)
if err != nil {
return "", err
}
t.mu.Lock()
t.tokens[key] = token
t.mu.Unlock()
}

// Delete tokens that have expired
t.reapExpiredTokens()

return token.Token, nil
}

func (t *TokenGetter) getToken(ctx context.Context, key types.NamespacedName) (*authenticationv1.TokenRequestStatus, error) {
req, err := t.client.ServiceAccounts(key.Namespace).CreateToken(ctx,
key.Name,
&authenticationv1.TokenRequest{
Spec: authenticationv1.TokenRequestSpec{ExpirationSeconds: ptr.To[int64](int64(t.expirationDuration / time.Second))},
}, metav1.CreateOptions{})
if err != nil {
return nil, err
}
return &req.Status, nil
}

func (t *TokenGetter) reapExpiredTokens() {
t.mu.Lock()
defer t.mu.Unlock()
for key, token := range t.tokens {
if metav1.Now().Sub(token.ExpirationTimestamp.Time) > 0 {
delete(t.tokens, key)
}
}
}
87 changes: 87 additions & 0 deletions internal/authentication/tokengetter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package authentication

import (
"context"
"fmt"
"testing"
"time"

"github.com/stretchr/testify/assert"
authenticationv1 "k8s.io/api/authentication/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/types"
"k8s.io/client-go/kubernetes/fake"
ctest "k8s.io/client-go/testing"
)

func TestTokenGetterGet(t *testing.T) {
fakeClient := fake.NewSimpleClientset()
fakeClient.PrependReactor("create", "serviceaccounts/token",
func(action ctest.Action) (bool, runtime.Object, error) {
act, ok := action.(ctest.CreateActionImpl)
if !ok {
return false, nil, nil
}
tokenRequest := act.GetObject().(*authenticationv1.TokenRequest)
var err error
if act.Name == "test-service-account-1" {
tokenRequest.Status = authenticationv1.TokenRequestStatus{
Token: "test-token-1",
ExpirationTimestamp: metav1.NewTime(metav1.Now().Add(DefaultExpirationDuration)),
}
}
if act.Name == "test-service-account-2" {
tokenRequest.Status = authenticationv1.TokenRequestStatus{
Token: "test-token-2",
ExpirationTimestamp: metav1.NewTime(metav1.Now().Add(1 * time.Second)),
}
}
if act.Name == "test-service-account-3" {
tokenRequest.Status = authenticationv1.TokenRequestStatus{
Token: "test-token-3",
ExpirationTimestamp: metav1.NewTime(metav1.Now().Add(-10 * time.Second)),
}
}
if act.Name == "test-service-account-4" {
tokenRequest = nil
err = fmt.Errorf("error when fetching token")
}
return true, tokenRequest, err
})

tg := NewTokenGetter(fakeClient.CoreV1(),
WithExpirationDuration(DefaultExpirationDuration))

tests := []struct {
testName string
serviceAccountName string
namespace string
want string
errorMsg string
}{
{"Testing getting token with fake client", "test-service-account-1",
"test-namespace-1", "test-token-1", "failed to get token"},
{"Testing getting token from cache", "test-service-account-1",
"test-namespace-1", "test-token-1", "failed to get token"},
{"Testing getting short lived token from fake client", "test-service-account-2",
"test-namespace-2", "test-token-2", "failed to get token"},
{"Testing getting nearly expired token from cache", "test-service-account-2",
"test-namespace-2", "test-token-2", "failed to refresh token"},
{"Testing token that expired 10 seconds ago", "test-service-account-3",
"test-namespace-3", "test-token-3", "failed to get token"},
{"Testing error when getting token from fake client", "test-service-account-4",
"test-namespace-4", "error when fetching token", "error when fetching token"},
}

for _, tc := range tests {
got, err := tg.Get(context.Background(), types.NamespacedName{Namespace: tc.namespace, Name: tc.serviceAccountName})
if err != nil {
t.Logf("%s: expected: %v, got: %v", tc.testName, tc.want, err)
assert.EqualError(t, err, tc.errorMsg)
} else {
t.Logf("%s: expected: %v, got: %v", tc.testName, tc.want, got)
assert.Equal(t, tc.want, got, tc.errorMsg)
}
}
}
Loading