Skip to content

Commit 88abb0d

Browse files
authored
Decoupled code from DefaultSigningKey (#16743)
Decoupled code from `DefaultSigningKey`. Makes testing a little bit easier and is cleaner.
1 parent cd8db3a commit 88abb0d

File tree

4 files changed

+27
-27
lines changed

4 files changed

+27
-27
lines changed

routers/web/user/oauth.go

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ type AccessTokenResponse struct {
115115
IDToken string `json:"id_token,omitempty"`
116116
}
117117

118-
func newAccessTokenResponse(grant *models.OAuth2Grant, signingKey oauth2.JWTSigningKey) (*AccessTokenResponse, *AccessTokenError) {
118+
func newAccessTokenResponse(grant *models.OAuth2Grant, serverKey, clientKey oauth2.JWTSigningKey) (*AccessTokenResponse, *AccessTokenError) {
119119
if setting.OAuth2.InvalidateRefreshTokens {
120120
if err := grant.IncreaseCounter(); err != nil {
121121
return nil, &AccessTokenError{
@@ -133,7 +133,7 @@ func newAccessTokenResponse(grant *models.OAuth2Grant, signingKey oauth2.JWTSign
133133
ExpiresAt: expirationDate.AsTime().Unix(),
134134
},
135135
}
136-
signedAccessToken, err := accessToken.SignToken()
136+
signedAccessToken, err := accessToken.SignToken(serverKey)
137137
if err != nil {
138138
return nil, &AccessTokenError{
139139
ErrorCode: AccessTokenErrorCodeInvalidRequest,
@@ -151,7 +151,7 @@ func newAccessTokenResponse(grant *models.OAuth2Grant, signingKey oauth2.JWTSign
151151
ExpiresAt: refreshExpirationDate,
152152
},
153153
}
154-
signedRefreshToken, err := refreshToken.SignToken()
154+
signedRefreshToken, err := refreshToken.SignToken(serverKey)
155155
if err != nil {
156156
return nil, &AccessTokenError{
157157
ErrorCode: AccessTokenErrorCodeInvalidRequest,
@@ -207,7 +207,7 @@ func newAccessTokenResponse(grant *models.OAuth2Grant, signingKey oauth2.JWTSign
207207
idToken.EmailVerified = user.IsActive
208208
}
209209

210-
signedIDToken, err = idToken.SignToken(signingKey)
210+
signedIDToken, err = idToken.SignToken(clientKey)
211211
if err != nil {
212212
return nil, &AccessTokenError{
213213
ErrorCode: AccessTokenErrorCodeInvalidRequest,
@@ -265,7 +265,7 @@ func IntrospectOAuth(ctx *context.Context) {
265265
}
266266

267267
form := web.GetForm(ctx).(*forms.IntrospectTokenForm)
268-
token, err := oauth2.ParseToken(form.Token)
268+
token, err := oauth2.ParseToken(form.Token, oauth2.DefaultSigningKey)
269269
if err == nil {
270270
if token.Valid() == nil {
271271
grant, err := models.GetOAuth2GrantByID(token.GrantID)
@@ -544,24 +544,25 @@ func AccessTokenOAuth(ctx *context.Context) {
544544
}
545545
}
546546

547-
signingKey := oauth2.DefaultSigningKey
548-
if signingKey.IsSymmetric() {
549-
clientKey, err := oauth2.CreateJWTSigningKey(signingKey.SigningMethod().Alg(), []byte(form.ClientSecret))
547+
serverKey := oauth2.DefaultSigningKey
548+
clientKey := serverKey
549+
if serverKey.IsSymmetric() {
550+
var err error
551+
clientKey, err = oauth2.CreateJWTSigningKey(serverKey.SigningMethod().Alg(), []byte(form.ClientSecret))
550552
if err != nil {
551553
handleAccessTokenError(ctx, AccessTokenError{
552554
ErrorCode: AccessTokenErrorCodeInvalidRequest,
553555
ErrorDescription: "Error creating signing key",
554556
})
555557
return
556558
}
557-
signingKey = clientKey
558559
}
559560

560561
switch form.GrantType {
561562
case "refresh_token":
562-
handleRefreshToken(ctx, form, signingKey)
563+
handleRefreshToken(ctx, form, serverKey, clientKey)
563564
case "authorization_code":
564-
handleAuthorizationCode(ctx, form, signingKey)
565+
handleAuthorizationCode(ctx, form, serverKey, clientKey)
565566
default:
566567
handleAccessTokenError(ctx, AccessTokenError{
567568
ErrorCode: AccessTokenErrorCodeUnsupportedGrantType,
@@ -570,8 +571,8 @@ func AccessTokenOAuth(ctx *context.Context) {
570571
}
571572
}
572573

573-
func handleRefreshToken(ctx *context.Context, form forms.AccessTokenForm, signingKey oauth2.JWTSigningKey) {
574-
token, err := oauth2.ParseToken(form.RefreshToken)
574+
func handleRefreshToken(ctx *context.Context, form forms.AccessTokenForm, serverKey, clientKey oauth2.JWTSigningKey) {
575+
token, err := oauth2.ParseToken(form.RefreshToken, serverKey)
575576
if err != nil {
576577
handleAccessTokenError(ctx, AccessTokenError{
577578
ErrorCode: AccessTokenErrorCodeUnauthorizedClient,
@@ -598,15 +599,15 @@ func handleRefreshToken(ctx *context.Context, form forms.AccessTokenForm, signin
598599
log.Warn("A client tried to use a refresh token for grant_id = %d was used twice!", grant.ID)
599600
return
600601
}
601-
accessToken, tokenErr := newAccessTokenResponse(grant, signingKey)
602+
accessToken, tokenErr := newAccessTokenResponse(grant, serverKey, clientKey)
602603
if tokenErr != nil {
603604
handleAccessTokenError(ctx, *tokenErr)
604605
return
605606
}
606607
ctx.JSON(http.StatusOK, accessToken)
607608
}
608609

609-
func handleAuthorizationCode(ctx *context.Context, form forms.AccessTokenForm, signingKey oauth2.JWTSigningKey) {
610+
func handleAuthorizationCode(ctx *context.Context, form forms.AccessTokenForm, serverKey, clientKey oauth2.JWTSigningKey) {
610611
app, err := models.GetOAuth2ApplicationByClientID(form.ClientID)
611612
if err != nil {
612613
handleAccessTokenError(ctx, AccessTokenError{
@@ -660,7 +661,7 @@ func handleAuthorizationCode(ctx *context.Context, form forms.AccessTokenForm, s
660661
ErrorDescription: "cannot proceed your request",
661662
})
662663
}
663-
resp, tokenErr := newAccessTokenResponse(authorizationCode.Grant, signingKey)
664+
resp, tokenErr := newAccessTokenResponse(authorizationCode.Grant, serverKey, clientKey)
664665
if tokenErr != nil {
665666
handleAccessTokenError(ctx, *tokenErr)
666667
return

routers/web/user/oauth_test.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,8 @@ func createAndParseToken(t *testing.T, grant *models.OAuth2Grant) *oauth2.OIDCTo
1818
signingKey, err := oauth2.CreateJWTSigningKey("HS256", make([]byte, 32))
1919
assert.NoError(t, err)
2020
assert.NotNil(t, signingKey)
21-
oauth2.DefaultSigningKey = signingKey
2221

23-
response, terr := newAccessTokenResponse(grant, signingKey)
22+
response, terr := newAccessTokenResponse(grant, signingKey, signingKey)
2423
assert.Nil(t, terr)
2524
assert.NotNil(t, response)
2625

services/auth/oauth2.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ func CheckOAuthAccessToken(accessToken string) int64 {
2929
if !strings.Contains(accessToken, ".") {
3030
return 0
3131
}
32-
token, err := oauth2.ParseToken(accessToken)
32+
token, err := oauth2.ParseToken(accessToken, oauth2.DefaultSigningKey)
3333
if err != nil {
34-
log.Trace("ParseOAuth2Token: %v", err)
34+
log.Trace("oauth2.ParseToken: %v", err)
3535
return 0
3636
}
3737
var grant *models.OAuth2Grant

services/auth/source/oauth2/token.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,12 @@ type Token struct {
4040
}
4141

4242
// ParseToken parses a signed jwt string
43-
func ParseToken(jwtToken string) (*Token, error) {
43+
func ParseToken(jwtToken string, signingKey JWTSigningKey) (*Token, error) {
4444
parsedToken, err := jwt.ParseWithClaims(jwtToken, &Token{}, func(token *jwt.Token) (interface{}, error) {
45-
if token.Method == nil || token.Method.Alg() != DefaultSigningKey.SigningMethod().Alg() {
45+
if token.Method == nil || token.Method.Alg() != signingKey.SigningMethod().Alg() {
4646
return nil, fmt.Errorf("unexpected signing algo: %v", token.Header["alg"])
4747
}
48-
return DefaultSigningKey.VerifyKey(), nil
48+
return signingKey.VerifyKey(), nil
4949
})
5050
if err != nil {
5151
return nil, err
@@ -59,11 +59,11 @@ func ParseToken(jwtToken string) (*Token, error) {
5959
}
6060

6161
// SignToken signs the token with the JWT secret
62-
func (token *Token) SignToken() (string, error) {
62+
func (token *Token) SignToken(signingKey JWTSigningKey) (string, error) {
6363
token.IssuedAt = time.Now().Unix()
64-
jwtToken := jwt.NewWithClaims(DefaultSigningKey.SigningMethod(), token)
65-
DefaultSigningKey.PreProcessToken(jwtToken)
66-
return jwtToken.SignedString(DefaultSigningKey.SignKey())
64+
jwtToken := jwt.NewWithClaims(signingKey.SigningMethod(), token)
65+
signingKey.PreProcessToken(jwtToken)
66+
return jwtToken.SignedString(signingKey.SignKey())
6767
}
6868

6969
// OIDCToken represents an OpenID Connect id_token

0 commit comments

Comments
 (0)