Skip to content

Commit

Permalink
fix: function_call can be a string or an object
Browse files Browse the repository at this point in the history
  • Loading branch information
j178 committed Jun 15, 2023
1 parent 0bd14f9 commit 3fe89c4
Show file tree
Hide file tree
Showing 4 changed files with 241 additions and 118 deletions.
25 changes: 22 additions & 3 deletions chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ const (
const chatCompletionsSuffix = "/chat/completions"

var (
ErrChatCompletionInvalidModel = errors.New("this model is not supported with this method, please use CreateCompletion client method instead") //nolint:lll
ErrChatCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateChatCompletionStream") //nolint:lll
ErrChatCompletionInvalidModel = errors.New("this model is not supported with this method, please use CreateCompletion client method instead") //nolint:lll
ErrChatCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateChatCompletionStream") //nolint:lll
ErrChatCompletionInvalidFunctionCall = errors.New(`FunctionCall parameter only supports "none", "auto", or a map[string]string`) //nolint:lll
)

type ChatCompletionMessage struct {
Expand Down Expand Up @@ -55,7 +56,20 @@ type ChatCompletionRequest struct {
LogitBias map[string]int `json:"logit_bias,omitempty"`
User string `json:"user,omitempty"`
Functions []*FunctionDefine `json:"functions,omitempty"`
FunctionCall string `json:"function_call,omitempty"`
FunctionCall any `json:"function_call,omitempty"`
}

func checkFunctionCall(value any) bool {
if value == nil {
return true
}
if v, isString := value.(string); isString {
return v == "none" || v == "auto"
}
if v, isMap := value.(map[string]string); isMap {
return v["name"] != ""
}
return false
}

type FunctionDefine struct {
Expand Down Expand Up @@ -146,6 +160,11 @@ func (c *Client) CreateChatCompletion(
return
}

if !checkFunctionCall(request.FunctionCall) {
err = ErrChatCompletionInvalidFunctionCall
return
}

req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request)
if err != nil {
return
Expand Down
10 changes: 8 additions & 2 deletions chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ import (
)

type ChatCompletionStreamChoiceDelta struct {
Content string `json:"content,omitempty"`
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
Role string `json:"role,omitempty"`
FunctionCall *FunctionCall `json:"function_call,omitempty"`
}

type ChatCompletionStreamChoice struct {
Expand Down Expand Up @@ -46,6 +47,11 @@ func (c *Client) CreateChatCompletionStream(
return
}

if !checkFunctionCall(request.FunctionCall) {
err = ErrChatCompletionInvalidFunctionCall
return
}

request.Stream = true
req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request, request.Model)
if err != nil {
Expand Down
239 changes: 148 additions & 91 deletions chat_stream_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package openai_test

import (
"fmt"

. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"

Expand Down Expand Up @@ -37,39 +39,43 @@ func TestChatCompletionsStreamWrongModel(t *testing.T) {
func TestCreateChatCompletionStream(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
server.RegisterHandler(
"/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")

// Send test responses
dataBytes := []byte{}
dataBytes = append(dataBytes, []byte("event: message\n")...)
//nolint:lll
data := `{"id":"1","object":"completion","created":1598069254,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"response1"},"finish_reason":"max_tokens"}]}`
dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)
// Send test responses
dataBytes := []byte{}
dataBytes = append(dataBytes, []byte("event: message\n")...)
//nolint:lll
data := `{"id":"1","object":"completion","created":1598069254,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"response1"},"finish_reason":"max_tokens"}]}`
dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)

dataBytes = append(dataBytes, []byte("event: message\n")...)
//nolint:lll
data = `{"id":"2","object":"completion","created":1598069255,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"response2"},"finish_reason":"max_tokens"}]}`
dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)
dataBytes = append(dataBytes, []byte("event: message\n")...)
//nolint:lll
data = `{"id":"2","object":"completion","created":1598069255,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"response2"},"finish_reason":"max_tokens"}]}`
dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)

dataBytes = append(dataBytes, []byte("event: done\n")...)
dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...)
dataBytes = append(dataBytes, []byte("event: done\n")...)
dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...)

_, err := w.Write(dataBytes)
checks.NoError(t, err, "Write error")
})
_, err := w.Write(dataBytes)
checks.NoError(t, err, "Write error")
},
)

stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{
MaxTokens: 5,
Model: GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{
{
Role: ChatMessageRoleUser,
Content: "Hello!",
stream, err := client.CreateChatCompletionStream(
context.Background(), ChatCompletionRequest{
MaxTokens: 5,
Model: GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{
{
Role: ChatMessageRoleUser,
Content: "Hello!",
},
},
Stream: true,
},
Stream: true,
})
)
checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close()

Expand Down Expand Up @@ -128,43 +134,82 @@ func TestCreateChatCompletionStream(t *testing.T) {
}
}

func TestChatCompletionsStreamWithFunctionCall(t *testing.T) {
config := DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
client := NewClientWithConfig(config)
ctx := context.Background()

cases := []struct {
FunctionCall any
Pass bool
}{
{"none", true},
{"auto", true},
{map[string]string{"name": "test"}, true},
{nil, true},
{"invalid", false},
{map[string]string{}, false},
}
for _, c := range cases {
req := ChatCompletionRequest{
FunctionCall: c.FunctionCall,
}
_, err := client.CreateChatCompletionStream(ctx, req)
if c.Pass {
checks.ErrorIsNot(t, err, ErrChatCompletionInvalidFunctionCall, "unexpected error")
} else {
checks.ErrorIs(
t,
err,
ErrChatCompletionInvalidFunctionCall,
fmt.Sprintf("should not pass for function call: %v", c.FunctionCall),
)
}
}
}

func TestCreateChatCompletionStreamError(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")

// Send test responses
dataBytes := []byte{}
dataStr := []string{
`{`,
`"error": {`,
`"message": "Incorrect API key provided: sk-***************************************",`,
`"type": "invalid_request_error",`,
`"param": null,`,
`"code": "invalid_api_key"`,
`}`,
`}`,
}
for _, str := range dataStr {
dataBytes = append(dataBytes, []byte(str+"\n")...)
}
server.RegisterHandler(
"/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")

_, err := w.Write(dataBytes)
checks.NoError(t, err, "Write error")
})
// Send test responses
dataBytes := []byte{}
dataStr := []string{
`{`,
`"error": {`,
`"message": "Incorrect API key provided: sk-***************************************",`,
`"type": "invalid_request_error",`,
`"param": null,`,
`"code": "invalid_api_key"`,
`}`,
`}`,
}
for _, str := range dataStr {
dataBytes = append(dataBytes, []byte(str+"\n")...)
}

stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{
MaxTokens: 5,
Model: GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{
{
Role: ChatMessageRoleUser,
Content: "Hello!",
_, err := w.Write(dataBytes)
checks.NoError(t, err, "Write error")
},
)

stream, err := client.CreateChatCompletionStream(
context.Background(), ChatCompletionRequest{
MaxTokens: 5,
Model: GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{
{
Role: ChatMessageRoleUser,
Content: "Hello!",
},
},
Stream: true,
},
Stream: true,
})
)
checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close()

Expand All @@ -181,31 +226,35 @@ func TestCreateChatCompletionStreamError(t *testing.T) {
func TestCreateChatCompletionStreamRateLimitError(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(429)

// Send test responses
dataBytes := []byte(`{"error":{` +
`"message": "You are sending requests too quickly.",` +
`"type":"rate_limit_reached",` +
`"param":null,` +
`"code":"rate_limit_reached"}}`)

_, err := w.Write(dataBytes)
checks.NoError(t, err, "Write error")
})
_, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{
MaxTokens: 5,
Model: GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{
{
Role: ChatMessageRoleUser,
Content: "Hello!",
server.RegisterHandler(
"/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(429)

// Send test responses
dataBytes := []byte(`{"error":{` +
`"message": "You are sending requests too quickly.",` +
`"type":"rate_limit_reached",` +
`"param":null,` +
`"code":"rate_limit_reached"}}`)

_, err := w.Write(dataBytes)
checks.NoError(t, err, "Write error")
},
)
_, err := client.CreateChatCompletionStream(
context.Background(), ChatCompletionRequest{
MaxTokens: 5,
Model: GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{
{
Role: ChatMessageRoleUser,
Content: "Hello!",
},
},
Stream: true,
},
Stream: true,
})
)
var apiErr *APIError
if !errors.As(err, &apiErr) {
t.Errorf("TestCreateChatCompletionStreamRateLimitError did not return APIError")
Expand All @@ -222,7 +271,8 @@ func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) {

client, server, teardown := setupAzureTestServer()
defer teardown()
server.RegisterHandler("/openai/deployments/gpt-35-turbo/chat/completions",
server.RegisterHandler(
"/openai/deployments/gpt-35-turbo/chat/completions",
func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusTooManyRequests)
Expand All @@ -231,26 +281,33 @@ func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) {
_, err := w.Write(dataBytes)

checks.NoError(t, err, "Write error")
})
},
)

apiErr := &APIError{}
_, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{
MaxTokens: 5,
Model: GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{
{
Role: ChatMessageRoleUser,
Content: "Hello!",
_, err := client.CreateChatCompletionStream(
context.Background(), ChatCompletionRequest{
MaxTokens: 5,
Model: GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{
{
Role: ChatMessageRoleUser,
Content: "Hello!",
},
},
Stream: true,
},
Stream: true,
})
)
if !errors.As(err, &apiErr) {
t.Errorf("Did not return APIError: %+v\n", apiErr)
return
}
if apiErr.HTTPStatusCode != http.StatusTooManyRequests {
t.Errorf("Did not return HTTPStatusCode got = %d, want = %d\n", apiErr.HTTPStatusCode, http.StatusTooManyRequests)
t.Errorf(
"Did not return HTTPStatusCode got = %d, want = %d\n",
apiErr.HTTPStatusCode,
http.StatusTooManyRequests,
)
return
}
code, ok := apiErr.Code.(string)
Expand Down
Loading

0 comments on commit 3fe89c4

Please sign in to comment.