Skip to content

Commit 39e569c

Browse files
committed
simplify configuration of oauth clients, support identity federation
OAuth can now be provided directly to the Client type. Using OAuthConfig.HTTPClient() is now deprecated. Updates tailscale/terraform-provider-tailscale#485 Signed-off-by: mcoulombe <[email protected]>
1 parent 4c357f9 commit 39e569c

File tree

5 files changed

+750
-3
lines changed

5 files changed

+750
-3
lines changed

README.md

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,67 @@ import (
4343
func main() {
4444
client := &tailscale.Client{
4545
Tailnet: os.Getenv("TAILSCALE_TAILNET"),
46-
HTTP: tailscale.OAuthConfig{
46+
Auth: &tailscale.OAuth{
4747
ClientID: os.Getenv("TAILSCALE_OAUTH_CLIENT_ID"),
4848
ClientSecret: os.Getenv("TAILSCALE_OAUTH_CLIENT_SECRET"),
4949
Scopes: []string{"all:write"},
50-
}.HTTPClient(),
50+
},
51+
}
52+
53+
devices, err := client.Devices().List(context.Background())
54+
}
55+
```
56+
57+
## Example (Using Identity Federation)
58+
59+
```go
60+
package main
61+
62+
import (
63+
"context"
64+
"os"
65+
66+
"tailscale.com/client/tailscale/v2"
67+
)
68+
69+
func main() {
70+
client := &tailscale.Client{
71+
Tailnet: os.Getenv("TAILSCALE_TAILNET"),
72+
Auth: &tailscale.IdentityFederation{
73+
ClientID: os.Getenv("TAILSCALE_OAUTH_CLIENT_ID"),
74+
IDTokenFunc: func() (string, error) {
75+
return os.Getenv("IDENTITY_TOKEN"), nil
76+
},
77+
},
78+
}
79+
80+
devices, err := client.Devices().List(context.Background())
81+
}
82+
```
83+
84+
## Example (Using Your Own Authentication Mechanism)
85+
86+
```go
87+
package main
88+
89+
import (
90+
"context"
91+
"os"
92+
93+
"tailscale.com/client/tailscale/v2"
94+
)
95+
96+
type MyAuth struct {...}
97+
98+
func (a *MyAuth) HTTPClient(orig *http.Client, baseURL string) *http.Client {
99+
// build an HTTP client that adds authentication to outgoing requests
100+
// see tailscale.OAuth for an example.
101+
}
102+
103+
func main() {
104+
client := &tailscale.Client{
105+
Tailnet: os.Getenv("TAILSCALE_TAILNET"),
106+
Auth: &MyAuth{...},
51107
}
52108

53109
devices, err := client.Devices().List(context.Background())

client.go

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,27 @@ import (
2121
"github.com/tailscale/hujson"
2222
)
2323

24+
// Auth is a pluggable mechanism for authenticating requests.
25+
type Auth interface {
26+
// HTTPClient builds an http.Client that uses orig as a starting point and
27+
// adds its own authentication to outgoing requests. baseURL is the base URL
28+
// of the API server to which we will be authenticating.
29+
HTTPClient(orig *http.Client, baseURL string) *http.Client
30+
}
31+
2432
// Client is used to perform actions against the Tailscale API.
2533
type Client struct {
2634
// BaseURL is the base URL for accessing the Tailscale API server. Defaults to https://api.tailscale.com.
2735
BaseURL *url.URL
2836
// UserAgent configures the User-Agent HTTP header for requests. Defaults to "tailscale-client-go".
2937
UserAgent string
3038
// APIKey allows specifying an APIKey to use for authentication.
31-
// To use OAuth Client credentials, construct an [http.Client] using [OAuthConfig] and specify that below.
39+
// To use OAuth Client credentials, specify OAuth in the Auth field instead.
40+
// To use Identity Federation, specify IdentityFederation in the Auth field instead.
3241
APIKey string
42+
// Auth specifies a mechanism for adding authentication to outgoing requests.
43+
// If provided, APIKey is ignored.
44+
Auth Auth
3345
// Tailnet allows specifying a specific tailnet by name, to which this Client will connect by default.
3446
// If Tailnet is left blank, the client will connect to default tailnet based on the client's credential,
3547
// using the "-" (dash) default tailnet path.
@@ -97,6 +109,10 @@ func (c *Client) init() {
97109
if c.Tailnet == "" {
98110
c.Tailnet = "-"
99111
}
112+
if c.Auth != nil {
113+
c.APIKey = ""
114+
c.HTTP = c.Auth.HTTPClient(c.HTTP, c.BaseURL.String())
115+
}
100116
c.contacts = &ContactsResource{c}
101117
c.devicePosture = &DevicePostureResource{c}
102118
c.devices = &DevicesResource{c}

identityfederation.go

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
// Copyright (c) David Bond, Tailscale Inc, & Contributors
2+
// SPDX-License-Identifier: MIT
3+
4+
package tailscale
5+
6+
import (
7+
"context"
8+
"encoding/base64"
9+
"encoding/json"
10+
"fmt"
11+
"io"
12+
"net/http"
13+
"net/url"
14+
"strings"
15+
"sync"
16+
"time"
17+
18+
"golang.org/x/oauth2"
19+
)
20+
21+
var _ Auth = &IdentityFederation{}
22+
23+
// tokenExchangeResponse represents the response from the Tailscale token exchange endpoint.
24+
type tokenExchangeResponse struct {
25+
AccessToken string `json:"access_token"`
26+
TokenType string `json:"token_type"`
27+
ExpiresIn int `json:"expires_in"` // in seconds
28+
Scope string `json:"scope"`
29+
}
30+
31+
// jwtClaims represents the claims in a JWT token (minimal set for validation).
32+
type jwtClaims struct {
33+
Exp int64 `json:"exp"`
34+
}
35+
36+
// IdentityFederation configures identity federation authentication.
37+
type IdentityFederation struct {
38+
// ClientID is the ID of the Tailscale OAuth client.
39+
ClientID string
40+
// IDTokenFunc returns an identity token from the IdP to exchange for a Tailscale API token.
41+
// The client calls this function to obtain a fresh ID token and reauthenticate when the API token
42+
// and cached ID token have expired. For static tokens, return the token directly. If a static token
43+
// expires, the client cannot automatically refresh the API token; the consumer is responsible to create a new client
44+
// with a fresh ID token.
45+
IDTokenFunc func() (string, error)
46+
}
47+
48+
// identityFederationTokenSource implements oauth2.TokenSource using identity federation.
49+
type identityFederationTokenSource struct {
50+
http *http.Client
51+
baseURL string
52+
clientID string
53+
idTokenFunc func() (string, error)
54+
55+
mu sync.Mutex // protects the below fields
56+
idToken string
57+
}
58+
59+
// HTTPClient implements the [Auth] interface.
60+
func (i *IdentityFederation) HTTPClient(orig *http.Client, baseURL string) *http.Client {
61+
s := &identityFederationTokenSource{
62+
http: orig,
63+
baseURL: baseURL,
64+
clientID: i.ClientID,
65+
idTokenFunc: i.IDTokenFunc,
66+
}
67+
68+
return &http.Client{
69+
Transport: &oauth2.Transport{
70+
Base: orig.Transport,
71+
Source: oauth2.ReuseTokenSource(nil, s),
72+
},
73+
CheckRedirect: orig.CheckRedirect,
74+
Jar: orig.Jar,
75+
Timeout: orig.Timeout,
76+
}
77+
}
78+
79+
// Token implements oauth2.TokenSource by exchanging an ID token for an API access token.
80+
func (i *identityFederationTokenSource) Token() (*oauth2.Token, error) {
81+
i.mu.Lock()
82+
defer i.mu.Unlock()
83+
84+
if i.idToken == "" || validateIDToken(i.idToken) != nil {
85+
idToken, err := i.idTokenFunc()
86+
if err != nil {
87+
return nil, fmt.Errorf("failed to fetch ID token: %w", err)
88+
}
89+
if err := validateIDToken(idToken); err != nil {
90+
return nil, fmt.Errorf("fetched ID token is invalid: %w", err)
91+
}
92+
i.idToken = idToken
93+
}
94+
95+
exchangeURL := fmt.Sprintf("%s/api/v2/oauth/token-exchange", i.baseURL)
96+
values := url.Values{
97+
"client_id": {i.clientID},
98+
"jwt": {i.idToken},
99+
}.Encode()
100+
101+
req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, exchangeURL, strings.NewReader(values))
102+
if err != nil {
103+
return nil, fmt.Errorf("failed to create token exchange request: %w", err)
104+
}
105+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
106+
107+
resp, err := i.http.Do(req)
108+
if err != nil {
109+
return nil, fmt.Errorf("unexpected token exchange request error: %w", err)
110+
}
111+
defer resp.Body.Close()
112+
113+
if resp.StatusCode >= http.StatusBadRequest {
114+
b, _ := io.ReadAll(resp.Body)
115+
return nil, fmt.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(b))
116+
}
117+
118+
var tokenResp tokenExchangeResponse
119+
if err = json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
120+
return nil, fmt.Errorf("failed to decode token exchange response: %w", err)
121+
}
122+
123+
return &oauth2.Token{
124+
AccessToken: tokenResp.AccessToken,
125+
TokenType: tokenResp.TokenType,
126+
Expiry: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second),
127+
}, nil
128+
}
129+
130+
// validateIDToken decodes and validates the ID token's expiration claim
131+
// to give a more helpful error if the token is expired or malformed.
132+
func validateIDToken(idToken string) error {
133+
parts := strings.Split(idToken, ".")
134+
if len(parts) != 3 {
135+
return fmt.Errorf("invalid JWT format: expected 3 parts separated by '.', got %d", len(parts))
136+
}
137+
138+
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
139+
if err != nil {
140+
return fmt.Errorf("failed to decode JWT payload: %w", err)
141+
}
142+
143+
var claims jwtClaims
144+
if err := json.Unmarshal(payload, &claims); err != nil {
145+
return fmt.Errorf("failed to parse JWT claims: %w", err)
146+
}
147+
148+
if claims.Exp == 0 {
149+
return fmt.Errorf("JWT is missing 'exp' (expiration) claim")
150+
}
151+
152+
expirationTime := time.Unix(claims.Exp, 0)
153+
if time.Now().After(expirationTime) {
154+
return fmt.Errorf("ID token has expired (expired at %s)", expirationTime.Format(time.RFC3339))
155+
}
156+
157+
return nil
158+
}

0 commit comments

Comments
 (0)