diff --git a/modules/context/api.go b/modules/context/api.go index b9d130e2a8ac0..517d0cf1dd64c 100644 --- a/modules/context/api.go +++ b/modules/context/api.go @@ -8,6 +8,7 @@ package context import ( "context" "fmt" + "html" "net/http" "net/url" "strings" @@ -21,6 +22,7 @@ import ( "code.gitea.io/gitea/modules/setting" "code.gitea.io/gitea/modules/web/middleware" auth_service "code.gitea.io/gitea/services/auth" + "gitea.com/go-chi/session" ) // APIContext is a specific context for API service @@ -190,6 +192,17 @@ func (ctx *APIContext) SetLinkHeader(total, pageSize int) { } } +// RequireCSRF requires a validated a CSRF token +func (ctx *APIContext) RequireCSRF() { + headerToken := ctx.Req.Header.Get(ctx.csrf.GetHeaderName()) + formValueToken := ctx.Req.FormValue(ctx.csrf.GetFormName()) + if len(headerToken) > 0 || len(formValueToken) > 0 { + ctx.csrf.Validate(ctx.Context) + } else { + ctx.Context.Error(http.StatusUnauthorized, "Missing CSRF token.") + } +} + // CheckForOTP validates OTP func (ctx *APIContext) CheckForOTP() { if skip, ok := ctx.Data["SkipLocalTwoFA"]; ok && skip.(bool) { @@ -241,15 +254,18 @@ func APIAuth(authMethod auth_service.Method) func(*APIContext) { // APIContexter returns apicontext as middleware func APIContexter() func(http.Handler) http.Handler { + csrfOpts := getCsrfOpts() + return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { locale := middleware.Locale(w, req) ctx := APIContext{ Context: &Context{ - Resp: NewResponse(w), - Data: map[string]interface{}{}, - Locale: locale, - Cache: cache.GetCache(), + Resp: NewResponse(w), + Data: map[string]interface{}{}, + Locale: locale, + Session: session.GetSession(req), + Cache: cache.GetCache(), Repo: &Repository{ PullRequest: &PullRequest{}, }, @@ -260,6 +276,7 @@ func APIContexter() func(http.Handler) http.Handler { defer ctx.Close() ctx.Req = WithAPIContext(WithContext(req, ctx.Context), &ctx) + ctx.csrf = PrepareCSRFProtector(csrfOpts, ctx.Context) // If request sends files, parse them here otherwise the Query() can't be parsed and the CsrfToken will be invalid. if ctx.Req.Method == "POST" && strings.Contains(ctx.Req.Header.Get("Content-Type"), "multipart/form-data") { @@ -272,6 +289,7 @@ func APIContexter() func(http.Handler) http.Handler { httpcache.AddCacheControlToHeader(ctx.Resp.Header(), 0, "no-transform") ctx.Resp.Header().Set(`X-Frame-Options`, setting.CORSConfig.XFrameOptions) + ctx.Data["CsrfToken"] = html.EscapeString(ctx.csrf.GetToken()) ctx.Data["Context"] = &ctx next.ServeHTTP(ctx.Resp, ctx.Req) diff --git a/routers/api/v1/api.go b/routers/api/v1/api.go index e1478fa2aa99a..2b8d1c4b44434 100644 --- a/routers/api/v1/api.go +++ b/routers/api/v1/api.go @@ -217,6 +217,7 @@ func reqToken() func(ctx *context.APIContext) { return } if ctx.IsSigned { + ctx.RequireCSRF() return } ctx.Error(http.StatusUnauthorized, "reqToken", "token is required") @@ -595,6 +596,7 @@ func buildAuthGroup() *auth.Group { &auth.OAuth2{}, &auth.HTTPSign{}, &auth.Basic{}, // FIXME: this should be removed once we don't allow basic auth in API + &auth.Session{}, ) if setting.Service.EnableReverseProxyAuth { group.Add(&auth.ReverseProxy{}) @@ -605,10 +607,12 @@ func buildAuthGroup() *auth.Group { } // Routes registers all v1 APIs routes to web application. -func Routes() *web.Route { +func Routes(sessioner func(http.Handler) http.Handler) *web.Route { m := web.NewRoute() m.Use(securityHeaders()) + m.Use(sessioner) + if setting.CORSConfig.Enabled { m.Use(cors.Handler(cors.Options{ // Scheme: setting.CORSConfig.Scheme, // FIXME: the cors middleware needs scheme option diff --git a/routers/init.go b/routers/init.go index e640ca48453bc..082ccb1841b2d 100644 --- a/routers/init.go +++ b/routers/init.go @@ -47,6 +47,7 @@ import ( "code.gitea.io/gitea/services/repository/archiver" "code.gitea.io/gitea/services/task" "code.gitea.io/gitea/services/webhook" + "gitea.com/go-chi/session" ) func mustInit(fn func() error) { @@ -171,8 +172,20 @@ func NormalRoutes() *web.Route { r.Use(middle) } - r.Mount("/", web_routers.Routes()) - r.Mount("/api/v1", apiv1.Routes()) + sessioner := session.Sessioner(session.Options{ + Provider: setting.SessionConfig.Provider, + ProviderConfig: setting.SessionConfig.ProviderConfig, + CookieName: setting.SessionConfig.CookieName, + CookiePath: setting.SessionConfig.CookiePath, + Gclifetime: setting.SessionConfig.Gclifetime, + Maxlifetime: setting.SessionConfig.Maxlifetime, + Secure: setting.SessionConfig.Secure, + SameSite: setting.SessionConfig.SameSite, + Domain: setting.SessionConfig.Domain, + }) + + r.Mount("/", web_routers.Routes(sessioner)) + r.Mount("/api/v1", apiv1.Routes(sessioner)) r.Mount("/api/internal", private.Routes()) if setting.Packages.Enabled { r.Mount("/api/packages", packages_router.Routes()) diff --git a/routers/web/web.go b/routers/web/web.go index a9f43fb2c4feb..18871fd994714 100644 --- a/routers/web/web.go +++ b/routers/web/web.go @@ -48,7 +48,6 @@ import ( _ "code.gitea.io/gitea/modules/session" // to registers all internal adapters "gitea.com/go-chi/captcha" - "gitea.com/go-chi/session" "github.com/NYTimes/gziphandler" "github.com/go-chi/chi/v5/middleware" "github.com/go-chi/cors" @@ -99,7 +98,7 @@ func buildAuthGroup() *auth_service.Group { } // Routes returns all web routes -func Routes() *web.Route { +func Routes(sessioner func(http.Handler) http.Handler) *web.Route { routes := web.NewRoute() routes.Use(web.WrapWithPrefix(public.AssetsURLPathPrefix, public.AssetsHandlerFunc(&public.Options{ @@ -108,17 +107,6 @@ func Routes() *web.Route { CorsHandler: CorsHandler(), }), "AssetsHandler")) - sessioner := session.Sessioner(session.Options{ - Provider: setting.SessionConfig.Provider, - ProviderConfig: setting.SessionConfig.ProviderConfig, - CookieName: setting.SessionConfig.CookieName, - CookiePath: setting.SessionConfig.CookiePath, - Gclifetime: setting.SessionConfig.Gclifetime, - Maxlifetime: setting.SessionConfig.Maxlifetime, - Secure: setting.SessionConfig.Secure, - SameSite: setting.SessionConfig.SameSite, - Domain: setting.SessionConfig.Domain, - }) routes.Use(sessioner) routes.Use(Recovery())