Skip to content

Commit

Permalink
feat: doc2x parse pdf support (labring#5441)
Browse files Browse the repository at this point in the history
* feat: doc2x parse pdf support

* feat: pdf build with string builder

* feat: conv pdf html to md

* feat: relay retry request at least once for an authorized channel

* fix: lint

* feat: parse pdf model from query

* test: add table to md test cases
  • Loading branch information
zijiren233 authored Mar 6, 2025
1 parent 1285f89 commit 8c67cc6
Show file tree
Hide file tree
Showing 16 changed files with 806 additions and 109 deletions.
238 changes: 189 additions & 49 deletions service/aiproxy/controller/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io"
"math/rand/v2"
"net/http"
"slices"
"time"

"github.com/gin-gonic/gin"
Expand Down Expand Up @@ -37,6 +38,8 @@ func relayController(mode int) (RelayController, bool) {
case relaymode.AudioTranslation,
relaymode.AudioTranscription:
relayController = controller.RelaySTTHelper
case relaymode.ParsePdf:
relayController = controller.RelayParsePdfHelper
case relaymode.Rerank:
relayController = controller.RerankHelper
case relaymode.ChatCompletions,
Expand Down Expand Up @@ -80,15 +83,73 @@ func RelayHelper(meta *meta.Meta, c *gin.Context, relayController RelayControlle
return err, shouldRetry(c, err.StatusCode)
}

func filterChannels(channels []*dbmodel.Channel, ignoreChannel ...int) []*dbmodel.Channel {
filtered := make([]*dbmodel.Channel, 0)
for _, channel := range channels {
if channel.Status != dbmodel.ChannelStatusEnabled {
continue
}
if slices.Contains(ignoreChannel, channel.ID) {
continue
}
filtered = append(filtered, channel)
}
return filtered
}

var (
ErrChannelsNotFound = errors.New("channels not found")
ErrChannelsExhausted = errors.New("channels exhausted")
)

func GetRandomChannel(c *dbmodel.ModelCaches, model string, ignoreChannel ...int) (*dbmodel.Channel, error) {
return getRandomChannel(c.EnabledModel2channels[model], ignoreChannel...)
}

//nolint:gosec
func getRandomChannel(channels []*dbmodel.Channel, ignoreChannel ...int) (*dbmodel.Channel, error) {
if len(channels) == 0 {
return nil, ErrChannelsNotFound
}

channels = filterChannels(channels, ignoreChannel...)
if len(channels) == 0 {
return nil, ErrChannelsExhausted
}

if len(channels) == 1 {
return channels[0], nil
}

var totalWeight int32
for _, ch := range channels {
totalWeight += ch.GetPriority()
}

if totalWeight == 0 {
return channels[rand.IntN(len(channels))], nil
}

r := rand.Int32N(totalWeight)
for _, ch := range channels {
r -= ch.GetPriority()
if r < 0 {
return ch, nil
}
}

return channels[rand.IntN(len(channels))], nil
}

func getChannelWithFallback(cache *dbmodel.ModelCaches, model string, ignoreChannelIDs ...int) (*dbmodel.Channel, error) {
channel, err := cache.GetRandomSatisfiedChannel(model, ignoreChannelIDs...)
channel, err := GetRandomChannel(cache, model, ignoreChannelIDs...)
if err == nil {
return channel, nil
}
if !errors.Is(err, dbmodel.ErrChannelsExhausted) {
if !errors.Is(err, ErrChannelsExhausted) {
return nil, err
}
return cache.GetRandomSatisfiedChannel(model)
return GetRandomChannel(cache, model)
}

func NewRelay(mode int) func(c *gin.Context) {
Expand All @@ -103,14 +164,50 @@ func NewRelay(mode int) func(c *gin.Context) {

func relay(c *gin.Context, mode int, relayController RelayController) {
log := middleware.GetLogger(c)

requestModel := middleware.GetOriginalModel(c)

// Get initial channel
channel, ignoreChannelIDs, err := getInitialChannel(c, requestModel, log)
if err != nil || channel == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{
"error": &model.Error{
Message: "the upstream load is saturated, please try again later",
Code: "upstream_load_saturated",
Type: middleware.ErrorTypeAIPROXY,
},
})
return
}

// First attempt
meta := middleware.NewMetaByContext(c, channel, requestModel, mode)
bizErr, retry := RelayHelper(meta, c, relayController)
if handleRelayResult(c, bizErr, retry) {
return
}

// Setup retry state
retryState := initRetryState(channel, bizErr, ignoreChannelIDs)

// Retry loop
retryLoop(c, mode, requestModel, retryState, relayController, log)
}

type retryState struct {
retryTimes int64
lastCanContinueChannel *dbmodel.Channel
ignoreChannelIDs []int
exhausted bool
bizErr *model.ErrorWithStatusCode
}

func getInitialChannel(c *gin.Context, requestModel string, log *log.Entry) (*dbmodel.Channel, []int, error) {
ids, err := monitor.GetBannedChannels(c.Request.Context(), requestModel)
if err != nil {
log.Errorf("get %s auto banned channels failed: %+v", requestModel, err)
}
log.Debugf("%s model banned channels: %+v", requestModel, ids)

ignoreChannelIDs := make([]int, 0, len(ids))
for _, id := range ids {
ignoreChannelIDs = append(ignoreChannelIDs, int(id))
Expand All @@ -119,85 +216,128 @@ func relay(c *gin.Context, mode int, relayController RelayController) {
mc := middleware.GetModelCaches(c)
channel, err := getChannelWithFallback(mc, requestModel, ignoreChannelIDs...)
if err != nil {
c.JSON(http.StatusServiceUnavailable, gin.H{
"error": &model.Error{
Message: "The upstream load is saturated, please try again later",
Code: "upstream_load_saturated",
Type: middleware.ErrorTypeAIPROXY,
},
})
return
return nil, nil, err
}

meta := middleware.NewMetaByContext(c, channel, requestModel, mode)
bizErr, retry := RelayHelper(meta, c, relayController)
return channel, ignoreChannelIDs, nil
}

func handleRelayResult(c *gin.Context, bizErr *model.ErrorWithStatusCode, retry bool) bool {
if bizErr == nil {
return
return true
}
if !retry {
bizErr.Error.Message = middleware.MessageWithRequestID(c, bizErr.Error.Message)
c.JSON(bizErr.StatusCode, bizErr)
return
return true
}
return false
}

var lastCanContinueChannel *dbmodel.Channel
func initRetryState(channel *dbmodel.Channel, bizErr *model.ErrorWithStatusCode, ignoreChannelIDs []int) *retryState {
state := &retryState{
retryTimes: config.GetRetryTimes(),
ignoreChannelIDs: ignoreChannelIDs,
bizErr: bizErr,
}

retryTimes := config.GetRetryTimes()
if !channelCanContinue(bizErr.StatusCode) {
ignoreChannelIDs = append(ignoreChannelIDs, channel.ID)
state.ignoreChannelIDs = append(state.ignoreChannelIDs, channel.ID)
} else {
lastCanContinueChannel = channel
state.lastCanContinueChannel = channel
}

for i := retryTimes; i > 0; i-- {
newChannel, err := mc.GetRandomSatisfiedChannel(requestModel, ignoreChannelIDs...)
return state
}

func retryLoop(c *gin.Context, mode int, requestModel string, state *retryState, relayController RelayController, log *log.Entry) {
mc := middleware.GetModelCaches(c)

for i := 0; i < int(state.retryTimes); i++ {
newChannel, err := getRetryChannel(mc, requestModel, state)
if err != nil {
if !errors.Is(err, dbmodel.ErrChannelsExhausted) ||
lastCanContinueChannel == nil {
break
}
// use last can continue channel to retry
newChannel = lastCanContinueChannel
break
}

log.Warnf("using channel %s (type: %d, id: %d) to retry (remain times %d)",
newChannel.Name,
newChannel.Type,
newChannel.ID,
i-1,
state.retryTimes-int64(i),
)

requestBody, err := common.GetRequestBody(c.Request)
if err != nil {
log.Errorf("GetRequestBody failed: %+v", err)
if !prepareRetry(c, state.bizErr.StatusCode) {
break
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))

if shouldDelay(bizErr.StatusCode) {
//nolint:gosec
// random wait 1-2 seconds
time.Sleep(time.Duration(rand.Float64()*float64(time.Second)) + time.Second)
meta := middleware.NewMetaByContext(c, newChannel, requestModel, mode)
bizErr, retry := RelayHelper(meta, c, relayController)

done := handleRetryResult(bizErr, retry, newChannel, state)
if done {
break
}
}

meta := middleware.NewMetaByContext(c, newChannel, requestModel, mode)
bizErr, retry = RelayHelper(meta, c, relayController)
if bizErr == nil {
return
if state.bizErr != nil {
state.bizErr.Error.Message = middleware.MessageWithRequestID(c, state.bizErr.Error.Message)
c.JSON(state.bizErr.StatusCode, state.bizErr)
}
}

func getRetryChannel(mc *dbmodel.ModelCaches, requestModel string, state *retryState) (*dbmodel.Channel, error) {
if state.exhausted {
return state.lastCanContinueChannel, nil
}

newChannel, err := GetRandomChannel(mc, requestModel, state.ignoreChannelIDs...)
if err != nil {
if !errors.Is(err, ErrChannelsExhausted) || state.lastCanContinueChannel == nil {
return nil, err
}
if !retry {
break
state.exhausted = true
return state.lastCanContinueChannel, nil
}

return newChannel, nil
}

func prepareRetry(c *gin.Context, statusCode int) bool {
requestBody, err := common.GetRequestBody(c.Request)
if err != nil {
log.Errorf("get request body failed in prepare retry: %+v", err)
return false
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))

if shouldDelay(statusCode) {
//nolint:gosec
time.Sleep(time.Duration(rand.Float64()*float64(time.Second)) + time.Second)
}

return true
}

func handleRetryResult(bizErr *model.ErrorWithStatusCode, retry bool, newChannel *dbmodel.Channel, state *retryState) (done bool) {
state.bizErr = bizErr
if bizErr == nil || !retry {
return true
}

if state.exhausted {
if !channelCanContinue(bizErr.StatusCode) {
return true
}
} else {
if !channelCanContinue(bizErr.StatusCode) {
ignoreChannelIDs = append(ignoreChannelIDs, newChannel.ID)
state.ignoreChannelIDs = append(state.ignoreChannelIDs, newChannel.ID)
state.retryTimes++
} else {
lastCanContinueChannel = newChannel
state.lastCanContinueChannel = newChannel
}
}

if bizErr != nil {
bizErr.Error.Message = middleware.MessageWithRequestID(c, bizErr.Error.Message)
c.JSON(bizErr.StatusCode, bizErr)
}
return false
}

var shouldRetryStatusCodesMap = map[int]struct{}{
Expand Down
19 changes: 15 additions & 4 deletions service/aiproxy/middleware/distributor.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/labring/sealos/service/aiproxy/common/rpmlimit"
"github.com/labring/sealos/service/aiproxy/model"
"github.com/labring/sealos/service/aiproxy/relay/meta"
"github.com/labring/sealos/service/aiproxy/relay/relaymode"
log "github.com/sirupsen/logrus"
)

Expand Down Expand Up @@ -178,7 +179,7 @@ func distribute(c *gin.Context, mode int) {
return
}

requestModel, err := getRequestModel(c)
requestModel, err := getRequestModel(c, mode)
if err != nil {
abortLogWithMessage(c, http.StatusBadRequest, err.Error())
return
Expand Down Expand Up @@ -261,15 +262,25 @@ type ModelRequest struct {
Model string `form:"model" json:"model"`
}

func getRequestModel(c *gin.Context) (string, error) {
func getRequestModel(c *gin.Context, mode int) (string, error) {
path := c.Request.URL.Path
switch {
case strings.HasPrefix(path, "/v1/audio/transcriptions"),
strings.HasPrefix(path, "/v1/audio/translations"):
case mode == relaymode.ParsePdf:
query := c.Request.URL.Query()
model := query.Get("model")
if model != "" {
return model, nil
}

fallthrough
case mode == relaymode.AudioTranscription,
mode == relaymode.AudioTranslation:
return c.Request.FormValue("model"), nil

case strings.HasPrefix(path, "/v1/engines") && strings.HasSuffix(path, "/embeddings"):
// /engines/:model/embeddings
return c.Param("model"), nil

default:
body, err := common.GetRequestBody(c.Request)
if err != nil {
Expand Down
Loading

0 comments on commit 8c67cc6

Please sign in to comment.