diff --git a/github/pulls.go b/github/pulls.go index 51c6ccb576c..9a161917dda 100644 --- a/github/pulls.go +++ b/github/pulls.go @@ -189,12 +189,33 @@ func (s *PullRequestsService) Create(owner string, repo string, pull *NewPullReq return p, resp, err } +type pullRequestUpdate struct { + Title *string `json:"title,omitempty"` + Body *string `json:"body,omitempty"` + State *string `json:"state,omitempty"` + Base *string `json:"base,omitempty"` +} + // Edit a pull request. // +// The following fields are editable: Title, Body, State, and Base.Ref. +// Base.Ref updates the base branch of the pull request. +// // GitHub API docs: https://developer.github.com/v3/pulls/#update-a-pull-request func (s *PullRequestsService) Edit(owner string, repo string, number int, pull *PullRequest) (*PullRequest, *Response, error) { u := fmt.Sprintf("repos/%v/%v/pulls/%d", owner, repo, number) - req, err := s.client.NewRequest("PATCH", u, pull) + + update := new(pullRequestUpdate) + if pull != nil { + update.Title = pull.Title + update.Body = pull.Body + update.State = pull.State + if pull.Base != nil { + update.Base = pull.Base.Ref + } + } + + req, err := s.client.NewRequest("PATCH", u, update) if err != nil { return nil, nil, err } diff --git a/github/pulls_test.go b/github/pulls_test.go index f0a4854e090..f1d78bdc7ed 100644 --- a/github/pulls_test.go +++ b/github/pulls_test.go @@ -8,6 +8,7 @@ package github import ( "encoding/json" "fmt" + "io" "net/http" "reflect" "strings" @@ -235,28 +236,58 @@ func TestPullRequestsService_Edit(t *testing.T) { setup() defer teardown() - input := &PullRequest{Title: String("t")} + tests := []struct { + input *PullRequest + sendResponse string - mux.HandleFunc("/repos/o/r/pulls/1", func(w http.ResponseWriter, r *http.Request) { - v := new(PullRequest) - json.NewDecoder(r.Body).Decode(v) + wantUpdate string + want *PullRequest + }{ + { + input: &PullRequest{Title: String("t")}, + sendResponse: `{"number":1}`, + wantUpdate: `{"title":"t"}`, + want: &PullRequest{Number: Int(1)}, + }, + { + // nil request + sendResponse: `{}`, + wantUpdate: `{}`, + want: &PullRequest{}, + }, + { + // base update + input: &PullRequest{Base: &PullRequestBranch{Ref: String("master")}}, + sendResponse: `{"number":1,"base":{"ref":"master"}}`, + wantUpdate: `{"base":"master"}`, + want: &PullRequest{ + Number: Int(1), + Base: &PullRequestBranch{Ref: String("master")}, + }, + }, + } - testMethod(t, r, "PATCH") - if !reflect.DeepEqual(v, input) { - t.Errorf("Request body = %+v, want %+v", v, input) - } + for i, tt := range tests { + madeRequest := false + mux.HandleFunc(fmt.Sprintf("/repos/o/r/pulls/%v", i), func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r, "PATCH") + testBody(t, r, tt.wantUpdate+"\n") + io.WriteString(w, tt.sendResponse) + madeRequest = true + }) - fmt.Fprint(w, `{"number":1}`) - }) + pull, _, err := client.PullRequests.Edit("o", "r", i, tt.input) + if err != nil { + t.Errorf("%d: PullRequests.Edit returned error: %v", i, err) + } - pull, _, err := client.PullRequests.Edit("o", "r", 1, input) - if err != nil { - t.Errorf("PullRequests.Edit returned error: %v", err) - } + if !reflect.DeepEqual(pull, tt.want) { + t.Errorf("%d: PullRequests.Edit returned %+v, want %+v", i, pull, tt.want) + } - want := &PullRequest{Number: Int(1)} - if !reflect.DeepEqual(pull, want) { - t.Errorf("PullRequests.Edit returned %+v, want %+v", pull, want) + if !madeRequest { + t.Errorf("%d: PullRequest.Edit did not make the expected request", i) + } } }