Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: function_call can be a string or an object #374

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ const (
const chatCompletionsSuffix = "/chat/completions"

var (
ErrChatCompletionInvalidModel = errors.New("this model is not supported with this method, please use CreateCompletion client method instead") //nolint:lll
ErrChatCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateChatCompletionStream") //nolint:lll
ErrChatCompletionInvalidModel = errors.New("this model is not supported with this method, please use CreateCompletion client method instead") //nolint:lll
ErrChatCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateChatCompletionStream") //nolint:lll
ErrChatCompletionInvalidFunctionCall = errors.New(`FunctionCall parameter only supports "none", "auto", or a map[string]string`) //nolint:lll
)

type ChatCompletionMessage struct {
Expand Down Expand Up @@ -55,7 +56,20 @@ type ChatCompletionRequest struct {
LogitBias map[string]int `json:"logit_bias,omitempty"`
User string `json:"user,omitempty"`
Functions []*FunctionDefine `json:"functions,omitempty"`
FunctionCall string `json:"function_call,omitempty"`
FunctionCall any `json:"function_call,omitempty"`
}

func checkFunctionCall(value any) bool {
if value == nil {
return true
}
if v, isString := value.(string); isString {
return v == "none" || v == "auto"
}
if v, isMap := value.(map[string]string); isMap {
return v["name"] != ""
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about adding a structure as a usable type?
For example:

type Foo struct {
	Name string `json:"name"`
}

return false
}

type FunctionDefine struct {
Expand Down Expand Up @@ -146,6 +160,11 @@ func (c *Client) CreateChatCompletion(
return
}

if !checkFunctionCall(request.FunctionCall) {
err = ErrChatCompletionInvalidFunctionCall
return
}

req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request)
if err != nil {
return
Expand Down
10 changes: 8 additions & 2 deletions chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ import (
)

type ChatCompletionStreamChoiceDelta struct {
Content string `json:"content,omitempty"`
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
Role string `json:"role,omitempty"`
FunctionCall *FunctionCall `json:"function_call,omitempty"`
}

type ChatCompletionStreamChoice struct {
Expand Down Expand Up @@ -46,6 +47,11 @@ func (c *Client) CreateChatCompletionStream(
return
}

if !checkFunctionCall(request.FunctionCall) {
err = ErrChatCompletionInvalidFunctionCall
return
}

request.Stream = true
req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request, request.Model)
if err != nil {
Expand Down
37 changes: 37 additions & 0 deletions chat_stream_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package openai_test

import (
"fmt"

. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"

Expand Down Expand Up @@ -128,6 +130,41 @@ func TestCreateChatCompletionStream(t *testing.T) {
}
}

func TestChatCompletionsStreamWithFunctionCall(t *testing.T) {
config := DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
client := NewClientWithConfig(config)
ctx := context.Background()

cases := []struct {
FunctionCall any
Pass bool
}{
{"none", true},
{"auto", true},
{map[string]string{"name": "test"}, true},
{nil, true},
{"invalid", false},
{map[string]string{}, false},
}
for _, c := range cases {
req := ChatCompletionRequest{
FunctionCall: c.FunctionCall,
}
_, err := client.CreateChatCompletionStream(ctx, req)
if c.Pass {
checks.ErrorIsNot(t, err, ErrChatCompletionInvalidFunctionCall, "unexpected error")
} else {
checks.ErrorIs(
t,
err,
ErrChatCompletionInvalidFunctionCall,
fmt.Sprintf("should not pass for function call: %v", c.FunctionCall),
)
}
}
}

func TestCreateChatCompletionStreamError(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
Expand Down
35 changes: 35 additions & 0 deletions chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,41 @@ func TestChatCompletionsWithStream(t *testing.T) {
checks.ErrorIs(t, err, ErrChatCompletionStreamNotSupported, "unexpected error")
}

func TestChatCompletionsWithFunctionCall(t *testing.T) {
config := DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
client := NewClientWithConfig(config)
ctx := context.Background()

cases := []struct {
FunctionCall any
Pass bool
}{
{"none", true},
{"auto", true},
{map[string]string{"name": "test"}, true},
{nil, true},
{"invalid", false},
{map[string]string{}, false},
}
for _, c := range cases {
req := ChatCompletionRequest{
FunctionCall: c.FunctionCall,
}
_, err := client.CreateChatCompletion(ctx, req)
if c.Pass {
checks.ErrorIsNot(t, err, ErrChatCompletionInvalidFunctionCall, "unexpected error")
} else {
checks.ErrorIs(
t,
err,
ErrChatCompletionInvalidFunctionCall,
fmt.Sprintf("should not pass for function call: %v", c.FunctionCall),
)
}
}
}

// TestCompletions Tests the completions endpoint of the API using the mocked server.
func TestChatCompletions(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
Expand Down