Skip to content

Commit 4226f51

Browse files
committed
fix
1 parent 3531e9d commit 4226f51

File tree

12 files changed

+127
-70
lines changed

12 files changed

+127
-70
lines changed

modules/session/mem.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
// Copyright 2025 The Gitea Authors. All rights reserved.
2+
// SPDX-License-Identifier: MIT
3+
4+
package session
5+
6+
import (
7+
"bytes"
8+
"encoding/gob"
9+
"net/http"
10+
11+
"gitea.com/go-chi/session"
12+
)
13+
14+
type MemStore struct {
15+
s *session.MemStore
16+
}
17+
18+
var _ session.RawStore = (*MemStore)(nil)
19+
20+
func (m *MemStore) Set(k, v any) error {
21+
var buf bytes.Buffer
22+
if err := gob.NewEncoder(&buf).Encode(map[string]any{"v": v}); err != nil {
23+
return err
24+
}
25+
return m.s.Set(k, buf.Bytes())
26+
}
27+
28+
func (m *MemStore) Get(k any) (ret any) {
29+
v, ok := m.s.Get(k).([]byte)
30+
if !ok {
31+
return nil
32+
}
33+
var w map[string]any
34+
_ = gob.NewDecoder(bytes.NewBuffer(v)).Decode(&w)
35+
return w["v"]
36+
}
37+
38+
func (m *MemStore) Delete(k any) error {
39+
return m.s.Delete(k)
40+
}
41+
42+
func (m *MemStore) ID() string {
43+
return m.s.ID()
44+
}
45+
46+
func (m *MemStore) Release() error {
47+
return m.s.Release()
48+
}
49+
50+
func (m *MemStore) Flush() error {
51+
return m.s.Flush()
52+
}
53+
54+
type mockMemStore struct {
55+
*MemStore
56+
}
57+
58+
var _ Store = (*mockMemStore)(nil)
59+
60+
func (m mockMemStore) Destroy(writer http.ResponseWriter, request *http.Request) error {
61+
return nil
62+
}
63+
64+
func NewMockStore(sid string) Store {
65+
return &mockMemStore{&MemStore{session.NewMemStore(sid)}}
66+
}

modules/session/mock.go

Lines changed: 0 additions & 26 deletions
This file was deleted.

modules/session/store.go

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,34 +11,34 @@ import (
1111
"gitea.com/go-chi/session"
1212
)
1313

14-
// Store represents a session store
14+
type RawStore = session.RawStore
15+
1516
type Store interface {
16-
Get(any) any
17-
Set(any, any) error
18-
Delete(any) error
19-
ID() string
20-
Release() error
21-
Flush() error
17+
RawStore
2218
Destroy(http.ResponseWriter, *http.Request) error
2319
}
2420

21+
type mockStoreContextKeyStruct struct{}
22+
23+
var MockStoreContextKey = mockStoreContextKeyStruct{}
24+
2525
// RegenerateSession regenerates the underlying session and returns the new store
2626
func RegenerateSession(resp http.ResponseWriter, req *http.Request) (Store, error) {
2727
for _, f := range BeforeRegenerateSession {
2828
f(resp, req)
2929
}
3030
if setting.IsInTesting {
31-
if store, ok := req.Context().Value(MockStoreContextKey).(*MockStore); ok {
32-
return store, nil
31+
if store := req.Context().Value(MockStoreContextKey); store != nil {
32+
return store.(Store), nil
3333
}
3434
}
3535
return session.RegenerateSession(resp, req)
3636
}
3737

3838
func GetContextSession(req *http.Request) Store {
3939
if setting.IsInTesting {
40-
if store, ok := req.Context().Value(MockStoreContextKey).(*MockStore); ok {
41-
return store
40+
if store := req.Context().Value(MockStoreContextKey); store != nil {
41+
return store.(Store)
4242
}
4343
}
4444
return session.GetSession(req)

modules/session/virtual.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ type VirtualSessionProvider struct {
2222
provider session.Provider
2323
}
2424

25-
// Init initializes the cookie session provider with given root path.
26-
func (o *VirtualSessionProvider) Init(gclifetime int64, config string) error {
25+
// Init initializes the cookie session provider with the given config.
26+
func (o *VirtualSessionProvider) Init(gcLifetime int64, config string) error {
2727
var opts session.Options
2828
if err := json.Unmarshal([]byte(config), &opts); err != nil {
2929
return err
@@ -52,7 +52,7 @@ func (o *VirtualSessionProvider) Init(gclifetime int64, config string) error {
5252
default:
5353
return fmt.Errorf("VirtualSessionProvider: Unknown Provider: %s", opts.Provider)
5454
}
55-
return o.provider.Init(gclifetime, opts.ProviderConfig)
55+
return o.provider.Init(gcLifetime, opts.ProviderConfig)
5656
}
5757

5858
// Read returns raw session store by session ID.

routers/web/auth/auth.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,7 @@ func createUserInContext(ctx *context.Context, tpl templates.TplName, form any,
565565
oauth2LinkAccount(ctx, user, possibleLinkAccountData, true)
566566
return false // user is already created here, all redirects are handled
567567
case setting.OAuth2AccountLinkingLogin:
568-
showLinkingLogin(ctx, &possibleLinkAccountData.AuthSource, possibleLinkAccountData.GothUser)
568+
showLinkingLogin(ctx, possibleLinkAccountData.AuthSourceID, possibleLinkAccountData.GothUser)
569569
return false // user will be created only after linking login
570570
}
571571
}
@@ -633,7 +633,7 @@ func handleUserCreated(ctx *context.Context, u *user_model.User, possibleLinkAcc
633633

634634
// update external user information
635635
if possibleLinkAccountData != nil {
636-
if err := externalaccount.EnsureLinkExternalToUser(ctx, possibleLinkAccountData.AuthSource.ID, u, possibleLinkAccountData.GothUser); err != nil {
636+
if err := externalaccount.EnsureLinkExternalToUser(ctx, possibleLinkAccountData.AuthSourceID, u, possibleLinkAccountData.GothUser); err != nil {
637637
log.Error("EnsureLinkExternalToUser failed: %v", err)
638638
}
639639
}

routers/web/auth/auth_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ func TestUserLogin(t *testing.T) {
6464
func TestSignUpOAuth2Login(t *testing.T) {
6565
defer test.MockVariableValue(&setting.OAuth2Client.EnableAutoRegistration, true)()
6666

67+
_ = oauth2.Init(t.Context())
6768
addOAuth2Source(t, "dummy-auth-source", oauth2.Source{})
6869

6970
t.Run("OAuth2MissingField", func(t *testing.T) {

routers/web/auth/linkaccount.go

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ func LinkAccountPostSignIn(ctx *context.Context) {
170170
}
171171

172172
func oauth2LinkAccount(ctx *context.Context, u *user_model.User, linkAccountData *LinkAccountData, remember bool) {
173-
oauth2SignInSync(ctx, &linkAccountData.AuthSource, u, linkAccountData.GothUser)
173+
oauth2SignInSync(ctx, linkAccountData.AuthSourceID, u, linkAccountData.GothUser)
174174
if ctx.Written() {
175175
return
176176
}
@@ -185,7 +185,7 @@ func oauth2LinkAccount(ctx *context.Context, u *user_model.User, linkAccountData
185185
return
186186
}
187187

188-
err = externalaccount.LinkAccountToUser(ctx, linkAccountData.AuthSource.ID, u, linkAccountData.GothUser)
188+
err = externalaccount.LinkAccountToUser(ctx, linkAccountData.AuthSourceID, u, linkAccountData.GothUser)
189189
if err != nil {
190190
ctx.ServerError("UserLinkAccount", err)
191191
return
@@ -295,7 +295,7 @@ func LinkAccountPostRegister(ctx *context.Context) {
295295
Email: form.Email,
296296
Passwd: form.Password,
297297
LoginType: auth.OAuth2,
298-
LoginSource: linkAccountData.AuthSource.ID,
298+
LoginSource: linkAccountData.AuthSourceID,
299299
LoginName: linkAccountData.GothUser.UserID,
300300
}
301301

@@ -304,7 +304,12 @@ func LinkAccountPostRegister(ctx *context.Context) {
304304
return
305305
}
306306

307-
source := linkAccountData.AuthSource.Cfg.(*oauth2.Source)
307+
authSource, err := auth.GetSourceByID(ctx, linkAccountData.AuthSourceID)
308+
if err != nil {
309+
ctx.ServerError("GetSourceByID", err)
310+
return
311+
}
312+
source := authSource.Cfg.(*oauth2.Source)
308313
if err := syncGroupsToTeams(ctx, source, &linkAccountData.GothUser, u); err != nil {
309314
ctx.ServerError("SyncGroupsToTeams", err)
310315
return
@@ -318,5 +323,5 @@ func linkAccountFromContext(ctx *context.Context, user *user_model.User) error {
318323
if linkAccountData == nil {
319324
return errors.New("not in LinkAccount session")
320325
}
321-
return externalaccount.LinkAccountToUser(ctx, linkAccountData.AuthSource.ID, user, linkAccountData.GothUser)
326+
return externalaccount.LinkAccountToUser(ctx, linkAccountData.AuthSourceID, user, linkAccountData.GothUser)
322327
}

routers/web/auth/oauth.go

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
package auth
55

66
import (
7+
"encoding/gob"
78
"errors"
89
"fmt"
910
"html"
@@ -171,7 +172,7 @@ func SignInOAuthCallback(ctx *context.Context) {
171172
gothUser.RawData = make(map[string]any)
172173
}
173174
gothUser.RawData["__giteaAutoRegMissingFields"] = missingFields
174-
showLinkingLogin(ctx, authSource, gothUser)
175+
showLinkingLogin(ctx, authSource.ID, gothUser)
175176
return
176177
}
177178
u = &user_model.User{
@@ -192,7 +193,7 @@ func SignInOAuthCallback(ctx *context.Context) {
192193
u.IsAdmin = isAdmin.ValueOrDefault(user_service.UpdateOptionField[bool]{FieldValue: false}).FieldValue
193194
u.IsRestricted = isRestricted.ValueOrDefault(setting.Service.DefaultUserIsRestricted)
194195

195-
linkAccountData := &LinkAccountData{*authSource, gothUser}
196+
linkAccountData := &LinkAccountData{authSource.ID, gothUser}
196197
if setting.OAuth2Client.AccountLinking == setting.OAuth2AccountLinkingDisabled {
197198
linkAccountData = nil
198199
}
@@ -207,7 +208,7 @@ func SignInOAuthCallback(ctx *context.Context) {
207208
}
208209
} else {
209210
// no existing user is found, request attach or new account
210-
showLinkingLogin(ctx, authSource, gothUser)
211+
showLinkingLogin(ctx, authSource.ID, gothUser)
211212
return
212213
}
213214
}
@@ -272,23 +273,29 @@ func getUserAdminAndRestrictedFromGroupClaims(source *oauth2.Source, gothUser *g
272273
}
273274

274275
type LinkAccountData struct {
275-
AuthSource auth.Source
276-
GothUser goth.User
276+
AuthSourceID int64
277+
GothUser goth.User
277278
}
278279

279280
func oauth2GetLinkAccountData(ctx *context.Context) *LinkAccountData {
281+
gob.Register(LinkAccountData{})
280282
v, ok := ctx.Session.Get("linkAccountData").(LinkAccountData)
281283
if !ok {
282284
return nil
283285
}
284286
return &v
285287
}
286288

287-
func showLinkingLogin(ctx *context.Context, authSource *auth.Source, gothUser goth.User) {
288-
if err := updateSession(ctx, nil, map[string]any{
289-
"linkAccountData": LinkAccountData{*authSource, gothUser},
290-
}); err != nil {
291-
ctx.ServerError("updateSession", err)
289+
func Oauth2SetLinkAccountData(ctx *context.Context, linkAccountData LinkAccountData) error {
290+
gob.Register(LinkAccountData{})
291+
return updateSession(ctx, nil, map[string]any{
292+
"linkAccountData": linkAccountData,
293+
})
294+
}
295+
296+
func showLinkingLogin(ctx *context.Context, authSourceID int64, gothUser goth.User) {
297+
if err := Oauth2SetLinkAccountData(ctx, LinkAccountData{authSourceID, gothUser}); err != nil {
298+
ctx.ServerError("Oauth2SetLinkAccountData", err)
292299
return
293300
}
294301
ctx.Redirect(setting.AppSubURL + "/user/link_account")
@@ -313,7 +320,7 @@ func oauth2UpdateAvatarIfNeed(ctx *context.Context, url string, u *user_model.Us
313320
}
314321

315322
func handleOAuth2SignIn(ctx *context.Context, authSource *auth.Source, u *user_model.User, gothUser goth.User) {
316-
oauth2SignInSync(ctx, authSource, u, gothUser)
323+
oauth2SignInSync(ctx, authSource.ID, u, gothUser)
317324
if ctx.Written() {
318325
return
319326
}

routers/web/auth/oauth_signin_sync.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,14 @@ import (
1818
"github.com/markbates/goth"
1919
)
2020

21-
func oauth2SignInSync(ctx *context.Context, authSource *auth.Source, u *user_model.User, gothUser goth.User) {
21+
func oauth2SignInSync(ctx *context.Context, authSourceID int64, u *user_model.User, gothUser goth.User) {
2222
oauth2UpdateAvatarIfNeed(ctx, gothUser.AvatarURL, u)
2323

24+
authSource, err := auth.GetSourceByID(ctx, authSourceID)
25+
if err != nil {
26+
ctx.ServerError("GetSourceByID", err)
27+
return
28+
}
2429
oauth2Source, _ := authSource.Cfg.(*oauth2.Source)
2530
if !authSource.IsOAuth2() || oauth2Source == nil {
2631
ctx.ServerError("oauth2SignInSync", fmt.Errorf("source %s is not an OAuth2 source", gothUser.Provider))
@@ -45,7 +50,7 @@ func oauth2SignInSync(ctx *context.Context, authSource *auth.Source, u *user_mod
4550
}
4651
}
4752

48-
err := oauth2UpdateSSHPubIfNeed(ctx, authSource, &gothUser, u)
53+
err = oauth2UpdateSSHPubIfNeed(ctx, authSource, &gothUser, u)
4954
if err != nil {
5055
log.Error("Unable to sync OAuth2 SSH public key %s: %v", gothUser.Provider, err)
5156
}

services/auth/source/oauth2/store.go

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ import (
1111
"code.gitea.io/gitea/modules/log"
1212
session_module "code.gitea.io/gitea/modules/session"
1313

14-
chiSession "gitea.com/go-chi/session"
1514
"github.com/gorilla/sessions"
1615
)
1716

@@ -35,11 +34,11 @@ func (st *SessionsStore) New(r *http.Request, name string) (*sessions.Session, e
3534

3635
// getOrNew gets the session from the chi-session if it exists. Override permits the overriding of an unexpected object.
3736
func (st *SessionsStore) getOrNew(r *http.Request, name string, override bool) (*sessions.Session, error) {
38-
chiStore := chiSession.GetSession(r)
37+
store := session_module.GetContextSession(r)
3938

4039
session := sessions.NewSession(st, name)
4140

42-
rawData := chiStore.Get(name)
41+
rawData := store.Get(name)
4342
if rawData != nil {
4443
oldSession, ok := rawData.(*sessions.Session)
4544
if ok {
@@ -56,21 +55,21 @@ func (st *SessionsStore) getOrNew(r *http.Request, name string, override bool) (
5655
}
5756

5857
session.IsNew = override
59-
session.ID = chiStore.ID() // Simply copy the session id from the chi store
58+
session.ID = store.ID() // Simply copy the session id from the chi store
6059

61-
return session, chiStore.Set(name, session)
60+
return session, store.Set(name, session)
6261
}
6362

6463
// Save should persist session to the underlying store implementation.
6564
func (st *SessionsStore) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error {
66-
chiStore := chiSession.GetSession(r)
65+
store := session_module.GetContextSession(r)
6766

6867
if session.IsNew {
6968
_, _ = session_module.RegenerateSession(w, r)
7069
session.IsNew = false
7170
}
7271

73-
if err := chiStore.Set(session.Name(), session); err != nil {
72+
if err := store.Set(session.Name(), session); err != nil {
7473
return err
7574
}
7675

@@ -83,7 +82,7 @@ func (st *SessionsStore) Save(r *http.Request, w http.ResponseWriter, session *s
8382
}
8483
}
8584

86-
return chiStore.Release()
85+
return store.Release()
8786
}
8887

8988
type sizeWriter struct {

0 commit comments

Comments
 (0)