diff --git a/src/errors/join.go b/src/errors/join.go index 1c486d591e35bf..95becb29476535 100644 --- a/src/errors/join.go +++ b/src/errors/join.go @@ -12,6 +12,11 @@ package errors // between each string. // // A non-nil error returned by Join implements the Unwrap() []error method. +// +// Calling Join on an error that was previously returned by Join wraps +// the error again. Consequently, calling Unwrap() []error on the result +// of Join(Join(err1, err2), err3) returns a slice with 2 items, of which +// the first one implements Unwrap() []error too. func Join(errs ...error) error { n := 0 for _, err := range errs { diff --git a/src/errors/join_test.go b/src/errors/join_test.go index 4828dc4d755fd6..0a44b23c9a0c10 100644 --- a/src/errors/join_test.go +++ b/src/errors/join_test.go @@ -70,3 +70,48 @@ func TestJoinErrorMethod(t *testing.T) { } } } + +func TestJoinWithJoinedError(t *testing.T) { + err1 := errors.New("err1") + err2 := errors.New("err2") + + var err error + err = errors.Join(err, err1) + if err == nil { + t.Fatal("errors.Join(err, err1) = nil, want non-nil") + } + + gotErrs := err.(interface{ Unwrap() []error }).Unwrap() + if len(gotErrs) != 1 { + t.Fatalf("errors.Join(err, err1) returns errors with len=%v, want len==1", len(gotErrs)) + } + + err = errors.Join(err, err2) + if err == nil { + t.Fatal("errors.Join(err, err2) = nil, want non-nil") + } + + gotErrs = err.(interface{ Unwrap() []error }).Unwrap() + if len(gotErrs) != 2 { + t.Fatalf("errors.Join(err, err2) returns errors with len=%v, want len==2", len(gotErrs)) + } + + // Wraps the error again, so the resulting joined error will have len==1 + err = errors.Join(err, nil) + if err == nil { + t.Fatal("errors.Join(err, nil) = nil, want non-nil") + } + + gotErrs = err.(interface{ Unwrap() []error }).Unwrap() + if len(gotErrs) != 1 { + t.Fatalf("errors.Join(err, nil) returns errors with len=%v, want len==1", len(gotErrs)) + } + + if err.Error() != "err1\nerr2" { + t.Errorf("Join(err, nil).Error() = %q; want %q", err.Error(), "err1\nerr2") + } + + if _, ok := gotErrs[0].(interface{ Unwrap() []error }); !ok { + t.Error("first error returned by errors.Join(err, nil) is not a joined error") + } +}