Skip to content

Commit

Permalink
refactor, Introduce model method for model ID without provider prefix
Browse files Browse the repository at this point in the history
  • Loading branch information
zimmski committed Feb 3, 2025
1 parent 3ec68e7 commit 7554962
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 17 deletions.
12 changes: 11 additions & 1 deletion model/llm/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,21 @@ func (m *Model) ID() (id string) {
return m.id
}

// ModelID returns the unique identifier of this model.
// ModelID returns the unique identifier of this model with its provider.
func (m *Model) ModelID() (modelID string) {
return m.modelID
}

// ModelIDWithoutProvider returns the unique identifier of this model without its provider.
func (m *Model) ModelIDWithoutProvider() (modelID string) {
_, modelID, ok := strings.Cut(m.modelID, provider.ProviderModelSeparator)
if !ok {
panic(m.modelID)
}

return modelID
}

// Attributes returns query attributes.
func (m *Model) Attributes() (attributes map[string]string) {
return m.attributes
Expand Down
4 changes: 3 additions & 1 deletion model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ import (
type Model interface {
// ID returns full identifier, including the provider and attributes.
ID() (id string)
// ModelID returns the unique identifier of this model.
// ModelID returns the unique identifier of this model with its provider.
ModelID() (modelID string)
// ModelIDWithoutProvider returns the unique identifier of this model without its provider.
ModelIDWithoutProvider() (modelID string)

// Attributes returns query attributes.
Attributes() (attributes map[string]string)
Expand Down
7 changes: 6 additions & 1 deletion model/symflower/symflower.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,16 @@ func (m *Model) ID() (id string) {
return "symflower" + provider.ProviderModelSeparator + m.id
}

// ModelID returns the unique identifier of this model.
// ModelID returns the unique identifier of this model with its provider.
func (m *Model) ModelID() (modelID string) {
return "symflower" + provider.ProviderModelSeparator + m.id
}

// ModelIDWithoutProvider returns the unique identifier of this model without its provider.
func (m *Model) ModelIDWithoutProvider() (modelID string) {
return m.id
}

// Attributes returns query attributes.
func (m *Model) Attributes() (attributes map[string]string) {
return nil
Expand Down
18 changes: 18 additions & 0 deletions model/testing/Model_mock_gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 1 addition & 4 deletions provider/ollama/ollama.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,7 @@ var _ provider.Query = (*Provider)(nil)

// Query queries the provider with the given model name.
func (p *Provider) Query(ctx context.Context, model model.Model, promptText string) (response string, err error) {
client := p.client()
modelIdentifier := strings.TrimPrefix(model.ModelID(), p.ID()+provider.ProviderModelSeparator)

return openaiapi.QueryOpenAIAPIModel(ctx, client, modelIdentifier, model.Attributes(), promptText)
return openaiapi.QueryOpenAIAPIModel(ctx, p.client(), model.ModelIDWithoutProvider(), model.Attributes(), promptText)
}

// client returns a new client with the current configuration.
Expand Down
6 changes: 1 addition & 5 deletions provider/openai-api/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package openaiapi

import (
"context"
"strings"

"github.com/sashabaranov/go-openai"

Expand Down Expand Up @@ -61,10 +60,7 @@ var _ provider.Query = (*Provider)(nil)

// Query queries the provider with the given model name.
func (p *Provider) Query(ctx context.Context, model model.Model, promptText string) (response string, err error) {
client := p.client()
modelIdentifier := strings.TrimPrefix(model.ModelID(), p.ID()+provider.ProviderModelSeparator)

return QueryOpenAIAPIModel(ctx, client, modelIdentifier, model.Attributes(), promptText)
return QueryOpenAIAPIModel(ctx, p.client(), model.ModelIDWithoutProvider(), model.Attributes(), promptText)
}

// client returns a new client with the current configuration.
Expand Down
6 changes: 1 addition & 5 deletions provider/openrouter/openrouter.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"io"
"net/http"
"net/url"
"strings"
"time"

"github.com/avast/retry-go"
Expand Down Expand Up @@ -139,10 +138,7 @@ var _ provider.Query = (*Provider)(nil)

// Query queries the provider with the given model name.
func (p *Provider) Query(ctx context.Context, model model.Model, promptText string) (response string, err error) {
client := p.client()
modelIdentifier := strings.TrimPrefix(model.ModelID(), p.ID()+provider.ProviderModelSeparator)

return openaiapi.QueryOpenAIAPIModel(ctx, client, modelIdentifier, model.Attributes(), promptText)
return openaiapi.QueryOpenAIAPIModel(ctx, p.client(), model.ModelIDWithoutProvider(), model.Attributes(), promptText)
}

// client returns a new client with the current configuration.
Expand Down

0 comments on commit 7554962

Please sign in to comment.