diff --git a/chat_stream.go b/chat_stream.go index 842835e15..9378c7124 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -66,7 +66,7 @@ func (c *Client) CreateChatCompletionStream( emptyMessagesLimit: c.config.EmptyMessagesLimit, reader: bufio.NewReader(resp.Body), response: resp, - errAccumulator: newErrorAccumulator(), + errAccumulator: utils.NewErrorAccumulator(), unmarshaler: &utils.JSONUnmarshaler{}, }, } diff --git a/chat_stream_test.go b/chat_stream_test.go index afcb86d5e..77d373c6a 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -1,7 +1,7 @@ -package openai_test +package openai //nolint:testpackage // testing private field import ( - . "github.com/sashabaranov/go-openai" + utils "github.com/sashabaranov/go-openai/internal" "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" @@ -63,9 +63,9 @@ func TestCreateChatCompletionStream(t *testing.T) { // Client portion of the test config := DefaultConfig(test.GetTestToken()) config.BaseURL = server.URL + "/v1" - config.HTTPClient.Transport = &tokenRoundTripper{ - test.GetTestToken(), - http.DefaultTransport, + config.HTTPClient.Transport = &test.TokenRoundTripper{ + Token: test.GetTestToken(), + Fallback: http.DefaultTransport, } client := NewClientWithConfig(config) @@ -170,9 +170,9 @@ func TestCreateChatCompletionStreamError(t *testing.T) { // Client portion of the test config := DefaultConfig(test.GetTestToken()) config.BaseURL = server.URL + "/v1" - config.HTTPClient.Transport = &tokenRoundTripper{ - test.GetTestToken(), - http.DefaultTransport, + config.HTTPClient.Transport = &test.TokenRoundTripper{ + Token: test.GetTestToken(), + Fallback: http.DefaultTransport, } client := NewClientWithConfig(config) @@ -227,9 +227,9 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) { // Client portion of the test config := DefaultConfig(test.GetTestToken()) config.BaseURL = ts.URL + "/v1" - config.HTTPClient.Transport = &tokenRoundTripper{ - test.GetTestToken(), - http.DefaultTransport, + config.HTTPClient.Transport = &test.TokenRoundTripper{ + Token: test.GetTestToken(), + Fallback: http.DefaultTransport, } client := NewClientWithConfig(config) @@ -255,6 +255,33 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) { t.Logf("%+v\n", apiErr) } +func TestCreateChatCompletionStreamErrorAccumulatorWriteErrors(t *testing.T) { + var err error + server := test.NewTestServer() + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "error", 200) + }) + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) + + ctx := context.Background() + + stream, err := client.CreateChatCompletionStream(ctx, ChatCompletionRequest{}) + checks.NoError(t, err) + + stream.errAccumulator = &utils.DefaultErrorAccumulator{ + Buffer: &test.FailingErrorBuffer{}, + } + + _, err = stream.Recv() + checks.ErrorIs(t, err, test.ErrTestErrorAccumulatorWriteFailed, "Did not return error when Write failed", err.Error()) +} + // Helper funcs. func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool { if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model { diff --git a/error_accumulator.go b/error_accumulator.go deleted file mode 100644 index 568afdbcd..000000000 --- a/error_accumulator.go +++ /dev/null @@ -1,53 +0,0 @@ -package openai - -import ( - "bytes" - "fmt" - "io" - - utils "github.com/sashabaranov/go-openai/internal" -) - -type errorAccumulator interface { - write(p []byte) error - unmarshalError() *ErrorResponse -} - -type errorBuffer interface { - io.Writer - Len() int - Bytes() []byte -} - -type defaultErrorAccumulator struct { - buffer errorBuffer - unmarshaler utils.Unmarshaler -} - -func newErrorAccumulator() errorAccumulator { - return &defaultErrorAccumulator{ - buffer: &bytes.Buffer{}, - unmarshaler: &utils.JSONUnmarshaler{}, - } -} - -func (e *defaultErrorAccumulator) write(p []byte) error { - _, err := e.buffer.Write(p) - if err != nil { - return fmt.Errorf("error accumulator write error, %w", err) - } - return nil -} - -func (e *defaultErrorAccumulator) unmarshalError() (errResp *ErrorResponse) { - if e.buffer.Len() == 0 { - return - } - - err := e.unmarshaler.Unmarshal(e.buffer.Bytes(), &errResp) - if err != nil { - errResp = nil - } - - return -} diff --git a/error_accumulator_test.go b/error_accumulator_test.go deleted file mode 100644 index 821eb21b4..000000000 --- a/error_accumulator_test.go +++ /dev/null @@ -1,100 +0,0 @@ -package openai //nolint:testpackage // testing private field - -import ( - "bytes" - "context" - "errors" - "net/http" - "testing" - - utils "github.com/sashabaranov/go-openai/internal" - "github.com/sashabaranov/go-openai/internal/test" - "github.com/sashabaranov/go-openai/internal/test/checks" -) - -var ( - errTestUnmarshalerFailed = errors.New("test unmarshaler failed") - errTestErrorAccumulatorWriteFailed = errors.New("test error accumulator failed") -) - -type ( - failingUnMarshaller struct{} - failingErrorBuffer struct{} -) - -func (b *failingErrorBuffer) Write(_ []byte) (n int, err error) { - return 0, errTestErrorAccumulatorWriteFailed -} - -func (b *failingErrorBuffer) Len() int { - return 0 -} - -func (b *failingErrorBuffer) Bytes() []byte { - return []byte{} -} - -func (*failingUnMarshaller) Unmarshal(_ []byte, _ any) error { - return errTestUnmarshalerFailed -} - -func TestErrorAccumulatorReturnsUnmarshalerErrors(t *testing.T) { - accumulator := &defaultErrorAccumulator{ - buffer: &bytes.Buffer{}, - unmarshaler: &failingUnMarshaller{}, - } - - respErr := accumulator.unmarshalError() - if respErr != nil { - t.Fatalf("Did not return nil with empty buffer: %v", respErr) - } - - err := accumulator.write([]byte("{")) - if err != nil { - t.Fatalf("%+v", err) - } - - respErr = accumulator.unmarshalError() - if respErr != nil { - t.Fatalf("Did not return nil when unmarshaler failed: %v", respErr) - } -} - -func TestErrorByteWriteErrors(t *testing.T) { - accumulator := &defaultErrorAccumulator{ - buffer: &failingErrorBuffer{}, - unmarshaler: &utils.JSONUnmarshaler{}, - } - err := accumulator.write([]byte("{")) - if !errors.Is(err, errTestErrorAccumulatorWriteFailed) { - t.Fatalf("Did not return error when write failed: %v", err) - } -} - -func TestErrorAccumulatorWriteErrors(t *testing.T) { - var err error - server := test.NewTestServer() - server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { - http.Error(w, "error", 200) - }) - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - - ctx := context.Background() - - stream, err := client.CreateChatCompletionStream(ctx, ChatCompletionRequest{}) - checks.NoError(t, err) - - stream.errAccumulator = &defaultErrorAccumulator{ - buffer: &failingErrorBuffer{}, - unmarshaler: &utils.JSONUnmarshaler{}, - } - - _, err = stream.Recv() - checks.ErrorIs(t, err, errTestErrorAccumulatorWriteFailed, "Did not return error when write failed", err.Error()) -} diff --git a/internal/error_accumulator.go b/internal/error_accumulator.go new file mode 100644 index 000000000..3d3e805fe --- /dev/null +++ b/internal/error_accumulator.go @@ -0,0 +1,44 @@ +package openai + +import ( + "bytes" + "fmt" + "io" +) + +type ErrorAccumulator interface { + Write(p []byte) error + Bytes() []byte +} + +type errorBuffer interface { + io.Writer + Len() int + Bytes() []byte +} + +type DefaultErrorAccumulator struct { + Buffer errorBuffer +} + +func NewErrorAccumulator() ErrorAccumulator { + return &DefaultErrorAccumulator{ + Buffer: &bytes.Buffer{}, + } +} + +func (e *DefaultErrorAccumulator) Write(p []byte) error { + _, err := e.Buffer.Write(p) + if err != nil { + return fmt.Errorf("error accumulator write error, %w", err) + } + return nil +} + +func (e *DefaultErrorAccumulator) Bytes() (errBytes []byte) { + if e.Buffer.Len() == 0 { + return + } + errBytes = e.Buffer.Bytes() + return +} diff --git a/internal/error_accumulator_test.go b/internal/error_accumulator_test.go new file mode 100644 index 000000000..d48f28177 --- /dev/null +++ b/internal/error_accumulator_test.go @@ -0,0 +1,41 @@ +package openai_test + +import ( + "bytes" + "errors" + "testing" + + utils "github.com/sashabaranov/go-openai/internal" + "github.com/sashabaranov/go-openai/internal/test" +) + +func TestErrorAccumulatorBytes(t *testing.T) { + accumulator := &utils.DefaultErrorAccumulator{ + Buffer: &bytes.Buffer{}, + } + + errBytes := accumulator.Bytes() + if len(errBytes) != 0 { + t.Fatalf("Did not return nil with empty bytes: %s", string(errBytes)) + } + + err := accumulator.Write([]byte("{}")) + if err != nil { + t.Fatalf("%+v", err) + } + + errBytes = accumulator.Bytes() + if len(errBytes) == 0 { + t.Fatalf("Did not return error bytes when has error: %s", string(errBytes)) + } +} + +func TestErrorByteWriteErrors(t *testing.T) { + accumulator := &utils.DefaultErrorAccumulator{ + Buffer: &test.FailingErrorBuffer{}, + } + err := accumulator.Write([]byte("{")) + if !errors.Is(err, test.ErrTestErrorAccumulatorWriteFailed) { + t.Fatalf("Did not return error when write failed: %v", err) + } +} diff --git a/internal/test/failer.go b/internal/test/failer.go new file mode 100644 index 000000000..10ca64e34 --- /dev/null +++ b/internal/test/failer.go @@ -0,0 +1,21 @@ +package test + +import "errors" + +var ( + ErrTestErrorAccumulatorWriteFailed = errors.New("test error accumulator failed") +) + +type FailingErrorBuffer struct{} + +func (b *FailingErrorBuffer) Write(_ []byte) (n int, err error) { + return 0, ErrTestErrorAccumulatorWriteFailed +} + +func (b *FailingErrorBuffer) Len() int { + return 0 +} + +func (b *FailingErrorBuffer) Bytes() []byte { + return []byte{} +} diff --git a/internal/test/helpers.go b/internal/test/helpers.go index 8461e5374..0e63ae82f 100644 --- a/internal/test/helpers.go +++ b/internal/test/helpers.go @@ -3,6 +3,7 @@ package test import ( "github.com/sashabaranov/go-openai/internal/test/checks" + "net/http" "os" "testing" ) @@ -27,3 +28,26 @@ func CreateTestDirectory(t *testing.T) (path string, cleanup func()) { return path, func() { os.RemoveAll(path) } } + +// TokenRoundTripper is a struct that implements the RoundTripper +// interface, specifically to handle the authentication token by adding a token +// to the request header. We need this because the API requires that each +// request include a valid API token in the headers for authentication and +// authorization. +type TokenRoundTripper struct { + Token string + Fallback http.RoundTripper +} + +// RoundTrip takes an *http.Request as input and returns an +// *http.Response and an error. +// +// It is expected to use the provided request to create a connection to an HTTP +// server and return the response, or an error if one occurred. The returned +// Response should have its Body closed. If the RoundTrip method returns an +// error, the Client's Get, Head, Post, and PostForm methods return the same +// error. +func (t *TokenRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + req.Header.Set("Authorization", "Bearer "+t.Token) + return t.Fallback.RoundTrip(req) +} diff --git a/stream.go b/stream.go index b9e784acf..d4e352314 100644 --- a/stream.go +++ b/stream.go @@ -55,7 +55,7 @@ func (c *Client) CreateCompletionStream( emptyMessagesLimit: c.config.EmptyMessagesLimit, reader: bufio.NewReader(resp.Body), response: resp, - errAccumulator: newErrorAccumulator(), + errAccumulator: utils.NewErrorAccumulator(), unmarshaler: &utils.JSONUnmarshaler{}, }, } diff --git a/stream_reader.go b/stream_reader.go index 5eb6df7b8..a9940b0ae 100644 --- a/stream_reader.go +++ b/stream_reader.go @@ -20,7 +20,7 @@ type streamReader[T streamable] struct { reader *bufio.Reader response *http.Response - errAccumulator errorAccumulator + errAccumulator utils.ErrorAccumulator unmarshaler utils.Unmarshaler } @@ -35,7 +35,7 @@ func (stream *streamReader[T]) Recv() (response T, err error) { waitForData: line, err := stream.reader.ReadBytes('\n') if err != nil { - respErr := stream.errAccumulator.unmarshalError() + respErr := stream.unmarshalError() if respErr != nil { err = fmt.Errorf("error, %w", respErr.Error) } @@ -45,7 +45,7 @@ waitForData: var headerData = []byte("data: ") line = bytes.TrimSpace(line) if !bytes.HasPrefix(line, headerData) { - if writeErr := stream.errAccumulator.write(line); writeErr != nil { + if writeErr := stream.errAccumulator.Write(line); writeErr != nil { err = writeErr return } @@ -69,6 +69,20 @@ waitForData: return } +func (stream *streamReader[T]) unmarshalError() (errResp *ErrorResponse) { + errBytes := stream.errAccumulator.Bytes() + if len(errBytes) == 0 { + return + } + + err := stream.unmarshaler.Unmarshal(errBytes, &errResp) + if err != nil { + errResp = nil + } + + return +} + func (stream *streamReader[T]) Close() { stream.response.Body.Close() } diff --git a/stream_reader_test.go b/stream_reader_test.go new file mode 100644 index 000000000..0e45c0b73 --- /dev/null +++ b/stream_reader_test.go @@ -0,0 +1,53 @@ +package openai //nolint:testpackage // testing private field + +import ( + "bufio" + "bytes" + "errors" + "testing" + + utils "github.com/sashabaranov/go-openai/internal" +) + +var errTestUnmarshalerFailed = errors.New("test unmarshaler failed") + +type failingUnMarshaller struct{} + +func (*failingUnMarshaller) Unmarshal(_ []byte, _ any) error { + return errTestUnmarshalerFailed +} + +func TestStreamReaderReturnsUnmarshalerErrors(t *testing.T) { + stream := &streamReader[ChatCompletionStreamResponse]{ + errAccumulator: utils.NewErrorAccumulator(), + unmarshaler: &failingUnMarshaller{}, + } + + respErr := stream.unmarshalError() + if respErr != nil { + t.Fatalf("Did not return nil with empty buffer: %v", respErr) + } + + err := stream.errAccumulator.Write([]byte("{")) + if err != nil { + t.Fatalf("%+v", err) + } + + respErr = stream.unmarshalError() + if respErr != nil { + t.Fatalf("Did not return nil when unmarshaler failed: %v", respErr) + } +} + +func TestStreamReaderReturnsErrTooManyEmptyStreamMessages(t *testing.T) { + stream := &streamReader[ChatCompletionStreamResponse]{ + emptyMessagesLimit: 3, + reader: bufio.NewReader(bytes.NewReader([]byte("\n\n\n\n"))), + errAccumulator: utils.NewErrorAccumulator(), + unmarshaler: &utils.JSONUnmarshaler{}, + } + _, err := stream.Recv() + if !errors.Is(err, ErrTooManyEmptyStreamMessages) { + t.Fatalf("Did not return error when recv failed: %v", err) + } +} diff --git a/stream_test.go b/stream_test.go index a5c591fde..589fc9e26 100644 --- a/stream_test.go +++ b/stream_test.go @@ -57,9 +57,9 @@ func TestCreateCompletionStream(t *testing.T) { // Client portion of the test config := DefaultConfig(test.GetTestToken()) config.BaseURL = server.URL + "/v1" - config.HTTPClient.Transport = &tokenRoundTripper{ - test.GetTestToken(), - http.DefaultTransport, + config.HTTPClient.Transport = &test.TokenRoundTripper{ + Token: test.GetTestToken(), + Fallback: http.DefaultTransport, } client := NewClientWithConfig(config) @@ -142,9 +142,9 @@ func TestCreateCompletionStreamError(t *testing.T) { // Client portion of the test config := DefaultConfig(test.GetTestToken()) config.BaseURL = server.URL + "/v1" - config.HTTPClient.Transport = &tokenRoundTripper{ - test.GetTestToken(), - http.DefaultTransport, + config.HTTPClient.Transport = &test.TokenRoundTripper{ + Token: test.GetTestToken(), + Fallback: http.DefaultTransport, } client := NewClientWithConfig(config) @@ -194,9 +194,9 @@ func TestCreateCompletionStreamRateLimitError(t *testing.T) { // Client portion of the test config := DefaultConfig(test.GetTestToken()) config.BaseURL = ts.URL + "/v1" - config.HTTPClient.Transport = &tokenRoundTripper{ - test.GetTestToken(), - http.DefaultTransport, + config.HTTPClient.Transport = &test.TokenRoundTripper{ + Token: test.GetTestToken(), + Fallback: http.DefaultTransport, } client := NewClientWithConfig(config) @@ -217,29 +217,6 @@ func TestCreateCompletionStreamRateLimitError(t *testing.T) { t.Logf("%+v\n", apiErr) } -// A "tokenRoundTripper" is a struct that implements the RoundTripper -// interface, specifically to handle the authentication token by adding a token -// to the request header. We need this because the API requires that each -// request include a valid API token in the headers for authentication and -// authorization. -type tokenRoundTripper struct { - token string - fallback http.RoundTripper -} - -// RoundTrip takes an *http.Request as input and returns an -// *http.Response and an error. -// -// It is expected to use the provided request to create a connection to an HTTP -// server and return the response, or an error if one occurred. The returned -// Response should have its Body closed. If the RoundTrip method returns an -// error, the Client's Get, Head, Post, and PostForm methods return the same -// error. -func (t *tokenRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - req.Header.Set("Authorization", "Bearer "+t.token) - return t.fallback.RoundTrip(req) -} - // Helper funcs. func compareResponses(r1, r2 CompletionResponse) bool { if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {