Skip to content

Commit

Permalink
Merge pull request #342 from grafana/workaround-temperature-0
Browse files Browse the repository at this point in the history
  • Loading branch information
sd2k authored May 10, 2024
2 parents bb64d6c + 291497b commit edfb330
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 1 deletion.
27 changes: 27 additions & 0 deletions packages/grafana-llm-app/pkg/plugin/llm_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"math"
"math/rand"
"strings"

Expand Down Expand Up @@ -67,6 +68,32 @@ type ChatCompletionRequest struct {
Model Model `json:"model"`
}

// UnmarshalJSON implements json.Unmarshaler.
// We have a custom implementation here to check whether temperature is being
// explicitly set to `0` in the incoming request, because the `openai.ChatCompletionRequest`
// struct has `omitempty` on the Temperature field and would omit it when marshaling.
// If there is an explicit 0 value in the request, we set it to `math.SmallestNonzeroFloat32`,
// a workaround mentioned in https://github.com/sashabaranov/go-openai/issues/9#issuecomment-894845206.
func (c *ChatCompletionRequest) UnmarshalJSON(data []byte) error {
// Create a wrapper type alias to avoid recursion, otherwise the
// subsequent call to UnmarshalJSON would call this method forever.
type Alias ChatCompletionRequest
var a Alias
if err := json.Unmarshal(data, &a); err != nil {
return err
}
// Also unmarshal to a map to check if temperature is being set explicitly in the request.
r := map[string]any{}
if err := json.Unmarshal(data, &r); err != nil {
return err
}
if t, ok := r["temperature"].(float64); ok && t == 0 {
a.ChatCompletionRequest.Temperature = math.SmallestNonzeroFloat32
}
*c = ChatCompletionRequest(a)
return nil
}

type ChatCompletionStreamResponse struct {
openai.ChatCompletionStreamResponse
// Random padding used to mitigate side channel attacks.
Expand Down
46 changes: 45 additions & 1 deletion packages/grafana-llm-app/pkg/plugin/llm_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package plugin

import (
"encoding/json"
"math"
"testing"

"github.com/sashabaranov/go-openai"
"github.com/stretchr/testify/assert"
)

func TestModelFromString(t *testing.T) {
Expand Down Expand Up @@ -83,7 +85,7 @@ func TestModelFromString(t *testing.T) {
}
}

func TestUnmarshalJSON(t *testing.T) {
func TestModelUnmarshalJSON(t *testing.T) {
tests := []struct {
input []byte
expected Model
Expand Down Expand Up @@ -164,6 +166,48 @@ func TestUnmarshalJSON(t *testing.T) {
}
}

func TestChatCompletionRequestUnmarshalJSON(t *testing.T) {
for _, tt := range []struct {
input []byte
expected ChatCompletionRequest
}{
{
input: []byte(`{"model":"base"}`),
expected: ChatCompletionRequest{
Model: ModelBase,
ChatCompletionRequest: openai.ChatCompletionRequest{
Temperature: 0,
},
},
},
{
input: []byte(`{"model":"base", "temperature":0.5}`),
expected: ChatCompletionRequest{
Model: ModelBase,
ChatCompletionRequest: openai.ChatCompletionRequest{
Temperature: 0.5,
},
},
},
{
input: []byte(`{"model":"base", "temperature":0}`),
expected: ChatCompletionRequest{
Model: ModelBase,
ChatCompletionRequest: openai.ChatCompletionRequest{
Temperature: math.SmallestNonzeroFloat32,
},
},
},
} {
t.Run(string(tt.input), func(t *testing.T) {
var req ChatCompletionRequest
err := json.Unmarshal(tt.input, &req)
assert.NoError(t, err)
assert.Equal(t, tt.expected, req)
})
}
}

func TestChatCompletionStreamResponseMarshalJSON(t *testing.T) {
resp := ChatCompletionStreamResponse{
ChatCompletionStreamResponse: openai.ChatCompletionStreamResponse{
Expand Down

0 comments on commit edfb330

Please sign in to comment.