diff --git a/models/issue_xref_test.go b/models/issue_xref_test.go index c13577e905aeb..d7a0a88f78ab9 100644 --- a/models/issue_xref_test.go +++ b/models/issue_xref_test.go @@ -143,7 +143,7 @@ func testCreatePR(t *testing.T, repo, doer int64, title, content string) *PullRe d := AssertExistsAndLoadBean(t, &User{ID: doer}).(*User) i := &Issue{RepoID: r.ID, PosterID: d.ID, Poster: d, Title: title, Content: content, IsPull: true} pr := &PullRequest{HeadRepoID: repo, BaseRepoID: repo, HeadBranch: "head", BaseBranch: "base"} - assert.NoError(t, NewPullRequest(r, i, nil, nil, pr, nil)) + assert.NoError(t, NewPullRequest(r, i, nil, nil, pr, 0, "unknown")) pr.Issue = i return pr } diff --git a/models/pull.go b/models/pull.go index 2bd79202f094b..82fb80e82b9c8 100644 --- a/models/pull.go +++ b/models/pull.go @@ -8,6 +8,7 @@ package models import ( "bufio" "fmt" + "io/ioutil" "os" "path" "path/filepath" @@ -595,11 +596,11 @@ func (pr *PullRequest) testPatch(e Engine) (err error) { } // NewPullRequest creates new pull request with labels for repository. -func NewPullRequest(repo *Repository, pull *Issue, labelIDs []int64, uuids []string, pr *PullRequest, patch []byte) (err error) { +func NewPullRequest(repo *Repository, pull *Issue, labelIDs []int64, uuids []string, pr *PullRequest, patchFileSize int64, patchFileName string) (err error) { // Retry several times in case INSERT fails due to duplicate key for (repo_id, index); see #7887 i := 0 for { - if err = newPullRequestAttempt(repo, pull, labelIDs, uuids, pr, patch); err == nil { + if err = newPullRequestAttempt(repo, pull, labelIDs, uuids, pr, patchFileSize, patchFileName); err == nil { return nil } if !IsErrNewIssueInsert(err) { @@ -613,7 +614,7 @@ func NewPullRequest(repo *Repository, pull *Issue, labelIDs []int64, uuids []str return fmt.Errorf("NewPullRequest: too many errors attempting to insert the new issue. Last error was: %v", err) } -func newPullRequestAttempt(repo *Repository, pull *Issue, labelIDs []int64, uuids []string, pr *PullRequest, patch []byte) (err error) { +func newPullRequestAttempt(repo *Repository, pull *Issue, labelIDs []int64, uuids []string, pr *PullRequest, patchFileSize int64, patchFileName string) (err error) { sess := x.NewSession() defer sess.Close() if err = sess.Begin(); err != nil { @@ -636,8 +637,8 @@ func newPullRequestAttempt(repo *Repository, pull *Issue, labelIDs []int64, uuid pr.Index = pull.Index pr.BaseRepo = repo pr.Status = PullRequestStatusChecking - if len(patch) > 0 { - if err = repo.savePatch(sess, pr.Index, patch); err != nil { + if patchFileSize > 0 { + if err = repo.savePatch(sess, pr.Index, patchFileName); err != nil { return fmt.Errorf("SavePatch: %v", err) } @@ -800,12 +801,23 @@ func (pr *PullRequest) UpdatePatch() (err error) { return fmt.Errorf("Update: %v", err) } - patch, err := headGitRepo.GetPatch(pr.MergeBase, pr.HeadBranch) + tmpPatchFile, err := ioutil.TempFile("", "patch") if err != nil { - return fmt.Errorf("GetPatch: %v", err) + log.Error("Unable to create temporary patch file! Error: %v", err) + return fmt.Errorf("Unable to create temporary patch file! Error: %v", err) + } + defer func() { + _ = os.Remove(tmpPatchFile.Name()) + }() + + if err := headGitRepo.GetPatch(pr.MergeBase, pr.HeadBranch, tmpPatchFile); err != nil { + tmpPatchFile.Close() + log.Error("Unable to get patch file from %s to %s in %s/%s Error: %v", pr.MergeBase, pr.HeadBranch, pr.BaseRepo.MustOwner().Name, pr.BaseRepo.Name, err) + return fmt.Errorf("Unable to get patch file from %s to %s in %s/%s Error: %v", pr.MergeBase, pr.HeadBranch, pr.BaseRepo.MustOwner().Name, pr.BaseRepo.Name, err) } - if err = pr.BaseRepo.SavePatch(pr.Index, patch); err != nil { + tmpPatchFile.Close() + if err = pr.BaseRepo.SavePatch(pr.Index, tmpPatchFile.Name()); err != nil { return fmt.Errorf("BaseRepo.SavePatch: %v", err) } diff --git a/models/repo.go b/models/repo.go index e809bafa309f1..d694b6aaeedbc 100644 --- a/models/repo.go +++ b/models/repo.go @@ -15,6 +15,7 @@ import ( // Needed for jpeg support _ "image/jpeg" "image/png" + "io" "io/ioutil" "net/url" "os" @@ -901,11 +902,11 @@ func (repo *Repository) patchPath(e Engine, index int64) (string, error) { } // SavePatch saves patch data to corresponding location by given issue ID. -func (repo *Repository) SavePatch(index int64, patch []byte) error { - return repo.savePatch(x, index, patch) +func (repo *Repository) SavePatch(index int64, name string) error { + return repo.savePatch(x, index, name) } -func (repo *Repository) savePatch(e Engine, index int64, patch []byte) error { +func (repo *Repository) savePatch(e Engine, index int64, name string) error { patchPath, err := repo.patchPath(e, index) if err != nil { return fmt.Errorf("PatchPath: %v", err) @@ -916,10 +917,21 @@ func (repo *Repository) savePatch(e Engine, index int64, patch []byte) error { return fmt.Errorf("Failed to create dir %s: %v", dir, err) } - if err = ioutil.WriteFile(patchPath, patch, 0644); err != nil { - return fmt.Errorf("WriteFile: %v", err) + inputFile, err := os.Open(name) + if err != nil { + return fmt.Errorf("Couldn't open temporary patch file: %s", err) + } + outputFile, err := os.Create(patchPath) + if err != nil { + inputFile.Close() + return fmt.Errorf("Couldn't open destination patch file: %s", err) + } + defer outputFile.Close() + _, err = io.Copy(outputFile, inputFile) + inputFile.Close() + if err != nil { + return fmt.Errorf("Writing to patch file failed: %s", err) } - return nil } diff --git a/modules/git/repo_compare.go b/modules/git/repo_compare.go index 677201c5e0358..383af0a8c4a5b 100644 --- a/modules/git/repo_compare.go +++ b/modules/git/repo_compare.go @@ -95,8 +95,8 @@ func (repo *Repository) GetCompareInfo(basePath, baseBranch, headBranch string) } // GetPatch generates and returns patch data between given revisions. -func (repo *Repository) GetPatch(base, head string) ([]byte, error) { - return NewCommand("diff", "-p", "--binary", base, head).RunInDirBytes(repo.Path) +func (repo *Repository) GetPatch(base, head string, w io.Writer) error { + return NewCommand("diff", "-p", "--binary", base, head).RunInDirPipeline(repo.Path, w, nil) } // GetFormatPatch generates and returns format-patch data between given revisions. diff --git a/routers/api/v1/repo/pull.go b/routers/api/v1/repo/pull.go index 9abcaa0496a15..c0c22ce5d4ac6 100644 --- a/routers/api/v1/repo/pull.go +++ b/routers/api/v1/repo/pull.go @@ -6,7 +6,9 @@ package repo import ( "fmt" + "io/ioutil" "net/http" + "os" "strings" "time" @@ -244,12 +246,29 @@ func CreatePullRequest(ctx *context.APIContext, form api.CreatePullRequestOption milestoneID = milestone.ID } - patch, err := headGitRepo.GetPatch(compareInfo.MergeBase, headBranch) + tmpPatchFile, err := ioutil.TempFile("", "patch") if err != nil { + ctx.Error(500, "CreateTemporaryFile", err) + return + } + defer func() { + _ = os.Remove(tmpPatchFile.Name()) + }() + + if err := headGitRepo.GetPatch(compareInfo.MergeBase, headBranch, tmpPatchFile); err != nil { + tmpPatchFile.Close() ctx.Error(500, "GetPatch", err) return } + stat, err := tmpPatchFile.Stat() + if err != nil { + tmpPatchFile.Close() + ctx.Error(500, "StatPatch", err) + return + } + + tmpPatchFile.Close() var deadlineUnix timeutil.TimeStamp if form.Deadline != nil { deadlineUnix = timeutil.TimeStamp(form.Deadline.Unix()) @@ -306,7 +325,7 @@ func CreatePullRequest(ctx *context.APIContext, form api.CreatePullRequestOption } } - if err := pull_service.NewPullRequest(repo, prIssue, labelIDs, []string{}, pr, patch, assigneeIDs); err != nil { + if err := pull_service.NewPullRequest(repo, prIssue, labelIDs, []string{}, pr, stat.Size(), tmpPatchFile.Name(), assigneeIDs); err != nil { if models.IsErrUserDoesNotHaveAccessToRepo(err) { ctx.Error(400, "UserDoesNotHaveAccessToRepo", err) return diff --git a/routers/repo/pull.go b/routers/repo/pull.go index 78406de8acdc8..8221e49fc9e8e 100644 --- a/routers/repo/pull.go +++ b/routers/repo/pull.go @@ -12,6 +12,8 @@ import ( "fmt" "html" "io" + "io/ioutil" + "os" "path" "strings" @@ -785,12 +787,29 @@ func CompareAndPullRequestPost(ctx *context.Context, form auth.CreateIssueForm) return } - patch, err := headGitRepo.GetPatch(prInfo.MergeBase, headBranch) + tmpPatchFile, err := ioutil.TempFile("", "patch") if err != nil { + ctx.ServerError("CreateTemporaryFile", err) + return + } + defer func() { + _ = os.Remove(tmpPatchFile.Name()) + }() + + if err := headGitRepo.GetPatch(prInfo.MergeBase, headBranch, tmpPatchFile); err != nil { + tmpPatchFile.Close() ctx.ServerError("GetPatch", err) return } + stat, err := tmpPatchFile.Stat() + if err != nil { + tmpPatchFile.Close() + ctx.ServerError("StatPatch", err) + return + } + tmpPatchFile.Close() + pullIssue := &models.Issue{ RepoID: repo.ID, Title: form.Title, @@ -813,7 +832,7 @@ func CompareAndPullRequestPost(ctx *context.Context, form auth.CreateIssueForm) // FIXME: check error in the case two people send pull request at almost same time, give nice error prompt // instead of 500. - if err := pull_service.NewPullRequest(repo, pullIssue, labelIDs, attachments, pullRequest, patch, assigneeIDs); err != nil { + if err := pull_service.NewPullRequest(repo, pullIssue, labelIDs, attachments, pullRequest, stat.Size(), tmpPatchFile.Name(), assigneeIDs); err != nil { if models.IsErrUserDoesNotHaveAccessToRepo(err) { ctx.Error(400, "UserDoesNotHaveAccessToRepo", err.Error()) return diff --git a/services/pull/pull.go b/services/pull/pull.go index 2650dacc116da..648ecf8649098 100644 --- a/services/pull/pull.go +++ b/services/pull/pull.go @@ -15,8 +15,8 @@ import ( ) // NewPullRequest creates new pull request with labels for repository. -func NewPullRequest(repo *models.Repository, pull *models.Issue, labelIDs []int64, uuids []string, pr *models.PullRequest, patch []byte, assigneeIDs []int64) error { - if err := models.NewPullRequest(repo, pull, labelIDs, uuids, pr, patch); err != nil { +func NewPullRequest(repo *models.Repository, pull *models.Issue, labelIDs []int64, uuids []string, pr *models.PullRequest, patchFileSize int64, patchFileName string, assigneeIDs []int64) error { + if err := models.NewPullRequest(repo, pull, labelIDs, uuids, pr, patchFileSize, patchFileName); err != nil { return err }