Skip to content

Commit 3e75e01

Browse files
committed
feat: Support Workload Identity Federation flow
Signed-off-by: Jorge Turrado <[email protected]>
1 parent 4cb99cc commit 3e75e01

File tree

11 files changed

+1230
-547
lines changed

11 files changed

+1230
-547
lines changed

core/auth/auth.go

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ func SetupAuth(cfg *config.Configuration) (rt http.RoundTripper, err error) {
4545

4646
if cfg.CustomAuth != nil {
4747
return cfg.CustomAuth, nil
48+
} else if useWorkloadIdentityFederation(cfg) {
49+
wifRoundTripper, err := WorkloadIdentityFederationAuth(cfg)
50+
if err != nil {
51+
return nil, fmt.Errorf("configuring no auth client: %w", err)
52+
}
53+
return wifRoundTripper, nil
4854
} else if cfg.NoAuth {
4955
noAuthRoundTripper, err := NoAuth(cfg)
5056
if err != nil {
@@ -84,14 +90,18 @@ func DefaultAuth(cfg *config.Configuration) (rt http.RoundTripper, err error) {
8490
cfg = &config.Configuration{}
8591
}
8692

87-
// Key flow
88-
rt, err = KeyAuth(cfg)
93+
// WIF flow
94+
rt, err = WorkloadIdentityFederationAuth(cfg)
8995
if err != nil {
90-
keyFlowErr := err
91-
// Token flow
92-
rt, err = TokenAuth(cfg)
96+
// Key flow
97+
rt, err = KeyAuth(cfg)
9398
if err != nil {
94-
return nil, fmt.Errorf("no valid credentials were found: trying key flow: %s, trying token flow: %w", keyFlowErr.Error(), err)
99+
keyFlowErr := err
100+
// Token flow
101+
rt, err = TokenAuth(cfg)
102+
if err != nil {
103+
return nil, fmt.Errorf("no valid credentials were found: trying key flow: %s, trying token flow: %w", keyFlowErr.Error(), err)
104+
}
95105
}
96106
}
97107
return rt, nil
@@ -221,6 +231,29 @@ func KeyAuth(cfg *config.Configuration) (http.RoundTripper, error) {
221231
return client, nil
222232
}
223233

234+
// WorkloadIdentityFederationAuth configures the wif flow and returns an http.RoundTripper
235+
// that can be used to make authenticated requests using an access token
236+
func WorkloadIdentityFederationAuth(cfg *config.Configuration) (http.RoundTripper, error) {
237+
wifConfig := clients.WorkloadIdentityFederationFlowConfig{
238+
TokenUrl: cfg.TokenCustomUrl,
239+
BackgroundTokenRefreshContext: cfg.BackgroundTokenRefreshContext,
240+
ClientID: cfg.ServiceAccountEmail,
241+
FederatedTokenFilePath: cfg.WorkloadIdentityFederationFederatedTokenPath,
242+
TokenExpiration: cfg.WorkloadIdentityFederationTokenExpiration,
243+
}
244+
245+
if cfg.HTTPClient != nil && cfg.HTTPClient.Transport != nil {
246+
wifConfig.HTTPTransport = cfg.HTTPClient.Transport
247+
}
248+
249+
client := &clients.WorkloadIdentityFederationFlow{}
250+
if err := client.Init(&wifConfig); err != nil {
251+
return nil, fmt.Errorf("error initializing client: %w", err)
252+
}
253+
254+
return client, nil
255+
}
256+
224257
// readCredentialsFile reads the credentials file from the specified path and returns Credentials
225258
func readCredentialsFile(path string) (*Credentials, error) {
226259
if path == "" {
@@ -361,3 +394,11 @@ func getServiceAccountKey(cfg *config.Configuration) error {
361394
func getPrivateKey(cfg *config.Configuration) error {
362395
return getKey(&cfg.PrivateKey, &cfg.PrivateKeyPath, "STACKIT_PRIVATE_KEY_PATH", "STACKIT_PRIVATE_KEY", privateKeyPathCredentialType, privateKeyCredentialType, cfg.CredentialsFilePath)
363396
}
397+
398+
func useWorkloadIdentityFederation(cfg *config.Configuration) bool {
399+
if cfg != nil && cfg.WorkloadIdentityFederation {
400+
return true
401+
}
402+
val, exists := os.LookupEnv(clients.FederatedTokenFileEnv)
403+
return exists && val != ""
404+
}

core/auth/auth_test.go

Lines changed: 87 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"testing"
1414
"time"
1515

16+
"github.com/golang-jwt/jwt/v5"
1617
"github.com/google/uuid"
1718
"github.com/stackitcloud/stackit-sdk-go/core/clients"
1819
"github.com/stackitcloud/stackit-sdk-go/core/config"
@@ -121,6 +122,32 @@ func TestSetupAuth(t *testing.T) {
121122
}
122123
}()
123124

125+
// create a wif assertion file
126+
wifAssertionFile, errs := os.CreateTemp("", "temp-*.txt")
127+
if errs != nil {
128+
t.Fatalf("Creating temporary file: %s", err)
129+
}
130+
defer func() {
131+
_ = wifAssertionFile.Close()
132+
err := os.Remove(wifAssertionFile.Name())
133+
if err != nil {
134+
t.Fatalf("Removing temporary file: %s", err)
135+
}
136+
}()
137+
138+
token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{
139+
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)),
140+
Subject: "sub",
141+
}).SignedString([]byte("test"))
142+
if err != nil {
143+
t.Fatalf("Removing temporary file: %s", err)
144+
}
145+
146+
_, errs = wifAssertionFile.WriteString(string(token))
147+
if errs != nil {
148+
t.Fatalf("Writing wif assertion to temporary file: %s", err)
149+
}
150+
124151
// create a credentials file with saKey and private key
125152
credentialsKeyFile, errs := os.CreateTemp("", "temp-*.txt")
126153
if errs != nil {
@@ -147,48 +174,48 @@ func TestSetupAuth(t *testing.T) {
147174
desc string
148175
config *config.Configuration
149176
setToken bool
177+
setWorkloadIdentity bool
150178
setKeys bool
151179
setKeyPaths bool
152180
setCredentialsFilePathToken bool
153181
setCredentialsFilePathKey bool
154-
isValid bool
155182
}{
183+
{
184+
desc: "wif_config",
185+
config: nil,
186+
setWorkloadIdentity: true,
187+
},
156188
{
157189
desc: "token_config",
158190
config: nil,
159191
setToken: true,
160192
setCredentialsFilePathToken: false,
161-
isValid: true,
162193
},
163194
{
164195
desc: "key_config",
165196
config: nil,
166197
setKeys: true,
167198
setCredentialsFilePathToken: false,
168-
isValid: true,
169199
},
170200
{
171201
desc: "key_config_path",
172202
config: nil,
173203
setKeys: false,
174204
setKeyPaths: true,
175205
setCredentialsFilePathToken: false,
176-
isValid: true,
177206
},
178207
{
179208
desc: "key_config_credentials_path",
180209
config: nil,
181210
setKeys: false,
182211
setKeyPaths: false,
183212
setCredentialsFilePathKey: true,
184-
isValid: true,
185213
},
186214
{
187215
desc: "valid_path_to_file",
188216
config: nil,
189217
setToken: false,
190218
setCredentialsFilePathToken: true,
191-
isValid: true,
192219
},
193220
{
194221
desc: "custom_config_token",
@@ -197,7 +224,6 @@ func TestSetupAuth(t *testing.T) {
197224
},
198225
setToken: false,
199226
setCredentialsFilePathToken: false,
200-
isValid: true,
201227
},
202228
{
203229
desc: "custom_config_path",
@@ -206,7 +232,6 @@ func TestSetupAuth(t *testing.T) {
206232
},
207233
setToken: false,
208234
setCredentialsFilePathToken: false,
209-
isValid: true,
210235
},
211236
} {
212237
t.Run(test.desc, func(t *testing.T) {
@@ -241,19 +266,21 @@ func TestSetupAuth(t *testing.T) {
241266
t.Setenv("STACKIT_CREDENTIALS_PATH", "")
242267
}
243268

269+
if test.setWorkloadIdentity {
270+
t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", wifAssertionFile.Name())
271+
} else {
272+
t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", "")
273+
}
274+
244275
t.Setenv("STACKIT_SERVICE_ACCOUNT_EMAIL", "test-email")
245276

246277
authRoundTripper, err := SetupAuth(test.config)
247278

248-
if err != nil && test.isValid {
279+
if err != nil {
249280
t.Fatalf("Test returned error on valid test case: %v", err)
250281
}
251282

252-
if err == nil && !test.isValid {
253-
t.Fatalf("Test didn't return error on invalid test case")
254-
}
255-
256-
if test.isValid && authRoundTripper == nil {
283+
if authRoundTripper == nil {
257284
t.Fatalf("Roundtripper returned is nil for valid test case")
258285
}
259286
})
@@ -381,6 +408,32 @@ func TestDefaultAuth(t *testing.T) {
381408
t.Fatalf("Writing private key to temporary file: %s", err)
382409
}
383410

411+
// create a wif assertion file
412+
wifAssertionFile, errs := os.CreateTemp("", "temp-*.txt")
413+
if errs != nil {
414+
t.Fatalf("Creating temporary file: %s", err)
415+
}
416+
defer func() {
417+
_ = wifAssertionFile.Close()
418+
err := os.Remove(wifAssertionFile.Name())
419+
if err != nil {
420+
t.Fatalf("Removing temporary file: %s", err)
421+
}
422+
}()
423+
424+
token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{
425+
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)),
426+
Subject: "sub",
427+
}).SignedString([]byte("test"))
428+
if err != nil {
429+
t.Fatalf("Removing temporary file: %s", err)
430+
}
431+
432+
_, errs = wifAssertionFile.WriteString(string(token))
433+
if errs != nil {
434+
t.Fatalf("Writing wif assertion to temporary file: %s", err)
435+
}
436+
384437
// create a credentials file with saKey and private key
385438
credentialsKeyFile, errs := os.CreateTemp("", "temp-*.txt")
386439
if errs != nil {
@@ -409,6 +462,7 @@ func TestDefaultAuth(t *testing.T) {
409462
setKeyPaths bool
410463
setKeys bool
411464
setCredentialsFilePathKey bool
465+
setWorkloadIdentity bool
412466
isValid bool
413467
expectedFlow string
414468
}{
@@ -418,6 +472,14 @@ func TestDefaultAuth(t *testing.T) {
418472
isValid: true,
419473
expectedFlow: "token",
420474
},
475+
{
476+
desc: "wif_precedes_key_precedes_token",
477+
setToken: true,
478+
setKeyPaths: true,
479+
setWorkloadIdentity: true,
480+
isValid: true,
481+
expectedFlow: "wif",
482+
},
421483
{
422484
desc: "key_precedes_token",
423485
setToken: true,
@@ -475,6 +537,13 @@ func TestDefaultAuth(t *testing.T) {
475537
} else {
476538
t.Setenv("STACKIT_SERVICE_ACCOUNT_TOKEN", "")
477539
}
540+
541+
if test.setWorkloadIdentity {
542+
t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", wifAssertionFile.Name())
543+
} else {
544+
t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", "")
545+
}
546+
478547
t.Setenv("STACKIT_SERVICE_ACCOUNT_EMAIL", "test-email")
479548

480549
// Get the default authentication client and ensure that it's not nil
@@ -501,6 +570,10 @@ func TestDefaultAuth(t *testing.T) {
501570
if _, ok := authClient.(*clients.KeyFlow); !ok {
502571
t.Fatalf("Expected key flow, got %s", reflect.TypeOf(authClient))
503572
}
573+
case "wif":
574+
if _, ok := authClient.(*clients.WorkloadIdentityFederationFlow); !ok {
575+
t.Fatalf("Expected key flow, got %s", reflect.TypeOf(authClient))
576+
}
504577
}
505578
}
506579
})

core/clients/auth_flow.go

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
package clients
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"io"
8+
"net/http"
9+
"time"
10+
11+
"github.com/golang-jwt/jwt/v5"
12+
"github.com/stackitcloud/stackit-sdk-go/core/oapierror"
13+
)
14+
15+
const (
16+
defaultTokenExpirationLeeway = time.Second * 5
17+
)
18+
19+
type AuthFlow interface {
20+
RoundTrip(req *http.Request) (*http.Response, error)
21+
GetAccessToken() (string, error)
22+
GetBackgroundTokenRefreshContext() context.Context
23+
}
24+
25+
// TokenResponseBody is the API response
26+
// when requesting a new token
27+
type TokenResponseBody struct {
28+
AccessToken string `json:"access_token"`
29+
ExpiresIn int `json:"expires_in"`
30+
Scope string `json:"scope"`
31+
TokenType string `json:"token_type"`
32+
}
33+
34+
func parseTokenResponse(res *http.Response) (*TokenResponseBody, error) {
35+
if res == nil {
36+
return nil, fmt.Errorf("received bad response from API")
37+
}
38+
if res.StatusCode != http.StatusOK {
39+
body, err := io.ReadAll(res.Body)
40+
if err != nil {
41+
// Fail silently, omit body from error
42+
// We're trying to show error details, so it's unnecessary to fail because of this err
43+
body = []byte{}
44+
}
45+
return nil, &oapierror.GenericOpenAPIError{
46+
StatusCode: res.StatusCode,
47+
Body: body,
48+
}
49+
}
50+
body, err := io.ReadAll(res.Body)
51+
if err != nil {
52+
return nil, err
53+
}
54+
55+
token := &TokenResponseBody{}
56+
err = json.Unmarshal(body, token)
57+
if err != nil {
58+
return nil, fmt.Errorf("unmarshal token response: %w", err)
59+
}
60+
return token, nil
61+
}
62+
63+
func tokenExpired(token string, tokenExpirationLeeway time.Duration) (bool, error) {
64+
if token == "" {
65+
return true, nil
66+
}
67+
68+
// We can safely use ParseUnverified because we are not authenticating the user at this point.
69+
// We're just checking the expiration time
70+
tokenParsed, _, err := jwt.NewParser().ParseUnverified(token, &jwt.RegisteredClaims{})
71+
if err != nil {
72+
return false, fmt.Errorf("parse token: %w", err)
73+
}
74+
75+
expirationTimestampNumeric, err := tokenParsed.Claims.GetExpirationTime()
76+
if err != nil {
77+
return false, fmt.Errorf("get expiration timestamp: %w", err)
78+
}
79+
80+
// Pretend to be `tokenExpirationLeeway` into the future to avoid token expiring
81+
// between retrieving the token and upstream systems validating it.
82+
now := time.Now().Add(tokenExpirationLeeway)
83+
return now.After(expirationTimestampNumeric.Time), nil
84+
}

0 commit comments

Comments
 (0)