diff --git a/xerrors/errors.go b/xerrors/errors.go index 8b0b4a9..3632258 100644 --- a/xerrors/errors.go +++ b/xerrors/errors.go @@ -2,6 +2,7 @@ package xerrors import ( "errors" + "slices" ) // IsTimeout reports whether the provided error indicates a timeout condition. @@ -10,11 +11,14 @@ func IsTimeout(err error) bool { return false } - e, ok := err.(interface{ Timeout() bool }) - if ok && e.Timeout() { + if te, ok := err.(interface{ Timeout() bool }); ok && te.Timeout() { return true } + if joined, ok := err.(interface{ Unwrap() []error }); ok { + return slices.ContainsFunc(joined.Unwrap(), IsTimeout) + } + inner := errors.Unwrap(err) return inner != nil && IsTimeout(inner) diff --git a/xerrors/errors_test.go b/xerrors/errors_test.go index f63edb2..a598449 100644 --- a/xerrors/errors_test.go +++ b/xerrors/errors_test.go @@ -2,6 +2,7 @@ package xerrors_test import ( "context" + "errors" "fmt" "testing" @@ -24,6 +25,8 @@ func TestIsTimeout(t *testing.T) { {name: "context timeout", err: context.DeadlineExceeded, want: true}, {name: "wrapped context timeout", err: fmt.Errorf("bad stuff: %w", context.DeadlineExceeded), want: true}, {name: "custom timeout", err: customTimeoutError{}, want: true}, + {name: "joined errors", err: errors.Join(errors.New("oops"), context.DeadlineExceeded), want: true}, + {name: "joined errors - no timeout", err: errors.Join(errors.New("oops"), errors.New("ididitagain")), want: false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {