diff --git a/internal/io/limitedwriter_test.go b/internal/io/limitedwriter_test.go index 59608c93..ac5f5c15 100644 --- a/internal/io/limitedwriter_test.go +++ b/internal/io/limitedwriter_test.go @@ -15,12 +15,12 @@ package io import ( "bytes" + "io" "testing" ) func TestLimitWriter(t *testing.T) { limit := int64(10) - longString := "1234567891011" tests := []struct { input string @@ -46,13 +46,22 @@ func TestLimitWriter(t *testing.T) { if buf.String() != tt.expected { t.Errorf("expected buffer %q, got %q", tt.expected, buf.String()) } + } +} - n, err = lw.Write([]byte(longString)) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if n == len(longString) { - t.Errorf("should not write more than the limit") - } +func TestLimitWriterFailed(t *testing.T) { + limit := int64(10) + longString := "1234567891011" + + var buf bytes.Buffer + lw := LimitWriter(&buf, limit) + _, err := lw.Write([]byte(longString)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + _, err = lw.Write([]byte(longString)) + expectedErr := io.ErrShortWrite + if err != expectedErr { + t.Errorf("expected error %v, got %v", expectedErr, err) } }