diff --git a/models/oauth.go b/models/oauth.go new file mode 100644 index 0000000000000..23706e3e06473 --- /dev/null +++ b/models/oauth.go @@ -0,0 +1,125 @@ +package models + +import ( + "fmt" + "strings" + + "code.gitea.io/gitea/models/db" +) + +// OAuth Login Source +type OAuth struct { + ID int64 `xorm:"pk autoincr"` + Name string +} + +func init() { + db.RegisterModel(new(OAuth)) +} + +func getOAuthByID(e db.Engine, id int64) (*OAuth, error) { + o := new(OAuth) + has, err := e.ID(id).Get(o) + if err != nil { + return nil, err + } else if !has { + return nil, ErrUserNotExist{id, "", 0} + } + return o, nil +} + +// GetOAuthByID returns the oauth object by given ID if exists. +func GetOAuthByID(id int64) (*OAuth, error) { + return getOAuthByID(db.GetEngine(db.DefaultContext), id) +} + +// GetOAuthByName returns oauth by given name. +func GetOAuthByName(name string) (*OAuth, error) { + return getOAuthByName(db.GetEngine(db.DefaultContext), name) +} + +func getOAuthByName(e db.Engine, name string) (*OAuth, error) { + if len(name) == 0 { + return nil, ErrUserNotExist{0, name, 0} + } + o := &OAuth{Name: strings.ToLower(name)} + has, err := e.Get(o) + if err != nil { + return nil, err + } else if !has { + return nil, ErrUserNotExist{0, name, 0} + } + return o, nil +} + +// CreateOAuth creates record of a new oauth. +func CreateOAuth(o *OAuth) (err error) { + sess := db.NewSession(db.DefaultContext) + defer sess.Close() + if err = sess.Begin(); err != nil { + return err + } + + if _, err = sess.Insert(o); err != nil { + return err + } + + return sess.Commit() +} + +func validateOAuth(o *OAuth) error { + + return nil +} + +func updateOAuth(e db.Engine, o *OAuth) error { + if err := validateOAuth(o); err != nil { + return err + } + + _, err := e.ID(o.ID).AllCols().Update(o) + return err +} + +// UpdateOAuth updates oauth's information. +func UpdateOAuth(o *OAuth) error { + return updateOAuth(db.GetEngine(db.DefaultContext), o) +} + +// UpdateOAuthCols update user according special columns +func UpdateOAuthCols(o *OAuth, cols ...string) error { + return updateOAuthCols(db.GetEngine(db.DefaultContext), o, cols...) +} + +func updateOAuthCols(e db.Engine, o *OAuth, cols ...string) error { + if err := validateOAuth(o); err != nil { + return err + } + + _, err := e.ID(o.ID).Cols(cols...).Update(o) + return err +} + +func deleteOAuth(e db.Engine, o *OAuth) error { + if _, err := e.ID(o.ID).Delete(new(OAuth)); err != nil { + return fmt.Errorf("Delete: %v", err) + } + + return nil +} + +// DeleteOAuth deletes the record of oauth +func DeleteOAuth(o *OAuth) (err error) { + sess := db.NewSession(db.DefaultContext) + defer sess.Close() + if err = sess.Begin(); err != nil { + return err + } + + if err = deleteOAuth(sess, o); err != nil { + // Note: don't wrapper error here. + return err + } + + return sess.Commit() +} diff --git a/models/user.go b/models/user.go index 934b834e96328..f3e7aaf03b610 100644 --- a/models/user.go +++ b/models/user.go @@ -162,6 +162,9 @@ type User struct { DiffViewStyle string `xorm:"NOT NULL DEFAULT ''"` Theme string `xorm:"NOT NULL DEFAULT ''"` KeepActivityPrivate bool `xorm:"NOT NULL DEFAULT false"` + + // OAuth + OAuthProvider int64 } func init() { diff --git a/routers/web/user/auth.go b/routers/web/user/auth.go index 65515402cf5d6..80a3454b96d0b 100644 --- a/routers/web/user/auth.go +++ b/routers/web/user/auth.go @@ -649,14 +649,28 @@ func SignInOAuthCallback(ctx *context.Context) { ctx.ServerError("CreateUser", err) return } + + oauthProviderName := strings.ToLower(provider) + + oauth, err := models.GetOAuthByName(oauthProviderName) + + if err != nil { + oauth = &models.OAuth{ + Name: oauthProviderName, + } + + models.CreateOAuth(oauth) + } + u = &models.User{ - Name: getUserName(&gothUser), - FullName: gothUser.Name, - Email: gothUser.Email, - IsActive: !setting.OAuth2Client.RegisterEmailConfirm, - LoginType: login.OAuth2, - LoginSource: loginSource.ID, - LoginName: gothUser.UserID, + Name: getUserName(&gothUser), + FullName: gothUser.Name, + Email: gothUser.Email, + IsActive: !setting.OAuth2Client.RegisterEmailConfirm, + LoginType: login.OAuth2, + LoginSource: loginSource.ID, + LoginName: gothUser.UserID, + OAuthProvider: oauth.ID, } if !createAndHandleCreatedUser(ctx, base.TplName(""), nil, u, &gothUser, setting.OAuth2Client.AccountLinking != setting.OAuth2AccountLinkingDisabled) {