Skip to content

Refactor request context #32956

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Dec 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 18 additions & 48 deletions modules/gitrepo/gitrepo.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"strings"

"code.gitea.io/gitea/modules/git"
"code.gitea.io/gitea/modules/reqctx"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/util"
)
Expand Down Expand Up @@ -38,63 +39,32 @@ func OpenWikiRepository(ctx context.Context, repo Repository) (*git.Repository,

// contextKey is a value for use with context.WithValue.
type contextKey struct {
name string
}

// RepositoryContextKey is a context key. It is used with context.Value() to get the current Repository for the context
var RepositoryContextKey = &contextKey{"repository"}

// RepositoryFromContext attempts to get the repository from the context
func repositoryFromContext(ctx context.Context, repo Repository) *git.Repository {
value := ctx.Value(RepositoryContextKey)
if value == nil {
return nil
}

if gitRepo, ok := value.(*git.Repository); ok && gitRepo != nil {
if gitRepo.Path == repoPath(repo) {
return gitRepo
}
}

return nil
repoPath string
}

// RepositoryFromContextOrOpen attempts to get the repository from the context or just opens it
func RepositoryFromContextOrOpen(ctx context.Context, repo Repository) (*git.Repository, io.Closer, error) {
gitRepo := repositoryFromContext(ctx, repo)
if gitRepo != nil {
return gitRepo, util.NopCloser{}, nil
ds := reqctx.GetRequestDataStore(ctx)
if ds != nil {
gitRepo, err := RepositoryFromRequestContextOrOpen(ctx, ds, repo)
return gitRepo, util.NopCloser{}, err
}

gitRepo, err := OpenRepository(ctx, repo)
return gitRepo, gitRepo, err
}

// repositoryFromContextPath attempts to get the repository from the context
func repositoryFromContextPath(ctx context.Context, path string) *git.Repository {
value := ctx.Value(RepositoryContextKey)
if value == nil {
return nil
// RepositoryFromRequestContextOrOpen opens the repository at the given relative path in the provided request context
// The repo will be automatically closed when the request context is done
func RepositoryFromRequestContextOrOpen(ctx context.Context, ds reqctx.RequestDataStore, repo Repository) (*git.Repository, error) {
ck := contextKey{repoPath: repoPath(repo)}
if gitRepo, ok := ctx.Value(ck).(*git.Repository); ok {
return gitRepo, nil
}

if repo, ok := value.(*git.Repository); ok && repo != nil {
if repo.Path == path {
return repo
}
gitRepo, err := git.OpenRepository(ctx, ck.repoPath)
if err != nil {
return nil, err
}

return nil
}

// RepositoryFromContextOrOpenPath attempts to get the repository from the context or just opens it
// Deprecated: Use RepositoryFromContextOrOpen instead
func RepositoryFromContextOrOpenPath(ctx context.Context, path string) (*git.Repository, io.Closer, error) {
gitRepo := repositoryFromContextPath(ctx, path)
if gitRepo != nil {
return gitRepo, util.NopCloser{}, nil
}

gitRepo, err := git.OpenRepository(ctx, path)
return gitRepo, gitRepo, err
ds.AddCloser(gitRepo)
ds.SetContextValue(ck, gitRepo)
return gitRepo, nil
}
12 changes: 4 additions & 8 deletions modules/gitrepo/walk_gogit.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,11 @@ import (
// WalkReferences walks all the references from the repository
// refname is empty, ObjectTag or ObjectBranch. All other values should be treated as equivalent to empty.
func WalkReferences(ctx context.Context, repo Repository, walkfn func(sha1, refname string) error) (int, error) {
gitRepo := repositoryFromContext(ctx, repo)
if gitRepo == nil {
var err error
gitRepo, err = OpenRepository(ctx, repo)
if err != nil {
return 0, err
}
defer gitRepo.Close()
gitRepo, closer, err := RepositoryFromContextOrOpen(ctx, repo)
if err != nil {
return 0, err
}
defer closer.Close()

i := 0
iter, err := gitRepo.GoGitRepo().References()
Expand Down
123 changes: 123 additions & 0 deletions modules/reqctx/datastore.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
// Copyright 2024 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT

package reqctx

import (
"context"
"io"
"sync"

"code.gitea.io/gitea/modules/process"
)

type ContextDataProvider interface {
GetData() ContextData
}

type ContextData map[string]any

func (ds ContextData) GetData() ContextData {
return ds
}

func (ds ContextData) MergeFrom(other ContextData) ContextData {
for k, v := range other {
ds[k] = v
}
return ds
}

// RequestDataStore is a short-lived context-related object that is used to store request-specific data.
type RequestDataStore interface {
GetData() ContextData
SetContextValue(k, v any)
GetContextValue(key any) any
AddCleanUp(f func())
AddCloser(c io.Closer)
}

type requestDataStoreKeyType struct{}

var RequestDataStoreKey requestDataStoreKeyType

type requestDataStore struct {
data ContextData

mu sync.RWMutex
values map[any]any
cleanUpFuncs []func()
}

func (r *requestDataStore) GetContextValue(key any) any {
if key == RequestDataStoreKey {
return r
}
r.mu.RLock()
defer r.mu.RUnlock()
return r.values[key]
}

func (r *requestDataStore) SetContextValue(k, v any) {
r.mu.Lock()
r.values[k] = v
r.mu.Unlock()
}

// GetData and the underlying ContextData are not thread-safe, callers should ensure thread-safety.
func (r *requestDataStore) GetData() ContextData {
if r.data == nil {
r.data = make(ContextData)
}
return r.data
}

func (r *requestDataStore) AddCleanUp(f func()) {
r.mu.Lock()
r.cleanUpFuncs = append(r.cleanUpFuncs, f)
r.mu.Unlock()
}

func (r *requestDataStore) AddCloser(c io.Closer) {
r.AddCleanUp(func() { _ = c.Close() })
}

func (r *requestDataStore) cleanUp() {
for _, f := range r.cleanUpFuncs {
f()
}
}

func GetRequestDataStore(ctx context.Context) RequestDataStore {
if req, ok := ctx.Value(RequestDataStoreKey).(*requestDataStore); ok {
return req
}
return nil
}

type requestContext struct {
context.Context
dataStore *requestDataStore
}

func (c *requestContext) Value(key any) any {
if v := c.dataStore.GetContextValue(key); v != nil {
return v
}
return c.Context.Value(key)
}

func NewRequestContext(parentCtx context.Context, profDesc string) (_ context.Context, finished func()) {
ctx, _, processFinished := process.GetManager().AddTypedContext(parentCtx, profDesc, process.RequestProcessType, true)
reqCtx := &requestContext{Context: ctx, dataStore: &requestDataStore{values: make(map[any]any)}}
return reqCtx, func() {
reqCtx.dataStore.cleanUp()
processFinished()
}
}

// NewRequestContextForTest creates a new RequestContext for testing purposes
// It doesn't add the context to the process manager, nor do cleanup
func NewRequestContextForTest(parentCtx context.Context) context.Context {
return &requestContext{Context: parentCtx, dataStore: &requestDataStore{values: make(map[any]any)}}
}
26 changes: 6 additions & 20 deletions modules/web/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
package web

import (
goctx "context"
"fmt"
"net/http"
"reflect"
Expand Down Expand Up @@ -51,7 +50,6 @@ func (r *responseWriter) WriteHeader(statusCode int) {
var (
httpReqType = reflect.TypeOf((*http.Request)(nil))
respWriterType = reflect.TypeOf((*http.ResponseWriter)(nil)).Elem()
cancelFuncType = reflect.TypeOf((*goctx.CancelFunc)(nil)).Elem()
)

// preCheckHandler checks whether the handler is valid, developers could get first-time feedback, all mistakes could be found at startup
Expand All @@ -65,11 +63,8 @@ func preCheckHandler(fn reflect.Value, argsIn []reflect.Value) {
if !hasStatusProvider {
panic(fmt.Sprintf("handler should have at least one ResponseStatusProvider argument, but got %s", fn.Type()))
}
if fn.Type().NumOut() != 0 && fn.Type().NumIn() != 1 {
panic(fmt.Sprintf("handler should have no return value or only one argument, but got %s", fn.Type()))
}
if fn.Type().NumOut() == 1 && fn.Type().Out(0) != cancelFuncType {
panic(fmt.Sprintf("handler should return a cancel function, but got %s", fn.Type()))
if fn.Type().NumOut() != 0 {
panic(fmt.Sprintf("handler should have no return value other than registered ones, but got %s", fn.Type()))
}
}

Expand Down Expand Up @@ -105,16 +100,10 @@ func prepareHandleArgsIn(resp http.ResponseWriter, req *http.Request, fn reflect
return argsIn
}

func handleResponse(fn reflect.Value, ret []reflect.Value) goctx.CancelFunc {
if len(ret) == 1 {
if cancelFunc, ok := ret[0].Interface().(goctx.CancelFunc); ok {
return cancelFunc
}
panic(fmt.Sprintf("unsupported return type: %s", ret[0].Type()))
} else if len(ret) > 1 {
func handleResponse(fn reflect.Value, ret []reflect.Value) {
if len(ret) != 0 {
panic(fmt.Sprintf("unsupported return values: %s", fn.Type()))
}
return nil
}

func hasResponseBeenWritten(argsIn []reflect.Value) bool {
Expand Down Expand Up @@ -171,11 +160,8 @@ func toHandlerProvider(handler any) func(next http.Handler) http.Handler {
routing.UpdateFuncInfo(req.Context(), funcInfo)
ret := fn.Call(argsIn)

// handle the return value, and defer the cancel function if there is one
cancelFunc := handleResponse(fn, ret)
if cancelFunc != nil {
defer cancelFunc()
}
// handle the return value (no-op at the moment)
handleResponse(fn, ret)

// if the response has not been written, call the next handler
if next != nil && !hasResponseBeenWritten(argsIn) {
Expand Down
37 changes: 6 additions & 31 deletions modules/web/middleware/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,46 +7,21 @@ import (
"context"
"time"

"code.gitea.io/gitea/modules/reqctx"
"code.gitea.io/gitea/modules/setting"
)

// ContextDataStore represents a data store
type ContextDataStore interface {
GetData() ContextData
}

type ContextData map[string]any

func (ds ContextData) GetData() ContextData {
return ds
}

func (ds ContextData) MergeFrom(other ContextData) ContextData {
for k, v := range other {
ds[k] = v
}
return ds
}

const ContextDataKeySignedUser = "SignedUser"

type contextDataKeyType struct{}

var contextDataKey contextDataKeyType

func WithContextData(c context.Context) context.Context {
return context.WithValue(c, contextDataKey, make(ContextData, 10))
}

func GetContextData(c context.Context) ContextData {
if ds, ok := c.Value(contextDataKey).(ContextData); ok {
return ds
func GetContextData(c context.Context) reqctx.ContextData {
if rc := reqctx.GetRequestDataStore(c); rc != nil {
return rc.GetData()
}
return nil
}

func CommonTemplateContextData() ContextData {
return ContextData{
func CommonTemplateContextData() reqctx.ContextData {
return reqctx.ContextData{
"IsLandingPageOrganizations": setting.LandingPageURL == setting.LandingPageOrganizations,

"ShowRegistrationButton": setting.Service.ShowRegistrationButton,
Expand Down
4 changes: 3 additions & 1 deletion modules/web/middleware/flash.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@ import (
"fmt"
"html/template"
"net/url"

"code.gitea.io/gitea/modules/reqctx"
)

// Flash represents a one time data transfer between two requests.
type Flash struct {
DataStore ContextDataStore
DataStore reqctx.RequestDataStore
url.Values
ErrorMsg, WarningMsg, InfoMsg, SuccessMsg string
}
Expand Down
5 changes: 3 additions & 2 deletions modules/web/route.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"strings"

"code.gitea.io/gitea/modules/htmlutil"
"code.gitea.io/gitea/modules/reqctx"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/web/middleware"

Expand All @@ -29,12 +30,12 @@ func Bind[T any](_ T) http.HandlerFunc {
}

// SetForm set the form object
func SetForm(dataStore middleware.ContextDataStore, obj any) {
func SetForm(dataStore reqctx.ContextDataProvider, obj any) {
dataStore.GetData()["__form"] = obj
}

// GetForm returns the validate form information
func GetForm(dataStore middleware.ContextDataStore) any {
func GetForm(dataStore reqctx.RequestDataStore) any {
return dataStore.GetData()["__form"]
}

Expand Down
Loading
Loading