Skip to content

Commit

Permalink
Differentiate between ID (with provider and attributes) and just the …
Browse files Browse the repository at this point in the history
…model ID (that we need to query LLM models)

Part of #407
  • Loading branch information
zimmski committed Feb 3, 2025
1 parent 07f31ec commit 3ec68e7
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 13 deletions.
25 changes: 17 additions & 8 deletions model/llm/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@ import (

// Model represents a LLM model accessed via a provider.
type Model struct {
// id holds the full identifier, including the provider and attributes.
id string
// provider is the client to query the LLM model.
provider provider.Query
// modelID holds the identifier for the LLM modelID.
// modelID holds the identifier for the LLM model.
modelID string

// attributes holds query attributes.
Expand All @@ -41,6 +43,7 @@ type Model struct {
// NewModel returns an LLM model corresponding to the given identifier which is queried via the given provider.
func NewModel(provider provider.Query, modelIDWithAttributes string) (llmModel *Model) {
llmModel = &Model{
id: modelIDWithAttributes,
provider: provider,

queryAttempts: 1,
Expand All @@ -53,6 +56,7 @@ func NewModel(provider provider.Query, modelIDWithAttributes string) (llmModel *
// NewModelWithMetaInformation returns a LLM model with meta information corresponding to the given identifier which is queried via the given provider.
func NewModelWithMetaInformation(provider provider.Query, modelIdentifier string, metaInformation *model.MetaInformation) *Model {
return &Model{
id: modelIdentifier,
provider: provider,
modelID: modelIdentifier,

Expand All @@ -62,6 +66,18 @@ func NewModelWithMetaInformation(provider provider.Query, modelIdentifier string
}
}

var _ model.Model = (*Model)(nil)

// ID returns full identifier, including the provider and attributes.
func (m *Model) ID() (id string) {
return m.id
}

// ModelID returns the unique identifier of this model.
func (m *Model) ModelID() (modelID string) {
return m.modelID
}

// Attributes returns query attributes.
func (m *Model) Attributes() (attributes map[string]string) {
return m.attributes
Expand Down Expand Up @@ -241,13 +257,6 @@ func (ctx *llmMigrateSourceFilePromptContext) Format() (message string, err erro
return b.String(), nil
}

var _ model.Model = (*Model)(nil)

// ID returns the unique ID of this model.
func (m *Model) ID() (id string) {
return m.modelID
}

var _ model.CapabilityWriteTests = (*Model)(nil)

// WriteTests generates test files for the given implementation file in a repository.
Expand Down
4 changes: 3 additions & 1 deletion model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ import (

// Model defines a model that can be queried for generations.
type Model interface {
// ID returns the unique ID of this model.
// ID returns full identifier, including the provider and attributes.
ID() (id string)
// ModelID returns the unique identifier of this model.
ModelID() (modelID string)

// Attributes returns query attributes.
Attributes() (attributes map[string]string)
Expand Down
7 changes: 6 additions & 1 deletion model/symflower/symflower.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,16 @@ func NewModelSmartTemplateWithTimeout(timeout time.Duration) (model *Model) {

var _ model.Model = (*Model)(nil)

// ID returns the unique ID of this model.
// ID returns full identifier, including the provider and attributes.
func (m *Model) ID() (id string) {
return "symflower" + provider.ProviderModelSeparator + m.id
}

// ModelID returns the unique identifier of this model.
func (m *Model) ModelID() (modelID string) {
return "symflower" + provider.ProviderModelSeparator + m.id
}

// Attributes returns query attributes.
func (m *Model) Attributes() (attributes map[string]string) {
return nil
Expand Down
18 changes: 18 additions & 0 deletions model/testing/Model_mock_gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion provider/ollama/ollama.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ var _ provider.Query = (*Provider)(nil)
// Query queries the provider with the given model name.
func (p *Provider) Query(ctx context.Context, model model.Model, promptText string) (response string, err error) {
client := p.client()
modelIdentifier := strings.TrimPrefix(model.ID(), p.ID()+provider.ProviderModelSeparator)
modelIdentifier := strings.TrimPrefix(model.ModelID(), p.ID()+provider.ProviderModelSeparator)

return openaiapi.QueryOpenAIAPIModel(ctx, client, modelIdentifier, model.Attributes(), promptText)
}
Expand Down
2 changes: 1 addition & 1 deletion provider/openai-api/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ var _ provider.Query = (*Provider)(nil)
// Query queries the provider with the given model name.
func (p *Provider) Query(ctx context.Context, model model.Model, promptText string) (response string, err error) {
client := p.client()
modelIdentifier := strings.TrimPrefix(model.ID(), p.ID()+provider.ProviderModelSeparator)
modelIdentifier := strings.TrimPrefix(model.ModelID(), p.ID()+provider.ProviderModelSeparator)

return QueryOpenAIAPIModel(ctx, client, modelIdentifier, model.Attributes(), promptText)
}
Expand Down
2 changes: 1 addition & 1 deletion provider/openrouter/openrouter.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ var _ provider.Query = (*Provider)(nil)
// Query queries the provider with the given model name.
func (p *Provider) Query(ctx context.Context, model model.Model, promptText string) (response string, err error) {
client := p.client()
modelIdentifier := strings.TrimPrefix(model.ID(), p.ID()+provider.ProviderModelSeparator)
modelIdentifier := strings.TrimPrefix(model.ModelID(), p.ID()+provider.ProviderModelSeparator)

return openaiapi.QueryOpenAIAPIModel(ctx, client, modelIdentifier, model.Attributes(), promptText)
}
Expand Down

0 comments on commit 3ec68e7

Please sign in to comment.