Skip to content

Commit 46d734e

Browse files
easyCZroboquat
authored andcommitted
[public-api] Authentication interceptors for connect API
1 parent 0a857e5 commit 46d734e

File tree

9 files changed

+353
-4
lines changed

9 files changed

+353
-4
lines changed

components/public-api-server/go.sum

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// Copyright (c) 2022 Gitpod GmbH. All rights reserved.
2+
// Licensed under the GNU Affero General Public License (AGPL).
3+
// See License-AGPL.txt in the project root for license information.
4+
5+
package auth
6+
7+
import (
8+
"errors"
9+
"fmt"
10+
"net/http"
11+
"strings"
12+
)
13+
14+
var (
15+
NoAccessToken = errors.New("missing access token")
16+
InvalidAccessToken = errors.New("invalid access token")
17+
)
18+
19+
const bearerPrefix = "Bearer "
20+
const authorizationHeaderKey = "Authorization"
21+
22+
func BearerTokenFromHeaders(h http.Header) (string, error) {
23+
authorization := strings.TrimSpace(h.Get(authorizationHeaderKey))
24+
if authorization == "" {
25+
return "", fmt.Errorf("empty authorization header: %w", NoAccessToken)
26+
}
27+
28+
if !strings.HasPrefix(authorization, bearerPrefix) {
29+
return "", fmt.Errorf("authorization header does not have a Bearer prefix: %w", NoAccessToken)
30+
}
31+
32+
return strings.TrimPrefix(authorization, bearerPrefix), nil
33+
}
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
// Copyright (c) 2022 Gitpod GmbH. All rights reserved.
2+
// Licensed under the GNU Affero General Public License (AGPL).
3+
// See License-AGPL.txt in the project root for license information.
4+
5+
package auth
6+
7+
import (
8+
"net/http"
9+
"testing"
10+
11+
"github.com/stretchr/testify/require"
12+
)
13+
14+
func TestBearerTokenFromHeaders(t *testing.T) {
15+
type Scenario struct {
16+
Name string
17+
18+
// Input
19+
Header http.Header
20+
21+
// Output
22+
Token string
23+
Error error
24+
}
25+
26+
for _, s := range []Scenario{
27+
{
28+
Name: "happy case",
29+
Header: addToHeader(http.Header{}, "Authorization", "Bearer foo"),
30+
Token: "foo",
31+
},
32+
{
33+
Name: "leading and trailing spaces are trimmed",
34+
Header: addToHeader(http.Header{}, "Authorization", " Bearer foo "),
35+
Token: "foo",
36+
},
37+
{
38+
Name: "anything after Bearer is extracted",
39+
Header: addToHeader(http.Header{}, "Authorization", "Bearer foo bar"),
40+
Token: "foo bar",
41+
},
42+
{
43+
Name: "duplicate bearer",
44+
Header: addToHeader(http.Header{}, "Authorization", "Bearer Bearer foo"),
45+
Token: "Bearer foo",
46+
},
47+
{
48+
Name: "missing Bearer prefix fails",
49+
Header: addToHeader(http.Header{}, "Authorization", "foo"),
50+
Error: NoAccessToken,
51+
},
52+
{
53+
Name: "missing Authorization header fails",
54+
Header: http.Header{},
55+
Error: NoAccessToken,
56+
},
57+
} {
58+
t.Run(s.Name, func(t *testing.T) {
59+
token, err := BearerTokenFromHeaders(s.Header)
60+
require.ErrorIs(t, err, s.Error)
61+
require.Equal(t, s.Token, token)
62+
})
63+
}
64+
}
65+
66+
func addToHeader(h http.Header, key, value string) http.Header {
67+
h.Add(key, value)
68+
return h
69+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// Copyright (c) 2022 Gitpod GmbH. All rights reserved.
2+
// Licensed under the GNU Affero General Public License (AGPL).
3+
// See License-AGPL.txt in the project root for license information.
4+
5+
package auth
6+
7+
import "context"
8+
9+
type contextKey int
10+
11+
const (
12+
authContextKey contextKey = iota
13+
)
14+
15+
func TokenToContext(ctx context.Context, token string) context.Context {
16+
return context.WithValue(ctx, authContextKey, token)
17+
}
18+
19+
func TokenFromContext(ctx context.Context) string {
20+
if val, ok := ctx.Value(authContextKey).(string); ok {
21+
return val
22+
}
23+
24+
return ""
25+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// Copyright (c) 2022 Gitpod GmbH. All rights reserved.
2+
// Licensed under the GNU Affero General Public License (AGPL).
3+
// See License-AGPL.txt in the project root for license information.
4+
5+
package auth
6+
7+
import (
8+
"context"
9+
"testing"
10+
11+
"github.com/stretchr/testify/require"
12+
)
13+
14+
func TestTokenToAndFromContext(t *testing.T) {
15+
token := "my_token"
16+
17+
extracted := TokenFromContext(TokenToContext(context.Background(), token))
18+
require.Equal(t, token, extracted)
19+
}
20+
21+
func TestTokenFromContext_EmptyWhenNotSet(t *testing.T) {
22+
require.Equal(t, "", TokenFromContext(context.Background()))
23+
}
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
// Copyright (c) 2022 Gitpod GmbH. All rights reserved.
2+
// Licensed under the GNU Affero General Public License (AGPL).
3+
// See License-AGPL.txt in the project root for license information.
4+
5+
package auth
6+
7+
import (
8+
"context"
9+
10+
"github.com/bufbuild/connect-go"
11+
)
12+
13+
// NewServerInterceptor creates a server-side interceptor which validates that an incoming request contains a Bearer Authorization header
14+
func NewServerInterceptor() connect.UnaryInterceptorFunc {
15+
interceptor := func(next connect.UnaryFunc) connect.UnaryFunc {
16+
return connect.UnaryFunc(func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
17+
18+
if req.Spec().IsClient {
19+
return next(ctx, req)
20+
}
21+
22+
headers := req.Header()
23+
24+
token, err := BearerTokenFromHeaders(headers)
25+
if err != nil {
26+
return nil, connect.NewError(connect.CodeUnauthenticated, err)
27+
28+
}
29+
return next(TokenToContext(ctx, token), req)
30+
})
31+
}
32+
33+
return connect.UnaryInterceptorFunc(interceptor)
34+
}
35+
36+
// NewClientInterceptor creates a client-side interceptor which injects token as a Bearer Authorization header
37+
func NewClientInterceptor(token string) connect.UnaryInterceptorFunc {
38+
interceptor := func(next connect.UnaryFunc) connect.UnaryFunc {
39+
return connect.UnaryFunc(func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
40+
41+
if !req.Spec().IsClient {
42+
return next(ctx, req)
43+
}
44+
45+
req.Header().Add(authorizationHeaderKey, bearerPrefix+token)
46+
return next(ctx, req)
47+
})
48+
}
49+
50+
return connect.UnaryInterceptorFunc(interceptor)
51+
}
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
// Copyright (c) 2022 Gitpod GmbH. All rights reserved.
2+
// Licensed under the GNU Affero General Public License (AGPL).
3+
// See License-AGPL.txt in the project root for license information.
4+
5+
package auth
6+
7+
import (
8+
"context"
9+
"fmt"
10+
"net/http"
11+
"net/http/httptest"
12+
"testing"
13+
14+
"github.com/bufbuild/connect-go"
15+
"github.com/stretchr/testify/require"
16+
)
17+
18+
func TestNewServerInterceptor(t *testing.T) {
19+
requestPayload := "request"
20+
type TokenResponse struct {
21+
Token string `json:"token"`
22+
}
23+
24+
type Header struct {
25+
Key string
26+
Value string
27+
}
28+
29+
handler := connect.UnaryFunc(func(ctx context.Context, ar connect.AnyRequest) (connect.AnyResponse, error) {
30+
token := TokenFromContext(ctx)
31+
return connect.NewResponse(&TokenResponse{Token: token}), nil
32+
})
33+
34+
scenarios := []struct {
35+
Name string
36+
37+
Headers []Header
38+
39+
ExpectedError error
40+
ExpectedToken string
41+
}{
42+
{
43+
Name: "no headers return Unathenticated",
44+
Headers: nil,
45+
ExpectedError: connect.NewError(connect.CodeUnauthenticated, fmt.Errorf("empty authorization header: %w", NoAccessToken)),
46+
},
47+
{
48+
Name: "authorization header with bearer token returns ok",
49+
Headers: []Header{{Key: "Authorization", Value: "Bearer foo"}},
50+
ExpectedToken: "foo",
51+
},
52+
}
53+
54+
for _, s := range scenarios {
55+
t.Run(s.Name, func(t *testing.T) {
56+
ctx := context.Background()
57+
request := connect.NewRequest(&requestPayload)
58+
59+
for _, header := range s.Headers {
60+
request.Header().Add(header.Key, header.Value)
61+
}
62+
63+
resp, err := NewServerInterceptor()(handler)(ctx, request)
64+
65+
require.Equal(t, s.ExpectedError, err)
66+
if err == nil {
67+
require.Equal(t, &TokenResponse{
68+
Token: s.ExpectedToken,
69+
}, resp.Any())
70+
}
71+
72+
})
73+
}
74+
}
75+
76+
func TestNewClientInterceptor(t *testing.T) {
77+
expectedToken := "my_token"
78+
79+
tokenOnRequest := ""
80+
// Setup a test server where we capture the token supplied, we don't actually care for the response.
81+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
82+
fmt.Println(r.Header)
83+
token, err := BearerTokenFromHeaders(r.Header)
84+
require.NoError(t, err)
85+
86+
// Capture the token supplied in the request so we can test for it
87+
tokenOnRequest = token
88+
w.WriteHeader(http.StatusNotFound)
89+
}))
90+
91+
client := connect.NewClient[any, any](http.DefaultClient, srv.URL, connect.WithInterceptors(
92+
NewClientInterceptor(expectedToken),
93+
))
94+
95+
_, _ = client.CallUnary(context.Background(), connect.NewRequest[any](nil))
96+
require.Equal(t, expectedToken, tokenOnRequest)
97+
}

components/public-api-server/pkg/server/integration_test.go

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212

1313
"github.com/bufbuild/connect-go"
1414
"github.com/gitpod-io/gitpod/common-go/baseserver"
15+
"github.com/gitpod-io/gitpod/public-api-server/pkg/auth"
1516
v1 "github.com/gitpod-io/gitpod/public-api/v1"
1617
"github.com/gitpod-io/gitpod/public-api/v1/v1connect"
1718
"github.com/stretchr/testify/require"
@@ -113,7 +114,7 @@ func TestPublicAPIServer_WorkspaceServiceHandler(t *testing.T) {
113114
require.NoError(t, register(srv, gitpodAPI))
114115
baseserver.StartServerForTests(t, srv)
115116

116-
client := v1connect.NewWorkspacesServiceClient(http.DefaultClient, srv.HTTPAddress())
117+
client := v1connect.NewWorkspacesServiceClient(http.DefaultClient, srv.HTTPAddress(), connect.WithInterceptors(auth.NewClientInterceptor("token")))
117118

118119
_, err = client.ListWorkspaces(ctx, connect.NewRequest(&v1.ListWorkspacesRequest{}))
119120
require.Equal(t, connect.CodeUnimplemented.String(), connect.CodeOf(err).String())
@@ -158,3 +159,42 @@ func requireErrorStatusCode(t *testing.T, expected codes.Code, err error) {
158159
require.True(t, ok)
159160
require.Equalf(t, expected, st.Code(), "expected: %s but got: %s", expected.String(), st.String())
160161
}
162+
163+
func TestConnectWorkspaceService_RequiresAuth(t *testing.T) {
164+
srv := baseserver.NewForTests(t,
165+
baseserver.WithHTTP(baseserver.MustUseRandomLocalAddress(t)),
166+
baseserver.WithGRPC(baseserver.MustUseRandomLocalAddress(t)),
167+
)
168+
169+
gitpodAPI, err := url.Parse("wss://main.preview.gitpod-dev.com/api/v1")
170+
require.NoError(t, err)
171+
172+
require.NoError(t, register(srv, gitpodAPI))
173+
174+
baseserver.StartServerForTests(t, srv)
175+
176+
clientWithoutAuth := v1connect.NewWorkspacesServiceClient(http.DefaultClient, srv.HTTPAddress())
177+
_, err = clientWithoutAuth.GetWorkspace(context.Background(), connect.NewRequest(&v1.GetWorkspaceRequest{WorkspaceId: "123"}))
178+
require.Error(t, err)
179+
require.Equal(t, connect.CodeUnauthenticated, connect.CodeOf(err))
180+
181+
}
182+
183+
func TestConnectPrebuildsService_RequiresAuth(t *testing.T) {
184+
srv := baseserver.NewForTests(t,
185+
baseserver.WithHTTP(baseserver.MustUseRandomLocalAddress(t)),
186+
baseserver.WithGRPC(baseserver.MustUseRandomLocalAddress(t)),
187+
)
188+
189+
gitpodAPI, err := url.Parse("wss://main.preview.gitpod-dev.com/api/v1")
190+
require.NoError(t, err)
191+
192+
require.NoError(t, register(srv, gitpodAPI))
193+
194+
baseserver.StartServerForTests(t, srv)
195+
196+
clientWithoutAuth := v1connect.NewPrebuildsServiceClient(http.DefaultClient, srv.HTTPAddress())
197+
_, err = clientWithoutAuth.GetPrebuild(context.Background(), connect.NewRequest(&v1.GetPrebuildRequest{PrebuildId: "123"}))
198+
require.Error(t, err)
199+
require.Equal(t, connect.CodeUnauthenticated, connect.CodeOf(err))
200+
}

0 commit comments

Comments
 (0)