Skip to content

Fix session gob #35128

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions modules/session/mem.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright 2025 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT

package session

import (
"bytes"
"encoding/gob"
"net/http"

"gitea.com/go-chi/session"
)

type mockMemRawStore struct {
s *session.MemStore
}

var _ session.RawStore = (*mockMemRawStore)(nil)

func (m *mockMemRawStore) Set(k, v any) error {
// We need to use gob to encode the value, to make it have the same behavior as other stores and catch abuses.
// Because gob needs to "Register" the type before it can encode it, and it's unable to decode a struct to "any" so use a map to help to decode the value.
var buf bytes.Buffer
if err := gob.NewEncoder(&buf).Encode(map[string]any{"v": v}); err != nil {
return err
}
return m.s.Set(k, buf.Bytes())
}

func (m *mockMemRawStore) Get(k any) (ret any) {
v, ok := m.s.Get(k).([]byte)
if !ok {
return nil
}
var w map[string]any
_ = gob.NewDecoder(bytes.NewBuffer(v)).Decode(&w)
return w["v"]
}

func (m *mockMemRawStore) Delete(k any) error {
return m.s.Delete(k)
}

func (m *mockMemRawStore) ID() string {
return m.s.ID()
}

func (m *mockMemRawStore) Release() error {
return m.s.Release()
}

func (m *mockMemRawStore) Flush() error {
return m.s.Flush()
}

type mockMemStore struct {
*mockMemRawStore
}

var _ Store = (*mockMemStore)(nil)

func (m mockMemStore) Destroy(writer http.ResponseWriter, request *http.Request) error {
return nil
}

func NewMockMemStore(sid string) Store {
return &mockMemStore{&mockMemRawStore{session.NewMemStore(sid)}}
}
26 changes: 0 additions & 26 deletions modules/session/mock.go

This file was deleted.

22 changes: 11 additions & 11 deletions modules/session/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,34 +11,34 @@ import (
"gitea.com/go-chi/session"
)

// Store represents a session store
type RawStore = session.RawStore

type Store interface {
Get(any) any
Set(any, any) error
Delete(any) error
ID() string
Release() error
Flush() error
RawStore
Destroy(http.ResponseWriter, *http.Request) error
}

type mockStoreContextKeyStruct struct{}

var MockStoreContextKey = mockStoreContextKeyStruct{}

// RegenerateSession regenerates the underlying session and returns the new store
func RegenerateSession(resp http.ResponseWriter, req *http.Request) (Store, error) {
for _, f := range BeforeRegenerateSession {
f(resp, req)
}
if setting.IsInTesting {
if store, ok := req.Context().Value(MockStoreContextKey).(*MockStore); ok {
return store, nil
if store := req.Context().Value(MockStoreContextKey); store != nil {
return store.(Store), nil
}
}
return session.RegenerateSession(resp, req)
}

func GetContextSession(req *http.Request) Store {
if setting.IsInTesting {
if store, ok := req.Context().Value(MockStoreContextKey).(*MockStore); ok {
return store
if store := req.Context().Value(MockStoreContextKey); store != nil {
return store.(Store)
}
}
return session.GetSession(req)
Expand Down
6 changes: 3 additions & 3 deletions modules/session/virtual.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ type VirtualSessionProvider struct {
provider session.Provider
}

// Init initializes the cookie session provider with given root path.
func (o *VirtualSessionProvider) Init(gclifetime int64, config string) error {
// Init initializes the cookie session provider with the given config.
func (o *VirtualSessionProvider) Init(gcLifetime int64, config string) error {
var opts session.Options
if err := json.Unmarshal([]byte(config), &opts); err != nil {
return err
Expand Down Expand Up @@ -52,7 +52,7 @@ func (o *VirtualSessionProvider) Init(gclifetime int64, config string) error {
default:
return fmt.Errorf("VirtualSessionProvider: Unknown Provider: %s", opts.Provider)
}
return o.provider.Init(gclifetime, opts.ProviderConfig)
return o.provider.Init(gcLifetime, opts.ProviderConfig)
}

// Read returns raw session store by session ID.
Expand Down
4 changes: 2 additions & 2 deletions routers/web/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ func createUserInContext(ctx *context.Context, tpl templates.TplName, form any,
oauth2LinkAccount(ctx, user, possibleLinkAccountData, true)
return false // user is already created here, all redirects are handled
case setting.OAuth2AccountLinkingLogin:
showLinkingLogin(ctx, &possibleLinkAccountData.AuthSource, possibleLinkAccountData.GothUser)
showLinkingLogin(ctx, possibleLinkAccountData.AuthSourceID, possibleLinkAccountData.GothUser)
return false // user will be created only after linking login
}
}
Expand Down Expand Up @@ -633,7 +633,7 @@ func handleUserCreated(ctx *context.Context, u *user_model.User, possibleLinkAcc

// update external user information
if possibleLinkAccountData != nil {
if err := externalaccount.EnsureLinkExternalToUser(ctx, possibleLinkAccountData.AuthSource.ID, u, possibleLinkAccountData.GothUser); err != nil {
if err := externalaccount.EnsureLinkExternalToUser(ctx, possibleLinkAccountData.AuthSourceID, u, possibleLinkAccountData.GothUser); err != nil {
log.Error("EnsureLinkExternalToUser failed: %v", err)
}
}
Expand Down
5 changes: 3 additions & 2 deletions routers/web/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,14 @@ func TestUserLogin(t *testing.T) {
func TestSignUpOAuth2Login(t *testing.T) {
defer test.MockVariableValue(&setting.OAuth2Client.EnableAutoRegistration, true)()

_ = oauth2.Init(t.Context())
addOAuth2Source(t, "dummy-auth-source", oauth2.Source{})

t.Run("OAuth2MissingField", func(t *testing.T) {
defer test.MockVariableValue(&gothic.CompleteUserAuth, func(res http.ResponseWriter, req *http.Request) (goth.User, error) {
return goth.User{Provider: "dummy-auth-source", UserID: "dummy-user"}, nil
})()
mockOpt := contexttest.MockContextOption{SessionStore: session.NewMockStore("dummy-sid")}
mockOpt := contexttest.MockContextOption{SessionStore: session.NewMockMemStore("dummy-sid")}
ctx, resp := contexttest.MockContext(t, "/user/oauth2/dummy-auth-source/callback?code=dummy-code", mockOpt)
ctx.SetPathParam("provider", "dummy-auth-source")
SignInOAuthCallback(ctx)
Expand All @@ -84,7 +85,7 @@ func TestSignUpOAuth2Login(t *testing.T) {
})

t.Run("OAuth2CallbackError", func(t *testing.T) {
mockOpt := contexttest.MockContextOption{SessionStore: session.NewMockStore("dummy-sid")}
mockOpt := contexttest.MockContextOption{SessionStore: session.NewMockMemStore("dummy-sid")}
ctx, resp := contexttest.MockContext(t, "/user/oauth2/dummy-auth-source/callback", mockOpt)
ctx.SetPathParam("provider", "dummy-auth-source")
SignInOAuthCallback(ctx)
Expand Down
15 changes: 10 additions & 5 deletions routers/web/auth/linkaccount.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ func LinkAccountPostSignIn(ctx *context.Context) {
}

func oauth2LinkAccount(ctx *context.Context, u *user_model.User, linkAccountData *LinkAccountData, remember bool) {
oauth2SignInSync(ctx, &linkAccountData.AuthSource, u, linkAccountData.GothUser)
oauth2SignInSync(ctx, linkAccountData.AuthSourceID, u, linkAccountData.GothUser)
if ctx.Written() {
return
}
Expand All @@ -185,7 +185,7 @@ func oauth2LinkAccount(ctx *context.Context, u *user_model.User, linkAccountData
return
}

err = externalaccount.LinkAccountToUser(ctx, linkAccountData.AuthSource.ID, u, linkAccountData.GothUser)
err = externalaccount.LinkAccountToUser(ctx, linkAccountData.AuthSourceID, u, linkAccountData.GothUser)
if err != nil {
ctx.ServerError("UserLinkAccount", err)
return
Expand Down Expand Up @@ -295,7 +295,7 @@ func LinkAccountPostRegister(ctx *context.Context) {
Email: form.Email,
Passwd: form.Password,
LoginType: auth.OAuth2,
LoginSource: linkAccountData.AuthSource.ID,
LoginSource: linkAccountData.AuthSourceID,
LoginName: linkAccountData.GothUser.UserID,
}

Expand All @@ -304,7 +304,12 @@ func LinkAccountPostRegister(ctx *context.Context) {
return
}

source := linkAccountData.AuthSource.Cfg.(*oauth2.Source)
authSource, err := auth.GetSourceByID(ctx, linkAccountData.AuthSourceID)
if err != nil {
ctx.ServerError("GetSourceByID", err)
return
}
source := authSource.Cfg.(*oauth2.Source)
if err := syncGroupsToTeams(ctx, source, &linkAccountData.GothUser, u); err != nil {
ctx.ServerError("SyncGroupsToTeams", err)
return
Expand All @@ -318,5 +323,5 @@ func linkAccountFromContext(ctx *context.Context, user *user_model.User) error {
if linkAccountData == nil {
return errors.New("not in LinkAccount session")
}
return externalaccount.LinkAccountToUser(ctx, linkAccountData.AuthSource.ID, user, linkAccountData.GothUser)
return externalaccount.LinkAccountToUser(ctx, linkAccountData.AuthSourceID, user, linkAccountData.GothUser)
}
29 changes: 18 additions & 11 deletions routers/web/auth/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package auth

import (
"encoding/gob"
"errors"
"fmt"
"html"
Expand Down Expand Up @@ -171,7 +172,7 @@ func SignInOAuthCallback(ctx *context.Context) {
gothUser.RawData = make(map[string]any)
}
gothUser.RawData["__giteaAutoRegMissingFields"] = missingFields
showLinkingLogin(ctx, authSource, gothUser)
showLinkingLogin(ctx, authSource.ID, gothUser)
return
}
u = &user_model.User{
Expand All @@ -192,7 +193,7 @@ func SignInOAuthCallback(ctx *context.Context) {
u.IsAdmin = isAdmin.ValueOrDefault(user_service.UpdateOptionField[bool]{FieldValue: false}).FieldValue
u.IsRestricted = isRestricted.ValueOrDefault(setting.Service.DefaultUserIsRestricted)

linkAccountData := &LinkAccountData{*authSource, gothUser}
linkAccountData := &LinkAccountData{authSource.ID, gothUser}
if setting.OAuth2Client.AccountLinking == setting.OAuth2AccountLinkingDisabled {
linkAccountData = nil
}
Expand All @@ -207,7 +208,7 @@ func SignInOAuthCallback(ctx *context.Context) {
}
} else {
// no existing user is found, request attach or new account
showLinkingLogin(ctx, authSource, gothUser)
showLinkingLogin(ctx, authSource.ID, gothUser)
return
}
}
Expand Down Expand Up @@ -272,23 +273,29 @@ func getUserAdminAndRestrictedFromGroupClaims(source *oauth2.Source, gothUser *g
}

type LinkAccountData struct {
AuthSource auth.Source
GothUser goth.User
AuthSourceID int64
GothUser goth.User
}

func oauth2GetLinkAccountData(ctx *context.Context) *LinkAccountData {
gob.Register(LinkAccountData{})
v, ok := ctx.Session.Get("linkAccountData").(LinkAccountData)
if !ok {
return nil
}
return &v
}

func showLinkingLogin(ctx *context.Context, authSource *auth.Source, gothUser goth.User) {
if err := updateSession(ctx, nil, map[string]any{
"linkAccountData": LinkAccountData{*authSource, gothUser},
}); err != nil {
ctx.ServerError("updateSession", err)
func Oauth2SetLinkAccountData(ctx *context.Context, linkAccountData LinkAccountData) error {
gob.Register(LinkAccountData{})
return updateSession(ctx, nil, map[string]any{
"linkAccountData": linkAccountData,
})
}

func showLinkingLogin(ctx *context.Context, authSourceID int64, gothUser goth.User) {
if err := Oauth2SetLinkAccountData(ctx, LinkAccountData{authSourceID, gothUser}); err != nil {
ctx.ServerError("Oauth2SetLinkAccountData", err)
return
}
ctx.Redirect(setting.AppSubURL + "/user/link_account")
Expand All @@ -313,7 +320,7 @@ func oauth2UpdateAvatarIfNeed(ctx *context.Context, url string, u *user_model.Us
}

func handleOAuth2SignIn(ctx *context.Context, authSource *auth.Source, u *user_model.User, gothUser goth.User) {
oauth2SignInSync(ctx, authSource, u, gothUser)
oauth2SignInSync(ctx, authSource.ID, u, gothUser)
if ctx.Written() {
return
}
Expand Down
9 changes: 7 additions & 2 deletions routers/web/auth/oauth_signin_sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,14 @@ import (
"github.com/markbates/goth"
)

func oauth2SignInSync(ctx *context.Context, authSource *auth.Source, u *user_model.User, gothUser goth.User) {
func oauth2SignInSync(ctx *context.Context, authSourceID int64, u *user_model.User, gothUser goth.User) {
oauth2UpdateAvatarIfNeed(ctx, gothUser.AvatarURL, u)

authSource, err := auth.GetSourceByID(ctx, authSourceID)
if err != nil {
ctx.ServerError("GetSourceByID", err)
return
}
oauth2Source, _ := authSource.Cfg.(*oauth2.Source)
if !authSource.IsOAuth2() || oauth2Source == nil {
ctx.ServerError("oauth2SignInSync", fmt.Errorf("source %s is not an OAuth2 source", gothUser.Provider))
Expand All @@ -45,7 +50,7 @@ func oauth2SignInSync(ctx *context.Context, authSource *auth.Source, u *user_mod
}
}

err := oauth2UpdateSSHPubIfNeed(ctx, authSource, &gothUser, u)
err = oauth2UpdateSSHPubIfNeed(ctx, authSource, &gothUser, u)
if err != nil {
log.Error("Unable to sync OAuth2 SSH public key %s: %v", gothUser.Provider, err)
}
Expand Down
Loading