Skip to content

Commit

Permalink
fix, Default to all repositories if none are selected in CLI
Browse files Browse the repository at this point in the history
Fixes #163
  • Loading branch information
bauersimon committed Jun 12, 2024
1 parent 4a8523a commit 57b2eb2
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 27 deletions.
74 changes: 47 additions & 27 deletions cmd/eval-dev-quality/cmd/evaluate.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,15 @@ func (command *Evaluate) Initialize(args []string) (cleanup func()) {
}
}

// Ensure the "testdata" path exists and make it absolute.
if err := osutil.DirExists(command.TestdataPath); err != nil {
command.logger.Panicf("ERROR: testdata path %q cannot be accessed: %s", command.TestdataPath, err)
}
command.TestdataPath, err = filepath.Abs(command.TestdataPath)
if err != nil {
command.logger.Panicf("ERROR: could not resolve testdata path %q to an absolute path: %s", command.TestdataPath, err)
}

// Register custom OpenAI API providers and models.
{
customProviders := map[string]*openaiapi.Provider{}
Expand Down Expand Up @@ -205,25 +214,45 @@ func (command *Evaluate) Initialize(args []string) (cleanup func()) {
{
commandRepositories := map[string]bool{}
commandRepositoriesLanguages := map[string]bool{}
for _, r := range command.Repositories {
languageIDOfRepository := strings.SplitN(r, string(os.PathSeparator), 2)[0]
commandRepositoriesLanguages[languageIDOfRepository] = true

if _, ok := languagesSelected[languageIDOfRepository]; ok {
commandRepositories[r] = true
} else {
command.logger.Printf("Excluded repository %s because its language %q is not enabled for this evaluation", r, languageIDOfRepository)
if len(command.Repositories) == 0 {
for _, language := range command.Languages {
commandRepositoriesLanguages[language] = true
commandRepositories[filepath.Join(language, evaluate.RepositoryPlainName)] = true

languagePath := filepath.Join(command.TestdataPath, language)
languageRepositories, err := os.ReadDir(languagePath)
if err != nil {
command.logger.Panicf("ERROR: language path %q cannot be accessed: %s", languagePath, err)
}

for _, repository := range languageRepositories {
if !repository.IsDir() {
continue
}
commandRepositories[filepath.Join(language, repository.Name())] = true
}
}
}
for languageID := range languagesSelected {
if len(command.Repositories) == 0 || commandRepositoriesLanguages[languageID] {
commandRepositories[filepath.Join(languageID, evaluate.RepositoryPlainName)] = true
} else {
command.Languages = slices.DeleteFunc(command.Languages, func(l string) bool {
return l == languageID
})
delete(languagesSelected, languageID)
command.logger.Printf("Excluded language %q because it is not part of the selected repositories", languageID)
} else {
for _, r := range command.Repositories {
languageIDOfRepository := strings.SplitN(r, string(os.PathSeparator), 2)[0]
commandRepositoriesLanguages[languageIDOfRepository] = true

if _, ok := languagesSelected[languageIDOfRepository]; ok {
commandRepositories[r] = true
} else {
command.logger.Printf("Excluded repository %s because its language %q is not enabled for this evaluation", r, languageIDOfRepository)
}
}
for languageID := range languagesSelected {
if commandRepositoriesLanguages[languageID] {
commandRepositories[filepath.Join(languageID, evaluate.RepositoryPlainName)] = true
} else {
command.Languages = slices.DeleteFunc(command.Languages, func(l string) bool {
return l == languageID
})
delete(languagesSelected, languageID)
command.logger.Printf("Excluded language %q because it is not part of the selected repositories", languageID)
}
}
}
command.Repositories = maps.Keys(commandRepositories)
Expand Down Expand Up @@ -303,15 +332,6 @@ func (command *Evaluate) Initialize(args []string) (cleanup func()) {
}
}

// Ensure the "testdata" path exists and make it absolute.
if err := osutil.DirExists(command.TestdataPath); err != nil {
command.logger.Panicf("ERROR: testdata path %q cannot be accessed: %s", command.TestdataPath, err)
}
command.TestdataPath, err = filepath.Abs(command.TestdataPath)
if err != nil {
command.logger.Panicf("ERROR: could not resolve testdata path %q to an absolute path: %s", command.TestdataPath, err)
}

return nil
}

Expand Down
27 changes: 27 additions & 0 deletions cmd/eval-dev-quality/cmd/evaluate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,8 @@ func TestEvaluateExecute(t *testing.T) {
Arguments: []string{
"--language", "golang",
"--model", "symflower/symbolic-execution",
"--repository", filepath.Join("golang", "plain"),
"--repository", filepath.Join("java", "plain"),
},

ExpectedOutputValidate: func(t *testing.T, output string, resultPath string) {
Expand Down Expand Up @@ -270,6 +272,8 @@ func TestEvaluateExecute(t *testing.T) {

Arguments: []string{
"--model", "symflower/symbolic-execution",
"--repository", filepath.Join("golang", "plain"),
"--repository", filepath.Join("java", "plain"),
},

ExpectedOutputValidate: func(t *testing.T, output string, resultPath string) {
Expand Down Expand Up @@ -720,6 +724,7 @@ func TestEvaluateExecute(t *testing.T) {
Arguments: []string{
"--language", "golang",
"--model", "symflower/symbolic-execution",
"--repository", filepath.Join("golang", "plain"),
},

ExpectedResultFiles: map[string]func(t *testing.T, filePath string, data string){
Expand All @@ -742,6 +747,7 @@ func TestEvaluateExecute(t *testing.T) {
Arguments: []string{
"--language", "golang",
"--model", "symflower/symbolic-execution",
"--repository", filepath.Join("golang", "plain"),
},

ExpectedResultFiles: map[string]func(t *testing.T, filePath string, data string){
Expand Down Expand Up @@ -892,4 +898,25 @@ func TestEvaluateInitialize(t *testing.T) {
}, command.Languages)
},
})
validate(t, &testCase{
Name: "Selecting no repository defaults to all",

Command: makeValidCommand(func(command *Evaluate) {
command.Repositories = []string{}
}),

ValidateCommand: func(t *testing.T, command *Evaluate) {
// Check if all Go repositories are indeed selected.
directories, err := os.ReadDir(filepath.Join("..", "..", "..", "testdata", "golang"))
require.NoError(t, err)
relativeRepositoryPath := make([]string, len(directories))
for i, directory := range directories {
relativeRepositoryPath[i] = filepath.Join("golang", directory.Name())
}

for _, golangRepository := range relativeRepositoryPath {
assert.Contains(t, command.Repositories, golangRepository)
}
},
})
}

0 comments on commit 57b2eb2

Please sign in to comment.