Skip to content

Commit 53e8ff6

Browse files
committed
feat: Support Workload Identity Federation flow
Signed-off-by: Jorge Turrado <[email protected]>
1 parent 0e1474e commit 53e8ff6

File tree

11 files changed

+1230
-547
lines changed

11 files changed

+1230
-547
lines changed

core/auth/auth.go

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@ func SetupAuth(cfg *config.Configuration) (rt http.RoundTripper, err error) {
4444

4545
if cfg.CustomAuth != nil {
4646
return cfg.CustomAuth, nil
47+
} else if useWorkloadIdentityFederation(cfg) {
48+
wifRoundTripper, err := WorkloadIdentityFederationAuth(cfg)
49+
if err != nil {
50+
return nil, fmt.Errorf("configuring no auth client: %w", err)
51+
}
52+
return wifRoundTripper, nil
4753
} else if cfg.NoAuth {
4854
noAuthRoundTripper, err := NoAuth(cfg)
4955
if err != nil {
@@ -83,14 +89,18 @@ func DefaultAuth(cfg *config.Configuration) (rt http.RoundTripper, err error) {
8389
cfg = &config.Configuration{}
8490
}
8591

86-
// Key flow
87-
rt, err = KeyAuth(cfg)
92+
// WIF flow
93+
rt, err = WorkloadIdentityFederationAuth(cfg)
8894
if err != nil {
89-
keyFlowErr := err
90-
// Token flow
91-
rt, err = TokenAuth(cfg)
95+
// Key flow
96+
rt, err = KeyAuth(cfg)
9297
if err != nil {
93-
return nil, fmt.Errorf("no valid credentials were found: trying key flow: %s, trying token flow: %w", keyFlowErr.Error(), err)
98+
keyFlowErr := err
99+
// Token flow
100+
rt, err = TokenAuth(cfg)
101+
if err != nil {
102+
return nil, fmt.Errorf("no valid credentials were found: trying key flow: %s, trying token flow: %w", keyFlowErr.Error(), err)
103+
}
94104
}
95105
}
96106
return rt, nil
@@ -216,6 +226,29 @@ func KeyAuth(cfg *config.Configuration) (http.RoundTripper, error) {
216226
return client, nil
217227
}
218228

229+
// WorkloadIdentityFederationAuth configures the wif flow and returns an http.RoundTripper
230+
// that can be used to make authenticated requests using an access token
231+
func WorkloadIdentityFederationAuth(cfg *config.Configuration) (http.RoundTripper, error) {
232+
wifConfig := clients.WorkloadIdentityFederationFlowConfig{
233+
TokenUrl: cfg.TokenCustomUrl,
234+
BackgroundTokenRefreshContext: cfg.BackgroundTokenRefreshContext,
235+
ClientID: cfg.ServiceAccountEmail,
236+
FederatedTokenFilePath: cfg.WorkloadIdentityFederationFederatedTokenPath,
237+
TokenExpiration: cfg.WorkloadIdentityFederationTokenExpiration,
238+
}
239+
240+
if cfg.HTTPClient != nil && cfg.HTTPClient.Transport != nil {
241+
wifConfig.HTTPTransport = cfg.HTTPClient.Transport
242+
}
243+
244+
client := &clients.WorkloadIdentityFederationFlow{}
245+
if err := client.Init(&wifConfig); err != nil {
246+
return nil, fmt.Errorf("error initializing client: %w", err)
247+
}
248+
249+
return client, nil
250+
}
251+
219252
// readCredentialsFile reads the credentials file from the specified path and returns Credentials
220253
func readCredentialsFile(path string) (*Credentials, error) {
221254
if path == "" {
@@ -356,3 +389,11 @@ func getServiceAccountKey(cfg *config.Configuration) error {
356389
func getPrivateKey(cfg *config.Configuration) error {
357390
return getKey(&cfg.PrivateKey, &cfg.PrivateKeyPath, "STACKIT_PRIVATE_KEY_PATH", "STACKIT_PRIVATE_KEY", privateKeyPathCredentialType, privateKeyCredentialType, cfg.CredentialsFilePath)
358391
}
392+
393+
func useWorkloadIdentityFederation(cfg *config.Configuration) bool {
394+
if cfg != nil && cfg.WorkloadIdentityFederation {
395+
return true
396+
}
397+
val, exists := os.LookupEnv(clients.FederatedTokenFileEnv)
398+
return exists && val != ""
399+
}

core/auth/auth_test.go

Lines changed: 87 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"testing"
1313
"time"
1414

15+
"github.com/golang-jwt/jwt/v5"
1516
"github.com/google/uuid"
1617
"github.com/stackitcloud/stackit-sdk-go/core/clients"
1718
"github.com/stackitcloud/stackit-sdk-go/core/config"
@@ -120,6 +121,32 @@ func TestSetupAuth(t *testing.T) {
120121
}
121122
}()
122123

124+
// create a wif assertion file
125+
wifAssertionFile, errs := os.CreateTemp("", "temp-*.txt")
126+
if errs != nil {
127+
t.Fatalf("Creating temporary file: %s", err)
128+
}
129+
defer func() {
130+
_ = wifAssertionFile.Close()
131+
err := os.Remove(wifAssertionFile.Name())
132+
if err != nil {
133+
t.Fatalf("Removing temporary file: %s", err)
134+
}
135+
}()
136+
137+
token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{
138+
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)),
139+
Subject: "sub",
140+
}).SignedString([]byte("test"))
141+
if err != nil {
142+
t.Fatalf("Removing temporary file: %s", err)
143+
}
144+
145+
_, errs = wifAssertionFile.WriteString(string(token))
146+
if errs != nil {
147+
t.Fatalf("Writing wif assertion to temporary file: %s", err)
148+
}
149+
123150
// create a credentials file with saKey and private key
124151
credentialsKeyFile, errs := os.CreateTemp("", "temp-*.txt")
125152
if errs != nil {
@@ -146,48 +173,48 @@ func TestSetupAuth(t *testing.T) {
146173
desc string
147174
config *config.Configuration
148175
setToken bool
176+
setWorkloadIdentity bool
149177
setKeys bool
150178
setKeyPaths bool
151179
setCredentialsFilePathToken bool
152180
setCredentialsFilePathKey bool
153-
isValid bool
154181
}{
182+
{
183+
desc: "wif_config",
184+
config: nil,
185+
setWorkloadIdentity: true,
186+
},
155187
{
156188
desc: "token_config",
157189
config: nil,
158190
setToken: true,
159191
setCredentialsFilePathToken: false,
160-
isValid: true,
161192
},
162193
{
163194
desc: "key_config",
164195
config: nil,
165196
setKeys: true,
166197
setCredentialsFilePathToken: false,
167-
isValid: true,
168198
},
169199
{
170200
desc: "key_config_path",
171201
config: nil,
172202
setKeys: false,
173203
setKeyPaths: true,
174204
setCredentialsFilePathToken: false,
175-
isValid: true,
176205
},
177206
{
178207
desc: "key_config_credentials_path",
179208
config: nil,
180209
setKeys: false,
181210
setKeyPaths: false,
182211
setCredentialsFilePathKey: true,
183-
isValid: true,
184212
},
185213
{
186214
desc: "valid_path_to_file",
187215
config: nil,
188216
setToken: false,
189217
setCredentialsFilePathToken: true,
190-
isValid: true,
191218
},
192219
{
193220
desc: "custom_config_token",
@@ -196,7 +223,6 @@ func TestSetupAuth(t *testing.T) {
196223
},
197224
setToken: false,
198225
setCredentialsFilePathToken: false,
199-
isValid: true,
200226
},
201227
{
202228
desc: "custom_config_path",
@@ -205,7 +231,6 @@ func TestSetupAuth(t *testing.T) {
205231
},
206232
setToken: false,
207233
setCredentialsFilePathToken: false,
208-
isValid: true,
209234
},
210235
} {
211236
t.Run(test.desc, func(t *testing.T) {
@@ -240,19 +265,21 @@ func TestSetupAuth(t *testing.T) {
240265
t.Setenv("STACKIT_CREDENTIALS_PATH", "")
241266
}
242267

268+
if test.setWorkloadIdentity {
269+
t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", wifAssertionFile.Name())
270+
} else {
271+
t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", "")
272+
}
273+
243274
t.Setenv("STACKIT_SERVICE_ACCOUNT_EMAIL", "test-email")
244275

245276
authRoundTripper, err := SetupAuth(test.config)
246277

247-
if err != nil && test.isValid {
278+
if err != nil {
248279
t.Fatalf("Test returned error on valid test case: %v", err)
249280
}
250281

251-
if err == nil && !test.isValid {
252-
t.Fatalf("Test didn't return error on invalid test case")
253-
}
254-
255-
if test.isValid && authRoundTripper == nil {
282+
if authRoundTripper == nil {
256283
t.Fatalf("Roundtripper returned is nil for valid test case")
257284
}
258285
})
@@ -380,6 +407,32 @@ func TestDefaultAuth(t *testing.T) {
380407
t.Fatalf("Writing private key to temporary file: %s", err)
381408
}
382409

410+
// create a wif assertion file
411+
wifAssertionFile, errs := os.CreateTemp("", "temp-*.txt")
412+
if errs != nil {
413+
t.Fatalf("Creating temporary file: %s", err)
414+
}
415+
defer func() {
416+
_ = wifAssertionFile.Close()
417+
err := os.Remove(wifAssertionFile.Name())
418+
if err != nil {
419+
t.Fatalf("Removing temporary file: %s", err)
420+
}
421+
}()
422+
423+
token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{
424+
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)),
425+
Subject: "sub",
426+
}).SignedString([]byte("test"))
427+
if err != nil {
428+
t.Fatalf("Removing temporary file: %s", err)
429+
}
430+
431+
_, errs = wifAssertionFile.WriteString(string(token))
432+
if errs != nil {
433+
t.Fatalf("Writing wif assertion to temporary file: %s", err)
434+
}
435+
383436
// create a credentials file with saKey and private key
384437
credentialsKeyFile, errs := os.CreateTemp("", "temp-*.txt")
385438
if errs != nil {
@@ -408,6 +461,7 @@ func TestDefaultAuth(t *testing.T) {
408461
setKeyPaths bool
409462
setKeys bool
410463
setCredentialsFilePathKey bool
464+
setWorkloadIdentity bool
411465
isValid bool
412466
expectedFlow string
413467
}{
@@ -417,6 +471,14 @@ func TestDefaultAuth(t *testing.T) {
417471
isValid: true,
418472
expectedFlow: "token",
419473
},
474+
{
475+
desc: "wif_precedes_key_precedes_token",
476+
setToken: true,
477+
setKeyPaths: true,
478+
setWorkloadIdentity: true,
479+
isValid: true,
480+
expectedFlow: "wif",
481+
},
420482
{
421483
desc: "key_precedes_token",
422484
setToken: true,
@@ -474,6 +536,13 @@ func TestDefaultAuth(t *testing.T) {
474536
} else {
475537
t.Setenv("STACKIT_SERVICE_ACCOUNT_TOKEN", "")
476538
}
539+
540+
if test.setWorkloadIdentity {
541+
t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", wifAssertionFile.Name())
542+
} else {
543+
t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", "")
544+
}
545+
477546
t.Setenv("STACKIT_SERVICE_ACCOUNT_EMAIL", "test-email")
478547

479548
// Get the default authentication client and ensure that it's not nil
@@ -500,6 +569,10 @@ func TestDefaultAuth(t *testing.T) {
500569
if _, ok := authClient.(*clients.KeyFlow); !ok {
501570
t.Fatalf("Expected key flow, got %s", reflect.TypeOf(authClient))
502571
}
572+
case "wif":
573+
if _, ok := authClient.(*clients.WorkloadIdentityFederationFlow); !ok {
574+
t.Fatalf("Expected key flow, got %s", reflect.TypeOf(authClient))
575+
}
503576
}
504577
}
505578
})

core/clients/auth_flow.go

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
package clients
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"io"
8+
"net/http"
9+
"time"
10+
11+
"github.com/golang-jwt/jwt/v5"
12+
"github.com/stackitcloud/stackit-sdk-go/core/oapierror"
13+
)
14+
15+
const (
16+
defaultTokenExpirationLeeway = time.Second * 5
17+
)
18+
19+
type AuthFlow interface {
20+
RoundTrip(req *http.Request) (*http.Response, error)
21+
GetAccessToken() (string, error)
22+
GetBackgroundTokenRefreshContext() context.Context
23+
}
24+
25+
// TokenResponseBody is the API response
26+
// when requesting a new token
27+
type TokenResponseBody struct {
28+
AccessToken string `json:"access_token"`
29+
ExpiresIn int `json:"expires_in"`
30+
Scope string `json:"scope"`
31+
TokenType string `json:"token_type"`
32+
}
33+
34+
func parseTokenResponse(res *http.Response) (*TokenResponseBody, error) {
35+
if res == nil {
36+
return nil, fmt.Errorf("received bad response from API")
37+
}
38+
if res.StatusCode != http.StatusOK {
39+
body, err := io.ReadAll(res.Body)
40+
if err != nil {
41+
// Fail silently, omit body from error
42+
// We're trying to show error details, so it's unnecessary to fail because of this err
43+
body = []byte{}
44+
}
45+
return nil, &oapierror.GenericOpenAPIError{
46+
StatusCode: res.StatusCode,
47+
Body: body,
48+
}
49+
}
50+
body, err := io.ReadAll(res.Body)
51+
if err != nil {
52+
return nil, err
53+
}
54+
55+
token := &TokenResponseBody{}
56+
err = json.Unmarshal(body, token)
57+
if err != nil {
58+
return nil, fmt.Errorf("unmarshal token response: %w", err)
59+
}
60+
return token, nil
61+
}
62+
63+
func tokenExpired(token string, tokenExpirationLeeway time.Duration) (bool, error) {
64+
if token == "" {
65+
return true, nil
66+
}
67+
68+
// We can safely use ParseUnverified because we are not authenticating the user at this point.
69+
// We're just checking the expiration time
70+
tokenParsed, _, err := jwt.NewParser().ParseUnverified(token, &jwt.RegisteredClaims{})
71+
if err != nil {
72+
return false, fmt.Errorf("parse token: %w", err)
73+
}
74+
75+
expirationTimestampNumeric, err := tokenParsed.Claims.GetExpirationTime()
76+
if err != nil {
77+
return false, fmt.Errorf("get expiration timestamp: %w", err)
78+
}
79+
80+
// Pretend to be `tokenExpirationLeeway` into the future to avoid token expiring
81+
// between retrieving the token and upstream systems validating it.
82+
now := time.Now().Add(tokenExpirationLeeway)
83+
return now.After(expirationTimestampNumeric.Time), nil
84+
}

0 commit comments

Comments
 (0)