Skip to content

Commit

Permalink
refactor, Move evaluation command setup/initialization into separate …
Browse files Browse the repository at this point in the history
…function for testability

Part of #163
  • Loading branch information
bauersimon committed Jun 13, 2024
1 parent f25eed3 commit d91200d
Show file tree
Hide file tree
Showing 2 changed files with 247 additions and 67 deletions.
173 changes: 106 additions & 67 deletions cmd/eval-dev-quality/cmd/evaluate.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,14 @@ type Evaluate struct {

// Languages determines which language should be used for the evaluation, or empty if all languages should be used.
Languages []string `long:"language" description:"Evaluate with this language. By default all languages are used."`
// Languages holds the resolved used languages.
languages []language.Language
// Models determines which models should be used for the evaluation, or empty if all models should be used.
Models []string `long:"model" description:"Evaluate with this model. By default all models are used."`
// modes holds the resolved used models.
models []model.Model
// providerForModel hods the providers per model.
providerForModel map[model.Model]provider.Provider
// ProviderTokens holds all API tokens for the providers.
ProviderTokens map[string]string `long:"tokens" description:"API tokens for model providers (of the form '$provider:$token'). When using the environment variable, separate multiple definitions with ','." env:"PROVIDER_TOKEN" env-delim:","`
// ProviderUrls holds all custom inference endpoint urls for the providers.
Expand All @@ -68,6 +74,8 @@ type Evaluate struct {

// logger holds the logger of the command.
logger *log.Logger
// timestamp holds the timestamp of the command execution.
timestamp time.Time
}

var _ SetLogger = (*Evaluate)(nil)
Expand All @@ -77,28 +85,42 @@ func (command *Evaluate) SetLogger(logger *log.Logger) {
command.logger = logger
}

// Execute executes the command.
func (command *Evaluate) Execute(args []string) (err error) {
evaluationTimestamp := time.Now()
command.ResultPath = strings.ReplaceAll(command.ResultPath, "%datetime%", evaluationTimestamp.Format("2006-01-02-15:04:05")) // REMARK Use a datetime format with a dash, so directories can be easily marked because they are only one group.
command.ResultPath, err = util.UniqueDirectory(command.ResultPath)
if err != nil {
return err
// Initialize initializes the command according to the arguments.
func (command *Evaluate) Initialize(args []string) (cleanup func()) {
// Ensure the cleanup always runs in case there is a panic.
defer func() {
if r := recover(); r != nil {
if cleanup != nil {
cleanup()
}
panic(r)
}
}()

// Setup evaluation result directory.
command.timestamp = time.Now()
command.ResultPath = strings.ReplaceAll(command.ResultPath, "%datetime%", command.timestamp.Format("2006-01-02-15:04:05")) // REMARK Use a datetime format with a dash, so directories can be easily marked because they are only one group.
if uniqueResultPath, err := util.UniqueDirectory(command.ResultPath); err != nil {
command.logger.Panicf("ERROR: %s", err)
} else {
command.ResultPath = uniqueResultPath
}
command.logger.Printf("Writing results to %s", command.ResultPath)

// Initialize logging within result directory.
log, logClose, err := log.WithFile(command.logger, filepath.Join(command.ResultPath, "evaluation.log"))
if err != nil {
return err
command.logger.Panicf("ERROR: %s", err)
}
defer logClose()
cleanup = logClose
command.logger = log

// Check common options.
// Check and validate common options.
{
if command.InstallToolsPath == "" {
command.InstallToolsPath, err = tools.InstallPathDefault()
if err != nil {
log.Panicf("ERROR: %s", err)
command.logger.Panicf("ERROR: %s", err)
}
}

Expand All @@ -114,11 +136,11 @@ func (command *Evaluate) Execute(args []string) (err error) {
}

if command.QueryAttempts == 0 {
log.Panicf("number of configured query attempts must be greater than zero")
command.logger.Panicf("number of configured query attempts must be greater than zero")
}

if command.Runs == 0 {
log.Panicf("number of configured runs must be greater than zero")
command.logger.Panicf("number of configured runs must be greater than zero")
}
}

Expand All @@ -141,11 +163,11 @@ func (command *Evaluate) Execute(args []string) (err error) {

providerID, _, ok := strings.Cut(model, provider.ProviderModelSeparator)
if !ok {
log.Panicf("ERROR: cannot split %q into provider and model name by %q", model, provider.ProviderModelSeparator)
command.logger.Panicf("ERROR: cannot split %q into provider and model name by %q", model, provider.ProviderModelSeparator)
}
modelProvider, ok := customProviders[providerID]
if !ok {
log.Panicf("ERROR: unknown custom provider %q for model %q", providerID, model)
command.logger.Panicf("ERROR: unknown custom provider %q for model %q", providerID, model)
}

modelProvider.AddModel(llm.NewModel(modelProvider, model))
Expand All @@ -166,7 +188,7 @@ func (command *Evaluate) Execute(args []string) (err error) {
ls := maps.Keys(language.Languages)
sort.Strings(ls)

log.Panicf("ERROR: language %s does not exist. Valid languages are: %s", languageID, strings.Join(ls, ", "))
command.logger.Panicf("ERROR: language %s does not exist. Valid languages are: %s", languageID, strings.Join(ls, ", "))
}

languages[languageID] = l
Expand All @@ -176,41 +198,51 @@ func (command *Evaluate) Execute(args []string) (err error) {
for _, languageID := range command.Languages {
languagesSelected[languageID] = languages[languageID]
}
}

commandRepositories := map[string]bool{}
commandRepositoriesLanguages := map[string]bool{}
for _, r := range command.Repositories {
languageIDOfRepository := strings.SplitN(r, string(os.PathSeparator), 2)[0]
commandRepositoriesLanguages[languageIDOfRepository] = true
}

if _, ok := languagesSelected[languageIDOfRepository]; ok {
commandRepositories[r] = true
} else {
log.Printf("Excluded repository %s because its language %q is not enabled for this evaluation", r, languageIDOfRepository)
// Gather repositories and update language selection accordingly.
{
commandRepositories := map[string]bool{}
commandRepositoriesLanguages := map[string]bool{}
for _, r := range command.Repositories {
languageIDOfRepository := strings.SplitN(r, string(os.PathSeparator), 2)[0]
commandRepositoriesLanguages[languageIDOfRepository] = true

if _, ok := languagesSelected[languageIDOfRepository]; ok {
commandRepositories[r] = true
} else {
command.logger.Printf("Excluded repository %s because its language %q is not enabled for this evaluation", r, languageIDOfRepository)
}
}
}
for languageID := range languagesSelected {
if len(command.Repositories) == 0 || commandRepositoriesLanguages[languageID] {
commandRepositories[filepath.Join(languageID, evaluate.RepositoryPlainName)] = true
} else {
command.Languages = slices.DeleteFunc(command.Languages, func(l string) bool {
return l == languageID
})
delete(languagesSelected, languageID)
log.Printf("Excluded language %q because it is not part of the selected repositories", languageID)
for languageID := range languagesSelected {
if len(command.Repositories) == 0 || commandRepositoriesLanguages[languageID] {
commandRepositories[filepath.Join(languageID, evaluate.RepositoryPlainName)] = true
} else {
command.Languages = slices.DeleteFunc(command.Languages, func(l string) bool {
return l == languageID
})
delete(languagesSelected, languageID)
command.logger.Printf("Excluded language %q because it is not part of the selected repositories", languageID)
}
}
command.Repositories = maps.Keys(commandRepositories)
sort.Strings(command.Repositories)
}

// Make the resolved selected languages available in the command.
command.languages = make([]language.Language, len(command.Languages))
for i, languageID := range command.Languages {
command.languages[i] = languagesSelected[languageID]
}
command.Repositories = maps.Keys(commandRepositories)
sort.Strings(command.Repositories)

// Gather models.
modelsSelected := map[string]model.Model{}
providerForModel := map[model.Model]provider.Provider{}
command.providerForModel = map[model.Model]provider.Provider{}
{
models := map[string]model.Model{}
for _, p := range provider.Providers {
log.Printf("Checking provider %q for models", p.ID())
command.logger.Printf("Checking provider %q for models", p.ID())

if t, ok := p.(provider.InjectToken); ok {
token, ok := command.ProviderTokens[p.ID()]
Expand All @@ -219,33 +251,33 @@ func (command *Evaluate) Execute(args []string) (err error) {
}
}
if err := p.Available(log); err != nil {
log.Printf("Skipping unavailable provider %q cause: %s", p.ID(), err)
command.logger.Printf("Skipping unavailable provider %q cause: %s", p.ID(), err)

continue
}

// Start services of providers.
if service, ok := p.(provider.Service); ok {
log.Printf("Starting services for provider %q", p.ID())
command.logger.Printf("Starting services for provider %q", p.ID())
shutdown, err := service.Start(log)
if err != nil {
log.Panicf("ERROR: could not start services for provider %q: %s", p, err)
command.logger.Panicf("ERROR: could not start services for provider %q: %s", p, err)
}
defer func() {
if err := shutdown(); err != nil {
log.Panicf("ERROR: could not shutdown services of provider %q: %s", p, err)
command.logger.Panicf("ERROR: could not shutdown services of provider %q: %s", p, err)
}
}()
}

ms, err := p.Models()
if err != nil {
log.Panicf("ERROR: could not query models for provider %q: %s", p.ID(), err)
command.logger.Panicf("ERROR: could not query models for provider %q: %s", p.ID(), err)
}

for _, m := range ms {
models[m.ID()] = m
providerForModel[m] = p
command.providerForModel[m] = p
}
}
modelIDs := maps.Keys(models)
Expand All @@ -255,44 +287,51 @@ func (command *Evaluate) Execute(args []string) (err error) {
} else {
for _, modelID := range command.Models {
if _, ok := models[modelID]; !ok {
log.Panicf("ERROR: model %s does not exist. Valid models are: %s", modelID, strings.Join(modelIDs, ", "))
command.logger.Panicf("ERROR: model %s does not exist. Valid models are: %s", modelID, strings.Join(modelIDs, ", "))
}
}
}
sort.Strings(command.Models)
for _, modelID := range command.Models {
modelsSelected[modelID] = models[modelID]
}

// Make the resolved selected models available in the command.
command.models = make([]model.Model, len(command.Models))
for i, modelID := range command.Models {
command.models[i] = modelsSelected[modelID]
}
}

// Ensure the "testdata" path exists and make it absolute.
if err := osutil.DirExists(command.TestdataPath); err != nil {
log.Panicf("ERROR: testdata path %q cannot be accessed: %s", command.TestdataPath, err)
command.logger.Panicf("ERROR: testdata path %q cannot be accessed: %s", command.TestdataPath, err)
}
command.TestdataPath, err = filepath.Abs(command.TestdataPath)
if err != nil {
log.Panicf("ERROR: could not resolve testdata path %q to an absolute path: %s", command.TestdataPath, err)
command.logger.Panicf("ERROR: could not resolve testdata path %q to an absolute path: %s", command.TestdataPath, err)
}

return cleanup
}

// Execute executes the command.
func (command *Evaluate) Execute(args []string) (err error) {
cleanup := command.Initialize(args)
defer cleanup()

// Install required tools for the basic evaluation.
if err := tools.InstallEvaluation(log, command.InstallToolsPath); err != nil {
log.Panicf("ERROR: %s", err)
if err := tools.InstallEvaluation(command.logger, command.InstallToolsPath); err != nil {
command.logger.Panicf("ERROR: %s", err)
}

ls := make([]language.Language, len(command.Languages))
for i, languageID := range command.Languages {
ls[i] = languagesSelected[languageID]
}
ms := make([]model.Model, len(command.Models))
for i, modelID := range command.Models {
ms[i] = modelsSelected[modelID]
}
assessments, totalScore := evaluate.Evaluate(&evaluate.Context{
Log: log,
Log: command.logger,

Languages: ls,
Languages: command.languages,

Models: ms,
ProviderForModel: providerForModel,
Models: command.models,
ProviderForModel: command.providerForModel,
QueryAttempts: command.QueryAttempts,

RepositoryPaths: command.Repositories,
Expand All @@ -306,7 +345,7 @@ func (command *Evaluate) Execute(args []string) (err error) {

assessmentsPerModel := assessments.CollapseByModel()
if err := (report.Markdown{
DateTime: evaluationTimestamp,
DateTime: command.timestamp,
Version: evaluate.Version,

CSVPath: "./evaluation.csv",
Expand All @@ -317,17 +356,17 @@ func (command *Evaluate) Execute(args []string) (err error) {
AssessmentPerModel: assessmentsPerModel,
TotalScore: totalScore,
}).WriteToFile(filepath.Join(command.ResultPath, "README.md")); err != nil {
return err
command.logger.Panicf("ERROR: %s", err)
}

_ = assessmentsPerModel.WalkByScore(func(model model.Model, assessment metrics.Assessments, score uint64) (err error) {
log.Printf("Evaluation score for %q (%q): %s", model.ID(), assessment.Category(totalScore).ID, assessment)
command.logger.Printf("Evaluation score for %q (%q): %s", model.ID(), assessment.Category(totalScore).ID, assessment)

return nil
})

if err := writeCSVs(command.ResultPath, assessments); err != nil {
log.Panicf("ERROR: %s", err)
command.logger.Panicf("ERROR: %s", err)
}

return nil
Expand Down
Loading

0 comments on commit d91200d

Please sign in to comment.