From c95b3146c6f0f312266fd7fa81c99aa663e74671 Mon Sep 17 00:00:00 2001 From: Rui Azevedo Date: Wed, 18 Dec 2024 17:04:21 +0000 Subject: [PATCH] New task for code migration Part of #375 --- .mockery.yml | 1 + evaluate/task/migrate.go | 199 ++++++++++++ evaluate/task/migrate_test.go | 332 ++++++++++++++++++++ evaluate/task/repository.go | 2 + evaluate/task/task.go | 6 + model/capability.go | 6 + model/llm/llm.go | 87 +++++ model/llm/llm_test.go | 131 ++++++++ model/testing/CapabilityMigrate_mock_gen.go | 59 ++++ model/testing/helper.go | 22 ++ 10 files changed, 845 insertions(+) create mode 100644 evaluate/task/migrate.go create mode 100644 evaluate/task/migrate_test.go create mode 100644 model/testing/CapabilityMigrate_mock_gen.go diff --git a/.mockery.yml b/.mockery.yml index bbfede58..ec21dfee 100644 --- a/.mockery.yml +++ b/.mockery.yml @@ -15,6 +15,7 @@ packages: Model: CapabilityWriteTests: CapabilityRepairCode: + CapabilityMigrate: CapabilityTranspile: github.com/symflower/eval-dev-quality/provider: interfaces: diff --git a/evaluate/task/migrate.go b/evaluate/task/migrate.go new file mode 100644 index 00000000..5c6db3c7 --- /dev/null +++ b/evaluate/task/migrate.go @@ -0,0 +1,199 @@ +package task + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + pkgerrors "github.com/pkg/errors" + "github.com/symflower/eval-dev-quality/evaluate/metrics" + "github.com/symflower/eval-dev-quality/language" + "github.com/symflower/eval-dev-quality/log" + "github.com/symflower/eval-dev-quality/model" + evaltask "github.com/symflower/eval-dev-quality/task" + "github.com/zimmski/osutil" +) + +// Migrate holds the migration task. +type Migrate struct{} + +var _ evaltask.Task = (*Migrate)(nil) + +// ArgumentsMigrate holds extra arguments to be used in a query prompt. +type ArgumentsMigrate struct { + // TestFramework holds the test framework to use. + TestFramework string +} + +// Identifier returns the migration task identifier. +func (t *Migrate) Identifier() evaltask.Identifier { + return IdentifierMigrate +} + +// Run migrates code and runs the generated tests to check if the migration was successful. +func (t *Migrate) Run(ctx evaltask.Context) (repositoryAssessment map[evaltask.Identifier]metrics.Assessments, problems []error, err error) { + modelCapability, ok := ctx.Model.(model.CapabilityMigrate) + if !ok { + return nil, nil, pkgerrors.Wrap(evaltask.ErrTaskUnsupportedByModel, fmt.Sprintf("%q does not support %q", ctx.Model.ID(), string(t.Identifier()))) + } + + taskLogger, err := newTaskLogger(ctx, t) + if err != nil { + return nil, nil, err + } + defer func() { + taskLogger.finalize(problems) + }() + + dataPath := ctx.Repository.DataPath() + filePaths, err := ctx.Language.Files(taskLogger.Logger, dataPath) + if err != nil { + return nil, problems, pkgerrors.WithStack(err) + } + var testFilesPath []string + for _, filePath := range filePaths { + if strings.HasSuffix(filePath, ctx.Language.DefaultTestFileSuffix()) { + testFilesPath = append(testFilesPath, filePath) + } + } + + testFramework := ctx.Language.TestFramework() + if ctx.Repository.Configuration().Prompt.TestFramework != "" { + testFramework = ctx.Repository.Configuration().Prompt.TestFramework + } + + modelAssessment := metrics.NewAssessments() + withSymflowerFixAssessment := metrics.NewAssessments() + + var maximumReachableFiles uint64 + for _, testFilePath := range testFilesPath { + if ctx.Repository.Configuration().IsFilePathIgnored(testFilePath) { + taskLogger.Printf("Ignoring file %q (as configured by the repository)", testFilePath) + + continue + } + maximumReachableFiles++ + + if err := ctx.Repository.Reset(ctx.Logger); err != nil { + ctx.Logger.Panicf("ERROR: unable to reset temporary repository path: %s", err) + } + + // Remove all the other test files so when the tests are executed they don't influence the coverage metrics of the test file under test. + if err := clearRepositoryForMigration(ctx.Repository.DataPath(), ctx.Language, filePaths, testFilePath); err != nil { + return nil, nil, err + } + + modelContext := model.Context{ + Language: ctx.Language, + + RepositoryPath: ctx.Repository.DataPath(), + FilePath: testFilePath, + + Arguments: &ArgumentsMigrate{ + TestFramework: testFramework, + }, + + Logger: taskLogger.Logger, + } + runTask := func(ctx model.Context) (assessments metrics.Assessments, err error) { + return modelCapability.Migrate(ctx) + } + modelAssessmentFile, withSymflowerFixAssessmentFile, ps, err := runModelAndSymflowerFix(ctx, taskLogger, runTask, modelContext, dataPath, testFilePath) + problems = append(problems, ps...) + if err != nil { + return nil, problems, err + } + + modelAssessment.Add(modelAssessmentFile) + withSymflowerFixAssessment.Add(withSymflowerFixAssessmentFile) + } + + modelAssessment[metrics.AssessmentKeyFilesExecutedMaximumReachable] = maximumReachableFiles + withSymflowerFixAssessment[metrics.AssessmentKeyFilesExecutedMaximumReachable] = maximumReachableFiles + + repositoryAssessment = map[evaltask.Identifier]metrics.Assessments{ + IdentifierMigrate: modelAssessment, + IdentifierMigrateSymflowerFix: withSymflowerFixAssessment, + } + + return repositoryAssessment, problems, nil +} + +// clearRepositoryForMigration removes all files from the repository except the implementation and test file of the given file. +func clearRepositoryForMigration(repositoryPath string, language language.Language, allFilePaths []string, testFilePath string) (err error) { + testFileName := filepath.Base(testFilePath) + testFileName = strings.TrimSuffix(testFileName, language.DefaultTestFileSuffix()) + + for _, filePath := range allFilePaths { + if filePath == testFilePath { + continue + } + + fileName := filepath.Base(filePath) + fileName = strings.TrimSuffix(fileName, language.DefaultFileExtension()) + if fileName == testFileName { + continue + } + + if err := os.Remove(filepath.Join(repositoryPath, filePath)); err != nil { + return pkgerrors.WithStack(err) + } + } + + return nil +} + +// validateTranspileRepository checks if the repository for the "transpile" task is well-formed. +func validateMigrateRepository(logger *log.Logger, repositoryPath string, language language.Language) (err error) { + logger.Printf("validating repository %q", repositoryPath) + + filePaths, err := osutil.FilesRecursive(repositoryPath) + if err != nil { + return pkgerrors.WithStack(err) + } + + // Keep a mapping between implementation file paths and test file paths. + implementationFileNames := map[string]bool{} + testFileNames := map[string]bool{} + for _, filePath := range filePaths { + filePathExtension := filepath.Ext(filePath) + // Ignore build and configuration files. + if filePathExtension == ".xml" || filePathExtension == ".json" { + continue + } else if filePathExtension != language.DefaultFileExtension() { + return pkgerrors.Errorf("the repository %q must contain only %s files but found %q", repositoryPath, language.Name(), filePath) + } + + fileName := filepath.Base(filePath) + if strings.HasSuffix(filePath, language.DefaultTestFileSuffix()) { + fileName = strings.TrimSuffix(fileName, language.DefaultTestFileSuffix()) + testFileNames[fileName] = true + } else { + fileName = strings.TrimSuffix(fileName, language.DefaultFileExtension()) + implementationFileNames[fileName] = true + } + } + + if len(implementationFileNames) == 0 { + return pkgerrors.Errorf("the repository %q must contain implementation files but found none", repositoryPath) + } else if len(testFileNames) == 0 { + return pkgerrors.Errorf("the repository %q must contain test files but found none", repositoryPath) + } + + // Check if for each implementation file a test file exists. + for implementationFileName := range implementationFileNames { + if !testFileNames[implementationFileName] { + return pkgerrors.Errorf("the repository %q must contain a test file for each implementation file but found none for %q", repositoryPath, implementationFileName) + } + } + + // Check if for each test file an implementation file exists. + for testFileName := range testFileNames { + if !implementationFileNames[testFileName] { + return pkgerrors.Errorf("the repository %q must contain an implementation file for each test file but found none for %q", repositoryPath, testFileName) + } + } + + return nil +} diff --git a/evaluate/task/migrate_test.go b/evaluate/task/migrate_test.go new file mode 100644 index 00000000..d51b6e14 --- /dev/null +++ b/evaluate/task/migrate_test.go @@ -0,0 +1,332 @@ +package task + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/zimmski/osutil" + "github.com/zimmski/osutil/bytesutil" + + "github.com/symflower/eval-dev-quality/evaluate/metrics" + metricstesting "github.com/symflower/eval-dev-quality/evaluate/metrics/testing" + tasktesting "github.com/symflower/eval-dev-quality/evaluate/task/testing" + "github.com/symflower/eval-dev-quality/language" + "github.com/symflower/eval-dev-quality/language/golang" + "github.com/symflower/eval-dev-quality/language/java" + "github.com/symflower/eval-dev-quality/log" + modeltesting "github.com/symflower/eval-dev-quality/model/testing" + evaltask "github.com/symflower/eval-dev-quality/task" +) + +func TestMigrateRun(t *testing.T) { + validate := func(t *testing.T, tc *tasktesting.TestCaseTask) { + t.Run(tc.Name, func(t *testing.T) { + task, err := ForIdentifier(IdentifierMigrate) + require.NoError(t, err) + tc.Task = task + + tc.Validate(t, + func(logger *log.Logger, testDataPath string, repositoryPathRelative string) (repository evaltask.Repository, cleanup func(), err error) { + return TemporaryRepository(logger, testDataPath, repositoryPathRelative) + }, + ) + }) + } + + t.Run("Java", func(t *testing.T) { + { + temporaryDirectoryPath := t.TempDir() + + repositoryPath := filepath.Join(temporaryDirectoryPath, "java", "migrate") + require.NoError(t, osutil.CopyTree(filepath.Join("..", "..", "testdata", "java", "migrate"), repositoryPath)) + + modelMock := modeltesting.NewMockCapabilityMigrateNamed(t, "mocked-model") + + decrementTestFileContent := bytesutil.StringTrimIndentations(` + package com.eval; + + import org.junit.jupiter.api.Test; + import static org.junit.jupiter.api.Assertions.assertEquals; + + public class DecrementTest { + @Test + public void decrement() { + int i = 1; + int expected = 0; + int actual = Decrement.decrement(i); + + assertEquals(expected, actual); + } + } + `) + modelMock.RegisterGenerateSuccess(t, filepath.Join("src", "test", "java", "com", "eval", "DecrementTest.java"), decrementTestFileContent, metricstesting.AssessmentsWithProcessingTime).Once() + + incrementTestFileContent := bytesutil.StringTrimIndentations(` + package com.eval; + + import org.junit.jupiter.api.Test; + import static org.junit.jupiter.api.Assertions.assertEquals; + + public class IncrementTest { + @Test + public void increment() { + int i = 1; + int expected = 2; + int actual = Increment.increment(i); + + assertEquals(expected, actual); + } + } + `) + modelMock.RegisterGenerateSuccess(t, filepath.Join("src", "test", "java", "com", "eval", "IncrementTest.java"), incrementTestFileContent, metricstesting.AssessmentsWithProcessingTime).Once() + + validate(t, &tasktesting.TestCaseTask{ + Name: "Plain", + + Model: modelMock, + Language: &java.Language{}, + TestDataPath: temporaryDirectoryPath, + RepositoryPath: filepath.Join("java", "migrate"), + + ExpectedRepositoryAssessment: map[evaltask.Identifier]metrics.Assessments{ + IdentifierMigrate: metrics.Assessments{ + metrics.AssessmentKeyFilesExecutedMaximumReachable: 2, + metrics.AssessmentKeyFilesExecuted: 2, + metrics.AssessmentKeyResponseNoError: 2, + metrics.AssessmentKeyCoverage: 40, + }, + IdentifierMigrateSymflowerFix: metrics.Assessments{ + metrics.AssessmentKeyFilesExecutedMaximumReachable: 2, + metrics.AssessmentKeyFilesExecuted: 2, + metrics.AssessmentKeyResponseNoError: 2, + metrics.AssessmentKeyCoverage: 40, + }, + }, + ValidateLog: func(t *testing.T, data string) { + assert.Contains(t, data, "BUILD SUCCESS") + }, + }) + } + }) +} + +func TestClearRepositoryForMigration(t *testing.T) { + type testCase struct { + Name string + + Language language.Language + AllFilePaths []string + FilePath string + + ExpectedFilePaths []string + } + + validate := func(t *testing.T, tc *testCase) { + t.Run(tc.Name, func(t *testing.T) { + repositoryPath := t.TempDir() + + for _, filePath := range tc.AllFilePaths { + require.NoError(t, osutil.MkdirAll(filepath.Join(repositoryPath, filepath.Dir(filePath)))) + require.NoError(t, os.WriteFile(filepath.Join(repositoryPath, filePath), []byte(filePath), 0700)) + } + + require.NoError(t, clearRepositoryForMigration(repositoryPath, tc.Language, tc.AllFilePaths, tc.FilePath)) + + actualFilePaths, err := osutil.FilesRecursive(repositoryPath) + require.NoError(t, err) + for i, filePath := range actualFilePaths { + filePath, err := filepath.Rel(repositoryPath, filePath) + require.NoError(t, err) + actualFilePaths[i] = filePath + } + assert.Equal(t, tc.ExpectedFilePaths, actualFilePaths) + }) + } + + t.Run("Go", func(t *testing.T) { + validate(t, &testCase{ + Name: "Single", + + Language: &golang.Language{}, + AllFilePaths: []string{ + "file.go", + "file_test.go", + }, + FilePath: "file_test.go", + + ExpectedFilePaths: []string{ + "file.go", + "file_test.go", + }, + }) + validate(t, &testCase{ + Name: "Multiple", + + Language: &golang.Language{}, + AllFilePaths: []string{ + "fileA.go", + "fileA_test.go", + "fileB.go", + "fileB_test.go", + "fileC.go", + "fileC_test.go", + }, + FilePath: "fileB_test.go", + + ExpectedFilePaths: []string{ + "fileB.go", + "fileB_test.go", + }, + }) + }) + t.Run("Java", func(t *testing.T) { + validate(t, &testCase{ + Name: "Single", + + Language: &java.Language{}, + AllFilePaths: []string{ + filepath.Join("src", "main", "java", "com", "eval", "File.java"), + filepath.Join("src", "test", "java", "com", "eval", "FileTest.java"), + }, + FilePath: filepath.Join("src", "test", "java", "com", "eval", "FileTest.java"), + + ExpectedFilePaths: []string{ + filepath.Join("src", "main", "java", "com", "eval", "File.java"), + filepath.Join("src", "test", "java", "com", "eval", "FileTest.java"), + }, + }) + validate(t, &testCase{ + Name: "Multiple", + + Language: &java.Language{}, + AllFilePaths: []string{ + filepath.Join("src", "main", "java", "com", "eval", "FileA.java"), + filepath.Join("src", "main", "java", "com", "eval", "FileB.java"), + filepath.Join("src", "main", "java", "com", "eval", "FileC.java"), + filepath.Join("src", "test", "java", "com", "eval", "FileATest.java"), + filepath.Join("src", "test", "java", "com", "eval", "FileBTest.java"), + filepath.Join("src", "test", "java", "com", "eval", "FileCTest.java"), + }, + FilePath: filepath.Join("src", "test", "java", "com", "eval", "FileBTest.java"), + + ExpectedFilePaths: []string{ + filepath.Join("src", "main", "java", "com", "eval", "FileB.java"), + filepath.Join("src", "test", "java", "com", "eval", "FileBTest.java"), + }, + }) + }) +} + +func TestValidateMigrateRepository(t *testing.T) { + type testCase struct { + Name string + + Before func(repositoryPath string) + + ExpectedError func(t *testing.T, err error) + } + + validateJava := func(t *testing.T, tc *testCase) { + validateRepository := &tasktesting.TestCaseValidateRepository{ + Name: tc.Name, + + Before: tc.Before, + + TestdataPath: filepath.Join("..", "..", "testdata"), + RepositoryPath: filepath.Join("java", "migrate"), + Language: &java.Language{}, + + ExpectedError: tc.ExpectedError, + } + + validateRepository.Validate(t, validateMigrateRepository) + } + + removeFilesWithCondition := func(language language.Language, repositoryPath string, condition func(filePath string) bool) { + _, logger := log.Buffer() + + filePaths, err := language.Files(logger, repositoryPath) + require.NoError(t, err) + + for _, filePath := range filePaths { + if condition(filePath) { + require.NoError(t, os.Remove(filepath.Join(repositoryPath, filePath))) + } + } + } + + t.Run("Java", func(t *testing.T) { + validateJava(t, &testCase{ + Name: "Invalid language", + + Before: func(repositoryPath string) { + require.NoError(t, os.WriteFile(filepath.Join(repositoryPath, "file.go"), []byte(`content`), 0700)) + }, + + ExpectedError: func(t *testing.T, err error) { + assert.ErrorContains(t, err, "must contain only Java files") + }, + }) + validateJava(t, &testCase{ + Name: "No test files", + + Before: func(repositoryPath string) { + javaLanguage := &java.Language{} + removeFilesWithCondition(javaLanguage, repositoryPath, func(filePath string) bool { + return strings.HasSuffix(filePath, javaLanguage.DefaultTestFileSuffix()) + }) + }, + + ExpectedError: func(t *testing.T, err error) { + assert.ErrorContains(t, err, "must contain test files but found none") + }, + }) + validateJava(t, &testCase{ + Name: "No implementation files", + + Before: func(repositoryPath string) { + javaLanguage := &java.Language{} + removeFilesWithCondition(javaLanguage, repositoryPath, func(filePath string) bool { + return !strings.HasSuffix(filePath, javaLanguage.DefaultTestFileSuffix()) + }) + }, + + ExpectedError: func(t *testing.T, err error) { + assert.ErrorContains(t, err, "must contain implementation files but found none") + }, + }) + validateJava(t, &testCase{ + Name: "Implementation files does not have a corresponding test file", + + Before: func(repositoryPath string) { + filePath := filepath.Join(repositoryPath, "src", "main", "java", "com", "eval", "File.java") + require.NoError(t, os.WriteFile(filePath, []byte(`content`), 0700)) + }, + + ExpectedError: func(t *testing.T, err error) { + assert.ErrorContains(t, err, "must contain a test file for each implementation file but found none") + }, + }) + validateJava(t, &testCase{ + Name: "Test file does not have a corresponding implementation file", + + Before: func(repositoryPath string) { + filePath := filepath.Join(repositoryPath, "src", "test", "java", "com", "eval", "FileTest.java") + require.NoError(t, os.WriteFile(filePath, []byte(`content`), 0700)) + }, + + ExpectedError: func(t *testing.T, err error) { + assert.ErrorContains(t, err, "must contain an implementation file for each test file but found none") + }, + }) + validateJava(t, &testCase{ + Name: "Valid", + + // The testdata repository is valid by default. + }) + }) +} diff --git a/evaluate/task/repository.go b/evaluate/task/repository.go index 9a07aed2..2835f0cb 100644 --- a/evaluate/task/repository.go +++ b/evaluate/task/repository.go @@ -59,6 +59,8 @@ func (r *Repository) Validate(logger *log.Logger, language language.Language) (e switch taskIdentifier { case IdentifierCodeRepair: return validateCodeRepairRepository(logger, r.DataPath(), language) + case IdentifierMigrate: + return validateMigrateRepository(logger, r.DataPath(), language) case IdentifierTranspile: return validateTranspileRepository(logger, r.DataPath(), language) case IdentifierWriteTests: diff --git a/evaluate/task/task.go b/evaluate/task/task.go index 0b28d5fa..97c5aa7f 100644 --- a/evaluate/task/task.go +++ b/evaluate/task/task.go @@ -41,6 +41,10 @@ var ( IdentifierWriteTestsSymflowerTemplateSymflowerFix = registerIdentifier("write-tests-symflower-template-symflower-fix") // IdentifierCodeRepair holds the identifier for the "code repair" task. IdentifierCodeRepair = registerIdentifier("code-repair") + // IdentifierMigrate holds the identifier for the "migrate" task. + IdentifierMigrate = registerIdentifier("migrate") + // IdentifierMigrateSymflowerFix holds the identifier for the "migrate" task. with the "symflower fix" applied. + IdentifierMigrateSymflowerFix = registerIdentifier("migrate-symflower-fix") // IdentifierTranspile holds the identifier for the "transpile" task. IdentifierTranspile = registerIdentifier("transpile") // IdentifierTranspileSymflowerFix holds the identifier for the "transpile" task with the "symflower fix" applied. @@ -54,6 +58,8 @@ func ForIdentifier(taskIdentifier evaltask.Identifier) (task evaltask.Task, err return &WriteTests{}, nil case IdentifierCodeRepair: return &CodeRepair{}, nil + case IdentifierMigrate: + return &Migrate{}, nil case IdentifierTranspile: return &Transpile{}, nil default: diff --git a/model/capability.go b/model/capability.go index c9a688fb..76b8a9ca 100644 --- a/model/capability.go +++ b/model/capability.go @@ -14,6 +14,12 @@ type CapabilityRepairCode interface { RepairCode(ctx Context) (assessments metrics.Assessments, err error) } +// CapabilityMigrate defines the capability of a model to migrate code. +type CapabilityMigrate interface { + // Migrate queries the model to migrate source code. + Migrate(ctx Context) (assessments metrics.Assessments, err error) +} + // CapabilityTranspile defines the capability of a model to transpile code. type CapabilityTranspile interface { // Transpile queries the model to transpile source code to another language. diff --git a/model/llm/llm.go b/model/llm/llm.go index 18a50542..e1bb7a81 100644 --- a/model/llm/llm.go +++ b/model/llm/llm.go @@ -198,6 +198,40 @@ func (ctx *llmTranspileSourceFilePromptContext) Format() (message string, err er return b.String(), nil } +// llmMigrateSourceFilePromptContext is the template context for a migration LLM prompt. +type llmMigrateSourceFilePromptContext struct { + // llmSourceFilePromptContext holds the context for a source file prompt. + llmSourceFilePromptContext + + // TestFramework defines the target test framework for migration. + TestFramework string +} + +// llmMigrateSourceFilePromptTemplate is the template for generating an LLM migration prompt. +var llmMigrateSourceFilePromptTemplate = template.Must(template.New("model-llm-migration-source-file-prompt").Parse(bytesutil.StringTrimIndentations(` + Given the following {{ .Language.Name }} test file "{{ .FilePath }}" with package "{{ .ImportPath }}", migrate the test file to {{ .TestFramework }} as the test framework. + The tests should produce 100 percent code coverage and must compile. + The response must contain only the test code in a fenced code block and nothing else. + + ` + "```" + `{{ .Language.ID }} + {{ .Code }} + ` + "```" + ` +`))) + +// Format returns the prompt to migrate a source file. +func (ctx *llmMigrateSourceFilePromptContext) Format() (message string, err error) { + // Use Linux paths even when running the evaluation on Windows to ensure consistency in prompting. + ctx.FilePath = filepath.ToSlash(ctx.FilePath) + ctx.Code = strings.TrimSpace(ctx.Code) + + var b strings.Builder + if err := llmMigrateSourceFilePromptTemplate.Execute(&b, ctx); err != nil { + return "", pkgerrors.WithStack(err) + } + + return b.String(), nil +} + var _ model.Model = (*Model)(nil) // ID returns the unique ID of this model. @@ -403,6 +437,59 @@ func (m *Model) Transpile(ctx model.Context) (assessment metrics.Assessments, er return assessment, nil } +var _ model.CapabilityMigrate = (*Model)(nil) + +// Migrate queries the model to migrate source code. +func (m *Model) Migrate(ctx model.Context) (assessment metrics.Assessments, err error) { + arguments, ok := ctx.Arguments.(*evaluatetask.ArgumentsMigrate) + if !ok { + return nil, pkgerrors.Errorf("unexpected type %T", ctx.Arguments) + } + + data, err := os.ReadFile(filepath.Join(ctx.RepositoryPath, ctx.FilePath)) + if err != nil { + return nil, pkgerrors.WithStack(err) + } + fileContent := strings.TrimSpace(string(data)) + + importPath := ctx.Language.ImportPath(ctx.RepositoryPath, ctx.FilePath) + + request, err := (&llmMigrateSourceFilePromptContext{ + llmSourceFilePromptContext: llmSourceFilePromptContext{ + Language: ctx.Language, + + Code: fileContent, + FilePath: ctx.FilePath, + ImportPath: importPath, + }, + + TestFramework: arguments.TestFramework, + }).Format() + if err != nil { + return nil, err + } + + response, duration, err := m.query(ctx.Logger, request) + if err != nil { + return nil, pkgerrors.WithStack(err) + } + + assessment, migrationFileContent, err := prompt.ParseResponse(response) + if err != nil { + return nil, pkgerrors.WithStack(err) + } + assessment[metrics.AssessmentKeyProcessingTime] = uint64(duration.Milliseconds()) + assessment[metrics.AssessmentKeyResponseCharacterCount] = uint64(len(response)) + assessment[metrics.AssessmentKeyGenerateTestsForFileCharacterCount] = uint64(len(migrationFileContent)) + + err = os.WriteFile(filepath.Join(ctx.RepositoryPath, ctx.FilePath), []byte(migrationFileContent), 0644) + if err != nil { + return nil, pkgerrors.WithStack(err) + } + + return assessment, nil +} + var _ model.SetQueryAttempts = (*Model)(nil) // SetQueryAttempts sets the number of query attempts to perform when a model request errors in the process of solving a task. diff --git a/model/llm/llm_test.go b/model/llm/llm_test.go index 7240f53c..fc1cdf69 100644 --- a/model/llm/llm_test.go +++ b/model/llm/llm_test.go @@ -552,6 +552,40 @@ func TestFormatPromptContext(t *testing.T) { ` + "```" + `java package com.eval; + class Foobar { + static int foobar(int i) {} + } + ` + "```" + ` + `), + }) + validate(t, &testCase{ + Name: "Migrate", + + Context: &llmMigrateSourceFilePromptContext{ + llmSourceFilePromptContext: llmSourceFilePromptContext{ + Language: &java.Language{}, + + Code: bytesutil.StringTrimIndentations(` + package com.eval; + + class Foobar { + static int foobar(int i) {} + } + `), + FilePath: "Foobar.java", + ImportPath: "com.eval", + }, + TestFramework: "JUnit 5", + }, + + ExpectedMessage: bytesutil.StringTrimIndentations(` + Given the following Java test file "Foobar.java" with package "com.eval", migrate the test file to JUnit 5 as the test framework. + The tests should produce 100 percent code coverage and must compile. + The response must contain only the test code in a fenced code block and nothing else. + + ` + "```" + `java + package com.eval; + class Foobar { static int foobar(int i) {} } @@ -718,3 +752,100 @@ func TestModelTranspile(t *testing.T) { }) }) } + +func TestModelMigrate(t *testing.T) { + type testCase struct { + Name string + + SetupMock func(t *testing.T, mockedProvider *providertesting.MockQuery) + + Language language.Language + + RepositoryPath string + TestFilePath string + TestFramework string + + ExpectedAssessment metrics.Assessments + ExpectedMigratedFileContent string + } + + validate := func(t *testing.T, tc *testCase) { + logOutput, logger := log.Buffer() + defer func() { + if t.Failed() { + t.Log(logOutput.String()) + } + }() + + temporaryPath := t.TempDir() + repositoryPath := filepath.Join(temporaryPath, filepath.Base(tc.RepositoryPath)) + require.NoError(t, osutil.CopyTree(tc.RepositoryPath, repositoryPath)) + + modelID := "some-model" + mock := providertesting.NewMockQuery(t) + tc.SetupMock(t, mock) + llm := NewModel(mock, modelID) + + ctx := model.Context{ + Language: tc.Language, + + RepositoryPath: repositoryPath, + FilePath: tc.TestFilePath, + + Arguments: &evaluatetask.ArgumentsMigrate{ + TestFramework: tc.TestFramework, + }, + + Logger: logger, + } + + actualAssessment, actualError := llm.Migrate(ctx) + assert.NoError(t, actualError) + + assert.Equal(t, metricstesting.Clean(tc.ExpectedAssessment), metricstesting.Clean(actualAssessment)) + + actualMigratedFileContent, err := os.ReadFile(filepath.Join(repositoryPath, tc.TestFilePath)) + assert.NoError(t, err) + + assert.Equal(t, strings.TrimSpace(tc.ExpectedMigratedFileContent), string(actualMigratedFileContent)) + } + + migratedTestFile := bytesutil.StringTrimIndentations(` + package com.eval; + + import org.junit.jupiter.api.Test; + import static org.junit.jupiter.api.Assertions.assertEquals; + + public class IncrementTest { + @Test + public void increment() { + int i = 1; + int expected = 2; + int actual = Increment.increment(i); + + assertEquals(expected, actual); + } + } + `) + + validate(t, &testCase{ + Name: "Increment", + + SetupMock: func(t *testing.T, mockedProvider *providertesting.MockQuery) { + mockedProvider.On("Query", mock.Anything, "some-model", mock.Anything).Return("```\n"+migratedTestFile+"```\n", nil) + }, + + Language: &java.Language{}, + + RepositoryPath: filepath.Join("..", "..", "testdata", "java", "migrate"), + TestFilePath: filepath.Join("src", "test", "java", "com", "eval", "IncrementTest.java"), + + ExpectedAssessment: metrics.Assessments{ + metrics.AssessmentKeyResponseNoExcess: 1, + metrics.AssessmentKeyResponseWithCode: 1, + metrics.AssessmentKeyGenerateTestsForFileCharacterCount: 290, + metrics.AssessmentKeyResponseCharacterCount: 299, + }, + ExpectedMigratedFileContent: migratedTestFile, + }) +} diff --git a/model/testing/CapabilityMigrate_mock_gen.go b/model/testing/CapabilityMigrate_mock_gen.go new file mode 100644 index 00000000..f9691ddf --- /dev/null +++ b/model/testing/CapabilityMigrate_mock_gen.go @@ -0,0 +1,59 @@ +// Code generated by mockery v2.40.3. DO NOT EDIT. + +package modeltesting + +import ( + mock "github.com/stretchr/testify/mock" + metrics "github.com/symflower/eval-dev-quality/evaluate/metrics" + + model "github.com/symflower/eval-dev-quality/model" +) + +// MockCapabilityMigrate is an autogenerated mock type for the CapabilityMigrate type +type MockCapabilityMigrate struct { + mock.Mock +} + +// Migrate provides a mock function with given fields: ctx +func (_m *MockCapabilityMigrate) Migrate(ctx model.Context) (metrics.Assessments, error) { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for Migrate") + } + + var r0 metrics.Assessments + var r1 error + if rf, ok := ret.Get(0).(func(model.Context) (metrics.Assessments, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(model.Context) metrics.Assessments); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(metrics.Assessments) + } + } + + if rf, ok := ret.Get(1).(func(model.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewMockCapabilityMigrate creates a new instance of MockCapabilityMigrate. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockCapabilityMigrate(t interface { + mock.TestingT + Cleanup(func()) +}) *MockCapabilityMigrate { + mock := &MockCapabilityMigrate{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/model/testing/helper.go b/model/testing/helper.go index 8eb0f457..258da096 100644 --- a/model/testing/helper.go +++ b/model/testing/helper.go @@ -50,6 +50,14 @@ func (m *MockCapabilityRepairCode) RegisterGenerateSuccess(t *testing.T, filePat }) } +// RegisterGenerateSuccess registers a mock call for successful generation. +func (m *MockCapabilityMigrate) RegisterGenerateSuccess(t *testing.T, filePath string, fileContent string, assessment metrics.Assessments) *mock.Call { + return m.On("Migrate", mock.Anything).Return(assessment, nil).Run(func(args mock.Arguments) { + ctx, _ := args.Get(0).(model.Context) + require.NoError(t, os.WriteFile(filepath.Join(ctx.RepositoryPath, filePath), []byte(fileContent), 0600)) + }) +} + // RegisterGenerateError registers a mock call that errors on generation. func (m *MockCapabilityRepairCode) RegisterGenerateError(err error) *mock.Call { return m.On("RepairCode", mock.Anything).Return(nil, err) @@ -94,6 +102,20 @@ func NewMockCapabilityRepairCodeNamed(t *testing.T, id string) *MockModelCapabil } } +// MockModelCapabilityMigrate holds a mock implementing the "Model" and the "CapabilityMigrate" interface. +type MockModelCapabilityMigrate struct { + *MockModel + *MockCapabilityMigrate +} + +// NewMockCapabilityMigrateNamed returns a new named mocked model. +func NewMockCapabilityMigrateNamed(t *testing.T, id string) *MockModelCapabilityMigrate { + return &MockModelCapabilityMigrate{ + MockModel: NewMockModelNamed(t, id), + MockCapabilityMigrate: NewMockCapabilityMigrate(t), + } +} + // MockModelCapabilityTranspile holds a mock implementing the "Model" and the "CapabilityTranspile" interface. type MockModelCapabilityTranspile struct { *MockModel