Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow to set reasoning_effort for models (e.g. OpenAI's o3-mini) #408

Merged
merged 11 commits into from
Feb 3, 2025
Merged
253 changes: 142 additions & 111 deletions cmd/eval-dev-quality/cmd/evaluate.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ type Evaluate struct {

// Languages determines which language should be used for the evaluation, or empty if all languages should be used.
Languages []string `long:"language" description:"Evaluate with this language. By default all languages are used."`
// Models determines which models should be used for the evaluation, or empty if all models should be used.
Models []string `long:"model" description:"Evaluate with this model. By default all models are used."`
// ModelIDsWithProviderAndAttributes determines which models should be used for the evaluation, or empty if all models should be used.
ModelIDsWithProviderAndAttributes []string `long:"model" description:"Evaluate with this model. By default all models are used."`
// ProviderTokens holds all API tokens for the providers.
ProviderTokens map[string]string `long:"tokens" description:"API tokens for model providers (of the form '$provider:$token'). When using the environment variable, separate multiple definitions with ','." env:"PROVIDER_TOKEN" env-delim:","`
// ProviderUrls holds all custom inference endpoint urls for the providers.
Expand Down Expand Up @@ -123,7 +123,7 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate.
command.logger.Panicf("the configuration file is not supported in containerized runtimes")
}

if len(command.Models) > 0 || len(command.Repositories) > 0 {
if len(command.ModelIDsWithProviderAndAttributes) > 0 || len(command.Repositories) > 0 {
command.logger.Panicf("do not provide models and repositories when loading a configuration file")
}

Expand All @@ -139,7 +139,7 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate.
command.logger.Panicf("ERROR: %s", err)
}

command.Models = configuration.Models.Selected
command.ModelIDsWithProviderAndAttributes = configuration.Models.Selected
command.Repositories = configuration.Repositories.Selected
}

Expand Down Expand Up @@ -258,43 +258,13 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate.
// In a containerized runtime we check the availability of the testdata, repositories and models/providers inside the container.
if command.Runtime != "local" {
// Copy the models over.
for _, modelID := range command.Models {
for _, modelID := range command.ModelIDsWithProviderAndAttributes {
evaluationContext.Models = append(evaluationContext.Models, llm.NewModel(nil, modelID))
}

return evaluationContext, evaluationConfiguration, func() {}
}

// Register custom OpenAI API providers and models.
{
customProviders := map[string]*openaiapi.Provider{}
for providerID, providerURL := range command.ProviderUrls {
if !strings.HasPrefix(providerID, "custom-") {
continue
}

p := openaiapi.NewProvider(providerID, providerURL)
provider.Register(p)
customProviders[providerID] = p
}
for _, model := range command.Models {
if !strings.HasPrefix(model, "custom-") {
continue
}

providerID, _, ok := strings.Cut(model, provider.ProviderModelSeparator)
if !ok {
command.logger.Panicf("ERROR: cannot split %q into provider and model name by %q", model, provider.ProviderModelSeparator)
}
modelProvider, ok := customProviders[providerID]
if !ok {
command.logger.Panicf("ERROR: unknown custom provider %q for model %q", providerID, model)
}

modelProvider.AddModel(llm.NewModel(modelProvider, model))
}
}

// Ensure the "testdata" path exists and make it absolute.
{
if err := osutil.DirExists(command.TestdataPath); err != nil {
Expand Down Expand Up @@ -371,101 +341,162 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate.
evaluationContext.Languages[i] = languagesSelected[languageID]
}

// Gather models.
serviceShutdown := []func() (err error){}
// Gather models and initialize providers.
var serviceShutdown []func() (err error)
{
// Check which providers are needed for the evaluation.
providersSelected := map[string]provider.Provider{}
if len(command.Models) == 0 {
providersSelected = provider.Providers
// Gather providers.
providers := map[string]provider.Provider{}
if len(command.ModelIDsWithProviderAndAttributes) == 0 {
for providerID, provider := range provider.Providers {
providers[providerID] = provider
command.logger.Info("selected provider", "provider", providerID)
}
} else {
for _, model := range command.Models {
p := strings.SplitN(model, provider.ProviderModelSeparator, 2)[0]
// Register custom providers.
for providerID, providerURL := range command.ProviderUrls {
if !strings.HasPrefix(providerID, "custom-") {
command.logger.Panicf("ERROR: cannot set URL of %q because it is not a custom provider", providerID)
}

if _, ok := providersSelected[p]; ok {
continue
p := openaiapi.NewProvider(providerID, providerURL)
provider.Register(p)
providers[providerID] = p
command.logger.Info("selected provider", "provider", providerID)
}

// Add remaining providers from models.
for _, modelIDsWithProviderAndAttributes := range command.ModelIDsWithProviderAndAttributes {
providerID, _, ok := strings.Cut(modelIDsWithProviderAndAttributes, provider.ProviderModelSeparator)
if !ok {
command.logger.Panicf("ERROR: cannot split %q into provider and model name by %q", modelIDsWithProviderAndAttributes, provider.ProviderModelSeparator)
}

if provider, ok := provider.Providers[p]; !ok {
command.logger.Panicf("Provider %q does not exist", p)
} else {
providersSelected[provider.ID()] = provider
p, ok := provider.Providers[providerID]
if !ok {
command.logger.Panicf("ERROR: unknown provider %q for model %q", providerID, modelIDsWithProviderAndAttributes)
}
if _, ok := providers[providerID]; !ok {
providers[providerID] = p
command.logger.Info("selected provider", "provider", providerID)
}
}
}

zimmski marked this conversation as resolved.
Show resolved Hide resolved
models := map[string]model.Model{}
modelsSelected := map[string]model.Model{}
evaluationContext.ProviderForModel = map[model.Model]provider.Provider{}
for _, p := range providersSelected {
command.logger.Info("querying provider models", "provider", p.ID())
// Initialize providers.
{
providerIDsSorted := maps.Keys(providers)
sort.Strings(providerIDsSorted)
for _, providerID := range providerIDsSorted {
p := providers[providerID]

if t, ok := p.(provider.InjectToken); ok {
token, ok := command.ProviderTokens[p.ID()]
if ok {
t.SetToken(token)
command.logger.Info("initializing provider", "provider", providerID)
if t, ok := p.(provider.InjectToken); ok {
if token, ok := command.ProviderTokens[p.ID()]; ok {
command.logger.Info("set token of provider", "provider", providerID)
t.SetToken(token)
}
}
}
if err := p.Available(command.logger); err != nil {
command.logger.Warn("skipping unavailable provider", "provider", p.ID(), "error", err)
command.logger.Info("checking availability for provider", "provider", providerID)
if err := p.Available(command.logger); err != nil {
command.logger.Info("skipping provider because it is not available", "error", err, "provider", providerID)
delete(providers, providerID)

continue
continue
}
if service, ok := p.(provider.Service); ok {
command.logger.Info("starting services for provider", "provider", p.ID())
shutdown, err := service.Start(command.logger)
if err != nil {
command.logger.Panicf("ERROR: could not start services for provider %q: %s", p, err)
}
serviceShutdown = append(serviceShutdown, shutdown)
}
}
}

// Start services of providers.
if service, ok := p.(provider.Service); ok {
command.logger.Info("starting services for provider", "provider", p.ID())
shutdown, err := service.Start(command.logger)
// Gather models.
models := map[string]model.Model{}
{
addAllModels := len(command.ModelIDsWithProviderAndAttributes) == 0
for _, p := range providers {
ms, err := p.Models()
if err != nil {
command.logger.Panicf("ERROR: could not start services for provider %q: %s", p, err)
command.logger.Panicf("ERROR: could not query models for provider %q: %s", p.ID(), err)
}
serviceShutdown = append(serviceShutdown, shutdown)
}
for _, m := range ms {
models[m.ID()] = m
evaluationConfiguration.Models.Available = append(evaluationConfiguration.Models.Available, m.ID())

// Check if a provider has the ability to pull models and do so if necessary.
if puller, ok := p.(provider.Puller); ok {
command.logger.Info("pulling available models for provider", "provider", p.ID())
for _, modelID := range command.Models {
if strings.HasPrefix(modelID, p.ID()) {
if err := puller.Pull(command.logger, modelID); err != nil {
command.logger.Panicf("ERROR: could not pull model %q: %s", modelID, err)
}
if addAllModels {
command.ModelIDsWithProviderAndAttributes = append(command.ModelIDsWithProviderAndAttributes, m.ID())
}
}
}
}
modelIDs := maps.Keys(models)
sort.Strings(modelIDs)
sort.Strings(command.ModelIDsWithProviderAndAttributes)

ms, err := p.Models()
if err != nil {
command.logger.Panicf("ERROR: could not query models for provider %q: %s", p.ID(), err)
// Check and initialize models.
evaluationContext.ProviderForModel = map[model.Model]provider.Provider{}
for _, modelIDsWithProviderAndAttributes := range command.ModelIDsWithProviderAndAttributes {
command.logger.Info("selecting model", "model", modelIDsWithProviderAndAttributes)

providerID, modelIDsWithAttributes, ok := strings.Cut(modelIDsWithProviderAndAttributes, provider.ProviderModelSeparator)
zimmski marked this conversation as resolved.
Show resolved Hide resolved
if !ok {
command.logger.Panicf("ERROR: cannot split %q into provider and model name by %q", modelIDsWithProviderAndAttributes, provider.ProviderModelSeparator)
}

for _, m := range ms {
models[m.ID()] = m
evaluationContext.ProviderForModel[m] = p
evaluationConfiguration.Models.Available = append(evaluationConfiguration.Models.Available, m.ID())
modelID, _ := model.ParseModelID(modelIDsWithAttributes)

p, ok := providers[providerID]
if !ok {
command.logger.Panicf("ERROR: cannot find provider %q", providerID)
}
}
modelIDs := maps.Keys(models)
sort.Strings(modelIDs)
if len(command.Models) == 0 {
command.Models = modelIDs
} else {
for _, modelID := range command.Models {
if _, ok := models[modelID]; !ok {
command.logger.Panicf("ERROR: model %s does not exist. Valid models are: %s", modelID, strings.Join(modelIDs, ", "))
if puller, ok := p.(provider.Puller); ok {
command.logger.Info("pulling model", "model", modelID)
if err := puller.Pull(command.logger, modelID); err != nil {
command.logger.Panicf("ERROR: could not pull model %q: %s", modelID, err)
}

// TODO If a model has not been pulled before, it was not available for at least the "Ollama" provider. Make this cleaner, we should not rebuild every time.
if _, ok := models[modelIDsWithProviderAndAttributes]; !ok {
ms, err := p.Models()
if err != nil {
command.logger.Panicf("ERROR: could not query models for provider %q: %s", p.ID(), err)
}
for _, m := range ms {
if _, ok := models[m.ID()]; ok {
continue
}

models[m.ID()] = m
evaluationConfiguration.Models.Available = append(evaluationConfiguration.Models.Available, m.ID())
}
modelIDs = maps.Keys(models)
sort.Strings(modelIDs)
}
}
}
sort.Strings(command.Models)
for _, modelID := range command.Models {
modelsSelected[modelID] = models[modelID]
}

// Make the resolved selected models available in the command.
evaluationContext.Models = make([]model.Model, len(command.Models))
for i, modelID := range command.Models {
evaluationContext.Models[i] = modelsSelected[modelID]
evaluationConfiguration.Models.Selected = append(evaluationConfiguration.Models.Selected, modelID)
var m model.Model
if strings.HasPrefix(providerID, "custom-") {
pc, ok := p.(*openaiapi.Provider)
if !ok {
command.logger.Panicf("ERROR: %q is not a custom provider", providerID)
}

m = llm.NewModel(pc, modelIDsWithProviderAndAttributes)
pc.AddModel(m)
} else {
var ok bool
m, ok = models[modelIDsWithProviderAndAttributes]
if !ok {
command.logger.Panicf("ERROR: model %q does not exist for provider %q. Valid models are: %s", modelIDsWithProviderAndAttributes, providerID, strings.Join(modelIDs, ", "))
}
}
evaluationContext.Models = append(evaluationContext.Models, m)
evaluationContext.ProviderForModel[m] = p
evaluationConfiguration.Models.Selected = append(evaluationConfiguration.Models.Selected, modelIDsWithProviderAndAttributes)
}
}

Expand Down Expand Up @@ -613,7 +644,7 @@ func (command *Evaluate) evaluateDocker(ctx *evaluate.Context) (err error) {
"-e", "SYMFLOWER_INTERNAL_LICENSE_FILE",
"-e", "SYMFLOWER_LICENSE_KEY",
"-v", volumeName + ":/app/evaluation",
"--rm", // automatically remove container after it finished
"--rm", // Automatically remove container after it finished.
command.RuntimeImage,
}

Expand Down Expand Up @@ -706,7 +737,7 @@ func (command *Evaluate) evaluateKubernetes(ctx *evaluate.Context) (err error) {
// Define a regex to replace all non alphanumeric characters and "-".
kubeNameRegex := regexp.MustCompile(`[^a-zA-Z0-9-]+`)

jobTmpl, err := template.ParseFiles(filepath.Join("conf", "kube", "job.yml"))
kubernetesJobTemplate, err := template.ParseFiles(filepath.Join("conf", "kube", "job.yml"))
if err != nil {
return pkgerrors.Wrap(err, "could not create kubernetes job template")
}
Expand Down Expand Up @@ -735,7 +766,7 @@ func (command *Evaluate) evaluateKubernetes(ctx *evaluate.Context) (err error) {
"kubectl",
"apply",
"-f",
"-", // apply STDIN
"-", // Apply STDIN.
}

// Commands for the evaluation to run inside the container.
Expand Down Expand Up @@ -763,14 +794,14 @@ func (command *Evaluate) evaluateKubernetes(ctx *evaluate.Context) (err error) {
}

parallel.Execute(func() {
var tmplData bytes.Buffer
if err := jobTmpl.Execute(&tmplData, data); err != nil {
var kubernetesJobData bytes.Buffer
if err := kubernetesJobTemplate.Execute(&kubernetesJobData, data); err != nil {
command.logger.Panicf("ERROR: %s", err)
}

commandOutput, err := util.CommandWithResult(context.Background(), command.logger, &util.Command{
Command: kubeCommand,
Stdin: tmplData.String(),
Stdin: kubernetesJobData.String(),
})
if err != nil {
command.logger.Error("kubernetes evaluation failed", "error", pkgerrors.WithMessage(pkgerrors.WithStack(err), commandOutput))
Expand Down Expand Up @@ -830,7 +861,7 @@ func (command *Evaluate) evaluateKubernetes(ctx *evaluate.Context) (err error) {

var storageTemplateData bytes.Buffer
if err := storageTemplate.Execute(&storageTemplateData, data); err != nil {
return pkgerrors.Wrap(err, "could not execute storate template")
return pkgerrors.Wrap(err, "could not execute storage template")
}

// Create the storage access pod.
Expand All @@ -839,7 +870,7 @@ func (command *Evaluate) evaluateKubernetes(ctx *evaluate.Context) (err error) {
"kubectl",
"apply",
"-f",
"-", // apply STDIN
"-", // Apply STDIN.
},
Stdin: storageTemplateData.String(),
})
Expand Down
Loading
Loading