diff --git a/gomock/call.go b/gomock/call.go index 3102e659..276b2fda 100644 --- a/gomock/call.go +++ b/gomock/call.go @@ -212,14 +212,68 @@ func (c *Call) String() string { // Tests if the given call matches the expected call. // If yes, returns nil. If no, returns error with message explaining why it does not match. func (c *Call) matches(args []interface{}) error { - if len(args) != len(c.args) { - return fmt.Errorf("Expected call at %s has the wrong number of arguments. Got: %s, want: %s", - c.origin, strconv.Itoa(len(args)), strconv.Itoa(len(c.args))) - } - for i, m := range c.args { - if !m.Matches(args[i]) { - return fmt.Errorf("Expected call at %s doesn't match the argument at index %s.\nGot: %v\nWant: %v", - c.origin, strconv.Itoa(i), args[i], m) + if !c.methodType.IsVariadic() { + if len(args) != len(c.args) { + return fmt.Errorf("Expected call at %s has the wrong number of arguments. Got: %d, want: %d", + c.origin, len(args), len(c.args)) + } + + for i, m := range c.args { + if !m.Matches(args[i]) { + return fmt.Errorf("Expected call at %s doesn't match the argument at index %s.\nGot: %v\nWant: %v", + c.origin, strconv.Itoa(i), args[i], m) + } + } + } else { + if len(c.args) < c.methodType.NumIn()-1 { + return fmt.Errorf("Expected call at %s has the wrong number of matchers. Got: %d, want: %d", + c.origin, len(c.args), c.methodType.NumIn()-1) + } + if len(c.args) != c.methodType.NumIn() && len(args) != len(c.args) { + return fmt.Errorf("Expected call at %s has the wrong number of arguments. Got: %d, want: %d", + c.origin, len(args), len(c.args)) + } + if len(args) < len(c.args)-1 { + return fmt.Errorf("Expected call at %s has the wrong number of arguments. Got: %d, want: greater than or equal to %d", + c.origin, len(args), len(c.args)-1) + } + + for i, m := range c.args { + if i == len(c.args)-1 { + // The last arg has a possibility of a variadic argument, so let it branch + + // sample: Foo(a int, b int, c ...int) + + if len(c.args) == len(args) { + if c.args[i].Matches(args[i]) { + // Got Foo(a, b, c) want Foo(matcherA, matcherB, gomock.Any()) + // Got Foo(a, b, c) want Foo(matcherA, matcherB, someSliceMatcher) + // Got Foo(a, b, c) want Foo(matcherA, matcherB, matcherC) + // Got Foo(a, b) want Foo(matcherA, matcherB) + // Got Foo(a, b, c, d) want Foo(matcherA, matcherB, matcherC, matcherD) + break + } + } else if c.args[i].Matches(args[i:]) { + // Got Foo(a, b, c, d, e) want Foo(matcherA, matcherB, gomock.Any()) + // Got Foo(a, b, c, d, e) want Foo(matcherA, matcherB, someSliceMatcher) + // Got Foo(a, b) want Foo(matcherA, matcherB, gomock.Any()) + // Got Foo(a, b) want Foo(matcherA, matcherB, someEmptySliceMatcher) + break + } + // Wrong number of matchers or not match. Fail. + // Got Foo(a, b) want Foo(matcherA, matcherB, matcherC, matcherD) + // Got Foo(a, b, c) want Foo(matcherA, matcherB, matcherC, matcherD) + // Got Foo(a, b, c, d) want Foo(matcherA, matcherB, matcherC, matcherD, matcherE) + // Got Foo(a, b, c, d, e) want Foo(matcherA, matcherB, matcherC, matcherD) + // Got Foo(a, b, c) want Foo(matcherA, matcherB) + return fmt.Errorf("Expected call at %s doesn't match the argument at index %s.\nGot: %v\nWant: %v", + c.origin, strconv.Itoa(i), args[i:], c.args[i]) + } + + if !m.Matches(args[i]) { + return fmt.Errorf("Expected call at %s doesn't match the argument at index %s.\nGot: %v\nWant: %v", + c.origin, strconv.Itoa(i), args[i], m) + } } } diff --git a/sample/user_test.go b/sample/user_test.go index c8c3864a..875e81f5 100644 --- a/sample/user_test.go +++ b/sample/user_test.go @@ -68,8 +68,7 @@ func TestVariadicFunction(t *testing.T) { defer ctrl.Finish() mockIndex := mock_user.NewMockIndex(ctrl) - m := mockIndex.EXPECT().Ellip("%d", 0, 1, 1, 2, 3) - m.Do(func(format string, nums ...int) { + mockIndex.EXPECT().Ellip("%d", 0, 1, 1, 2, 3).Do(func(format string, nums ...int) { sum := 0 for _, value := range nums { sum += value @@ -78,8 +77,48 @@ func TestVariadicFunction(t *testing.T) { t.Errorf("Expected 7, got %d", sum) } }) + mockIndex.EXPECT().Ellip("%d", gomock.Any()).Do(func(format string, nums ...int) { + sum := 0 + for _, value := range nums { + sum += value + } + if sum != 7 { + t.Errorf("Expected 7, got %d", sum) + } + }) + mockIndex.EXPECT().Ellip("%d", gomock.Any()).Do(func(format string, nums ...int) { + sum := 0 + for _, value := range nums { + sum += value + } + if sum != 0 { + t.Errorf("Expected 0, got %d", sum) + } + }) + mockIndex.EXPECT().Ellip("%d", gomock.Any()).Do(func(format string, nums ...int) { + sum := 0 + for _, value := range nums { + sum += value + } + if sum != 0 { + t.Errorf("Expected 0, got %d", sum) + } + }) + mockIndex.EXPECT().Ellip("%d").Do(func(format string, nums ...int) { + sum := 0 + for _, value := range nums { + sum += value + } + if sum != 0 { + t.Errorf("Expected 0, got %d", sum) + } + }) mockIndex.Ellip("%d", 0, 1, 1, 2, 3) + mockIndex.Ellip("%d", 0, 1, 1, 2, 3) + mockIndex.Ellip("%d", 0) + mockIndex.Ellip("%d") + mockIndex.Ellip("%d") } func TestGrabPointer(t *testing.T) {