diff --git a/CHANGELOG.md b/CHANGELOG.md index 04516c203..df08fd9bd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ - `core`: [v0.y.z] - **Feature:** Add package `runtime`, which implements methods to be used when performing API requests. - **Feature:** Add method `WithCaptureHTTPResponse` to package `runtime`, which does the same as `config.WithCaptureHTTPResponse`. Method was moved to avoid confusion due to it not being a configuration option, and will be removed in a later release. +- **Feature:** Add configuration option that, for the key flow, enables a goroutine to be spawned that will refresh the access token when it's close to expiring - **Deprecation:** Mark method `config.WithCaptureHTTPResponse` as deprecated, to avoid confusion due to it not being a configuration option. Use `runtime.WithCaptureHTTPResponse` instead. - **Deprecation:** Mark method `config.WithJWKSEndpoint` and field `config.Configuration.JWKSCustomUrl` as deprecated. Validation using JWKS was removed, for being redundant with token validation done in the APIs. These have no effect. - **Breaking Change:** Remove method `KeyFlow.Clone`, that was no longer being used. diff --git a/core/CHANGELOG.md b/core/CHANGELOG.md index ffe7c3aad..07e9e770a 100644 --- a/core/CHANGELOG.md +++ b/core/CHANGELOG.md @@ -1,3 +1,7 @@ +## v0.10.0 (YYYY-MM-DD) + +- **Feature:** Add configuration option that, for the key flow, enables a goroutine to be spawned that will refresh the access token when it's close to expiring + ## v0.9.0 (2024-02-19) - **Deprecation:** Mark method `config.WithCaptureHTTPResponse` as deprecated, to avoid confusion due to it not being a configuration option. Use `runtime.WithCaptureHTTPResponse` instead. diff --git a/core/auth/auth.go b/core/auth/auth.go index 56fa855bb..45d2e046d 100644 --- a/core/auth/auth.go +++ b/core/auth/auth.go @@ -179,10 +179,11 @@ func KeyAuth(cfg *config.Configuration) (http.RoundTripper, error) { } keyCfg := clients.KeyFlowConfig{ - ServiceAccountKey: serviceAccountKey, - PrivateKey: cfg.PrivateKey, - ClientRetry: cfg.RetryOptions, - TokenUrl: cfg.TokenCustomUrl, + ServiceAccountKey: serviceAccountKey, + PrivateKey: cfg.PrivateKey, + ClientRetry: cfg.RetryOptions, + TokenUrl: cfg.TokenCustomUrl, + BackgroundTokenRefreshContext: cfg.BackgroundTokenRefreshContext, } client := &clients.KeyFlow{} diff --git a/core/clients/key_flow.go b/core/clients/key_flow.go index ab49b682b..e935abdb6 100644 --- a/core/clients/key_flow.go +++ b/core/clients/key_flow.go @@ -1,6 +1,7 @@ package clients import ( + "context" "crypto/rsa" "crypto/x509" "encoding/json" @@ -10,8 +11,11 @@ import ( "net/http" "net/url" "strings" + "sync" "time" + "github.com/stackitcloud/stackit-sdk-go/core/oapierror" + "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" ) @@ -36,15 +40,18 @@ type KeyFlow struct { key *ServiceAccountKeyResponse privateKey *rsa.PrivateKey privateKeyPEM []byte - token *TokenResponseBody + + tokenMutex sync.RWMutex + token *TokenResponseBody } // KeyFlowConfig is the flow config type KeyFlowConfig struct { - ServiceAccountKey *ServiceAccountKeyResponse - PrivateKey string - ClientRetry *RetryConfig - TokenUrl string + ServiceAccountKey *ServiceAccountKeyResponse + PrivateKey string + ClientRetry *RetryConfig + TokenUrl string + BackgroundTokenRefreshContext context.Context // Functionality is enabled if this isn't nil } // TokenResponseBody is the API response @@ -97,13 +104,19 @@ func (c *KeyFlow) GetServiceAccountEmail() string { // GetToken returns the token field func (c *KeyFlow) GetToken() TokenResponseBody { + c.tokenMutex.RLock() + defer c.tokenMutex.RUnlock() + if c.token == nil { return TokenResponseBody{} } + // Returned struct is passed by value (because it's a struct) + // So no deepy copy needed return *c.token } func (c *KeyFlow) Init(cfg *KeyFlowConfig) error { + // No concurrency at this point, so no mutex check needed c.token = &TokenResponseBody{} c.config = cfg c.doer = Do @@ -115,7 +128,14 @@ func (c *KeyFlow) Init(cfg *KeyFlowConfig) error { if c.config.ClientRetry == nil { c.config.ClientRetry = NewRetryConfig() } - return c.validate() + err := c.validate() + if err != nil { + return err + } + if c.config.BackgroundTokenRefreshContext != nil { + go continuousRefreshToken(c) + } + return nil } // SetToken can be used to set an access and refresh token manually in the client. @@ -132,6 +152,7 @@ func (c *KeyFlow) SetToken(accessToken, refreshToken string) error { return fmt.Errorf("get expiration time from access token: %w", err) } + c.tokenMutex.Lock() c.token = &TokenResponseBody{ AccessToken: accessToken, ExpiresIn: int(exp.Time.Unix()), @@ -139,6 +160,7 @@ func (c *KeyFlow) SetToken(accessToken, refreshToken string) error { RefreshToken: refreshToken, TokenType: defaultTokenType, } + c.tokenMutex.Unlock() return nil } @@ -158,17 +180,21 @@ func (c *KeyFlow) RoundTrip(req *http.Request) (*http.Response, error) { // GetAccessToken returns a short-lived access token and saves the access and refresh tokens in the token field func (c *KeyFlow) GetAccessToken() (string, error) { - accessTokenExpired, err := tokenExpired(c.token.AccessToken) + c.tokenMutex.RLock() + accessToken := c.token.AccessToken + c.tokenMutex.RUnlock() + + accessTokenExpired, err := tokenExpired(accessToken) if err != nil { - return "", fmt.Errorf("failed initial validation: %w", err) + return "", fmt.Errorf("check access token is expired: %w", err) } if !accessTokenExpired { - return c.token.AccessToken, nil + return accessToken, nil } if err := c.recreateAccessToken(); err != nil { - return "", fmt.Errorf("failed during token recreation: %w", err) + return "", fmt.Errorf("get new access token: %w", err) } - return c.token.AccessToken, nil + return accessToken, nil } // configureHTTPClient configures the HTTP client @@ -191,7 +217,7 @@ func (c *KeyFlow) validate() error { var err error c.privateKey, err = jwt.ParseRSAPrivateKeyFromPEM([]byte(c.config.PrivateKey)) if err != nil { - return fmt.Errorf("parsing private key from PEM file: %w", err) + return fmt.Errorf("parse private key from PEM file: %w", err) } // Encode the private key in PEM format @@ -209,7 +235,11 @@ func (c *KeyFlow) validate() error { // recreateAccessToken is used to create a new access token // when the existing one isn't valid anymore func (c *KeyFlow) recreateAccessToken() error { - refreshTokenExpired, err := tokenExpired(c.token.RefreshToken) + c.tokenMutex.RLock() + refreshToken := c.token.RefreshToken + c.tokenMutex.RUnlock() + + refreshTokenExpired, err := tokenExpired(refreshToken) if err != nil { return err } @@ -232,8 +262,8 @@ func (c *KeyFlow) createAccessToken() (err error) { } defer func() { tempErr := res.Body.Close() - if tempErr != nil { - err = fmt.Errorf("closing request access token response: %w", tempErr) + if tempErr != nil && err == nil { + err = fmt.Errorf("close request access token response: %w", tempErr) } }() return c.parseTokenResponse(res) @@ -242,14 +272,18 @@ func (c *KeyFlow) createAccessToken() (err error) { // createAccessTokenWithRefreshToken creates an access token using // an existing pre-validated refresh token func (c *KeyFlow) createAccessTokenWithRefreshToken() (err error) { - res, err := c.requestToken("refresh_token", c.token.RefreshToken) + c.tokenMutex.RLock() + refreshToken := c.token.RefreshToken + c.tokenMutex.RUnlock() + + res, err := c.requestToken("refresh_token", refreshToken) if err != nil { return err } defer func() { tempErr := res.Body.Close() - if tempErr != nil { - err = fmt.Errorf("closing request access token with refresh token response: %w", tempErr) + if tempErr != nil && err == nil { + err = fmt.Errorf("close request access token with refresh token response: %w", tempErr) } }() return c.parseTokenResponse(res) @@ -294,14 +328,32 @@ func (c *KeyFlow) parseTokenResponse(res *http.Response) error { return fmt.Errorf("received bad response from API") } if res.StatusCode != http.StatusOK { - return fmt.Errorf("received: %+v", res) + body, err := io.ReadAll(res.Body) + if err != nil { + // Fail silently, omit body from error + // We're trying to show error details, so it's unnecessary to fail because of this err + body = []byte{} + } + return &oapierror.GenericOpenAPIError{ + StatusCode: res.StatusCode, + Body: body, + ErrorMessage: err.Error(), + } } body, err := io.ReadAll(res.Body) if err != nil { return err } + + c.tokenMutex.Lock() c.token = &TokenResponseBody{} - return json.Unmarshal(body, c.token) + err = json.Unmarshal(body, c.token) + c.tokenMutex.Unlock() + if err != nil { + return fmt.Errorf("unmarshal token response: %w", err) + } + + return nil } func tokenExpired(token string) (bool, error) { @@ -309,11 +361,11 @@ func tokenExpired(token string) (bool, error) { // We're just checking the expiration time tokenParsed, _, err := jwt.NewParser().ParseUnverified(token, &jwt.RegisteredClaims{}) if err != nil { - return false, fmt.Errorf("parse access token: %w", err) + return false, fmt.Errorf("parse token: %w", err) } expirationTimestampNumeric, err := tokenParsed.Claims.GetExpirationTime() if err != nil { - return false, fmt.Errorf("get expiration timestamp from access token: %w", err) + return false, fmt.Errorf("get expiration timestamp: %w", err) } expirationTimestamp := expirationTimestampNumeric.Time now := time.Now() diff --git a/core/clients/key_flow_continuous_refresh.go b/core/clients/key_flow_continuous_refresh.go new file mode 100644 index 000000000..540e9d155 --- /dev/null +++ b/core/clients/key_flow_continuous_refresh.go @@ -0,0 +1,126 @@ +package clients + +import ( + "errors" + "fmt" + "os" + "time" + + "github.com/golang-jwt/jwt/v5" + + "github.com/stackitcloud/stackit-sdk-go/core/oapierror" +) + +var ( + defaultTimeStartBeforeTokenExpiration = 30 * time.Minute + defaultTimeBetweenContextCheck = time.Second + defaultTimeBetweenTries = 5 * time.Minute +) + +// Continuously refreshes the token of a key flow, retrying if the token API returns 5xx errrors. Writes to stderr when it terminates. +// +// To terminate this routine, close the context in keyFlow.config.BackgroundTokenRefreshContext. +func continuousRefreshToken(keyflow *KeyFlow) { + refresher := &continuousTokenRefresher{ + keyFlow: keyflow, + timeStartBeforeTokenExpiration: defaultTimeStartBeforeTokenExpiration, + timeBetweenContextCheck: defaultTimeBetweenContextCheck, + timeBetweenTries: defaultTimeBetweenTries, + } + err := refresher.continuousRefreshToken() + fmt.Fprintf(os.Stderr, "Token refreshing terminated: %v", err) +} + +type continuousTokenRefresher struct { + keyFlow *KeyFlow + // Token refresh tries start at [Access token expiration timestamp] - [This duration] + timeStartBeforeTokenExpiration time.Duration + timeBetweenContextCheck time.Duration + timeBetweenTries time.Duration +} + +// Continuously refreshes the token of a key flow, retrying if the token API returns 5xx errrors. Always returns with a non-nil error. +// +// To terminate this routine, close the context in refresher.keyFlow.config.BackgroundTokenRefreshContext. +func (refresher *continuousTokenRefresher) continuousRefreshToken() error { + expirationTimestamp, err := refresher.getAccessTokenExpirationTimestamp() + if err != nil { + return fmt.Errorf("get access token expiration timestamp: %w", err) + } + startRefreshTimestamp := expirationTimestamp.Add(-refresher.timeStartBeforeTokenExpiration) + + for { + err = refresher.waitUntilTimestamp(startRefreshTimestamp) + if err != nil { + return err + } + + err := refresher.keyFlow.config.BackgroundTokenRefreshContext.Err() + if err != nil { + return fmt.Errorf("check context: %w", err) + } + + ok, err := refresher.refreshToken() + if err != nil { + return fmt.Errorf("refresh tokens: %w", err) + } + if !ok { + startRefreshTimestamp = startRefreshTimestamp.Add(refresher.timeBetweenTries) + continue + } + + expirationTimestamp, err := refresher.getAccessTokenExpirationTimestamp() + if err != nil { + return fmt.Errorf("get access token expiration timestamp: %w", err) + } + startRefreshTimestamp = expirationTimestamp.Add(-refresher.timeStartBeforeTokenExpiration) + } +} + +func (refresher *continuousTokenRefresher) getAccessTokenExpirationTimestamp() (*time.Time, error) { + token := refresher.keyFlow.token.AccessToken + + // We can safely use ParseUnverified because we are not doing authentication of any kind + // We're just checking the expiration time + tokenParsed, _, err := jwt.NewParser().ParseUnverified(token, &jwt.RegisteredClaims{}) + if err != nil { + return nil, fmt.Errorf("parse token: %w", err) + } + expirationTimestampNumeric, err := tokenParsed.Claims.GetExpirationTime() + if err != nil { + return nil, fmt.Errorf("get expiration timestamp: %w", err) + } + return &expirationTimestampNumeric.Time, nil +} + +func (refresher *continuousTokenRefresher) waitUntilTimestamp(timestamp time.Time) error { + for time.Now().Before(timestamp) { + err := refresher.keyFlow.config.BackgroundTokenRefreshContext.Err() + if err != nil { + return fmt.Errorf("check context: %w", err) + } + time.Sleep(refresher.timeBetweenContextCheck) + } + return nil +} + +// Returns: +// - (true, nil) if successful. +// - (false, nil) if not successful but should be retried. +// - (_, err) if not successful and shouldn't be retried. +func (refresher *continuousTokenRefresher) refreshToken() (bool, error) { + err := refresher.keyFlow.createAccessTokenWithRefreshToken() + if err == nil { + return true, nil + } + + // Should be retried if this is an API error with status code 5xx + oapiErr := &oapierror.GenericOpenAPIError{} + if !errors.As(err, &oapiErr) { + return false, err + } + if oapiErr.StatusCode < 500 { + return false, err + } + return false, nil +} diff --git a/core/clients/key_flow_continuous_refresh_test.go b/core/clients/key_flow_continuous_refresh_test.go new file mode 100644 index 000000000..d439f91cb --- /dev/null +++ b/core/clients/key_flow_continuous_refresh_test.go @@ -0,0 +1,356 @@ +package clients + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" + + "github.com/stackitcloud/stackit-sdk-go/core/oapierror" +) + +func TestContinuousRefreshToken(t *testing.T) { + // The times here are in the order of miliseconds (so they run faster) + // For this to work, we need to increase precision of the expiration timestamps + jwt.TimePrecision = time.Millisecond + + // Refresher settings + timeStartBeforeTokenExpiration := 100 * time.Millisecond + timeBetweenContextCheck := 5 * time.Millisecond + timeBetweenTries := 40 * time.Millisecond + + // All generated acess tokens will have this time to live + accessTokensTimeToLive := 200 * time.Millisecond + + tests := []struct { + desc string + contextClosesIn time.Duration + doError error + expectedNumberDoCalls int + }{ + { + desc: "update access token once", + contextClosesIn: 150 * time.Millisecond, + expectedNumberDoCalls: 1, + }, + { + desc: "update access token twice", + contextClosesIn: 250 * time.Millisecond, + expectedNumberDoCalls: 2, + }, + { + desc: "context canceled at start", + contextClosesIn: 0, + expectedNumberDoCalls: 0, + }, + { + desc: "context canceled before start", + contextClosesIn: -150 * time.Millisecond, + expectedNumberDoCalls: 0, + }, + { + desc: "context canceled before token refresh", + contextClosesIn: 50 * time.Millisecond, + expectedNumberDoCalls: 0, + }, + { + desc: "refresh token fails - non-API error", + contextClosesIn: 250 * time.Millisecond, + doError: fmt.Errorf("something went wrong"), + expectedNumberDoCalls: 1, + }, + { + desc: "refresh token fails - API non-5xx error", + contextClosesIn: 250 * time.Millisecond, + doError: &oapierror.GenericOpenAPIError{ + StatusCode: http.StatusBadRequest, + }, + expectedNumberDoCalls: 1, + }, + { + desc: "refresh token fails - API 5xx error", + contextClosesIn: 200 * time.Millisecond, + doError: &oapierror.GenericOpenAPIError{ + StatusCode: http.StatusInternalServerError, + }, + expectedNumberDoCalls: 3, + }, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + numberDoCalls := 0 + mockDo := func(client *http.Client, req *http.Request, cfg *RetryConfig) (resp *http.Response, err error) { + numberDoCalls++ + + if tt.doError != nil { + return nil, tt.doError + } + + accessToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(accessTokensTimeToLive)), + }).SignedString([]byte("test")) + if err != nil { + t.Fatalf("Do call: failed to create access token: %v", err) + } + + responseBodyStruct := TokenResponseBody{ + AccessToken: accessToken, + } + responseBody, err := json.Marshal(responseBodyStruct) + if err != nil { + t.Fatalf("Do call: failed to marshal response: %v", err) + } + response := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader(responseBody)), + } + return response, nil + } + + accessToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(accessTokensTimeToLive)), + }).SignedString([]byte("test")) + if err != nil { + t.Fatalf("failed to create access token: %v", err) + } + + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, tt.contextClosesIn) + defer cancel() + + keyFlow := &KeyFlow{ + config: &KeyFlowConfig{ + ClientRetry: NewRetryConfig(), + BackgroundTokenRefreshContext: ctx, + }, + doer: mockDo, + token: &TokenResponseBody{ + AccessToken: accessToken, + }, + } + + refresher := &continuousTokenRefresher{ + keyFlow: keyFlow, + timeStartBeforeTokenExpiration: timeStartBeforeTokenExpiration, + timeBetweenContextCheck: timeBetweenContextCheck, + timeBetweenTries: timeBetweenTries, + } + + err = refresher.continuousRefreshToken() + if err == nil { + t.Fatalf("routine finished with non-nil error") + } + if numberDoCalls != tt.expectedNumberDoCalls { + t.Fatalf("expected %d calls to API to refresh token, got %d", tt.expectedNumberDoCalls, numberDoCalls) + } + }) + } +} + +// Tests if +// - continuousRefreshToken() changes the token +// - The access token can be accessed while continuousRefreshToken() is trying to update it +func TestContinuousRefreshTokenConcurrency(t *testing.T) { + // The times here are in the order of miliseconds (so they run faster) + // For this to work, we need to increase precision of the expiration timestamps + jwt.TimePrecision = time.Millisecond + + // Test plan: + // 1) continuousRefreshToken() will trigger a token update. It will be blocked in the mockDo() routine (defined below) + // 2) After continuousRefreshToken() is blocked, a request will be made using the key flow. That request should carry the access token (shouldn't be blocked just because continuousRefreshToken() is trying to refresh the token) + // 3) After the request is successful, continuousRefreshToken() will be unblocked + // 4) After waiting a bit, a new request will be made using the key flow. That request should carry the new access token + + // Where we're at in the test plan: + // - Starts at 0 + // - Is set to 1 before continuousRefreshToken() is called + // - Is set to 2 once the continuousRefreshToken() is blocked + // - Is set to 3 once the first request goes through and is checked + // - Is set to 4 after a small wait after continuousRefreshToken() is unblocked + currentTestPhase := 0 + + // Used to signal continuousRefreshToken() has been blocked + chanBlockContinuousRefreshToken := make(chan bool) + + // Used to signal continuousRefreshToken() should be unblocked + chanUnblockContinuousRefreshToken := make(chan bool) + + // The access token at the start + accessTokenFirst, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(100 * time.Millisecond)), + }).SignedString([]byte("token-first")) + if err != nil { + t.Fatalf("failed to create first access token: %v", err) + } + + // The access token that will replace accessTokenFirst + // Has a much longer expiration timestamp + accessTokenSecond, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), + }).SignedString([]byte("token-second")) + if err != nil { + t.Fatalf("failed to create second access token: %v", err) + } + + if accessTokenFirst == accessTokenSecond { + t.Fatalf("created tokens are equal") + } + + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + defer cancel() // This cancels the refresher goroutine + + // The Do() routine, that both the keyFlow and continuousRefreshToken() use to make their requests + // The bools are used to make sure only one request goes through on each test phase + doTestPhase1RequestDone := false + doTestPhase2RequestDone := false + doTestPhase4RequestDone := false + mockDo := func(client *http.Client, req *http.Request, cfg *RetryConfig) (resp *http.Response, err error) { + switch currentTestPhase { + default: + t.Fatalf("Do call: unexpected request during test phase %d", currentTestPhase) + return nil, nil + case 1: // Call by continuousRefreshToken() + if doTestPhase1RequestDone { + t.Fatalf("Do call: multiple requests during test phase 1") + } + doTestPhase1RequestDone = true + + currentTestPhase = 2 + chanBlockContinuousRefreshToken <- true + + // Wait until continuousRefreshToken() is to be unblocked + <-chanUnblockContinuousRefreshToken + + if currentTestPhase != 3 { + t.Fatalf("Do call: after unlocking refreshToken(), expected test phase to be 3, got %d", currentTestPhase) + } + + // Return response with accessTokenSecond + responseBodyStruct := TokenResponseBody{ + AccessToken: accessTokenSecond, + } + responseBody, err := json.Marshal(responseBodyStruct) + if err != nil { + t.Fatalf("Do call: failed to marshal access token response: %v", err) + } + response := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader(responseBody)), + } + return response, nil + case 2: // Call by tokenFlow, first request + if doTestPhase2RequestDone { + t.Fatalf("Do call: multiple requests during test phase 2") + } + doTestPhase2RequestDone = true + + // Check host and access token + host := req.Host + expectedHost := "first-request-url.com" + if host != expectedHost { + t.Fatalf("Do call: first request expected to have host %q, found %q", expectedHost, host) + } + authHeader := req.Header.Get("Authorization") + if authHeader != fmt.Sprintf("Bearer %s", accessTokenFirst) { + t.Fatalf("Do call: first request didn't carry first access token") + } + + // Return empty response + response := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader([]byte{})), + } + return response, nil + case 4: // Call by tokenFlow, second request + if doTestPhase4RequestDone { + t.Fatalf("Do call: multiple requests during test phase 4") + } + doTestPhase4RequestDone = true + + // Check host and access token + host := req.Host + expectedHost := "second-request-url.com" + if host != expectedHost { + t.Fatalf("Do call: second request expected to have host %q, found %q", expectedHost, host) + } + authHeader := req.Header.Get("Authorization") + if authHeader != fmt.Sprintf("Bearer %s", accessTokenSecond) { + t.Fatalf("Do call: second request didn't carry second access token") + } + + // Return empty response + response := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader([]byte{})), + } + return response, nil + } + } + + keyFlow := &KeyFlow{ + client: &http.Client{}, + config: &KeyFlowConfig{ + ClientRetry: NewRetryConfig(), + BackgroundTokenRefreshContext: ctx, + }, + doer: mockDo, + token: &TokenResponseBody{ + AccessToken: accessTokenFirst, + }, + } + + // TEST START + currentTestPhase = 1 + go continuousRefreshToken(keyFlow) + + // Wait until continuousRefreshToken() is blocked + <-chanBlockContinuousRefreshToken + + if currentTestPhase != 2 { + t.Fatalf("Unexpected test phase %d after continuousRefreshToken() was blocked", currentTestPhase) + } + + // Perform first request + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://first-request-url.com", http.NoBody) + if err != nil { + t.Fatalf("Create first request failed: %v", err) + } + resp, err := keyFlow.RoundTrip(req) + if err != nil { + t.Fatalf("Perform first request failed: %v", err) + } + err = resp.Body.Close() + if err != nil { + t.Fatalf("First request body failed to close: %v", err) + } + + // Unblock continuousRefreshToken() + currentTestPhase = 3 + chanUnblockContinuousRefreshToken <- true + + // Wait for a bit + time.Sleep(10 * time.Millisecond) + currentTestPhase = 4 + + // Perform second request + req, err = http.NewRequestWithContext(ctx, http.MethodGet, "http://second-request-url.com", http.NoBody) + if err != nil { + t.Fatalf("Create second request failed: %v", err) + } + resp, err = keyFlow.RoundTrip(req) + if err != nil { + t.Fatalf("Second request failed: %v", err) + } + err = resp.Body.Close() + if err != nil { + t.Fatalf("Second request body failed to close: %v", err) + } +} diff --git a/core/clients/key_flow_test.go b/core/clients/key_flow_test.go index 30c744bac..afca6d106 100644 --- a/core/clients/key_flow_test.go +++ b/core/clients/key_flow_test.go @@ -182,12 +182,7 @@ func TestSetToken(t *testing.T) { } } -func TestKeyFlowValidateToken(t *testing.T) { - // Generate a random private key - privateKey := make([]byte, 32) - if _, err := rand.Read(privateKey); err != nil { - t.Fatal(err) - } +func TestTokenExpired(t *testing.T) { tests := []struct { desc string tokenInvalid bool diff --git a/core/config/config.go b/core/config/config.go index c490ddcae..aa4e33dbb 100644 --- a/core/config/config.go +++ b/core/config/config.go @@ -85,6 +85,12 @@ type Configuration struct { HTTPClient *http.Client RetryOptions *clients.RetryConfig + // If != nil, a goroutine will be launched that will refresh the service account's access token when it's close to being expired. + // The goroutine is killed whenever this context is canceled. + // + // Only has effect for key flow + BackgroundTokenRefreshContext context.Context + // Deprecated: validation using JWKS was removed, for being redundant with token validation done in the APIs. This field has no effect, and will be removed in a later update JWKSCustomUrl string `json:"jwksCustomUrl,omitempty"` @@ -283,6 +289,22 @@ func WithJar(jar http.CookieJar) ConfigurationOption { } } +// WithBackgroundTokenRefresh returns a ConfigurationOption that enables access token refreshing in backgound. +// +// If enabled, a goroutine will be launched that will refresh the service account's access token when it's close to being expired. +// The goroutine is killed whenever the given context is canceled. +// +// Only has effect for key flow +func WithBackgroundTokenRefresh(ctx context.Context) ConfigurationOption { + return func(c *Configuration) error { + if ctx == nil { + return fmt.Errorf("context for token refresh in background cannot be empty") + } + c.BackgroundTokenRefreshContext = ctx + return nil + } +} + // WithCustomConfiguration returns a ConfigurationOption that sets a custom Configuration func WithCustomConfiguration(cfg *Configuration) ConfigurationOption { return func(config *Configuration) error { @@ -305,6 +327,7 @@ func WithCustomConfiguration(cfg *Configuration) ConfigurationOption { config.setCustomEndpoint = (len(cfg.Servers) > 0) config.OperationServers = cfg.OperationServers config.HTTPClient = cfg.HTTPClient + config.BackgroundTokenRefreshContext = cfg.BackgroundTokenRefreshContext return nil } }