From f575905e4f0bccac7c74192f38290cf54464a2db Mon Sep 17 00:00:00 2001 From: Rui Azevedo Date: Fri, 12 Jul 2024 16:44:21 +0100 Subject: [PATCH] Check if the testdata repository is valid before running the evaluation, so it is checked just once Part of #263 --- evaluate/evaluate.go | 4 + evaluate/task/repository.go | 13 ++ evaluate/task/task-code-repair.go | 53 ++++++- evaluate/task/task-code-repair_test.go | 205 +++++++++++++++++++++++++ evaluate/task/task.go | 21 +-- evaluate/task/testing/task.go | 38 +++++ task/task.go | 3 + 7 files changed, 322 insertions(+), 15 deletions(-) diff --git a/evaluate/evaluate.go b/evaluate/evaluate.go index cda1034a1..5d63a8c77 100644 --- a/evaluate/evaluate.go +++ b/evaluate/evaluate.go @@ -93,6 +93,8 @@ func Evaluate(ctx *Context) (assessments *report.AssessmentStore, totalScore uin temporaryRepository, cleanup, err := evaluatetask.TemporaryRepository(ctx.Log, ctx.TestdataPath, repositoryPath) if err != nil { ctx.Log.Panicf("ERROR: unable to create temporary repository path: %+v", err) + } else if err = temporaryRepository.Validate(ctx.Log, language); err != nil { + ctx.Log.Panicf("ERROR: malformed repository %q: %+v", temporaryRepository.Name(), err) } defer cleanup() @@ -197,6 +199,8 @@ func Evaluate(ctx *Context) (assessments *report.AssessmentStore, totalScore uin temporaryRepository, cleanup, err := evaluatetask.TemporaryRepository(ctx.Log, ctx.TestdataPath, repositoryPath) if err != nil { ctx.Log.Panicf("ERROR: unable to create temporary repository path: %s", err) + } else if err = temporaryRepository.Validate(ctx.Log, l); err != nil { + ctx.Log.Panicf("ERROR: malformed repository %q: %+v", temporaryRepository.Name(), err) } defer cleanup() diff --git a/evaluate/task/repository.go b/evaluate/task/repository.go index 208bca36d..aadbc7b8e 100644 --- a/evaluate/task/repository.go +++ b/evaluate/task/repository.go @@ -12,6 +12,7 @@ import ( "github.com/zimmski/osutil" "github.com/zimmski/osutil/bytesutil" + "github.com/symflower/eval-dev-quality/language" "github.com/symflower/eval-dev-quality/log" "github.com/symflower/eval-dev-quality/task" "github.com/symflower/eval-dev-quality/util" @@ -87,6 +88,18 @@ func (r *Repository) SupportedTasks() (tasks []task.Identifier) { return r.Tasks } +// Validate checks it the repository is well-formed. +func (r *Repository) Validate(logger *log.Logger, language language.Language) (err error) { + for _, taskIdentifier := range r.SupportedTasks() { + switch taskIdentifier { + case IdentifierCodeRepair: + return validateCodeRepairRepository(logger, r.DataPath(), language) + } + } + + return nil +} + // Reset resets a repository back to its "initial" commit. func (r *Repository) Reset(logger *log.Logger) (err error) { out, err := util.CommandWithResult(context.Background(), logger, &util.Command{ diff --git a/evaluate/task/task-code-repair.go b/evaluate/task/task-code-repair.go index 2541f2186..2d1f58f1e 100644 --- a/evaluate/task/task-code-repair.go +++ b/evaluate/task/task-code-repair.go @@ -8,6 +8,7 @@ import ( pkgerrors "github.com/pkg/errors" "github.com/symflower/eval-dev-quality/evaluate/metrics" + "github.com/symflower/eval-dev-quality/language" "github.com/symflower/eval-dev-quality/log" "github.com/symflower/eval-dev-quality/model" evaltask "github.com/symflower/eval-dev-quality/task" @@ -121,10 +122,60 @@ func (t *TaskCodeRepair) unpackCodeRepairPackage(ctx evaltask.Context, fileLogge return "", nil, pkgerrors.Errorf("package %q in repository %q must contain source files with compilation errors", packagePath, ctx.Repository.Name()) } - sourceFilePath, err = packageHasSourceAndTestFile(fileLogger, ctx.Repository.Name(), packagePath, ctx.Language) + sourceFilePath, err = packageSourceFile(fileLogger, packagePath, ctx.Language) if err != nil { return "", nil, err } return sourceFilePath, mistakes, nil } + +// validateCodeRepairRepository checks if the repository for the "code-repair" task is well-formed. +func validateCodeRepairRepository(logger *log.Logger, repositoryPath string, language language.Language) (err error) { + logger.Printf("validating repository %q", repositoryPath) + + files, err := os.ReadDir(repositoryPath) + if err != nil { + return pkgerrors.WithStack(err) + } + + var packagePaths []string + var otherFiles []string + for _, file := range files { + if file.Name() == "repository.json" { + continue + } else if file.IsDir() { + packagePaths = append(packagePaths, filepath.Join(repositoryPath, file.Name())) + } else { + otherFiles = append(otherFiles, file.Name()) + } + } + + if len(otherFiles) > 0 { + return pkgerrors.Errorf("the code repair repository %q must contain only packages, but found %+v", repositoryPath, otherFiles) + } + + for _, packagePath := range packagePaths { + files, err := language.Files(logger, packagePath) + if err != nil { + return pkgerrors.WithStack(err) + } + + sourceFiles := []string{} + testFiles := []string{} + for _, file := range files { + if strings.HasSuffix(file, language.DefaultTestFileSuffix()) { + testFiles = append(testFiles, file) + } else if strings.HasSuffix(file, language.DefaultFileExtension()) { + sourceFiles = append(sourceFiles, file) + } + } + if len(sourceFiles) != 1 { + return pkgerrors.Errorf("the code repair package %q in repository %q must contain exactly one %s source file, but found %+v", packagePath, repositoryPath, language.Name(), sourceFiles) + } else if len(testFiles) != 1 { + return pkgerrors.Errorf("the code repair repository %q must contain exactly one %s test file, but found %+v", repositoryPath, language.Name(), testFiles) + } + } + + return nil +} diff --git a/evaluate/task/task-code-repair_test.go b/evaluate/task/task-code-repair_test.go index a31c89b2f..ebab8eb62 100644 --- a/evaluate/task/task-code-repair_test.go +++ b/evaluate/task/task-code-repair_test.go @@ -1,6 +1,8 @@ package task import ( + "fmt" + "os" "path/filepath" "testing" @@ -267,3 +269,206 @@ func TestTaskCodeRepairRun(t *testing.T) { } }) } + +func TestValidateCodeRepairRepository(t *testing.T) { + validate := func(t *testing.T, tc *tasktesting.TestCaseValidateRepository) { + t.Run(tc.Name, func(t *testing.T) { + tc.Validate(t, validateCodeRepairRepository) + }) + } + + validate(t, &tasktesting.TestCaseValidateRepository{ + Name: "Repository root path contains source files", + + Before: func(repositoryPath string) { + someFile, err := os.Create(filepath.Join(repositoryPath, "someFile.go")) + require.NoError(t, err) + someFile.Close() + }, + + TestdataPath: filepath.Join("..", "..", "testdata"), + RepositoryPath: filepath.Join("golang", "mistakes"), + Language: &golang.Language{}, + + ExpectedErrorContains: "must contain only packages, but found [someFile.go]", + }) + t.Run("Go", func(t *testing.T) { + validate(t, &tasktesting.TestCaseValidateRepository{ + Name: "Package does not contain source file", + + Before: func(repositoryPath string) { + require.NoError(t, os.MkdirAll(filepath.Join(repositoryPath, "somePackage"), 0700)) + }, + + TestdataPath: filepath.Join("..", "..", "testdata"), + RepositoryPath: filepath.Join("golang", "mistakes"), + Language: &golang.Language{}, + + ExpectedErrorContains: "must contain exactly one Go source file, but found []", + }) + validate(t, &tasktesting.TestCaseValidateRepository{ + Name: "Package contains multiple source files", + + Before: func(repositoryPath string) { + somePackage := filepath.Join(repositoryPath, "somePackage") + require.NoError(t, os.MkdirAll(somePackage, 0700)) + + fileA, err := os.Create(filepath.Join(somePackage, "fileA.go")) + require.NoError(t, err) + fileA.Close() + + fileB, err := os.Create(filepath.Join(somePackage, "fileB.go")) + require.NoError(t, err) + fileB.Close() + }, + + TestdataPath: filepath.Join("..", "..", "testdata"), + RepositoryPath: filepath.Join("golang", "mistakes"), + Language: &golang.Language{}, + + ExpectedErrorContains: "must contain exactly one Go source file, but found [fileA.go fileB.go]", + }) + validate(t, &tasktesting.TestCaseValidateRepository{ + Name: "Package does not contain test file", + + Before: func(repositoryPath string) { + somePackage := filepath.Join(repositoryPath, "somePackage") + require.NoError(t, os.MkdirAll(somePackage, 0700)) + + file, err := os.Create(filepath.Join(somePackage, "someFile.go")) + require.NoError(t, err) + defer file.Close() + }, + + TestdataPath: filepath.Join("..", "..", "testdata"), + RepositoryPath: filepath.Join("golang", "mistakes"), + Language: &golang.Language{}, + + ExpectedErrorContains: "must contain exactly one Go test file, but found []", + }) + validate(t, &tasktesting.TestCaseValidateRepository{ + Name: "Package contains multiple test files", + + Before: func(repositoryPath string) { + somePackage := filepath.Join(repositoryPath, "somePackage") + require.NoError(t, os.MkdirAll(somePackage, 0700)) + + fileA, err := os.Create(filepath.Join(somePackage, "fileA.go")) + require.NoError(t, err) + fileA.Close() + + fileATest, err := os.Create(filepath.Join(somePackage, "fileA_test.go")) + require.NoError(t, err) + fileATest.Close() + + fileBTest, err := os.Create(filepath.Join(somePackage, "fileB_test.go")) + require.NoError(t, err) + fileBTest.Close() + }, + + TestdataPath: filepath.Join("..", "..", "testdata"), + RepositoryPath: filepath.Join("golang", "mistakes"), + Language: &golang.Language{}, + + ExpectedErrorContains: "must contain exactly one Go test file, but found [fileA_test.go fileB_test.go]", + }) + validate(t, &tasktesting.TestCaseValidateRepository{ + Name: "Well-formed", + + TestdataPath: filepath.Join("..", "..", "testdata"), + RepositoryPath: filepath.Join("golang", "mistakes"), + Language: &golang.Language{}, + }) + }) + t.Run("Java", func(t *testing.T) { + validate(t, &tasktesting.TestCaseValidateRepository{ + Name: "Package does not contain source file", + + Before: func(repositoryPath string) { + require.NoError(t, os.MkdirAll(filepath.Join(repositoryPath, "somePackage"), 0700)) + }, + + TestdataPath: filepath.Join("..", "..", "testdata"), + RepositoryPath: filepath.Join("java", "mistakes"), + Language: &java.Language{}, + + ExpectedErrorContains: "must contain exactly one Java source file, but found []", + }) + validate(t, &tasktesting.TestCaseValidateRepository{ + Name: "Package contains multiple source files", + + Before: func(repositoryPath string) { + somePackage := filepath.Join(repositoryPath, "somePackage", "src", "main", "java", "com", "eval") + require.NoError(t, os.MkdirAll(somePackage, 0700)) + + fileA, err := os.Create(filepath.Join(somePackage, "FileA.java")) + require.NoError(t, err) + fileA.Close() + + fileB, err := os.Create(filepath.Join(somePackage, "FileB.java")) + require.NoError(t, err) + fileB.Close() + }, + + TestdataPath: filepath.Join("..", "..", "testdata"), + RepositoryPath: filepath.Join("java", "mistakes"), + Language: &java.Language{}, + + ExpectedErrorContains: fmt.Sprintf("must contain exactly one Java source file, but found [%s %s]", filepath.Join("src", "main", "java", "com", "eval", "FileA.java"), filepath.Join("src", "main", "java", "com", "eval", "FileB.java")), + }) + + validate(t, &tasktesting.TestCaseValidateRepository{ + Name: "Package does not contain test file", + + Before: func(repositoryPath string) { + somePackage := filepath.Join(repositoryPath, "somePackage", "src", "main", "java", "com", "eval") + require.NoError(t, os.MkdirAll(somePackage, 0700)) + + fileA, err := os.Create(filepath.Join(somePackage, "FileA.java")) + require.NoError(t, err) + fileA.Close() + }, + + TestdataPath: filepath.Join("..", "..", "testdata"), + RepositoryPath: filepath.Join("java", "mistakes"), + Language: &java.Language{}, + + ExpectedErrorContains: "must contain exactly one Java test file, but found []", + }) + validate(t, &tasktesting.TestCaseValidateRepository{ + Name: "Package contains multiple test files", + + Before: func(repositoryPath string) { + sourcePackage := filepath.Join(repositoryPath, "somePackage", "src", "main", "java", "com", "eval") + require.NoError(t, os.MkdirAll(sourcePackage, 0700)) + testPackage := filepath.Join(repositoryPath, "somePackage", "src", "test", "java", "com", "eval") + require.NoError(t, os.MkdirAll(testPackage, 0700)) + + fileA, err := os.Create(filepath.Join(sourcePackage, "FileA.java")) + require.NoError(t, err) + fileA.Close() + + fileATest, err := os.Create(filepath.Join(testPackage, "FileATest.java")) + require.NoError(t, err) + fileATest.Close() + + fileBTest, err := os.Create(filepath.Join(testPackage, "FileBTest.java")) + require.NoError(t, err) + fileBTest.Close() + }, + + TestdataPath: filepath.Join("..", "..", "testdata"), + RepositoryPath: filepath.Join("java", "mistakes"), + Language: &java.Language{}, + + ExpectedErrorContains: fmt.Sprintf("must contain exactly one Java test file, but found [%s %s]", filepath.Join("src", "test", "java", "com", "eval", "FileATest.java"), filepath.Join("src", "test", "java", "com", "eval", "FileBTest.java")), + }) + validate(t, &tasktesting.TestCaseValidateRepository{ + Name: "Well-formed", + + TestdataPath: filepath.Join("..", "..", "testdata"), + RepositoryPath: filepath.Join("java", "mistakes"), + Language: &java.Language{}, + }) + }) +} diff --git a/evaluate/task/task.go b/evaluate/task/task.go index 09d546493..c6c63363a 100644 --- a/evaluate/task/task.go +++ b/evaluate/task/task.go @@ -78,27 +78,20 @@ func (t *taskLogger) finalize(problems []error) { t.Logger.Printf("Evaluated model %q on task %q using language %q and repository %q: encountered %d problems: %+v", t.ctx.Model.ID(), t.task.Identifier(), t.ctx.Language.ID(), t.ctx.Repository.Name(), len(problems), problems) } -// packageHasSourceAndTestFile checks if a package as a source file and the corresponding test file for the given language, and returns the source file path. -func packageHasSourceAndTestFile(log *log.Logger, repositoryName string, packagePath string, language language.Language) (sourceFilePath string, err error) { +// packageSourceFile returns the source file of a package. +func packageSourceFile(log *log.Logger, packagePath string, language language.Language) (sourceFilePath string, err error) { filePaths, err := language.Files(log, packagePath) if err != nil { return "", pkgerrors.WithStack(err) - } else if len(filePaths) != 2 { - return "", pkgerrors.Errorf("package %q in repository %q must only contain an implementation file and the corresponding test file, but found %#v", packagePath, repositoryName, filePaths) } - var hasTestFile bool + for _, file := range filePaths { if strings.HasSuffix(file, language.DefaultTestFileSuffix()) { - hasTestFile = true - } else if filepath.Ext(file) == language.DefaultFileExtension() { - sourceFilePath = file + continue + } else if filepath.Ext(file) == language.DefaultFileExtension() { // We can assume there is only one source file because the package structure was previously verified. + return file, nil } } - if sourceFilePath == "" { - return "", pkgerrors.Errorf("package %q in repository %q does not contain a source file", packagePath, repositoryName) - } else if !hasTestFile { - return "", pkgerrors.Errorf("package %q in repository %q does not contain a test file", packagePath, repositoryName) - } - return sourceFilePath, nil + return sourceFilePath, pkgerrors.WithStack(pkgerrors.Errorf("could not find any %s source file in package %q", language.Name(), packagePath)) } diff --git a/evaluate/task/testing/task.go b/evaluate/task/testing/task.go index ee1a0e38c..cb3ba55ed 100644 --- a/evaluate/task/testing/task.go +++ b/evaluate/task/testing/task.go @@ -97,3 +97,41 @@ func (tc *TestCaseTask) Validate(t *testing.T, createRepository createRepository tc.ValidateLog(t, logOutput.String()) } } + +type TestCaseValidateRepository struct { + Name string + + Before func(repositoryPath string) + + TestdataPath string + RepositoryPath string + Language language.Language + + ExpectedErrorContains string +} + +type validateRepositoryForTask func(logger *log.Logger, repositoryPath string, language language.Language) (err error) + +func (tc *TestCaseValidateRepository) Validate(t *testing.T, validateRepositoryForTask validateRepositoryForTask) { + logOutput, logger := log.Buffer() + defer func() { + if t.Failed() { + t.Logf("Logging output: %s", logOutput.String()) + } + }() + + temporaryDirectory := t.TempDir() + repositoryPath := filepath.Join(temporaryDirectory, "testdata", tc.RepositoryPath) + require.NoError(t, osutil.CopyTree(filepath.Join(tc.TestdataPath, tc.RepositoryPath), repositoryPath)) + + if tc.Before != nil { + tc.Before(repositoryPath) + } + + actualErr := validateRepositoryForTask(logger, repositoryPath, tc.Language) + if len(tc.ExpectedErrorContains) > 0 { + assert.ErrorContains(t, actualErr, tc.ExpectedErrorContains) + } else { + require.NoError(t, actualErr) + } +} diff --git a/task/task.go b/task/task.go index 0166f33a4..da3c04c57 100644 --- a/task/task.go +++ b/task/task.go @@ -54,6 +54,9 @@ type Repository interface { // SupportedTasks returns the list of task identifiers the repository supports. SupportedTasks() (tasks []Identifier) + // Validate checks it the repository is well-formed. + Validate(logger *log.Logger, language language.Language) (err error) + // Reset resets the repository to its initial state. Reset(logger *log.Logger) (err error) }