Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
bentwnghk committed Jan 11, 2025
2 parents 40ee3b4 + c106893 commit 427812e
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 138 deletions.
124 changes: 2 additions & 122 deletions providers/gemini/chat.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
package gemini

import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/common/config"
"one-api/common/requester"
"one-api/common/utils"
"one-api/providers/base"
Expand Down Expand Up @@ -35,42 +32,7 @@ type OpenAIStreamHandler struct {

func (p *GeminiProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
if p.UseOpenaiAPI {
req, errWithCode := p.GetRequestTextBody(config.RelayModeChatCompletions, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()

var openaiResponse types.ChatCompletionResponse
response := &GeminiOpenaiChatResponse{}
// 发送请求
_, errWithCode = p.Requester.SendRequest(req, response, false)
if errWithCode != nil {
return nil, errWithCode
}

openaiResponse = response.ChatCompletionResponse

if response.Usage == nil || response.Usage.CompletionTokens == 0 {
openaiResponse.Usage = &types.Usage{
PromptTokens: p.Usage.PromptTokens,
CompletionTokens: 0,
TotalTokens: 0,
}
// 那么需要计算
response.Usage.CompletionTokens = common.CountTokenText(response.GetContent(), request.Model)
response.Usage.TotalTokens = response.Usage.PromptTokens + response.Usage.CompletionTokens
} else {
openaiResponse.Usage = &types.Usage{
PromptTokens: response.Usage.PromptTokens,
CompletionTokens: response.Usage.CompletionTokens,
TotalTokens: response.Usage.TotalTokens,
}
}

*p.Usage = *openaiResponse.Usage

return &openaiResponse, nil
return p.OpenAIProvider.CreateChatCompletion(request)
}

geminiRequest, errWithCode := ConvertFromChatOpenai(request)
Expand Down Expand Up @@ -98,32 +60,7 @@ func (p *GeminiProvider) CreateChatCompletionStream(request *types.ChatCompletio

channel := p.GetChannel()
if p.UseOpenaiAPI {
streamOptions := request.StreamOptions
request.StreamOptions = &types.StreamOptions{
IncludeUsage: true,
}

req, errWithCode := p.GetRequestTextBody(config.RelayModeChatCompletions, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()

// 恢复原来的配置
request.StreamOptions = streamOptions

// 发送请求
resp, errWithCode := p.Requester.SendRequestRaw(req)
if errWithCode != nil {
return nil, errWithCode
}

chatHandler := OpenAIStreamHandler{
Usage: p.Usage,
ModelName: request.Model,
}

return requester.RequestStream[string](p.Requester, resp, chatHandler.HandlerChatStream)
return p.OpenAIProvider.CreateChatCompletionStream(request)
}

geminiRequest, errWithCode := ConvertFromChatOpenai(request)
Expand Down Expand Up @@ -493,60 +430,3 @@ func (p *GeminiProvider) pluginHandle(request *GeminiChatRequest) {
})

}

func (h *OpenAIStreamHandler) HandlerChatStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
// 如果rawLine 前缀不为data:,则直接返回
if !strings.HasPrefix(string(*rawLine), "data:") {
*rawLine = nil
return
}

// 去除前缀
*rawLine = (*rawLine)[5:]
*rawLine = bytes.TrimSpace(*rawLine)

// 如果等于 DONE 则结束
if string(*rawLine) == "[DONE]" {
errChan <- io.EOF
*rawLine = requester.StreamClosed
return
}

var openaiResponse types.ChatCompletionStreamResponse
var response GeminiOpenaiChatStreamResponse
err := json.Unmarshal(*rawLine, &response)
if err != nil {
errChan <- common.ErrorToOpenAIError(err)
return
}

openaiResponse = response.ChatCompletionStreamResponse

if response.Usage != nil {
if response.Usage.CompletionTokens > 0 {
openaiResponse.Usage = &types.Usage{
PromptTokens: response.Usage.PromptTokens,
CompletionTokens: response.Usage.CompletionTokens,
TotalTokens: response.Usage.TotalTokens,
}

*h.Usage = *openaiResponse.Usage
}

if len(response.Choices) == 0 {
*rawLine = nil
return
}
} else {
if h.Usage.TotalTokens == 0 {
h.Usage.TotalTokens = h.Usage.PromptTokens
}
countTokenText := common.CountTokenText(openaiResponse.GetResponseText(), h.ModelName)
h.Usage.CompletionTokens += countTokenText
h.Usage.TotalTokens += countTokenText
}

// 转换字符串
responseBody, _ := json.Marshal(openaiResponse)
dataChan <- string(responseBody)
}
16 changes: 0 additions & 16 deletions providers/gemini/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -412,22 +412,6 @@ func (e *GeminiErrorWithStatusCode) ToOpenAiError() *types.OpenAIErrorWithStatus
}
}

type GeminiOpenaiUsage struct {
PromptTokens int `json:"promptTokens"`
CompletionTokens int `json:"completionTokens"`
TotalTokens int `json:"totalTokens"`
}

type GeminiOpenaiChatResponse struct {
types.ChatCompletionResponse
Usage *GeminiOpenaiUsage `json:"usage,omitempty"`
}

type GeminiOpenaiChatStreamResponse struct {
types.ChatCompletionStreamResponse
Usage *GeminiOpenaiUsage `json:"usage,omitempty"`
}

type GeminiErrors []*GeminiErrorResponse

func (e *GeminiErrors) Error() *GeminiErrorResponse {
Expand Down

0 comments on commit 427812e

Please sign in to comment.