diff --git a/CHANGELOG.md b/CHANGELOG.md index 6f7fe7f60..9c46081bc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,8 @@ ## Release (2025-XX-YY) +- `core`: + - [v0.21.0](core/CHANGELOG.md#v0210) + - **Chore:** Use `jwt-bearer` grant to get a fresh token instead of `refresh_token` + - **Feature:** Support Workload Identity Federation flow - `scf`: [v0.3.0](services/scf/CHANGELOG.md#v030) - **Feature:** Add new model `IsolationSegment` and `IsolationSegmentsList` - `iaas`: diff --git a/README.md b/README.md index 1260da59c..61d6894b4 100644 --- a/README.md +++ b/README.md @@ -105,13 +105,20 @@ To authenticate with the SDK, you need a [service account](https://docs.stackit. The SDK supports two authentication methods: -1. **Key Flow** (Recommended) +1. **Workload Identity Federation Flow** (Recommended) + + - Uses OIDC trusted tokens + - Provides best security through short-lived tokens without secrets + +> NOTE: This flow isn't publicly available yet. It'll be public during Q1 2026 + +2. **Key Flow** (Recommended) - Uses RSA key-pair based authentication - Provides better security through short-lived tokens - Supports both STACKIT-generated and custom key pairs -2. **Token Flow** +3. **Token Flow** - Uses long-lived service account tokens - Simpler but less secure @@ -120,10 +127,42 @@ The SDK supports two authentication methods: The SDK searches for credentials in the following order: 1. Explicit configuration in code -2. Environment variables (KEY_PATH for KEY) +2. Environment variables 3. Credentials file (`$HOME/.stackit/credentials.json`) -For each authentication method, the key flow is attempted first, followed by the token flow. +For each authentication method, the try order is: +1. Workload Identity Federation Flow +2. Key Flow +3. Token Flow + +### Using the Workload Identity Fedearion Flow + +1. Create a service account trusted relation in the STACKIT Portal: + + - Navigate to `Service Accounts` → Select account → `Federated Identity Providers` → Add a Federated Identity Provider + - Configure the trusted issuer and the required assertions to trust in. (Link to official docs here after GA) + +2. Configure authentication using any of these methods: + + **A. Code Configuration** + + ```go + // Using wokload identity federation flow + config.WithWorkloadIdentityFederationAuth() + // With the custom path for the external OIDC token + config.WithWorkloadIdentityFederationTokenPath("/path/to/your/federated/token") + // For the service account + config.WithServiceAccountEmail("my-sa@sa-stackit.cloud") + ``` + + **B. Environment Variables** + + ```bash + # With the custom path for the external OIDC token + STACKIT_FEDERATED_TOKEN_FILE=/path/to/your/federated/token + # For the service account + STACKIT_SERVICE_ACCOUNT_EMAIL=my-sa@sa-stackit.cloud + ``` ### Using the Key Flow diff --git a/core/CHANGELOG.md b/core/CHANGELOG.md index 8b1d2fb86..aaaa0636c 100644 --- a/core/CHANGELOG.md +++ b/core/CHANGELOG.md @@ -1,3 +1,7 @@ +## v0.21.0 +- **Chore:** Use `jwt-bearer` grant to get a fresh token instead of `refresh_token` +- **Feature:** Support Workload Identity Federation flow + ## v0.20.1 - **Improvement:** Improve error message when passing a PEM encoded file to as service account key @@ -9,6 +13,7 @@ ## v0.18.0 - **New:** Added duration utils +- **Chore:** Use `jwt-bearer` grant to get a fresh token instead of `refresh_token` ## v0.17.3 - **Dependencies:** Bump `github.com/golang-jwt/jwt/v5` from `v5.2.2` to `v5.2.3` diff --git a/core/VERSION b/core/VERSION index 2c80271d5..759e855fb 100644 --- a/core/VERSION +++ b/core/VERSION @@ -1 +1 @@ -v0.20.1 +v0.21.0 diff --git a/core/auth/auth.go b/core/auth/auth.go index 568847aea..e3b10bc46 100644 --- a/core/auth/auth.go +++ b/core/auth/auth.go @@ -51,6 +51,12 @@ func SetupAuth(cfg *config.Configuration) (rt http.RoundTripper, err error) { return nil, fmt.Errorf("configuring no auth client: %w", err) } return noAuthRoundTripper, nil + } else if cfg.WorkloadIdentityFederation { + wifRoundTripper, err := WorkloadIdentityFederationAuth(cfg) + if err != nil { + return nil, fmt.Errorf("configuring no auth client: %w", err) + } + return wifRoundTripper, nil } else if cfg.ServiceAccountKey != "" || cfg.ServiceAccountKeyPath != "" { keyRoundTripper, err := KeyAuth(cfg) if err != nil { @@ -84,14 +90,18 @@ func DefaultAuth(cfg *config.Configuration) (rt http.RoundTripper, err error) { cfg = &config.Configuration{} } - // Key flow - rt, err = KeyAuth(cfg) + // WIF flow + rt, err = WorkloadIdentityFederationAuth(cfg) if err != nil { - keyFlowErr := err - // Token flow - rt, err = TokenAuth(cfg) + // Key flow + rt, err = KeyAuth(cfg) if err != nil { - return nil, fmt.Errorf("no valid credentials were found: trying key flow: %s, trying token flow: %w", keyFlowErr.Error(), err) + keyFlowErr := err + // Token flow + rt, err = TokenAuth(cfg) + if err != nil { + return nil, fmt.Errorf("no valid credentials were found: trying key flow: %s, trying token flow: %w", keyFlowErr.Error(), err) + } } } return rt, nil @@ -221,6 +231,29 @@ func KeyAuth(cfg *config.Configuration) (http.RoundTripper, error) { return client, nil } +// WorkloadIdentityFederationAuth configures the wif flow and returns an http.RoundTripper +// that can be used to make authenticated requests using an access token +func WorkloadIdentityFederationAuth(cfg *config.Configuration) (http.RoundTripper, error) { + wifConfig := clients.WorkloadIdentityFederationFlowConfig{ + TokenUrl: cfg.TokenCustomUrl, + BackgroundTokenRefreshContext: cfg.BackgroundTokenRefreshContext, + ClientID: cfg.ServiceAccountEmail, + FederatedTokenFilePath: cfg.WorkloadIdentityFederationFederatedTokenPath, + TokenExpiration: cfg.WorkloadIdentityFederationTokenExpiration, + } + + if cfg.HTTPClient != nil && cfg.HTTPClient.Transport != nil { + wifConfig.HTTPTransport = cfg.HTTPClient.Transport + } + + client := &clients.WorkloadIdentityFederationFlow{} + if err := client.Init(&wifConfig); err != nil { + return nil, fmt.Errorf("error initializing client: %w", err) + } + + return client, nil +} + // readCredentialsFile reads the credentials file from the specified path and returns Credentials func readCredentialsFile(path string) (*Credentials, error) { if path == "" { diff --git a/core/auth/auth_test.go b/core/auth/auth_test.go index a7c776946..b861bf581 100644 --- a/core/auth/auth_test.go +++ b/core/auth/auth_test.go @@ -13,6 +13,7 @@ import ( "testing" "time" + "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" "github.com/stackitcloud/stackit-sdk-go/core/clients" "github.com/stackitcloud/stackit-sdk-go/core/config" @@ -121,6 +122,32 @@ func TestSetupAuth(t *testing.T) { } }() + // create a wif assertion file + wifAssertionFile, errs := os.CreateTemp("", "temp-*.txt") + if errs != nil { + t.Fatalf("Creating temporary file: %s", err) + } + defer func() { + _ = wifAssertionFile.Close() + err := os.Remove(wifAssertionFile.Name()) + if err != nil { + t.Fatalf("Removing temporary file: %s", err) + } + }() + + token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)), + Subject: "sub", + }).SignedString([]byte("test")) + if err != nil { + t.Fatalf("Removing temporary file: %s", err) + } + + _, errs = wifAssertionFile.WriteString(string(token)) + if errs != nil { + t.Fatalf("Writing wif assertion to temporary file: %s", err) + } + // create a credentials file with saKey and private key credentialsKeyFile, errs := os.CreateTemp("", "temp-*.txt") if errs != nil { @@ -147,12 +174,19 @@ func TestSetupAuth(t *testing.T) { desc string config *config.Configuration setToken bool + setWorkloadIdentity bool setKeys bool setKeyPaths bool setCredentialsFilePathToken bool setCredentialsFilePathKey bool isValid bool }{ + { + desc: "wif_config", + config: nil, + setWorkloadIdentity: true, + isValid: true, + }, { desc: "token_config", config: nil, @@ -241,6 +275,12 @@ func TestSetupAuth(t *testing.T) { t.Setenv("STACKIT_CREDENTIALS_PATH", "") } + if test.setWorkloadIdentity { + t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", wifAssertionFile.Name()) + } else { + t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", "") + } + t.Setenv("STACKIT_SERVICE_ACCOUNT_EMAIL", "test-email") authRoundTripper, err := SetupAuth(test.config) @@ -253,7 +293,7 @@ func TestSetupAuth(t *testing.T) { t.Fatalf("Test didn't return error on invalid test case") } - if test.isValid && authRoundTripper == nil { + if authRoundTripper == nil && test.isValid { t.Fatalf("Roundtripper returned is nil for valid test case") } }) @@ -381,6 +421,32 @@ func TestDefaultAuth(t *testing.T) { t.Fatalf("Writing private key to temporary file: %s", err) } + // create a wif assertion file + wifAssertionFile, errs := os.CreateTemp("", "temp-*.txt") + if errs != nil { + t.Fatalf("Creating temporary file: %s", err) + } + defer func() { + _ = wifAssertionFile.Close() + err := os.Remove(wifAssertionFile.Name()) + if err != nil { + t.Fatalf("Removing temporary file: %s", err) + } + }() + + token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)), + Subject: "sub", + }).SignedString([]byte("test")) + if err != nil { + t.Fatalf("Removing temporary file: %s", err) + } + + _, errs = wifAssertionFile.WriteString(string(token)) + if errs != nil { + t.Fatalf("Writing wif assertion to temporary file: %s", err) + } + // create a credentials file with saKey and private key credentialsKeyFile, errs := os.CreateTemp("", "temp-*.txt") if errs != nil { @@ -409,6 +475,7 @@ func TestDefaultAuth(t *testing.T) { setKeyPaths bool setKeys bool setCredentialsFilePathKey bool + setWorkloadIdentity bool isValid bool expectedFlow string }{ @@ -418,6 +485,14 @@ func TestDefaultAuth(t *testing.T) { isValid: true, expectedFlow: "token", }, + { + desc: "wif_precedes_key_precedes_token", + setToken: true, + setKeyPaths: true, + setWorkloadIdentity: true, + isValid: true, + expectedFlow: "wif", + }, { desc: "key_precedes_token", setToken: true, @@ -475,6 +550,13 @@ func TestDefaultAuth(t *testing.T) { } else { t.Setenv("STACKIT_SERVICE_ACCOUNT_TOKEN", "") } + + if test.setWorkloadIdentity { + t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", wifAssertionFile.Name()) + } else { + t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", "") + } + t.Setenv("STACKIT_SERVICE_ACCOUNT_EMAIL", "test-email") // Get the default authentication client and ensure that it's not nil @@ -501,6 +583,10 @@ func TestDefaultAuth(t *testing.T) { if _, ok := authClient.(*clients.KeyFlow); !ok { t.Fatalf("Expected key flow, got %s", reflect.TypeOf(authClient)) } + case "wif": + if _, ok := authClient.(*clients.WorkloadIdentityFederationFlow); !ok { + t.Fatalf("Expected key flow, got %s", reflect.TypeOf(authClient)) + } } } }) diff --git a/core/clients/auth_flow.go b/core/clients/auth_flow.go new file mode 100644 index 000000000..141d75489 --- /dev/null +++ b/core/clients/auth_flow.go @@ -0,0 +1,84 @@ +package clients + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/stackitcloud/stackit-sdk-go/core/oapierror" +) + +const ( + defaultTokenExpirationLeeway = time.Second * 5 +) + +type AuthFlow interface { + RoundTrip(req *http.Request) (*http.Response, error) + GetAccessToken() (string, error) + GetBackgroundTokenRefreshContext() context.Context +} + +// TokenResponseBody is the API response +// when requesting a new token +type TokenResponseBody struct { + AccessToken string `json:"access_token"` + ExpiresIn int `json:"expires_in"` + Scope string `json:"scope"` + TokenType string `json:"token_type"` +} + +func parseTokenResponse(res *http.Response) (*TokenResponseBody, error) { + if res == nil { + return nil, fmt.Errorf("received bad response from API") + } + if res.StatusCode != http.StatusOK { + body, err := io.ReadAll(res.Body) + if err != nil { + // Fail silently, omit body from error + // We're trying to show error details, so it's unnecessary to fail because of this err + body = []byte{} + } + return nil, &oapierror.GenericOpenAPIError{ + StatusCode: res.StatusCode, + Body: body, + } + } + body, err := io.ReadAll(res.Body) + if err != nil { + return nil, err + } + + token := &TokenResponseBody{} + err = json.Unmarshal(body, token) + if err != nil { + return nil, fmt.Errorf("unmarshal token response: %w", err) + } + return token, nil +} + +func tokenExpired(token string, tokenExpirationLeeway time.Duration) (bool, error) { + if token == "" { + return true, nil + } + + // We can safely use ParseUnverified because we are not authenticating the user at this point. + // We're just checking the expiration time + tokenParsed, _, err := jwt.NewParser().ParseUnverified(token, &jwt.RegisteredClaims{}) + if err != nil { + return false, fmt.Errorf("parse token: %w", err) + } + + expirationTimestampNumeric, err := tokenParsed.Claims.GetExpirationTime() + if err != nil { + return false, fmt.Errorf("get expiration timestamp: %w", err) + } + + // Pretend to be `tokenExpirationLeeway` into the future to avoid token expiring + // between retrieving the token and upstream systems validating it. + now := time.Now().Add(tokenExpirationLeeway) + return now.After(expirationTimestampNumeric.Time), nil +} diff --git a/core/clients/key_flow.go b/core/clients/key_flow.go index 589774314..d18d4f0bf 100644 --- a/core/clients/key_flow.go +++ b/core/clients/key_flow.go @@ -4,11 +4,9 @@ import ( "context" "crypto/rsa" "crypto/x509" - "encoding/json" "encoding/pem" "errors" "fmt" - "io" "net/http" "net/url" "regexp" @@ -30,12 +28,10 @@ const ( ServiceAccountKeyPath = "STACKIT_SERVICE_ACCOUNT_KEY_PATH" PrivateKeyPath = "STACKIT_PRIVATE_KEY_PATH" tokenAPI = "https://service-account.api.stackit.cloud/token" //nolint:gosec // linter false positive - defaultTokenType = "Bearer" - defaultScope = "" - - defaultTokenExpirationLeeway = time.Second * 5 ) +var _ AuthFlow = &KeyFlow{} + // KeyFlow handles auth with SA key type KeyFlow struct { rt http.RoundTripper @@ -65,16 +61,6 @@ type KeyFlowConfig struct { AuthHTTPClient *http.Client } -// TokenResponseBody is the API response -// when requesting a new token -type TokenResponseBody struct { - AccessToken string `json:"access_token"` - ExpiresIn int `json:"expires_in"` - RefreshToken string `json:"refresh_token"` - Scope string `json:"scope"` - TokenType string `json:"token_type"` -} - // ServiceAccountKeyResponse is the API response // when creating a new SA key type ServiceAccountKeyResponse struct { @@ -114,6 +100,7 @@ func (c *KeyFlow) GetServiceAccountEmail() string { } // GetToken returns the token field +// Deprecated: Use GetAccessToken instead func (c *KeyFlow) GetToken() TokenResponseBody { c.tokenMutex.RLock() defer c.tokenMutex.RUnlock() @@ -160,6 +147,7 @@ func (c *KeyFlow) Init(cfg *KeyFlowConfig) error { // SetToken can be used to set an access and refresh token manually in the client. // The other fields in the token field are determined by inspecting the token or setting default values. +// Deprecated func (c *KeyFlow) SetToken(accessToken, refreshToken string) error { // We can safely use ParseUnverified because we are not authenticating the user, // We are parsing the token just to get the expiration time claim @@ -174,11 +162,10 @@ func (c *KeyFlow) SetToken(accessToken, refreshToken string) error { c.tokenMutex.Lock() c.token = &TokenResponseBody{ - AccessToken: accessToken, - ExpiresIn: int(exp.Time.Unix()), - Scope: defaultScope, - RefreshToken: refreshToken, - TokenType: defaultTokenType, + AccessToken: accessToken, + ExpiresIn: int(exp.Time.Unix()), + Scope: "", + TokenType: "Bearer", } c.tokenMutex.Unlock() return nil @@ -198,12 +185,11 @@ func (c *KeyFlow) RoundTrip(req *http.Request) (*http.Response, error) { return c.rt.RoundTrip(req) } -// GetAccessToken returns a short-lived access token and saves the access and refresh tokens in the token field +// GetAccessToken returns a short-lived access token and saves the access token in the token field func (c *KeyFlow) GetAccessToken() (string, error) { if c.rt == nil { return "", fmt.Errorf("nil http round tripper, please run Init()") } - var accessToken string c.tokenMutex.RLock() @@ -219,7 +205,7 @@ func (c *KeyFlow) GetAccessToken() (string, error) { if !accessTokenExpired { return accessToken, nil } - if err = c.recreateAccessToken(); err != nil { + if err = c.createAccessToken(); err != nil { var oapiErr *oapierror.GenericOpenAPIError if ok := errors.As(err, &oapiErr); ok { reg := regexp.MustCompile("Key with kid .*? was not found") @@ -237,6 +223,10 @@ func (c *KeyFlow) GetAccessToken() (string, error) { return accessToken, nil } +func (c *KeyFlow) GetBackgroundTokenRefreshContext() context.Context { + return c.config.BackgroundTokenRefreshContext +} + // validate the client is configured well func (c *KeyFlow) validate() error { if c.config.ServiceAccountKey == nil { @@ -269,27 +259,6 @@ func (c *KeyFlow) validate() error { // Flow auth functions -// recreateAccessToken is used to create a new access token -// when the existing one isn't valid anymore -func (c *KeyFlow) recreateAccessToken() error { - var refreshToken string - - c.tokenMutex.RLock() - if c.token != nil { - refreshToken = c.token.RefreshToken - } - c.tokenMutex.RUnlock() - - refreshTokenExpired, err := tokenExpired(refreshToken, c.tokenExpirationLeeway) - if err != nil { - return err - } - if !refreshTokenExpired { - return c.createAccessTokenWithRefreshToken() - } - return c.createAccessToken() -} - // createAccessToken creates an access token using self signed JWT func (c *KeyFlow) createAccessToken() (err error) { grant := "urn:ietf:params:oauth:grant-type:jwt-bearer" @@ -307,27 +276,14 @@ func (c *KeyFlow) createAccessToken() (err error) { err = fmt.Errorf("close request access token response: %w", tempErr) } }() - return c.parseTokenResponse(res) -} - -// createAccessTokenWithRefreshToken creates an access token using -// an existing pre-validated refresh token -func (c *KeyFlow) createAccessTokenWithRefreshToken() (err error) { - c.tokenMutex.RLock() - refreshToken := c.token.RefreshToken - c.tokenMutex.RUnlock() - - res, err := c.requestToken("refresh_token", refreshToken) + token, err := parseTokenResponse(res) if err != nil { return err } - defer func() { - tempErr := res.Body.Close() - if tempErr != nil && err == nil { - err = fmt.Errorf("close request access token with refresh token response: %w", tempErr) - } - }() - return c.parseTokenResponse(res) + c.tokenMutex.Lock() + c.token = token + c.tokenMutex.Unlock() + return nil } // generateSelfSignedJWT generates JWT token @@ -338,7 +294,7 @@ func (c *KeyFlow) generateSelfSignedJWT() (string, error) { "jti": uuid.New(), "aud": c.key.Credentials.Aud, "iat": jwt.NewNumericDate(time.Now()), - "exp": jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), + "exp": jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), } token := jwt.NewWithClaims(jwt.SigningMethodRS512, claims) token.Header["kid"] = c.key.Credentials.Kid @@ -353,11 +309,8 @@ func (c *KeyFlow) generateSelfSignedJWT() (string, error) { func (c *KeyFlow) requestToken(grant, assertion string) (*http.Response, error) { body := url.Values{} body.Set("grant_type", grant) - if grant == "refresh_token" { - body.Set("refresh_token", assertion) - } else { - body.Set("assertion", assertion) - } + body.Set("assertion", assertion) + payload := strings.NewReader(body.Encode()) req, err := http.NewRequest(http.MethodPost, c.config.TokenUrl, payload) if err != nil { @@ -367,60 +320,3 @@ func (c *KeyFlow) requestToken(grant, assertion string) (*http.Response, error) return c.authClient.Do(req) } - -// parseTokenResponse parses the response from the server -func (c *KeyFlow) parseTokenResponse(res *http.Response) error { - if res == nil { - return fmt.Errorf("received bad response from API") - } - if res.StatusCode != http.StatusOK { - body, err := io.ReadAll(res.Body) - if err != nil { - // Fail silently, omit body from error - // We're trying to show error details, so it's unnecessary to fail because of this err - body = []byte{} - } - return &oapierror.GenericOpenAPIError{ - StatusCode: res.StatusCode, - Body: body, - } - } - body, err := io.ReadAll(res.Body) - if err != nil { - return err - } - - c.tokenMutex.Lock() - c.token = &TokenResponseBody{} - err = json.Unmarshal(body, c.token) - c.tokenMutex.Unlock() - if err != nil { - return fmt.Errorf("unmarshal token response: %w", err) - } - - return nil -} - -func tokenExpired(token string, tokenExpirationLeeway time.Duration) (bool, error) { - if token == "" { - return true, nil - } - - // We can safely use ParseUnverified because we are not authenticating the user at this point. - // We're just checking the expiration time - tokenParsed, _, err := jwt.NewParser().ParseUnverified(token, &jwt.RegisteredClaims{}) - if err != nil { - return false, fmt.Errorf("parse token: %w", err) - } - - expirationTimestampNumeric, err := tokenParsed.Claims.GetExpirationTime() - if err != nil { - return false, fmt.Errorf("get expiration timestamp: %w", err) - } - - // Pretend to be `tokenExpirationLeeway` into the future to avoid token expiring - // between retrieving the token and upstream systems validating it. - now := time.Now().Add(tokenExpirationLeeway) - - return now.After(expirationTimestampNumeric.Time), nil -} diff --git a/core/clients/key_flow_continuous_refresh.go b/core/clients/key_flow_continuous_refresh.go index f5129aa02..702b3695c 100644 --- a/core/clients/key_flow_continuous_refresh.go +++ b/core/clients/key_flow_continuous_refresh.go @@ -20,9 +20,9 @@ var ( // Continuously refreshes the token of a key flow, retrying if the token API returns 5xx errrors. Writes to stderr when it terminates. // // To terminate this routine, close the context in keyFlow.config.BackgroundTokenRefreshContext. -func continuousRefreshToken(keyflow *KeyFlow) { +func continuousRefreshToken(flow AuthFlow) { refresher := &continuousTokenRefresher{ - keyFlow: keyflow, + flow: flow, timeStartBeforeTokenExpiration: defaultTimeStartBeforeTokenExpiration, timeBetweenContextCheck: defaultTimeBetweenContextCheck, timeBetweenTries: defaultTimeBetweenTries, @@ -32,7 +32,7 @@ func continuousRefreshToken(keyflow *KeyFlow) { } type continuousTokenRefresher struct { - keyFlow *KeyFlow + flow AuthFlow // Token refresh tries start at [Access token expiration timestamp] - [This duration] timeStartBeforeTokenExpiration time.Duration timeBetweenContextCheck time.Duration @@ -46,22 +46,12 @@ func (refresher *continuousTokenRefresher) continuousRefreshToken() error { // Compute timestamp where we'll refresh token // Access token may be empty at this point, we have to check it var startRefreshTimestamp time.Time - var accessToken string - refresher.keyFlow.tokenMutex.RLock() - if refresher.keyFlow.token != nil { - accessToken = refresher.keyFlow.token.AccessToken - } - refresher.keyFlow.tokenMutex.RUnlock() - if accessToken == "" { - startRefreshTimestamp = time.Now() - } else { - expirationTimestamp, err := refresher.getAccessTokenExpirationTimestamp() - if err != nil { - return fmt.Errorf("get access token expiration timestamp: %w", err) - } - startRefreshTimestamp = expirationTimestamp.Add(-refresher.timeStartBeforeTokenExpiration) + expirationTimestamp, err := refresher.getAccessTokenExpirationTimestamp() + if err != nil { + return fmt.Errorf("get access token expiration timestamp: %w", err) } + startRefreshTimestamp = expirationTimestamp.Add(-refresher.timeStartBeforeTokenExpiration) for { err := refresher.waitUntilTimestamp(startRefreshTimestamp) @@ -69,7 +59,7 @@ func (refresher *continuousTokenRefresher) continuousRefreshToken() error { return err } - err = refresher.keyFlow.config.BackgroundTokenRefreshContext.Err() + err = refresher.flow.GetBackgroundTokenRefreshContext().Err() if err != nil { return fmt.Errorf("check context: %w", err) } @@ -92,13 +82,14 @@ func (refresher *continuousTokenRefresher) continuousRefreshToken() error { } func (refresher *continuousTokenRefresher) getAccessTokenExpirationTimestamp() (*time.Time, error) { - refresher.keyFlow.tokenMutex.RLock() - token := refresher.keyFlow.token.AccessToken - refresher.keyFlow.tokenMutex.RUnlock() + accessToken, err := refresher.flow.GetAccessToken() + if err != nil { + return nil, err + } // We can safely use ParseUnverified because we are not doing authentication of any kind // We're just checking the expiration time - tokenParsed, _, err := jwt.NewParser().ParseUnverified(token, &jwt.RegisteredClaims{}) + tokenParsed, _, err := jwt.NewParser().ParseUnverified(accessToken, &jwt.RegisteredClaims{}) if err != nil { return nil, fmt.Errorf("parse token: %w", err) } @@ -111,7 +102,7 @@ func (refresher *continuousTokenRefresher) getAccessTokenExpirationTimestamp() ( func (refresher *continuousTokenRefresher) waitUntilTimestamp(timestamp time.Time) error { for time.Now().Before(timestamp) { - err := refresher.keyFlow.config.BackgroundTokenRefreshContext.Err() + err := refresher.flow.GetBackgroundTokenRefreshContext().Err() if err != nil { return fmt.Errorf("check context: %w", err) } @@ -125,7 +116,7 @@ func (refresher *continuousTokenRefresher) waitUntilTimestamp(timestamp time.Tim // - (false, nil) if not successful but should be retried. // - (_, err) if not successful and shouldn't be retried. func (refresher *continuousTokenRefresher) refreshToken() (bool, error) { - err := refresher.keyFlow.recreateAccessToken() + _, err := refresher.flow.GetAccessToken() if err == nil { return true, nil } diff --git a/core/clients/key_flow_continuous_refresh_test.go b/core/clients/key_flow_continuous_refresh_test.go index 7c7ee9565..cfd50e763 100644 --- a/core/clients/key_flow_continuous_refresh_test.go +++ b/core/clients/key_flow_continuous_refresh_test.go @@ -1,18 +1,13 @@ package clients import ( - "bytes" "context" - "encoding/json" "fmt" - "io" "net/http" - "net/url" "testing" "time" "github.com/golang-jwt/jwt/v5" - "github.com/stackitcloud/stackit-sdk-go/core/oapierror" ) @@ -22,9 +17,9 @@ func TestContinuousRefreshToken(t *testing.T) { jwt.TimePrecision = time.Millisecond // Refresher settings - timeStartBeforeTokenExpiration := 500 * time.Millisecond - timeBetweenContextCheck := 10 * time.Millisecond - timeBetweenTries := 100 * time.Millisecond + timeStartBeforeTokenExpiration := 0 * time.Second + timeBetweenContextCheck := 50 * time.Millisecond + timeBetweenTries := 500 * time.Millisecond // All generated acess tokens will have this time to live accessTokensTimeToLive := 1 * time.Second @@ -34,16 +29,20 @@ func TestContinuousRefreshToken(t *testing.T) { contextClosesIn time.Duration doError error expectedNumberDoCalls int - expectedCallRange []int // Optional: for tests that can have variable call counts }{ + { + desc: "update access token never", + contextClosesIn: 900 * time.Millisecond, // Should allow no refresh + expectedNumberDoCalls: 0, + }, { desc: "update access token once", - contextClosesIn: 700 * time.Millisecond, // Should allow one refresh + contextClosesIn: 1900 * time.Millisecond, // Should allow one refresh expectedNumberDoCalls: 1, }, { desc: "update access token twice", - contextClosesIn: 1300 * time.Millisecond, // Should allow two refreshes + contextClosesIn: 2900 * time.Millisecond, // Should allow two refreshes expectedNumberDoCalls: 2, }, { @@ -62,14 +61,14 @@ func TestContinuousRefreshToken(t *testing.T) { expectedNumberDoCalls: 0, }, { - desc: "refresh token fails - non-API error", - contextClosesIn: 700 * time.Millisecond, + desc: "refresh token fails - error", + contextClosesIn: 1900 * time.Millisecond, doError: fmt.Errorf("something went wrong"), expectedNumberDoCalls: 1, }, { desc: "refresh token fails - API non-5xx error", - contextClosesIn: 700 * time.Millisecond, + contextClosesIn: 1900 * time.Millisecond, doError: &oapierror.GenericOpenAPIError{ StatusCode: http.StatusBadRequest, }, @@ -77,92 +76,35 @@ func TestContinuousRefreshToken(t *testing.T) { }, { desc: "refresh token fails - API 5xx error", - contextClosesIn: 800 * time.Millisecond, + contextClosesIn: 2900 * time.Millisecond, doError: &oapierror.GenericOpenAPIError{ StatusCode: http.StatusInternalServerError, }, - expectedNumberDoCalls: 3, - expectedCallRange: []int{3, 4}, // Allow 3 or 4 calls due to timing race condition + expectedNumberDoCalls: 4, }, } for _, tt := range tests { + tt := tt t.Run(tt.desc, func(t *testing.T) { - accessToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(accessTokensTimeToLive)), - }).SignedString([]byte("test")) - if err != nil { - t.Fatalf("failed to create access token: %v", err) - } - - refreshToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), - }).SignedString([]byte("test")) + t.Parallel() + accessToken, err := signToken(accessTokensTimeToLive) if err != nil { - t.Fatalf("failed to create refresh token: %v", err) - } - - numberDoCalls := 0 - mockDo := func(_ *http.Request) (resp *http.Response, err error) { - numberDoCalls++ // count refresh attempts - if tt.doError != nil { - return nil, tt.doError - } - newAccessToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(accessTokensTimeToLive)), - }).SignedString([]byte("test")) - if err != nil { - t.Fatalf("Do call: failed to create access token: %v", err) - } - responseBodyStruct := TokenResponseBody{ - AccessToken: newAccessToken, - RefreshToken: refreshToken, - } - responseBody, err := json.Marshal(responseBodyStruct) - if err != nil { - t.Fatalf("Do call: failed to marshal response: %v", err) - } - response := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader(responseBody)), - } - return response, nil + t.Fatalf("failed to sign access token: %v", err) } - ctx := context.Background() ctx, cancel := context.WithTimeout(ctx, tt.contextClosesIn) defer cancel() - keyFlow := &KeyFlow{} - privateKeyBytes, err := generatePrivateKey() - if err != nil { - t.Fatalf("Error generating private key: %s", err) - } - keyFlowConfig := &KeyFlowConfig{ - ServiceAccountKey: fixtureServiceAccountKey(), - PrivateKey: string(privateKeyBytes), - AuthHTTPClient: &http.Client{ - Transport: mockTransportFn{mockDo}, - }, - HTTPTransport: mockTransportFn{mockDo}, - BackgroundTokenRefreshContext: nil, + authFlow := &fakeAuthFlow{ + backgroundTokenRefreshContext: ctx, + doError: tt.doError, + accessTokensTimeToLive: accessTokensTimeToLive, + accessToken: accessToken, } - err = keyFlow.Init(keyFlowConfig) - if err != nil { - t.Fatalf("failed to initialize key flow: %v", err) - } - - // Set the token after initialization - err = keyFlow.SetToken(accessToken, refreshToken) - if err != nil { - t.Fatalf("failed to set token: %v", err) - } - - // Set the context for continuous refresh - keyFlow.config.BackgroundTokenRefreshContext = ctx refresher := &continuousTokenRefresher{ - keyFlow: keyFlow, + flow: authFlow, timeStartBeforeTokenExpiration: timeStartBeforeTokenExpiration, timeBetweenContextCheck: timeBetweenContextCheck, timeBetweenTries: timeBetweenTries, @@ -172,315 +114,56 @@ func TestContinuousRefreshToken(t *testing.T) { if err == nil { t.Fatalf("routine finished with non-nil error") } - - // Check if we have a range of expected calls (for timing-sensitive tests) - if tt.expectedCallRange != nil { - if !contains(tt.expectedCallRange, numberDoCalls) { - t.Fatalf("expected %v calls to API to refresh token, got %d", tt.expectedCallRange, numberDoCalls) - } - } else if numberDoCalls != tt.expectedNumberDoCalls { + numberDoCalls := authFlow.getTokenCalls() + if numberDoCalls != tt.expectedNumberDoCalls { t.Fatalf("expected %d calls to API to refresh token, got %d", tt.expectedNumberDoCalls, numberDoCalls) } }) } } -// Tests if -// - continuousRefreshToken() updates access token using the refresh token -// - The access token can be accessed while continuousRefreshToken() is trying to update it -func TestContinuousRefreshTokenConcurrency(t *testing.T) { - // The times here are in the order of miliseconds (so they run faster) - // For this to work, we need to increase precision of the expiration timestamps - jwt.TimePrecision = time.Millisecond - - // Test plan: - // 1) continuousRefreshToken() will trigger a token update. It will be blocked in the mockDo() routine (defined below) - // 2) After continuousRefreshToken() is blocked, a request will be made using the key flow. That request should carry the access token (shouldn't be blocked just because continuousRefreshToken() is trying to refresh the token) - // 3) After the request is successful, continuousRefreshToken() will be unblocked - // 4) After waiting a bit, a new request will be made using the key flow. That request should carry the new access token - - // Where we're at in the test plan: - // - Starts at 0 - // - Is set to 1 before continuousRefreshToken() is called - // - Is set to 2 once the continuousRefreshToken() is blocked - // - Is set to 3 once the first request goes through and is checked - // - Is set to 4 after a small wait after continuousRefreshToken() is unblocked - currentTestPhase := 0 - - // Used to signal continuousRefreshToken() has been blocked - chanBlockContinuousRefreshToken := make(chan bool) - - // Used to signal continuousRefreshToken() should be unblocked - chanUnblockContinuousRefreshToken := make(chan bool) - - // The access token at the start - accessTokenFirst, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(10 * time.Second)), - }).SignedString([]byte("token-first")) - if err != nil { - t.Fatalf("failed to create first access token: %v", err) - } - - // The access token that will replace accessTokenFirst - // Has a much longer expiration timestamp - accessTokenSecond, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), - }).SignedString([]byte("token-second")) - if err != nil { - t.Fatalf("failed to create second access token: %v", err) - } - - if accessTokenFirst == accessTokenSecond { - t.Fatalf("created tokens are equal") - } - - // The refresh token used to update the access token - refreshToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), +func signToken(expiration time.Duration) (string, error) { + return jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(expiration)), }).SignedString([]byte("test")) - if err != nil { - t.Fatalf("failed to create refresh token: %v", err) - } - - ctx := context.Background() - ctx, cancel := context.WithCancel(ctx) - defer cancel() // This cancels the refresher goroutine - - // Extract host from tokenAPI constant for consistency - tokenURL, _ := url.Parse(tokenAPI) - tokenHost := tokenURL.Host - - // The Do() routine, that both the keyFlow and continuousRefreshToken() use to make their requests - // The bools are used to make sure only one request goes through on each test phase - doTestPhase1RequestDone := false - doTestPhase2RequestDone := false - doTestPhase4RequestDone := false - mockDo := func(req *http.Request) (resp *http.Response, err error) { - // Handle auth requests (token refresh) - if req.URL.Host == tokenHost { - switch currentTestPhase { - default: - // After phase 1, allow additional auth requests but don't fail the test - // This handles the continuous nature of the refresh routine - if currentTestPhase > 1 { - // Return a valid response for any additional auth requests - newAccessToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), - }).SignedString([]byte("additional-token")) - if err != nil { - t.Fatalf("Do call: failed to create additional access token: %v", err) - } - responseBodyStruct := TokenResponseBody{ - AccessToken: newAccessToken, - RefreshToken: refreshToken, - } - responseBody, err := json.Marshal(responseBodyStruct) - if err != nil { - t.Fatalf("Do call: failed to marshal additional response: %v", err) - } - response := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader(responseBody)), - } - return response, nil - } - t.Fatalf("Do call: unexpected request during test phase %d", currentTestPhase) - return nil, nil - case 1: // Call by continuousRefreshToken() - if doTestPhase1RequestDone { - t.Fatalf("Do call: multiple requests during test phase 1") - } - doTestPhase1RequestDone = true - - currentTestPhase = 2 - chanBlockContinuousRefreshToken <- true - - // Wait until continuousRefreshToken() is to be unblocked - <-chanUnblockContinuousRefreshToken - - if currentTestPhase != 3 { - t.Fatalf("Do call: after unlocking refreshToken(), expected test phase to be 3, got %d", currentTestPhase) - } - - // Check required fields are passed - err = req.ParseForm() - if err != nil { - t.Fatalf("Do call: failed to parse body form: %v", err) - } - reqGrantType := req.Form.Get("grant_type") - if reqGrantType != "refresh_token" { - t.Fatalf("Do call: failed request to refresh token: call to refresh access expected to have grant type as %q, found %q instead", "refresh_token", reqGrantType) - } - reqRefreshToken := req.Form.Get("refresh_token") - if reqRefreshToken != refreshToken { - t.Fatalf("Do call: failed request to refresh token: call to refresh access token did not have the expected refresh token set") - } - - // Return response with accessTokenSecond - responseBodyStruct := TokenResponseBody{ - AccessToken: accessTokenSecond, - RefreshToken: refreshToken, - } - responseBody, err := json.Marshal(responseBodyStruct) - if err != nil { - t.Fatalf("Do call: failed request to refresh token: marshal access token response: %v", err) - } - response := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader(responseBody)), - } - return response, nil - } - } - - // Handle regular HTTP requests - switch currentTestPhase { - default: - t.Fatalf("Do call: unexpected request during test phase %d", currentTestPhase) - return nil, nil - case 2: // Call by tokenFlow, first request - if doTestPhase2RequestDone { - t.Fatalf("Do call: multiple requests during test phase 2") - } - doTestPhase2RequestDone = true - - // Check host and access token - host := req.Host - expectedHost := "first-request-url.com" - if host != expectedHost { - t.Fatalf("Do call: first request expected to have host %q, found %q", expectedHost, host) - } - authHeader := req.Header.Get("Authorization") - expectedAuthHeader := fmt.Sprintf("Bearer %s", accessTokenFirst) - if authHeader != expectedAuthHeader { - t.Fatalf("Do call: first request didn't carry first access token. Expected: %s, Got: %s", expectedAuthHeader, authHeader) - } - - // Return empty response - response := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader([]byte{})), - } - return response, nil - case 4: // Call by tokenFlow, second request - if doTestPhase4RequestDone { - t.Fatalf("Do call: multiple requests during test phase 4") - } - doTestPhase4RequestDone = true - - // Check host and access token - host := req.Host - expectedHost := "second-request-url.com" - if host != expectedHost { - t.Fatalf("Do call: second request expected to have host %q, found %q", expectedHost, host) - } - authHeader := req.Header.Get("Authorization") - if authHeader != fmt.Sprintf("Bearer %s", accessTokenSecond) { - t.Fatalf("Do call: second request didn't carry second access token") - } - - // Return empty response - response := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader([]byte{})), - } - return response, nil - } - } - - keyFlow := &KeyFlow{} - privateKeyBytes, err := generatePrivateKey() - if err != nil { - t.Fatalf("Error generating private key: %s", err) - } - keyFlowConfig := &KeyFlowConfig{ - ServiceAccountKey: fixtureServiceAccountKey(), - PrivateKey: string(privateKeyBytes), - AuthHTTPClient: &http.Client{ - Transport: mockTransportFn{mockDo}, - }, - HTTPTransport: mockTransportFn{mockDo}, // Use same mock for regular requests - // Don't start continuous refresh automatically - BackgroundTokenRefreshContext: nil, - } - err = keyFlow.Init(keyFlowConfig) - if err != nil { - t.Fatalf("failed to initialize key flow: %v", err) - } - - // Set the token after initialization - err = keyFlow.SetToken(accessTokenFirst, refreshToken) - if err != nil { - t.Fatalf("failed to set token: %v", err) - } - - // Set the context for continuous refresh - keyFlow.config.BackgroundTokenRefreshContext = ctx - - // Create a custom refresher with shorter timing for the test - refresher := &continuousTokenRefresher{ - keyFlow: keyFlow, - timeStartBeforeTokenExpiration: 9 * time.Second, // Start 9 seconds before expiration - timeBetweenContextCheck: 5 * time.Millisecond, - timeBetweenTries: 40 * time.Millisecond, - } - - // TEST START - currentTestPhase = 1 - // Ignore returned error as expected in test - go func() { - _ = refresher.continuousRefreshToken() - }() +} - // Wait until continuousRefreshToken() is blocked - <-chanBlockContinuousRefreshToken +var _ AuthFlow = &fakeAuthFlow{} - if currentTestPhase != 2 { - t.Fatalf("Unexpected test phase %d after continuousRefreshToken() was blocked", currentTestPhase) - } +type fakeAuthFlow struct { + backgroundTokenRefreshContext context.Context + tokenCounter int + doError error + accessTokensTimeToLive time.Duration + accessToken string +} - // Perform first request - req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://first-request-url.com", http.NoBody) - if err != nil { - t.Fatalf("Create first request failed: %v", err) - } - resp, err := keyFlow.RoundTrip(req) +func (f *fakeAuthFlow) RoundTrip(req *http.Request) (*http.Response, error) { + return nil, nil +} +func (f *fakeAuthFlow) GetAccessToken() (string, error) { + expired, err := tokenExpired(f.accessToken, 0) if err != nil { - t.Fatalf("Perform first request failed: %v", err) + return "", err } - err = resp.Body.Close() - if err != nil { - t.Fatalf("First request body failed to close: %v", err) + if !expired { + return f.accessToken, nil } - - // Unblock continuousRefreshToken() - currentTestPhase = 3 - chanUnblockContinuousRefreshToken <- true - - // Wait for a bit - time.Sleep(10 * time.Millisecond) - currentTestPhase = 4 - - // Perform second request - req, err = http.NewRequestWithContext(ctx, http.MethodGet, "http://second-request-url.com", http.NoBody) - if err != nil { - t.Fatalf("Create second request failed: %v", err) - } - resp, err = keyFlow.RoundTrip(req) - if err != nil { - t.Fatalf("Second request failed: %v", err) + f.tokenCounter++ + if f.doError != nil { + return "", f.doError } - err = resp.Body.Close() + accessToken, err := signToken(f.accessTokensTimeToLive) if err != nil { - t.Fatalf("Second request body failed to close: %v", err) + return "", f.doError } + f.accessToken = accessToken + return accessToken, nil +} +func (f *fakeAuthFlow) GetBackgroundTokenRefreshContext() context.Context { + return f.backgroundTokenRefreshContext } -func contains(arr []int, val int) bool { - for _, v := range arr { - if v == val { - return true - } - } - return false +func (f *fakeAuthFlow) getTokenCalls() int { + return f.tokenCounter } diff --git a/core/clients/key_flow_test.go b/core/clients/key_flow_test.go index 9803f24ee..7c094331e 100644 --- a/core/clients/key_flow_test.go +++ b/core/clients/key_flow_test.go @@ -175,11 +175,10 @@ func TestSetToken(t *testing.T) { } if err == nil { expectedKeyFlowToken := &TokenResponseBody{ - AccessToken: accessToken, - ExpiresIn: int(timestamp.Unix()), - RefreshToken: tt.refreshToken, - Scope: defaultScope, - TokenType: defaultTokenType, + AccessToken: accessToken, + ExpiresIn: int(timestamp.Unix()), + Scope: "", + TokenType: "Bearer", } if !cmp.Equal(expectedKeyFlowToken, keyFlow.token) { t.Errorf("The returned result is wrong. Expected %+v, got %+v", expectedKeyFlowToken, keyFlow.token) @@ -194,25 +193,25 @@ func TestTokenExpired(t *testing.T) { tests := []struct { desc string tokenInvalid bool - tokenExpiresAt time.Time + tokenDuration time.Duration expectedErr bool expectedIsExpired bool }{ { desc: "token valid", - tokenExpiresAt: time.Now().Add(time.Hour), + tokenDuration: time.Hour, expectedErr: false, expectedIsExpired: false, }, { desc: "token expired", - tokenExpiresAt: time.Now().Add(-time.Hour), + tokenDuration: -time.Hour, expectedErr: false, expectedIsExpired: true, }, { desc: "token almost expired", - tokenExpiresAt: time.Now().Add(tokenExpirationLeeway), + tokenDuration: tokenExpirationLeeway, expectedErr: false, expectedIsExpired: true, }, @@ -228,9 +227,7 @@ func TestTokenExpired(t *testing.T) { var err error token := "foo" if !tt.tokenInvalid { - token, err = jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(tt.tokenExpiresAt), - }).SignedString([]byte("test")) + token, err = signToken(tt.tokenDuration) if err != nil { t.Fatalf("failed to create token: %v", err) } @@ -442,10 +439,9 @@ func TestKeyFlow_Do(t *testing.T) { res.Header().Set("Content-Type", "application/json") token := &TokenResponseBody{ - AccessToken: testBearerToken, - ExpiresIn: 2147483647, - RefreshToken: testBearerToken, - TokenType: "Bearer", + AccessToken: testBearerToken, + ExpiresIn: 2147483647, + TokenType: "Bearer", } if err := json.NewEncoder(res.Body).Encode(token); err != nil { diff --git a/core/clients/workload_identity_flow.go b/core/clients/workload_identity_flow.go new file mode 100644 index 000000000..65b6fc461 --- /dev/null +++ b/core/clients/workload_identity_flow.go @@ -0,0 +1,249 @@ +package clients + +import ( + "context" + "fmt" + "net/http" + "net/url" + "os" + "strings" + "sync" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +const ( + clientIDEnv = "STACKIT_SERVICE_ACCOUNT_EMAIL" + FederatedTokenFileEnv = "STACKIT_FEDERATED_TOKEN_FILE" + wifTokenEndpointEnv = "STACKIT_IDP_ENDPOINT" + wifTokenExpirationEnv = "STACKIT_IDP_EXPIRATION_SECONDS" + + wifClientAssertionType = "urn:schwarz:params:oauth:client-assertion-type:workload-jwt" + wifGrantType = "client_credentials" + defaultWifTokenEndpoint = "https://accounts.stackit.cloud/oauth/v2/token" + defaultFederatedTokenPath = "/var/run/secrets/stackit.cloud/serviceaccount/token" + defaultWifExpirationToken = "1h" +) + +var ( + _ = getEnvOrDefault(wifTokenExpirationEnv, defaultWifExpirationToken) // Not used yet +) + +func getEnvOrDefault(envVar, defaultValue string) string { + if value := os.Getenv(envVar); value != "" { + return value + } + return defaultValue +} + +var _ AuthFlow = &WorkloadIdentityFederationFlow{} + +// WorkloadIdentityFlow handles auth with Workload Identity Federation +type WorkloadIdentityFederationFlow struct { + rt http.RoundTripper + authClient *http.Client + config *WorkloadIdentityFederationFlowConfig + + tokenMutex sync.RWMutex + token *TokenResponseBody + + parser *jwt.Parser + + // If the current access token would expire in less than TokenExpirationLeeway, + // the client will refresh it early to prevent clock skew or other timing issues. + tokenExpirationLeeway time.Duration +} + +// KeyFlowConfig is the flow config +type WorkloadIdentityFederationFlowConfig struct { + TokenUrl string + ClientID string + FederatedTokenFilePath string + TokenExpiration string // Not supported yet + BackgroundTokenRefreshContext context.Context // Functionality is enabled if this isn't nil + HTTPTransport http.RoundTripper + AuthHTTPClient *http.Client +} + +// GetConfig returns the flow configuration +func (c *WorkloadIdentityFederationFlow) GetConfig() WorkloadIdentityFederationFlowConfig { + if c.config == nil { + return WorkloadIdentityFederationFlowConfig{} + } + return *c.config +} + +// GetAccessToken implements AuthFlow. +func (c *WorkloadIdentityFederationFlow) GetAccessToken() (string, error) { + if c.rt == nil { + return "", fmt.Errorf("nil http round tripper, please run Init()") + } + var accessToken string + + c.tokenMutex.RLock() + if c.token != nil { + accessToken = c.token.AccessToken + } + c.tokenMutex.RUnlock() + + accessTokenExpired, err := tokenExpired(accessToken, c.tokenExpirationLeeway) + if err != nil { + return "", fmt.Errorf("check access token is expired: %w", err) + } + if !accessTokenExpired { + return accessToken, nil + } + if err = c.createAccessToken(); err != nil { + return "", fmt.Errorf("get new access token: %w", err) + } + + c.tokenMutex.RLock() + accessToken = c.token.AccessToken + c.tokenMutex.RUnlock() + + return accessToken, nil +} + +// RoundTrip implements the http.RoundTripper interface. +// It gets a token, adds it to the request's authorization header, and performs the request. +func (c *WorkloadIdentityFederationFlow) RoundTrip(req *http.Request) (*http.Response, error) { + if c.rt == nil { + return nil, fmt.Errorf("please run Init()") + } + + accessToken, err := c.GetAccessToken() + if err != nil { + return nil, err + } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) + return c.rt.RoundTrip(req) +} + +// GetBackgroundTokenRefreshContext implements AuthFlow. +func (c *WorkloadIdentityFederationFlow) GetBackgroundTokenRefreshContext() context.Context { + return c.config.BackgroundTokenRefreshContext +} + +func (c *WorkloadIdentityFederationFlow) Init(cfg *WorkloadIdentityFederationFlowConfig) error { + // No concurrency at this point, so no mutex check needed + c.token = &TokenResponseBody{} + c.config = cfg + c.parser = jwt.NewParser() + + if c.config.TokenUrl == "" { + c.config.TokenUrl = getEnvOrDefault(wifTokenEndpointEnv, defaultWifTokenEndpoint) + } + + if c.config.ClientID == "" { + c.config.ClientID = getEnvOrDefault(clientIDEnv, "") + } + + if c.config.FederatedTokenFilePath == "" { + c.config.FederatedTokenFilePath = getEnvOrDefault(FederatedTokenFileEnv, defaultFederatedTokenPath) + } + + c.tokenExpirationLeeway = defaultTokenExpirationLeeway + + if c.rt = cfg.HTTPTransport; c.rt == nil { + c.rt = http.DefaultTransport + } + + if c.authClient = cfg.AuthHTTPClient; cfg.AuthHTTPClient == nil { + c.authClient = &http.Client{ + Transport: c.rt, + Timeout: DefaultClientTimeout, + } + } + + err := c.validate() + if err != nil { + return err + } + + // // Init the token + // _, err = c.GetAccessToken() + // if err != nil { + // return err + // } + + if c.config.BackgroundTokenRefreshContext != nil { + go continuousRefreshToken(c) + } + return nil +} + +// validate the client is configured well +func (c *WorkloadIdentityFederationFlow) validate() error { + if c.config.ClientID == "" { + return fmt.Errorf("client ID cannot be empty") + } + if c.config.TokenUrl == "" { + return fmt.Errorf("token URL cannot be empty") + } + if _, err := c.readJWTFromFileSystem(c.config.FederatedTokenFilePath); err != nil { + return fmt.Errorf("error reading federated token file - %w", err) + } + if c.tokenExpirationLeeway < 0 { + return fmt.Errorf("token expiration leeway cannot be negative") + } + + return nil +} + +// createAccessToken creates an access token using self signed JWT +func (c *WorkloadIdentityFederationFlow) createAccessToken() (err error) { + clientAssertion, err := c.readJWTFromFileSystem(c.config.FederatedTokenFilePath) + if err != nil { + return fmt.Errorf("error reading service account assertion - %w", err) + } + + res, err := c.requestToken(c.config.ClientID, clientAssertion) + if err != nil { + return err + } + defer func() { + tempErr := res.Body.Close() + if tempErr != nil && err == nil { + err = fmt.Errorf("close request access token response: %w", tempErr) + } + }() + token, err := parseTokenResponse(res) + if err != nil { + return err + } + c.tokenMutex.Lock() + c.token = token + c.tokenMutex.Unlock() + return nil +} + +func (c *WorkloadIdentityFederationFlow) requestToken(clientID, assertion string) (*http.Response, error) { + body := url.Values{} + body.Set("grant_type", wifGrantType) + body.Set("client_assertion_type", wifClientAssertionType) + body.Set("client_assertion", assertion) + body.Set("client_id", clientID) + + payload := strings.NewReader(body.Encode()) + req, err := http.NewRequest(http.MethodPost, c.config.TokenUrl, payload) + if err != nil { + return nil, err + } + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + + return c.authClient.Do(req) +} + +func (c *WorkloadIdentityFederationFlow) readJWTFromFileSystem(tokenFilePath string) (string, error) { + token, err := os.ReadFile(tokenFilePath) + if err != nil { + return "", err + } + tokenStr := string(token) + _, _, err = c.parser.ParseUnverified(tokenStr, jwt.MapClaims{}) + if err != nil { + return "", err + } + return tokenStr, nil +} diff --git a/core/clients/workload_identity_flow_test.go b/core/clients/workload_identity_flow_test.go new file mode 100644 index 000000000..ef8f7a15f --- /dev/null +++ b/core/clients/workload_identity_flow_test.go @@ -0,0 +1,566 @@ +package clients + +import ( + "encoding/json" + "log" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +func TestWorkloadIdentityFlowInit(t *testing.T) { + tests := []struct { + name string + clientID string + clientIDAsEnv bool + customTokenUrl string + customTokenUrlEnv bool + tokenExpiration string + validAssertion bool + tokenFilePathAsEnv bool + missingTokenFilePath bool + wantErr bool + }{ + { + name: "ok setting all", + clientID: "test@stackit.cloud", + validAssertion: true, + wantErr: false, + }, + { + name: "ok using defaults from envs", + clientID: "test@stackit.cloud", + clientIDAsEnv: true, + tokenFilePathAsEnv: true, + customTokenUrlEnv: true, + validAssertion: true, + wantErr: false, + }, + { + name: "ok using defaults from envs", + clientID: "test@stackit.cloud", + clientIDAsEnv: true, + tokenFilePathAsEnv: true, + customTokenUrlEnv: true, + validAssertion: true, + wantErr: false, + }, + { + name: "missing client id", + validAssertion: true, + wantErr: true, + }, + { + name: "missing assertion", + clientID: "test@stackit.cloud", + missingTokenFilePath: true, + wantErr: true, + }, + { + name: "invalid assertion", + clientID: "test@stackit.cloud", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + flow := &WorkloadIdentityFederationFlow{} + flowConfig := &WorkloadIdentityFederationFlowConfig{} + if tt.customTokenUrl != "" { + if tt.customTokenUrlEnv { + t.Setenv("STACKIT_IDP_ENDPOINT", tt.customTokenUrl) + } else { + flowConfig.TokenUrl = tt.customTokenUrl + } + } + + if tt.clientID != "" { + if tt.clientIDAsEnv { + t.Setenv("STACKIT_SERVICE_ACCOUNT_EMAIL", tt.clientID) + } else { + flowConfig.ClientID = tt.clientID + } + } + if tt.tokenExpiration != "" { + flowConfig.TokenExpiration = tt.tokenExpiration + } + + if !tt.missingTokenFilePath { + file, err := os.CreateTemp("", "*.token") + if err != nil { + log.Fatal(err) + } + defer os.Remove(file.Name()) + if tt.validAssertion { + token, err := signTokenWithSubject("subject", time.Minute) + if err != nil { + t.Fatalf("failed to create token: %v", err) + } + os.WriteFile(file.Name(), []byte(token), os.ModeAppend) + } + if tt.tokenFilePathAsEnv { + t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", file.Name()) + } else { + flowConfig.FederatedTokenFilePath = file.Name() + } + } + + if err := flow.Init(flowConfig); (err != nil) != tt.wantErr { + t.Errorf("KeyFlow.Init() error = %v, wantErr %v", err, tt.wantErr) + } + if flow.config == nil { + t.Error("config is nil") + } + + if flow.config.ClientID != tt.clientID { + t.Errorf("clientID mismatch, want %s, got %s", tt.clientID, flow.config.ClientID) + } + + if tt.customTokenUrl != "" && flow.config.TokenUrl != tt.customTokenUrl { + t.Errorf("tokenUrl mismatch, want %s, got %s", tt.customTokenUrl, flow.config.TokenUrl) + } + + if tt.customTokenUrl == "" && flow.config.TokenUrl != "https://accounts.stackit.cloud/oauth/v2/token" { + t.Errorf("tokenUrl mismatch, want %s, got %s", "https://accounts.stackit.cloud/oauth/v2/token", flow.config.TokenUrl) + } + + if tt.missingTokenFilePath && flow.config.FederatedTokenFilePath != "/var/run/secrets/stackit.cloud/serviceaccount/token" { + t.Errorf("clientID mismatch, want %s, got %s", "/var/run/secrets/stackit.cloud/serviceaccount/token", flow.config.FederatedTokenFilePath) + } + + if !tt.missingTokenFilePath && flow.config.FederatedTokenFilePath == "/var/run/secrets/stackit.cloud/serviceaccount/token" { + t.Errorf("clientID mismatch, want different from %s", flow.config.FederatedTokenFilePath) + } + + if tt.tokenExpiration != "" && flow.config.TokenExpiration != tt.tokenExpiration { + t.Errorf("tokenExpiration mismatch, want %s, got %s", tt.tokenExpiration, flow.config.TokenExpiration) + } + }) + } +} + +func signTokenWithSubject(sub string, expiration time.Duration) (string, error) { + return jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(expiration)), + Subject: sub, + }).SignedString([]byte("test")) +} + +func TestWorkloadIdentityFlowRoundTrip(t *testing.T) { + validSub := "valid-sub" + serviceAccountSub := "sa-sub" + tests := []struct { + name string + clientID string + validAssertion bool + wantErr bool + }{ + { + name: "ok setting all", + clientID: "test@stackit.cloud", + validAssertion: true, + wantErr: false, + }, + { + name: "invalid assertion", + clientID: "test@stackit.cloud", + validAssertion: false, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + assertionType := r.PostForm.Get("client_assertion_type") + if assertionType != "urn:schwarz:params:oauth:client-assertion-type:workload-jwt" { + t.Fatalf("invalid assertion type: %s", assertionType) + } + grantType := r.PostForm.Get("grant_type") + if grantType != "client_credentials" { + t.Fatalf("invalid grant type: %s", assertionType) + } + context, _, err := jwt.NewParser().ParseUnverified(r.PostForm.Get("client_assertion"), jwt.MapClaims{}) + if err != nil { + t.Fatalf("failed to validate token: %v", err) + } + sub, err := context.Claims.GetSubject() + if err != nil { + t.Fatalf("failed to validate token sub: %v", err) + } + if sub != validSub { + w.WriteHeader(http.StatusUnauthorized) + return + } + + token, err := signTokenWithSubject(serviceAccountSub, time.Minute) + if err != nil { + t.Fatalf("failed to create token: %v", err) + } + + tokenResponse := &TokenResponseBody{ + AccessToken: token, + ExpiresIn: 60, + TokenType: "Bearer", + } + + payload, err := json.Marshal(tokenResponse) + if err != nil { + t.Fatalf("failed to create token payload: %v", err) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(payload) + })) + t.Cleanup(authServer.Close) + + protectedResource := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + context, _, err := jwt.NewParser().ParseUnverified(strings.Fields(r.Header.Get("Authorization"))[1], jwt.MapClaims{}) + if err != nil { + t.Fatalf("failed to validate token: %v", err) + } + + sub, err := context.Claims.GetSubject() + if err != nil { + t.Fatalf("failed to validate token sub: %v", err) + } + if sub != serviceAccountSub { + t.Fatalf("invalid token on protected resource: %v", err) + } + + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(protectedResource.Close) + + flow := &WorkloadIdentityFederationFlow{} + flowConfig := &WorkloadIdentityFederationFlowConfig{} + flowConfig.TokenUrl = authServer.URL + + flowConfig.ClientID = tt.clientID + file, err := os.CreateTemp("", "*.token") + if err != nil { + log.Fatal(err) + } + defer os.Remove(file.Name()) + flowConfig.FederatedTokenFilePath = file.Name() + + subject := "wrong" + if tt.validAssertion { + subject = validSub + } + token, err := signTokenWithSubject(subject, time.Minute) + if err != nil { + t.Fatalf("failed to create token: %v", err) + } + os.WriteFile(file.Name(), []byte(token), os.ModeAppend) + + if err := flow.Init(flowConfig); err != nil { + t.Errorf("KeyFlow.Init() error = %v", err) + } + if flow.config == nil { + t.Error("config is nil") + } + + client := http.Client{ + Transport: flow, + } + resp, err := client.Get(protectedResource.URL) + if (err != nil || resp.StatusCode != http.StatusOK) && !tt.wantErr { + t.Fatalf("failed request to protected resource: %v", err) + } + }) + } +} + +// func TestRequestToken(t *testing.T) { +// testCases := []struct { +// name string +// grant string +// assertion string +// mockResponse *http.Response +// mockError error +// expectedError error +// }{ +// { +// name: "Success", +// grant: "test_grant", +// assertion: "test_assertion", +// mockResponse: &http.Response{ +// StatusCode: 200, +// Body: io.NopCloser(strings.NewReader(`{"access_token": "test_token"}`)), +// }, +// mockError: nil, +// expectedError: nil, +// }, +// { +// name: "Error", +// grant: "test_grant", +// assertion: "test_assertion", +// mockResponse: nil, +// mockError: fmt.Errorf("request error"), +// expectedError: fmt.Errorf("request error"), +// }, +// } + +// for _, tt := range testCases { +// t.Run(tt.name, func(t *testing.T) { +// keyFlow := &KeyFlow{} +// privateKeyBytes, err := generatePrivateKey() +// if err != nil { +// t.Fatalf("Error generating private key: %s", err) +// } +// keyFlowConfig := &KeyFlowConfig{ +// AuthHTTPClient: &http.Client{ +// Transport: mockTransportFn{func(_ *http.Request) (*http.Response, error) { +// return tt.mockResponse, tt.mockError +// }}, +// }, +// ServiceAccountKey: fixtureServiceAccountKey(), +// PrivateKey: string(privateKeyBytes), +// HTTPTransport: http.DefaultTransport, +// } +// err = keyFlow.Init(keyFlowConfig) +// if err != nil { +// t.Fatalf("failed to initialize key flow: %v", err) +// } + +// res, err := keyFlow.requestToken(tt.grant, tt.assertion) +// defer func() { +// if res != nil { +// tempErr := res.Body.Close() +// if tempErr != nil { +// t.Errorf("closing request token response: %s", tempErr.Error()) +// } +// } +// }() +// if tt.expectedError != nil { +// if err == nil { +// t.Errorf("Expected error '%v' but no error was returned", tt.expectedError) +// } else if errors.Is(err, tt.expectedError) { +// t.Errorf("Error is not correct. Expected %v, got %v", tt.expectedError, err) +// } +// } else { +// if err != nil { +// t.Errorf("Expected no error but error was returned: %v", err) +// } +// if !cmp.Equal(tt.mockResponse, res, cmp.AllowUnexported(strings.Reader{})) { +// t.Errorf("The returned result is wrong. Expected %v, got %v", tt.mockResponse, res) +// } +// } +// }) +// } +// } + +// func TestKeyFlow_Do(t *testing.T) { +// t.Parallel() + +// tests := []struct { +// name string +// handlerFn func(tb testing.TB) http.HandlerFunc +// want int +// wantErr bool +// }{ +// { +// name: "success", +// handlerFn: func(tb testing.TB) http.HandlerFunc { +// tb.Helper() + +// return func(w http.ResponseWriter, r *http.Request) { +// if r.Header.Get("Authorization") != "Bearer "+testBearerToken { +// tb.Errorf("expected Authorization header to be 'Bearer %s', but got %s", testBearerToken, r.Header.Get("Authorization")) +// } + +// w.Header().Set("Content-Type", "application/json") +// w.WriteHeader(http.StatusOK) +// _, _ = fmt.Fprintln(w, `{"status":"ok"}`) +// } +// }, +// want: http.StatusOK, +// wantErr: false, +// }, +// { +// name: "success with code 500", +// handlerFn: func(_ testing.TB) http.HandlerFunc { +// return func(w http.ResponseWriter, _ *http.Request) { +// w.Header().Set("Content-Type", "text/html") +// w.WriteHeader(http.StatusInternalServerError) +// _, _ = fmt.Fprintln(w, `Internal Server Error`) +// } +// }, +// want: http.StatusInternalServerError, +// wantErr: false, +// }, +// { +// name: "success with custom transport", +// handlerFn: func(tb testing.TB) http.HandlerFunc { +// tb.Helper() + +// return func(w http.ResponseWriter, r *http.Request) { +// if r.Header.Get("User-Agent") != "custom_transport" { +// tb.Errorf("expected User-Agent header to be 'custom_transport', but got %s", r.Header.Get("User-Agent")) +// } + +// w.Header().Set("Content-Type", "application/json") +// w.WriteHeader(http.StatusOK) +// _, _ = fmt.Fprintln(w, `{"status":"ok"}`) +// } +// }, +// want: http.StatusOK, +// wantErr: false, +// }, +// { +// name: "fail with custom proxy", +// handlerFn: func(testing.TB) http.HandlerFunc { +// return func(w http.ResponseWriter, _ *http.Request) { +// w.Header().Set("Content-Type", "application/json") +// w.WriteHeader(http.StatusOK) +// _, _ = fmt.Fprintln(w, `{"status":"ok"}`) +// } +// }, +// want: 0, +// wantErr: true, +// }, +// } +// for _, tt := range tests { +// t.Run(tt.name, func(t *testing.T) { +// ctx := context.Background() +// ctx, cancel := context.WithCancel(ctx) +// t.Cleanup(cancel) // This cancels the refresher goroutine + +// privateKeyBytes, err := generatePrivateKey() +// if err != nil { +// t.Errorf("no error is expected, but got %v", err) +// } + +// keyFlow := &KeyFlow{} +// keyFlowConfig := &KeyFlowConfig{ +// ServiceAccountKey: fixtureServiceAccountKey(), +// PrivateKey: string(privateKeyBytes), +// BackgroundTokenRefreshContext: ctx, +// HTTPTransport: func() http.RoundTripper { +// switch tt.name { +// case "success with custom transport": +// return mockTransportFn{ +// fn: func(req *http.Request) (*http.Response, error) { +// req.Header.Set("User-Agent", "custom_transport") +// return http.DefaultTransport.RoundTrip(req) +// }, +// } +// case "fail with custom proxy": +// return &http.Transport{ +// Proxy: func(_ *http.Request) (*url.URL, error) { +// return nil, fmt.Errorf("proxy error") +// }, +// } +// default: +// return http.DefaultTransport +// } +// }(), +// AuthHTTPClient: &http.Client{ +// Transport: mockTransportFn{ +// fn: func(_ *http.Request) (*http.Response, error) { +// res := httptest.NewRecorder() +// res.WriteHeader(http.StatusOK) +// res.Header().Set("Content-Type", "application/json") + +// token := &TokenResponseBody{ +// AccessToken: testBearerToken, +// ExpiresIn: 2147483647, +// TokenType: "Bearer", +// } + +// if err := json.NewEncoder(res.Body).Encode(token); err != nil { +// t.Logf("no error is expected, but got %v", err) +// } + +// return res.Result(), nil +// }, +// }, +// }, +// } +// err = keyFlow.Init(keyFlowConfig) +// if err != nil { +// t.Fatalf("failed to initialize key flow: %v", err) +// } + +// go continuousRefreshToken(keyFlow) + +// tokenCtx, tokenCancel := context.WithTimeout(context.Background(), 1*time.Second) + +// token: +// for { +// select { +// case <-tokenCtx.Done(): +// t.Error(tokenCtx.Err()) +// case <-time.After(50 * time.Millisecond): +// keyFlow.tokenMutex.RLock() +// if keyFlow.token != nil { +// keyFlow.tokenMutex.RUnlock() +// tokenCancel() +// break token +// } + +// keyFlow.tokenMutex.RUnlock() +// } +// } + +// server := httptest.NewServer(tt.handlerFn(t)) +// t.Cleanup(server.Close) + +// u, err := url.Parse(server.URL) +// if err != nil { +// t.Errorf("no error is expected, but got %v", err) +// } + +// req, err := http.NewRequest(http.MethodGet, u.String(), http.NoBody) +// if err != nil { +// t.Errorf("no error is expected, but got %v", err) +// } + +// httpClient := &http.Client{ +// Transport: keyFlow, +// } + +// res, err := httpClient.Do(req) + +// if tt.wantErr { +// if err == nil { +// t.Errorf("error is expected, but got %v", err) +// } +// } else { +// if err != nil { +// t.Errorf("no error is expected, but got %v", err) +// } + +// if res.StatusCode != tt.want { +// t.Errorf("expected status code %d, but got %d", tt.want, res.StatusCode) +// } + +// // Defer discard and close the body +// t.Cleanup(func() { +// if _, err := io.Copy(io.Discard, res.Body); err != nil { +// t.Errorf("no error is expected, but got %v", err) +// } + +// if err := res.Body.Close(); err != nil { +// t.Errorf("no error is expected, but got %v", err) +// } +// }) +// } +// }) +// } +// } + +// type mockTransportFn struct { +// fn func(req *http.Request) (*http.Response, error) +// } + +// func (m mockTransportFn) RoundTrip(req *http.Request) (*http.Response, error) { +// return m.fn(req) +// } diff --git a/core/config/config.go b/core/config/config.go index 93002c02a..ae2d8c498 100644 --- a/core/config/config.go +++ b/core/config/config.go @@ -75,26 +75,29 @@ type Middleware func(http.RoundTripper) http.RoundTripper // Configuration stores the configuration of the API client type Configuration struct { - Host string `json:"host,omitempty"` - Scheme string `json:"scheme,omitempty"` - DefaultHeader map[string]string `json:"defaultHeader,omitempty"` - UserAgent string `json:"userAgent,omitempty"` - Debug bool `json:"debug,omitempty"` - NoAuth bool `json:"noAuth,omitempty"` - ServiceAccountEmail string `json:"serviceAccountEmail,omitempty"` // Deprecated: ServiceAccountEmail is not required and will be removed after 12th June 2025. - Token string `json:"token,omitempty"` - ServiceAccountKey string `json:"serviceAccountKey,omitempty"` - PrivateKey string `json:"privateKey,omitempty"` - ServiceAccountKeyPath string `json:"serviceAccountKeyPath,omitempty"` - PrivateKeyPath string `json:"privateKeyPath,omitempty"` - CredentialsFilePath string `json:"credentialsFilePath,omitempty"` - TokenCustomUrl string `json:"tokenCustomUrl,omitempty"` - Region string `json:"region,omitempty"` - CustomAuth http.RoundTripper - Servers ServerConfigurations - OperationServers map[string]ServerConfigurations - HTTPClient *http.Client - Middleware []Middleware + Host string `json:"host,omitempty"` + Scheme string `json:"scheme,omitempty"` + DefaultHeader map[string]string `json:"defaultHeader,omitempty"` + UserAgent string `json:"userAgent,omitempty"` + Debug bool `json:"debug,omitempty"` + NoAuth bool `json:"noAuth,omitempty"` + WorkloadIdentityFederation bool `json:"workloadIdentityFederation,omitempty"` + WorkloadIdentityFederationTokenExpiration string `json:"workloadIdentityFederationTokenExpiration,omitempty"` + WorkloadIdentityFederationFederatedTokenPath string `json:"workloadIdentityFederationFederatedTokenPath,omitempty"` + ServiceAccountEmail string `json:"serviceAccountEmail,omitempty"` + Token string `json:"token,omitempty"` + ServiceAccountKey string `json:"serviceAccountKey,omitempty"` + PrivateKey string `json:"privateKey,omitempty"` + ServiceAccountKeyPath string `json:"serviceAccountKeyPath,omitempty"` + PrivateKeyPath string `json:"privateKeyPath,omitempty"` + CredentialsFilePath string `json:"credentialsFilePath,omitempty"` + TokenCustomUrl string `json:"tokenCustomUrl,omitempty"` + Region string `json:"region,omitempty"` + CustomAuth http.RoundTripper + Servers ServerConfigurations + OperationServers map[string]ServerConfigurations + HTTPClient *http.Client + Middleware []Middleware // If != nil, a goroutine will be launched that will refresh the service account's access token when it's close to being expired. // The goroutine is killed whenever this context is canceled. @@ -176,8 +179,6 @@ func WithTokenEndpoint(url string) ConfigurationOption { } // WithServiceAccountEmail returns a ConfigurationOption that sets the service account email -// -// Deprecated: WithServiceAccountEmail is not required and will be removed after 12th June 2025. func WithServiceAccountEmail(serviceAccountEmail string) ConfigurationOption { return func(config *Configuration) error { config.ServiceAccountEmail = serviceAccountEmail @@ -237,6 +238,30 @@ func WithToken(token string) ConfigurationOption { } } +// WithWorkloadIdentityFederationAuth returns a ConfigurationOption that sets workload identity flow to be used for authentication in API calls +func WithWorkloadIdentityFederationAuth() ConfigurationOption { + return func(config *Configuration) error { + config.WorkloadIdentityFederation = true + return nil + } +} + +// WithWorkloadIdentityFederation returns a ConfigurationOption that sets workload identity flow to be used for authentication in API calls +func WithWorkloadIdentityFederationTokenPath(path string) ConfigurationOption { + return func(config *Configuration) error { + config.WorkloadIdentityFederationFederatedTokenPath = path + return nil + } +} + +// WithWorkloadIdentityFederationTokenExpiration returns a ConfigurationOption that sets the token expiration for workload identity federation flow +func WithWorkloadIdentityFederationTokenExpiration(expiration string) ConfigurationOption { + return func(config *Configuration) error { + config.WorkloadIdentityFederationTokenExpiration = expiration + return nil + } +} + // Deprecated: retry options were removed to reduce complexity of the client. If this functionality is needed, you can provide your own custom HTTP client. This option has no effect, and will be removed in a later update func WithMaxRetries(_ int) ConfigurationOption { return func(_ *Configuration) error { diff --git a/examples/authentication/authentication.go b/examples/authentication/authentication.go index 839999938..64758bd87 100644 --- a/examples/authentication/authentication.go +++ b/examples/authentication/authentication.go @@ -14,7 +14,8 @@ func main() { // When creating a new API client without providing any configuration, it will setup default authentication. // The SDK will search for a valid service account key or token in several locations. - // It will first try to use the key flow, by looking into the variables STACKIT_SERVICE_ACCOUNT_KEY, STACKIT_SERVICE_ACCOUNT_KEY_PATH, + // It will first try to use the workload identity federation flow by looking into the variables STACKIT_FEDERATED_TOKEN_FILE, STACKIT_SERVICE_ACCOUNT_EMAIL and their default values, + // Then, it will try key flow, by looking into the variables STACKIT_SERVICE_ACCOUNT_KEY, STACKIT_SERVICE_ACCOUNT_KEY_PATH, // STACKIT_PRIVATE_KEY and STACKIT_PRIVATE_KEY_PATH. If the keys cannot be retrieved, it will check the credentials file located in STACKIT_CREDENTIALS_PATH, if specified, or in // $HOME/.stackit/credentials.json as a fallback. If the key are found and are valid, the KeyAuth flow is used. // If the key flow cannot be used, it will try to find a token in the STACKIT_SERVICE_ACCOUNT_TOKEN. If not present, it will @@ -35,18 +36,27 @@ func main() { // Create a new API client, that will authenticate using the provided bearer token token := "TOKEN" - _, err = dns.NewAPIClient(config.WithToken(token)) + dnsClient, err := dns.NewAPIClient(config.WithToken(token)) if err != nil { fmt.Fprintf(os.Stderr, "[DNS API] Creating API client: %v\n", err) os.Exit(1) } + // Check that you can make an authenticated request + getZoneResp, err := dnsClient.ListZones(context.Background(), projectId).Execute() + + if err != nil { + fmt.Fprintf(os.Stderr, "[DNS API] Error when calling `ZoneApi.GetZones`: %v\n", err) + } else { + fmt.Printf("[DNS API] Number of zones: %v\n", len(*getZoneResp.Zones)) + } + // Create a new API client, that will authenticate using the key flow // If you created a service account key and provided your own RSA key pair, // you need to add the path to a PEM encoded file including the private key // using config.WithPrivateKeyPath("path/to/private_key.pem") saKeyPath := "/path/to/service_account_key.json" - dnsClient, err := dns.NewAPIClient( + dnsClient, err = dns.NewAPIClient( config.WithServiceAccountKeyPath(saKeyPath), ) if err != nil { @@ -55,7 +65,30 @@ func main() { } // Check that you can make an authenticated request - getZoneResp, err := dnsClient.ListZones(context.Background(), projectId).Execute() + getZoneResp, err = dnsClient.ListZones(context.Background(), projectId).Execute() + + if err != nil { + fmt.Fprintf(os.Stderr, "[DNS API] Error when calling `ZoneApi.GetZones`: %v\n", err) + } else { + fmt.Printf("[DNS API] Number of zones: %v\n", len(*getZoneResp.Zones)) + } + + // Create a new API client, that will authenticate using the wif flow + // You need to create a service account key and configure the federate identity provider, + // then you can init the SDK setting fields + dnsClient, err = dns.NewAPIClient( + config.WithWorkloadIdentityFederationAuth(), + config.WithTokenEndpoint("custom token endpoint"), + config.WithWorkloadIdentityFederationTokenPath("/path/to/your/federated/token"), + config.WithServiceAccountEmail("my-sa@sa-stackit.cloud"), + ) + if err != nil { + fmt.Fprintf(os.Stderr, "[DNS API] Creating API client: %v\n", err) + os.Exit(1) + } + + // Check that you can make an authenticated request + getZoneResp, err = dnsClient.ListZones(context.Background(), projectId).Execute() if err != nil { fmt.Fprintf(os.Stderr, "[DNS API] Error when calling `ZoneApi.GetZones`: %v\n", err)