diff --git a/cmd/eval-dev-quality/cmd/evaluate.go b/cmd/eval-dev-quality/cmd/evaluate.go index 7f4ee14a..8d6d2bb5 100644 --- a/cmd/eval-dev-quality/cmd/evaluate.go +++ b/cmd/eval-dev-quality/cmd/evaluate.go @@ -53,8 +53,8 @@ type Evaluate struct { // Languages determines which language should be used for the evaluation, or empty if all languages should be used. Languages []string `long:"language" description:"Evaluate with this language. By default all languages are used."` - // Models determines which models should be used for the evaluation, or empty if all models should be used. - Models []string `long:"model" description:"Evaluate with this model. By default all models are used."` + // ModelIDsWithProviderAndAttributes determines which models should be used for the evaluation, or empty if all models should be used. + ModelIDsWithProviderAndAttributes []string `long:"model" description:"Evaluate with this model. By default all models are used."` // ProviderTokens holds all API tokens for the providers. ProviderTokens map[string]string `long:"tokens" description:"API tokens for model providers (of the form '$provider:$token'). When using the environment variable, separate multiple definitions with ','." env:"PROVIDER_TOKEN" env-delim:","` // ProviderUrls holds all custom inference endpoint urls for the providers. @@ -123,7 +123,7 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate. command.logger.Panicf("the configuration file is not supported in containerized runtimes") } - if len(command.Models) > 0 || len(command.Repositories) > 0 { + if len(command.ModelIDsWithProviderAndAttributes) > 0 || len(command.Repositories) > 0 { command.logger.Panicf("do not provide models and repositories when loading a configuration file") } @@ -139,7 +139,7 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate. command.logger.Panicf("ERROR: %s", err) } - command.Models = configuration.Models.Selected + command.ModelIDsWithProviderAndAttributes = configuration.Models.Selected command.Repositories = configuration.Repositories.Selected } @@ -258,43 +258,13 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate. // In a containerized runtime we check the availability of the testdata, repositories and models/providers inside the container. if command.Runtime != "local" { // Copy the models over. - for _, modelID := range command.Models { + for _, modelID := range command.ModelIDsWithProviderAndAttributes { evaluationContext.Models = append(evaluationContext.Models, llm.NewModel(nil, modelID)) } return evaluationContext, evaluationConfiguration, func() {} } - // Register custom OpenAI API providers and models. - { - customProviders := map[string]*openaiapi.Provider{} - for providerID, providerURL := range command.ProviderUrls { - if !strings.HasPrefix(providerID, "custom-") { - continue - } - - p := openaiapi.NewProvider(providerID, providerURL) - provider.Register(p) - customProviders[providerID] = p - } - for _, model := range command.Models { - if !strings.HasPrefix(model, "custom-") { - continue - } - - providerID, _, ok := strings.Cut(model, provider.ProviderModelSeparator) - if !ok { - command.logger.Panicf("ERROR: cannot split %q into provider and model name by %q", model, provider.ProviderModelSeparator) - } - modelProvider, ok := customProviders[providerID] - if !ok { - command.logger.Panicf("ERROR: unknown custom provider %q for model %q", providerID, model) - } - - modelProvider.AddModel(llm.NewModel(modelProvider, model)) - } - } - // Ensure the "testdata" path exists and make it absolute. { if err := osutil.DirExists(command.TestdataPath); err != nil { @@ -371,101 +341,162 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate. evaluationContext.Languages[i] = languagesSelected[languageID] } - // Gather models. - serviceShutdown := []func() (err error){} + // Gather models and initialize providers. + var serviceShutdown []func() (err error) { - // Check which providers are needed for the evaluation. - providersSelected := map[string]provider.Provider{} - if len(command.Models) == 0 { - providersSelected = provider.Providers + // Gather providers. + providers := map[string]provider.Provider{} + if len(command.ModelIDsWithProviderAndAttributes) == 0 { + for providerID, provider := range provider.Providers { + providers[providerID] = provider + command.logger.Info("selected provider", "provider", providerID) + } } else { - for _, model := range command.Models { - p := strings.SplitN(model, provider.ProviderModelSeparator, 2)[0] + // Register custom providers. + for providerID, providerURL := range command.ProviderUrls { + if !strings.HasPrefix(providerID, "custom-") { + command.logger.Panicf("ERROR: cannot set URL of %q because it is not a custom provider", providerID) + } - if _, ok := providersSelected[p]; ok { - continue + p := openaiapi.NewProvider(providerID, providerURL) + provider.Register(p) + providers[providerID] = p + command.logger.Info("selected provider", "provider", providerID) + } + + // Add remaining providers from models. + for _, modelIDsWithProviderAndAttributes := range command.ModelIDsWithProviderAndAttributes { + providerID, _, ok := strings.Cut(modelIDsWithProviderAndAttributes, provider.ProviderModelSeparator) + if !ok { + command.logger.Panicf("ERROR: cannot split %q into provider and model name by %q", modelIDsWithProviderAndAttributes, provider.ProviderModelSeparator) } - if provider, ok := provider.Providers[p]; !ok { - command.logger.Panicf("Provider %q does not exist", p) - } else { - providersSelected[provider.ID()] = provider + p, ok := provider.Providers[providerID] + if !ok { + command.logger.Panicf("ERROR: unknown provider %q for model %q", providerID, modelIDsWithProviderAndAttributes) + } + if _, ok := providers[providerID]; !ok { + providers[providerID] = p + command.logger.Info("selected provider", "provider", providerID) } } } - models := map[string]model.Model{} - modelsSelected := map[string]model.Model{} - evaluationContext.ProviderForModel = map[model.Model]provider.Provider{} - for _, p := range providersSelected { - command.logger.Info("querying provider models", "provider", p.ID()) + // Initialize providers. + { + providerIDsSorted := maps.Keys(providers) + sort.Strings(providerIDsSorted) + for _, providerID := range providerIDsSorted { + p := providers[providerID] - if t, ok := p.(provider.InjectToken); ok { - token, ok := command.ProviderTokens[p.ID()] - if ok { - t.SetToken(token) + command.logger.Info("initializing provider", "provider", providerID) + if t, ok := p.(provider.InjectToken); ok { + if token, ok := command.ProviderTokens[p.ID()]; ok { + command.logger.Info("set token of provider", "provider", providerID) + t.SetToken(token) + } } - } - if err := p.Available(command.logger); err != nil { - command.logger.Warn("skipping unavailable provider", "provider", p.ID(), "error", err) + command.logger.Info("checking availability for provider", "provider", providerID) + if err := p.Available(command.logger); err != nil { + command.logger.Info("skipping provider because it is not available", "error", err, "provider", providerID) + delete(providers, providerID) - continue + continue + } + if service, ok := p.(provider.Service); ok { + command.logger.Info("starting services for provider", "provider", p.ID()) + shutdown, err := service.Start(command.logger) + if err != nil { + command.logger.Panicf("ERROR: could not start services for provider %q: %s", p, err) + } + serviceShutdown = append(serviceShutdown, shutdown) + } } + } - // Start services of providers. - if service, ok := p.(provider.Service); ok { - command.logger.Info("starting services for provider", "provider", p.ID()) - shutdown, err := service.Start(command.logger) + // Gather models. + models := map[string]model.Model{} + { + addAllModels := len(command.ModelIDsWithProviderAndAttributes) == 0 + for _, p := range providers { + ms, err := p.Models() if err != nil { - command.logger.Panicf("ERROR: could not start services for provider %q: %s", p, err) + command.logger.Panicf("ERROR: could not query models for provider %q: %s", p.ID(), err) } - serviceShutdown = append(serviceShutdown, shutdown) - } + for _, m := range ms { + models[m.ID()] = m + evaluationConfiguration.Models.Available = append(evaluationConfiguration.Models.Available, m.ID()) - // Check if a provider has the ability to pull models and do so if necessary. - if puller, ok := p.(provider.Puller); ok { - command.logger.Info("pulling available models for provider", "provider", p.ID()) - for _, modelID := range command.Models { - if strings.HasPrefix(modelID, p.ID()) { - if err := puller.Pull(command.logger, modelID); err != nil { - command.logger.Panicf("ERROR: could not pull model %q: %s", modelID, err) - } + if addAllModels { + command.ModelIDsWithProviderAndAttributes = append(command.ModelIDsWithProviderAndAttributes, m.ID()) } } } + } + modelIDs := maps.Keys(models) + sort.Strings(modelIDs) + sort.Strings(command.ModelIDsWithProviderAndAttributes) - ms, err := p.Models() - if err != nil { - command.logger.Panicf("ERROR: could not query models for provider %q: %s", p.ID(), err) + // Check and initialize models. + evaluationContext.ProviderForModel = map[model.Model]provider.Provider{} + for _, modelIDsWithProviderAndAttributes := range command.ModelIDsWithProviderAndAttributes { + command.logger.Info("selecting model", "model", modelIDsWithProviderAndAttributes) + + providerID, modelIDsWithAttributes, ok := strings.Cut(modelIDsWithProviderAndAttributes, provider.ProviderModelSeparator) + if !ok { + command.logger.Panicf("ERROR: cannot split %q into provider and model name by %q", modelIDsWithProviderAndAttributes, provider.ProviderModelSeparator) } - for _, m := range ms { - models[m.ID()] = m - evaluationContext.ProviderForModel[m] = p - evaluationConfiguration.Models.Available = append(evaluationConfiguration.Models.Available, m.ID()) + modelID, _ := model.ParseModelID(modelIDsWithAttributes) + + p, ok := providers[providerID] + if !ok { + command.logger.Panicf("ERROR: cannot find provider %q", providerID) } - } - modelIDs := maps.Keys(models) - sort.Strings(modelIDs) - if len(command.Models) == 0 { - command.Models = modelIDs - } else { - for _, modelID := range command.Models { - if _, ok := models[modelID]; !ok { - command.logger.Panicf("ERROR: model %s does not exist. Valid models are: %s", modelID, strings.Join(modelIDs, ", ")) + if puller, ok := p.(provider.Puller); ok { + command.logger.Info("pulling model", "model", modelID) + if err := puller.Pull(command.logger, modelID); err != nil { + command.logger.Panicf("ERROR: could not pull model %q: %s", modelID, err) + } + + // TODO If a model has not been pulled before, it was not available for at least the "Ollama" provider. Make this cleaner, we should not rebuild every time. + if _, ok := models[modelIDsWithProviderAndAttributes]; !ok { + ms, err := p.Models() + if err != nil { + command.logger.Panicf("ERROR: could not query models for provider %q: %s", p.ID(), err) + } + for _, m := range ms { + if _, ok := models[m.ID()]; ok { + continue + } + + models[m.ID()] = m + evaluationConfiguration.Models.Available = append(evaluationConfiguration.Models.Available, m.ID()) + } + modelIDs = maps.Keys(models) + sort.Strings(modelIDs) } } - } - sort.Strings(command.Models) - for _, modelID := range command.Models { - modelsSelected[modelID] = models[modelID] - } - // Make the resolved selected models available in the command. - evaluationContext.Models = make([]model.Model, len(command.Models)) - for i, modelID := range command.Models { - evaluationContext.Models[i] = modelsSelected[modelID] - evaluationConfiguration.Models.Selected = append(evaluationConfiguration.Models.Selected, modelID) + var m model.Model + if strings.HasPrefix(providerID, "custom-") { + pc, ok := p.(*openaiapi.Provider) + if !ok { + command.logger.Panicf("ERROR: %q is not a custom provider", providerID) + } + + m = llm.NewModel(pc, modelIDsWithProviderAndAttributes) + pc.AddModel(m) + } else { + var ok bool + m, ok = models[modelIDsWithProviderAndAttributes] + if !ok { + command.logger.Panicf("ERROR: model %q does not exist for provider %q. Valid models are: %s", modelIDsWithProviderAndAttributes, providerID, strings.Join(modelIDs, ", ")) + } + } + evaluationContext.Models = append(evaluationContext.Models, m) + evaluationContext.ProviderForModel[m] = p + evaluationConfiguration.Models.Selected = append(evaluationConfiguration.Models.Selected, modelIDsWithProviderAndAttributes) } } @@ -613,7 +644,7 @@ func (command *Evaluate) evaluateDocker(ctx *evaluate.Context) (err error) { "-e", "SYMFLOWER_INTERNAL_LICENSE_FILE", "-e", "SYMFLOWER_LICENSE_KEY", "-v", volumeName + ":/app/evaluation", - "--rm", // automatically remove container after it finished + "--rm", // Automatically remove container after it finished. command.RuntimeImage, } @@ -706,7 +737,7 @@ func (command *Evaluate) evaluateKubernetes(ctx *evaluate.Context) (err error) { // Define a regex to replace all non alphanumeric characters and "-". kubeNameRegex := regexp.MustCompile(`[^a-zA-Z0-9-]+`) - jobTmpl, err := template.ParseFiles(filepath.Join("conf", "kube", "job.yml")) + kubernetesJobTemplate, err := template.ParseFiles(filepath.Join("conf", "kube", "job.yml")) if err != nil { return pkgerrors.Wrap(err, "could not create kubernetes job template") } @@ -735,7 +766,7 @@ func (command *Evaluate) evaluateKubernetes(ctx *evaluate.Context) (err error) { "kubectl", "apply", "-f", - "-", // apply STDIN + "-", // Apply STDIN. } // Commands for the evaluation to run inside the container. @@ -763,14 +794,14 @@ func (command *Evaluate) evaluateKubernetes(ctx *evaluate.Context) (err error) { } parallel.Execute(func() { - var tmplData bytes.Buffer - if err := jobTmpl.Execute(&tmplData, data); err != nil { + var kubernetesJobData bytes.Buffer + if err := kubernetesJobTemplate.Execute(&kubernetesJobData, data); err != nil { command.logger.Panicf("ERROR: %s", err) } commandOutput, err := util.CommandWithResult(context.Background(), command.logger, &util.Command{ Command: kubeCommand, - Stdin: tmplData.String(), + Stdin: kubernetesJobData.String(), }) if err != nil { command.logger.Error("kubernetes evaluation failed", "error", pkgerrors.WithMessage(pkgerrors.WithStack(err), commandOutput)) @@ -830,7 +861,7 @@ func (command *Evaluate) evaluateKubernetes(ctx *evaluate.Context) (err error) { var storageTemplateData bytes.Buffer if err := storageTemplate.Execute(&storageTemplateData, data); err != nil { - return pkgerrors.Wrap(err, "could not execute storate template") + return pkgerrors.Wrap(err, "could not execute storage template") } // Create the storage access pod. @@ -839,7 +870,7 @@ func (command *Evaluate) evaluateKubernetes(ctx *evaluate.Context) (err error) { "kubectl", "apply", "-f", - "-", // apply STDIN + "-", // Apply STDIN. }, Stdin: storageTemplateData.String(), }) diff --git a/cmd/eval-dev-quality/cmd/evaluate_test.go b/cmd/eval-dev-quality/cmd/evaluate_test.go index 5fee5290..f08424a4 100644 --- a/cmd/eval-dev-quality/cmd/evaluate_test.go +++ b/cmd/eval-dev-quality/cmd/evaluate_test.go @@ -462,10 +462,10 @@ func TestEvaluateExecute(t *testing.T) { ExpectedResultFiles: map[string]func(t *testing.T, filePath string, data string){ filepath.Join("result-directory", "evaluation.log"): func(t *testing.T, filePath string, data string) { - assert.Contains(t, data, "\"msg\":\"skipping unavailable provider\",\"provider\":\"openrouter\"") + assert.Contains(t, data, `"msg":"skipping provider because it is not available","error":"missing access token","provider":"openrouter"`) }, }, - ExpectedPanicContains: "ERROR: model openrouter/auto does not exist", + ExpectedPanicContains: `ERROR: cannot find provider "openrouter"`, }) }) t.Run("Ollama", func(t *testing.T) { @@ -1271,13 +1271,17 @@ func TestEvaluateInitialize(t *testing.T) { // makeValidCommand is a helper to abstract all the default values that have to be set to make a command valid. makeValidCommand := func(modify func(command *Evaluate)) *Evaluate { c := &Evaluate{ + ModelIDsWithProviderAndAttributes: []string{"symflower/smart-template"}, + QueryAttempts: 1, + + ResultPath: filepath.Join("$TEMP_PATH", "result-directory"), + TestdataPath: filepath.Join("..", "..", "..", "testdata"), + ExecutionTimeout: 1, - Parallel: 1, - QueryAttempts: 1, - ResultPath: filepath.Join("$TEMP_PATH", "result-directory"), Runs: 1, - Runtime: "local", - TestdataPath: filepath.Join("..", "..", "..", "testdata"), + + Runtime: "local", + Parallel: 1, } if modify != nil { @@ -1324,12 +1328,16 @@ func TestEvaluateInitialize(t *testing.T) { Name: "Selecting no model defaults to all", Command: makeValidCommand(func(command *Evaluate) { - command.Models = []string{} + command.ModelIDsWithProviderAndAttributes = []string{} + command.ProviderTokens = map[string]string{ + "openrouter": "fake-token", + } }), // Could also select arbitrary Ollama or new Openrouter models so sanity check that at least symflower is there. ValidateCommand: func(t *testing.T, command *Evaluate) { - assert.Contains(t, command.Models, "symflower/symbolic-execution") + assert.Contains(t, command.ModelIDsWithProviderAndAttributes, "symflower/smart-template") + assert.Contains(t, command.ModelIDsWithProviderAndAttributes, "symflower/symbolic-execution") }, ValidateContext: func(t *testing.T, context *evaluate.Context) { modelIDs := make([]string, len(context.Models)) @@ -1448,12 +1456,13 @@ func TestEvaluateInitialize(t *testing.T) { Command: makeValidCommand(func(command *Evaluate) { command.Configuration = "config.json" + command.ModelIDsWithProviderAndAttributes = nil }), ValidateCommand: func(t *testing.T, command *Evaluate) { assert.Equal(t, []string{ "symflower/symbolic-execution", - }, command.Models) + }, command.ModelIDsWithProviderAndAttributes) assert.Equal(t, []string{ filepath.Join("golang", "plain"), filepath.Join("java", "plain"), @@ -1513,7 +1522,7 @@ func TestEvaluateInitialize(t *testing.T) { } validate(t, &testCase{ - Name: "Parallel parameter hast to be greater then zero", + Name: "Parallel parameter has to be greater then zero", Command: makeValidCommand(func(command *Evaluate) { command.Runtime = "docker" diff --git a/cmd/eval-dev-quality/cmd/init_test.go b/cmd/eval-dev-quality/cmd/init_test.go new file mode 100644 index 00000000..cf660720 --- /dev/null +++ b/cmd/eval-dev-quality/cmd/init_test.go @@ -0,0 +1,13 @@ +package cmd + +import "os" + +func init() { + // Unset environment variables that are often found as defaults in the terminal configuration. + if err := os.Unsetenv("PROVIDER_TOKEN"); err != nil { + panic(err) + } + if err := os.Unsetenv("PROVIDER_URL"); err != nil { + panic(err) + } +} diff --git a/evaluate/evaluate_test.go b/evaluate/evaluate_test.go index f4eef243..40ff8cf0 100644 --- a/evaluate/evaluate_test.go +++ b/evaluate/evaluate_test.go @@ -244,7 +244,7 @@ func TestEvaluate(t *testing.T) { Before: func(t *testing.T, logger *log.Logger, resultPath string) { // Set up mocks, when test is running. - mockedQuery.On("Query", mock.Anything, mockedModelID, mock.Anything).Return("", ErrEmptyResponseFromModel) + mockedQuery.On("Query", mock.Anything, mock.Anything, mock.Anything).Return("", ErrEmptyResponseFromModel) }, After: func(t *testing.T, logger *log.Logger, resultPath string) { mockedQuery.AssertNumberOfCalls(t, "Query", 2) @@ -325,10 +325,10 @@ func TestEvaluate(t *testing.T) { Before: func(t *testing.T, logger *log.Logger, resultPath string) { // Set up mocks, when test is running. - mockedQuery.On("Query", mock.Anything, mockedModelID, mock.Anything).Return("", ErrEmptyResponseFromModel).Once() - mockedQuery.On("Query", mock.Anything, mockedModelID, mock.Anything).Return("model-response", nil).Once().After(10 * time.Millisecond) // Simulate a model response delay because our internal safety measures trigger when a query is done in 0 milliseconds. - mockedQuery.On("Query", mock.Anything, mockedModelID, mock.Anything).Return("", ErrEmptyResponseFromModel).Once() - mockedQuery.On("Query", mock.Anything, mockedModelID, mock.Anything).Return("model-response", nil).Once().After(10 * time.Millisecond) // Simulate a model response delay because our internal safety measures trigger when a query is done in 0 milliseconds. + mockedQuery.On("Query", mock.Anything, mock.Anything, mock.Anything).Return("", ErrEmptyResponseFromModel).Once() + mockedQuery.On("Query", mock.Anything, mock.Anything, mock.Anything).Return("model-response", nil).Once().After(10 * time.Millisecond) // Simulate a model response delay because our internal safety measures trigger when a query is done in 0 milliseconds. + mockedQuery.On("Query", mock.Anything, mock.Anything, mock.Anything).Return("", ErrEmptyResponseFromModel).Once() + mockedQuery.On("Query", mock.Anything, mock.Anything, mock.Anything).Return("model-response", nil).Once().After(10 * time.Millisecond) // Simulate a model response delay because our internal safety measures trigger when a query is done in 0 milliseconds. }, After: func(t *testing.T, logger *log.Logger, resultPath string) { mockedQuery.AssertNumberOfCalls(t, "Query", 4) @@ -424,7 +424,7 @@ func TestEvaluate(t *testing.T) { Before: func(t *testing.T, logger *log.Logger, resultPath string) { // Set up mocks, when test is running. - mockedQuery.On("Query", mock.Anything, mockedModelID, mock.Anything).Return("model-response", nil).After(10 * time.Millisecond) // Simulate a model response delay because our internal safety measures trigger when a query is done in 0 milliseconds. + mockedQuery.On("Query", mock.Anything, mock.Anything, mock.Anything).Return("model-response", nil).After(10 * time.Millisecond) // Simulate a model response delay because our internal safety measures trigger when a query is done in 0 milliseconds. }, After: func(t *testing.T, logger *log.Logger, resultPath string) { mockedQuery.AssertNumberOfCalls(t, "Query", 2) diff --git a/go.mod b/go.mod index 66c74aec..10af809f 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/jessevdk/go-flags v1.5.1-0.20210607101731-3927b71304df github.com/kr/pretty v0.3.1 github.com/pkg/errors v0.9.1 - github.com/sashabaranov/go-openai v1.20.4 + github.com/sashabaranov/go-openai v1.36.2-0.20250131190529-45aa99607be0 github.com/stretchr/testify v1.9.0 github.com/symflower/lockfile v0.0.0-20240419143922-aa3b60940c84 github.com/zimmski/osutil v1.3.0 diff --git a/go.sum b/go.sum index 7d99515a..8877228d 100644 --- a/go.sum +++ b/go.sum @@ -29,6 +29,8 @@ github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZV github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/sashabaranov/go-openai v1.20.4 h1:095xQ/fAtRa0+Rj21sezVJABgKfGPNbyx/sAN/hJUmg= github.com/sashabaranov/go-openai v1.20.4/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= +github.com/sashabaranov/go-openai v1.36.2-0.20250131190529-45aa99607be0 h1:WlepprDHs6tUt2/ihnvTL7DvXpO6IOllVP+oPWn2U0k= +github.com/sashabaranov/go-openai v1.36.2-0.20250131190529-45aa99607be0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/schollz/progressbar/v3 v3.14.2 h1:EducH6uNLIWsr560zSV1KrTeUb/wZGAHqyMFIEa99ks= github.com/schollz/progressbar/v3 v3.14.2/go.mod h1:aQAZQnhF4JGFtRJiw/eobaXpsqpVQAftEQ+hLGXaRc4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/model/llm/llm.go b/model/llm/llm.go index 84c6cb2d..1f99d2cf 100644 --- a/model/llm/llm.go +++ b/model/llm/llm.go @@ -24,11 +24,15 @@ import ( // Model represents a LLM model accessed via a provider. type Model struct { + // id holds the full identifier, including the provider and attributes. + id string // provider is the client to query the LLM model. provider provider.Query - // model holds the identifier for the LLM model. - model string + // modelID holds the identifier for the LLM model. + modelID string + // attributes holds query attributes. + attributes map[string]string // queryAttempts holds the number of query attempts to perform when a model request errors in the process of solving a task. queryAttempts uint @@ -37,20 +41,24 @@ type Model struct { } // NewModel returns an LLM model corresponding to the given identifier which is queried via the given provider. -func NewModel(provider provider.Query, modelIdentifier string) *Model { - return &Model{ +func NewModel(provider provider.Query, modelIDWithAttributes string) (llmModel *Model) { + llmModel = &Model{ + id: modelIDWithAttributes, provider: provider, - model: modelIdentifier, queryAttempts: 1, } + llmModel.modelID, llmModel.attributes = model.ParseModelID(modelIDWithAttributes) + + return llmModel } // NewModelWithMetaInformation returns a LLM model with meta information corresponding to the given identifier which is queried via the given provider. func NewModelWithMetaInformation(provider provider.Query, modelIdentifier string, metaInformation *model.MetaInformation) *Model { return &Model{ + id: modelIdentifier, provider: provider, - model: modelIdentifier, + modelID: modelIdentifier, queryAttempts: 1, @@ -58,6 +66,33 @@ func NewModelWithMetaInformation(provider provider.Query, modelIdentifier string } } +var _ model.Model = (*Model)(nil) + +// ID returns full identifier, including the provider and attributes. +func (m *Model) ID() (id string) { + return m.id +} + +// ModelID returns the unique identifier of this model with its provider. +func (m *Model) ModelID() (modelID string) { + return m.modelID +} + +// ModelIDWithoutProvider returns the unique identifier of this model without its provider. +func (m *Model) ModelIDWithoutProvider() (modelID string) { + _, modelID, ok := strings.Cut(m.modelID, provider.ProviderModelSeparator) + if !ok { + panic(m.modelID) + } + + return modelID +} + +// Attributes returns query attributes. +func (m *Model) Attributes() (attributes map[string]string) { + return m.attributes +} + // MetaInformation returns the meta information of a model. func (m *Model) MetaInformation() (metaInformation *model.MetaInformation) { return m.metaInformation @@ -232,13 +267,6 @@ func (ctx *llmMigrateSourceFilePromptContext) Format() (message string, err erro return b.String(), nil } -var _ model.Model = (*Model)(nil) - -// ID returns the unique ID of this model. -func (m *Model) ID() (id string) { - return m.model -} - var _ model.CapabilityWriteTests = (*Model)(nil) // WriteTests generates test files for the given implementation file in a repository. @@ -302,7 +330,7 @@ func (m *Model) query(logger *log.Logger, request string) (response string, dura id := uuid.NewString logger.Info("querying model", "model", m.ID(), "id", id, "prompt", string(bytesutil.PrefixLines([]byte(request), []byte("\t")))) start := time.Now() - response, err = m.provider.Query(context.Background(), m.model, request) + response, err = m.provider.Query(context.Background(), m, request) if err != nil { return err } diff --git a/model/llm/llm_test.go b/model/llm/llm_test.go index b4f46f97..bcea2b6b 100644 --- a/model/llm/llm_test.go +++ b/model/llm/llm_test.go @@ -100,7 +100,7 @@ func TestModelGenerateTestsForFile(t *testing.T) { Name: "Simple", SetupMock: func(mockedProvider *providertesting.MockQuery) { - mockedProvider.On("Query", mock.Anything, "model-id", promptMessage).Return(bytesutil.StringTrimIndentations(` + mockedProvider.On("Query", mock.Anything, mock.Anything, promptMessage).Return(bytesutil.StringTrimIndentations(` `+"```"+` package native @@ -191,7 +191,7 @@ func TestModelRepairSourceCodeFile(t *testing.T) { Name: "Opening bracket is missing", SetupMock: func(t *testing.T, mockedProvider *providertesting.MockQuery) { - mockedProvider.On("Query", mock.Anything, "some-model", mock.Anything).Return(bytesutil.StringTrimIndentations(` + mockedProvider.On("Query", mock.Anything, mock.Anything, mock.Anything).Return(bytesutil.StringTrimIndentations(` `+"```"+` package openingBracketMissing func openingBracketMissing(x int) int { @@ -240,7 +240,7 @@ func TestModelRepairSourceCodeFile(t *testing.T) { Name: "Opening bracket is missing", SetupMock: func(t *testing.T, mockedProvider *providertesting.MockQuery) { - mockedProvider.On("Query", mock.Anything, "some-model", mock.Anything).Return(bytesutil.StringTrimIndentations(` + mockedProvider.On("Query", mock.Anything, mock.Anything, mock.Anything).Return(bytesutil.StringTrimIndentations(` `+"```"+` package com.eval; public class OpeningBracketMissing { @@ -682,7 +682,7 @@ func TestModelTranspile(t *testing.T) { Name: "Binary search", SetupMock: func(t *testing.T, mockedProvider *providertesting.MockQuery) { - mockedProvider.On("Query", mock.Anything, "some-model", mock.Anything).Return("```\n"+transpiledFileContent+"```\n", nil) + mockedProvider.On("Query", mock.Anything, mock.Anything, mock.Anything).Return("```\n"+transpiledFileContent+"```\n", nil) }, Language: &golang.Language{}, @@ -729,10 +729,10 @@ func TestModelTranspile(t *testing.T) { } `) validate(t, &testCase{ - Name: "Binary Search", + Name: "Binary search", SetupMock: func(t *testing.T, mockedProvider *providertesting.MockQuery) { - mockedProvider.On("Query", mock.Anything, "some-model", mock.Anything).Return("```\n"+transpiledFileContent+"```\n", nil) + mockedProvider.On("Query", mock.Anything, mock.Anything, mock.Anything).Return("```\n"+transpiledFileContent+"```\n", nil) }, Language: &java.Language{}, @@ -832,7 +832,7 @@ func TestModelMigrate(t *testing.T) { Name: "Increment", SetupMock: func(t *testing.T, mockedProvider *providertesting.MockQuery) { - mockedProvider.On("Query", mock.Anything, "some-model", mock.Anything).Return("```\n"+migratedTestFile+"```\n", nil) + mockedProvider.On("Query", mock.Anything, mock.Anything, mock.Anything).Return("```\n"+migratedTestFile+"```\n", nil) }, Language: &java.Language{}, diff --git a/model/model.go b/model/model.go index fac8870d..15d0c2d8 100644 --- a/model/model.go +++ b/model/model.go @@ -1,22 +1,45 @@ package model import ( + "strings" + "github.com/symflower/eval-dev-quality/language" "github.com/symflower/eval-dev-quality/log" ) // Model defines a model that can be queried for generations. type Model interface { - // ID returns the unique ID of this model. + // ID returns full identifier, including the provider and attributes. ID() (id string) + // ModelID returns the unique identifier of this model with its provider. + ModelID() (modelID string) + // ModelIDWithoutProvider returns the unique identifier of this model without its provider. + ModelIDWithoutProvider() (modelID string) + + // Attributes returns query attributes. + Attributes() (attributes map[string]string) // MetaInformation returns the meta information of a model. MetaInformation() *MetaInformation } +// ParseModelID takes a packaged model ID with optional attributes and converts it into its model ID and optional attributes. +func ParseModelID(modelIDWithAttributes string) (modelID string, attributes map[string]string) { + ms := strings.Split(modelIDWithAttributes, "@") + if len(ms) > 1 { + attributes = map[string]string{} + for i := 1; i < len(ms); i++ { + as := strings.Split(ms[i], "=") + attributes[as[0]] = as[1] + } + } + + return ms[0], attributes +} + // MetaInformation holds a model. type MetaInformation struct { - // ID holds the model id. + // ID holds the model ID. ID string `json:"id"` // Name holds the model name. Name string `json:"name"` diff --git a/model/symflower/symflower.go b/model/symflower/symflower.go index aa60f23e..5093b91f 100644 --- a/model/symflower/symflower.go +++ b/model/symflower/symflower.go @@ -62,11 +62,26 @@ func NewModelSmartTemplateWithTimeout(timeout time.Duration) (model *Model) { var _ model.Model = (*Model)(nil) -// ID returns the unique ID of this model. +// ID returns full identifier, including the provider and attributes. func (m *Model) ID() (id string) { return "symflower" + provider.ProviderModelSeparator + m.id } +// ModelID returns the unique identifier of this model with its provider. +func (m *Model) ModelID() (modelID string) { + return "symflower" + provider.ProviderModelSeparator + m.id +} + +// ModelIDWithoutProvider returns the unique identifier of this model without its provider. +func (m *Model) ModelIDWithoutProvider() (modelID string) { + return m.id +} + +// Attributes returns query attributes. +func (m *Model) Attributes() (attributes map[string]string) { + return nil +} + // MetaInformation returns the meta information of a model. func (m *Model) MetaInformation() (metaInformation *model.MetaInformation) { return nil diff --git a/model/testing/Model_mock_gen.go b/model/testing/Model_mock_gen.go index 35f6f95c..7d07fbe3 100644 --- a/model/testing/Model_mock_gen.go +++ b/model/testing/Model_mock_gen.go @@ -12,6 +12,26 @@ type MockModel struct { mock.Mock } +// Attributes provides a mock function with given fields: +func (_m *MockModel) Attributes() map[string]string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Attributes") + } + + var r0 map[string]string + if rf, ok := ret.Get(0).(func() map[string]string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]string) + } + } + + return r0 +} + // ID provides a mock function with given fields: func (_m *MockModel) ID() string { ret := _m.Called() @@ -50,6 +70,42 @@ func (_m *MockModel) MetaInformation() *model.MetaInformation { return r0 } +// ModelID provides a mock function with given fields: +func (_m *MockModel) ModelID() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for ModelID") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// ModelIDWithoutProvider provides a mock function with given fields: +func (_m *MockModel) ModelIDWithoutProvider() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for ModelIDWithoutProvider") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + // NewMockModel creates a new instance of MockModel. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewMockModel(t interface { diff --git a/provider/ollama/ollama.go b/provider/ollama/ollama.go index 84f0aff7..5441d3a8 100644 --- a/provider/ollama/ollama.go +++ b/provider/ollama/ollama.go @@ -81,11 +81,8 @@ func (p *Provider) Models() (models []model.Model, err error) { var _ provider.Query = (*Provider)(nil) // Query queries the provider with the given model name. -func (p *Provider) Query(ctx context.Context, modelIdentifier string, promptText string) (response string, err error) { - client := p.client() - modelIdentifier = strings.TrimPrefix(modelIdentifier, p.ID()+provider.ProviderModelSeparator) - - return openaiapi.QueryOpenAIAPIModel(ctx, client, modelIdentifier, promptText) +func (p *Provider) Query(ctx context.Context, model model.Model, promptText string) (response string, err error) { + return openaiapi.QueryOpenAIAPIModel(ctx, p.client(), model.ModelIDWithoutProvider(), model.Attributes(), promptText) } // client returns a new client with the current configuration. diff --git a/provider/openai-api/openai.go b/provider/openai-api/openai.go index 18afcad6..222a7338 100644 --- a/provider/openai-api/openai.go +++ b/provider/openai-api/openai.go @@ -2,7 +2,6 @@ package openaiapi import ( "context" - "strings" "github.com/sashabaranov/go-openai" @@ -60,11 +59,8 @@ func (p *Provider) SetToken(token string) { var _ provider.Query = (*Provider)(nil) // Query queries the provider with the given model name. -func (p *Provider) Query(ctx context.Context, modelIdentifier string, promptText string) (response string, err error) { - client := p.client() - modelIdentifier = strings.TrimPrefix(modelIdentifier, p.ID()+provider.ProviderModelSeparator) - - return QueryOpenAIAPIModel(ctx, client, modelIdentifier, promptText) +func (p *Provider) Query(ctx context.Context, model model.Model, promptText string) (response string, err error) { + return QueryOpenAIAPIModel(ctx, p.client(), model.ModelIDWithoutProvider(), model.Attributes(), promptText) } // client returns a new client with the current configuration. diff --git a/provider/openai-api/query.go b/provider/openai-api/query.go index 5d2e1ee6..f637e9d8 100644 --- a/provider/openai-api/query.go +++ b/provider/openai-api/query.go @@ -9,18 +9,26 @@ import ( ) // QueryOpenAIAPIModel queries an OpenAI API model. -func QueryOpenAIAPIModel(ctx context.Context, client *openai.Client, modelIdentifier string, promptText string) (response string, err error) { - apiResponse, err := client.CreateChatCompletion( - ctx, - openai.ChatCompletionRequest{ - Model: modelIdentifier, - Messages: []openai.ChatCompletionMessage{ - { - Role: openai.ChatMessageRoleUser, - Content: promptText, - }, +func QueryOpenAIAPIModel(ctx context.Context, client *openai.Client, modelIdentifier string, attributes map[string]string, promptText string) (response string, err error) { + apiRequest := openai.ChatCompletionRequest{ + Model: modelIdentifier, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: promptText, }, }, + } + + if attributes != nil { + if reasoningEffort, ok := attributes["reasoning_effort"]; ok { + apiRequest.ReasoningEffort = reasoningEffort + } + } + + apiResponse, err := client.CreateChatCompletion( + ctx, + apiRequest, ) if err != nil { return "", pkgerrors.WithStack(err) diff --git a/provider/openrouter/openrouter.go b/provider/openrouter/openrouter.go index b34f1c95..03c19975 100644 --- a/provider/openrouter/openrouter.go +++ b/provider/openrouter/openrouter.go @@ -7,7 +7,6 @@ import ( "io" "net/http" "net/url" - "strings" "time" "github.com/avast/retry-go" @@ -138,11 +137,8 @@ func (p *Provider) SetToken(token string) { var _ provider.Query = (*Provider)(nil) // Query queries the provider with the given model name. -func (p *Provider) Query(ctx context.Context, modelIdentifier string, promptText string) (response string, err error) { - client := p.client() - modelIdentifier = strings.TrimPrefix(modelIdentifier, p.ID()+provider.ProviderModelSeparator) - - return openaiapi.QueryOpenAIAPIModel(ctx, client, modelIdentifier, promptText) +func (p *Provider) Query(ctx context.Context, model model.Model, promptText string) (response string, err error) { + return openaiapi.QueryOpenAIAPIModel(ctx, p.client(), model.ModelIDWithoutProvider(), model.Attributes(), promptText) } // client returns a new client with the current configuration. diff --git a/provider/provider.go b/provider/provider.go index 06f51373..cc1d168f 100644 --- a/provider/provider.go +++ b/provider/provider.go @@ -45,7 +45,7 @@ type InjectToken interface { // Query is a provider that allows to query a model directly. type Query interface { // Query queries the provider with the given model name. - Query(ctx context.Context, modelIdentifier string, promptText string) (response string, err error) + Query(ctx context.Context, model model.Model, promptText string) (response string, err error) } // Service is a provider that requires background services. diff --git a/provider/testing/Query_mock_gen.go b/provider/testing/Query_mock_gen.go index e0d9ed7b..0bf4526f 100644 --- a/provider/testing/Query_mock_gen.go +++ b/provider/testing/Query_mock_gen.go @@ -6,6 +6,7 @@ import ( context "context" mock "github.com/stretchr/testify/mock" + model "github.com/symflower/eval-dev-quality/model" ) // MockQuery is an autogenerated mock type for the Query type @@ -13,9 +14,9 @@ type MockQuery struct { mock.Mock } -// Query provides a mock function with given fields: ctx, modelIdentifier, promptText -func (_m *MockQuery) Query(ctx context.Context, modelIdentifier string, promptText string) (string, error) { - ret := _m.Called(ctx, modelIdentifier, promptText) +// Query provides a mock function with given fields: ctx, _a1, promptText +func (_m *MockQuery) Query(ctx context.Context, _a1 model.Model, promptText string) (string, error) { + ret := _m.Called(ctx, _a1, promptText) if len(ret) == 0 { panic("no return value specified for Query") @@ -23,17 +24,17 @@ func (_m *MockQuery) Query(ctx context.Context, modelIdentifier string, promptTe var r0 string var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string) (string, error)); ok { - return rf(ctx, modelIdentifier, promptText) + if rf, ok := ret.Get(0).(func(context.Context, model.Model, string) (string, error)); ok { + return rf(ctx, _a1, promptText) } - if rf, ok := ret.Get(0).(func(context.Context, string, string) string); ok { - r0 = rf(ctx, modelIdentifier, promptText) + if rf, ok := ret.Get(0).(func(context.Context, model.Model, string) string); ok { + r0 = rf(ctx, _a1, promptText) } else { r0 = ret.Get(0).(string) } - if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { - r1 = rf(ctx, modelIdentifier, promptText) + if rf, ok := ret.Get(1).(func(context.Context, model.Model, string) error); ok { + r1 = rf(ctx, _a1, promptText) } else { r1 = ret.Error(1) }