From 32ae3754a4bb8fdf777fc615029e3f9439d0f6b3 Mon Sep 17 00:00:00 2001 From: Junjie Gao Date: Mon, 2 Dec 2024 09:10:37 +0000 Subject: [PATCH] fix: update Signed-off-by: Junjie Gao --- internal/io/limitedwriter.go | 9 +++++++-- internal/io/limitedwriter_test.go | 6 +++--- plugin/plugin.go | 2 ++ 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/internal/io/limitedwriter.go b/internal/io/limitedwriter.go index 241be4b9..7a1990ba 100644 --- a/internal/io/limitedwriter.go +++ b/internal/io/limitedwriter.go @@ -16,7 +16,12 @@ package io -import "io" +import ( + "errors" + "io" +) + +var ErrLimitExceeded = errors.New("write limit exceeded") // LimitedWriter is a writer that writes to an underlying writer up to a limit. type LimitedWriter struct { @@ -36,7 +41,7 @@ func LimitWriter(w io.Writer, limit int64) *LimitedWriter { // Write writes p to the underlying writer up to the limit. func (l *LimitedWriter) Write(p []byte) (int, error) { if l.N <= 0 { - return 0, io.ErrShortWrite + return 0, ErrLimitExceeded } if int64(len(p)) > l.N { p = p[:l.N] diff --git a/internal/io/limitedwriter_test.go b/internal/io/limitedwriter_test.go index ac5f5c15..264886df 100644 --- a/internal/io/limitedwriter_test.go +++ b/internal/io/limitedwriter_test.go @@ -15,7 +15,7 @@ package io import ( "bytes" - "io" + "errors" "testing" ) @@ -60,8 +60,8 @@ func TestLimitWriterFailed(t *testing.T) { t.Fatalf("unexpected error: %v", err) } _, err = lw.Write([]byte(longString)) - expectedErr := io.ErrShortWrite - if err != expectedErr { + expectedErr := errors.New("write limit exceeded") + if err.Error() != expectedErr.Error() { t.Errorf("expected error %v, got %v", expectedErr, err) } } diff --git a/plugin/plugin.go b/plugin/plugin.go index 1ea674d7..6402fb96 100644 --- a/plugin/plugin.go +++ b/plugin/plugin.go @@ -222,6 +222,8 @@ func (c execCommander) Output(ctx context.Context, name string, command plugin.C var stdout, stderr bytes.Buffer cmd := exec.CommandContext(ctx, name, string(command)) cmd.Stdin = bytes.NewReader(req) + // The limit writer will be handled by the caller in run() by comparing the + // bytes written with the expected length of the bytes. cmd.Stderr = io.LimitWriter(&stderr, maxPluginOutputSize) cmd.Stdout = io.LimitWriter(&stdout, maxPluginOutputSize) err := cmd.Run()