Skip to content

Commit

Permalink
Merge pull request #367 from symflower/more-spring-examples
Browse files Browse the repository at this point in the history
Support Spring
  • Loading branch information
ruiAzevedo19 authored Dec 19, 2024
2 parents bef0084 + 9d574e5 commit 4ef6a1a
Show file tree
Hide file tree
Showing 32 changed files with 916 additions and 128 deletions.
7 changes: 3 additions & 4 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@
"request": "launch",
"mode": "auto",
"program": "cmd/eval-dev-quality",
"args": [
"${input:args}",
]
"args": "${input:args}",
"cwd": "${workspaceFolder}"
},
],
"inputs": [
Expand Down Expand Up @@ -54,7 +53,7 @@
"command": "memento.promptString",
"args": {
"id": "args",
"description": "Arguments? (Make sure to use absolute paths!)",
"description": "Arguments?",
"default": "",
},
},
Expand Down
39 changes: 39 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,45 @@ Each repository can contain a configuration file `repository.json` in its root d

For the evaluation of the repository only the specified tasks are executed. If no `repository.json` file exists, all tasks are executed.

Depending on the task, it can be beneficial to exclude parts of the repository from explicit evaluation. To give a concrete example: Spring controller tests can never be executed on their own but need a supporting [`Application` class](https://docs.spring.io/spring-boot/reference/testing/spring-boot-applications.html#testing.spring-boot-applications.using-main). But [such a file](testdata/java/spring-plain/src/main/java/com/example/Application.java) should never be used itself to prompt models for tests. Therefore, it can be excluded through the `repository.json` configuration:

```json
{
"tasks": ["write-tests"],
"ignore": ["src/main/java/com/example/Application.java"]
}
```

This `ignore` setting is currently only applicable for generation task `write-tests`.

It is possible to configure some model prompt parameters through `repository.json`:

```json
{
"tasks": ["write-tests"],
"prompt": {
"test-framework": "JUnit 5 for Spring Boot" // Overwrite the default test framework in the prompt.
}
}
```

This `prompt.test-framework` setting is currently only applicable for the test generation task `write-tests`.

When task results are validated, some repositories might require custom logic. For example: generating tests for a Spring Boot project requires ensuring that the tests used an actual Spring context (i.e. Spring Boot was initialized when the tests were executed). Therefore, the `repository.json` supports adding rudimentary custom validation:

```json
{
"tasks": ["write-tests"],
"validation": {
"execution": {
"stdout": "Initializing Spring" // Ensure the string "Initializing Spring" is contained in the execution output.
}
}
}
```

This `validation.execution.stdout` setting is currently only applicable for the test generation task `write-tests`.

## Tasks

### Task: Test Generation
Expand Down
3 changes: 2 additions & 1 deletion cmd/eval-dev-quality/cmd/evaluate.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (
openaiapi "github.com/symflower/eval-dev-quality/provider/openai-api"
_ "github.com/symflower/eval-dev-quality/provider/openrouter" // Register provider.
_ "github.com/symflower/eval-dev-quality/provider/symflower" // Register provider.
"github.com/symflower/eval-dev-quality/task"
"github.com/symflower/eval-dev-quality/tools"
"github.com/symflower/eval-dev-quality/util"
)
Expand Down Expand Up @@ -317,7 +318,7 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate.
command.logger.Panicf("ERROR: %s", err)
}
for _, r := range repositories {
config, err := evaltask.LoadRepositoryConfiguration(filepath.Join(command.TestdataPath, r))
config, err := task.LoadRepositoryConfiguration(filepath.Join(command.TestdataPath, r), evaltask.AllIdentifiers)
if err != nil {
command.logger.Panicf("ERROR: %s", err)
}
Expand Down
11 changes: 4 additions & 7 deletions evaluate/evaluate.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package evaluate
import (
"os"
"path/filepath"
"strings"

"github.com/symflower/eval-dev-quality/evaluate/metrics"
"github.com/symflower/eval-dev-quality/evaluate/report"
Expand All @@ -13,6 +12,7 @@ import (
evalmodel "github.com/symflower/eval-dev-quality/model"
"github.com/symflower/eval-dev-quality/provider"
evaltask "github.com/symflower/eval-dev-quality/task"
"github.com/symflower/eval-dev-quality/util"
)

// Context holds an evaluation context.
Expand Down Expand Up @@ -134,7 +134,7 @@ func Evaluate(ctx *Context) (assessments *report.AssessmentStore) {
r.SetQueryAttempts(ctx.QueryAttempts)
}

for _, taskIdentifier := range temporaryRepository.SupportedTasks() {
for _, taskIdentifier := range temporaryRepository.Configuration().Tasks {
task, err := evaluatetask.ForIdentifier(taskIdentifier)
if err != nil {
logger.Fatal(err)
Expand Down Expand Up @@ -191,10 +191,7 @@ func Evaluate(ctx *Context) (assessments *report.AssessmentStore) {
}
}

repositoriesLookup := make(map[string]bool, len(ctx.RepositoryPaths))
for _, repositoryPath := range ctx.RepositoryPaths {
repositoriesLookup[repositoryPath] = true
}
repositoriesLookup := util.Set(ctx.RepositoryPaths)

// Evaluating models and languages.
ctx.Log.Printf("Evaluating models and languages")
Expand All @@ -207,7 +204,7 @@ func Evaluate(ctx *Context) (assessments *report.AssessmentStore) {
}
for _, repositoryPath := range relativeRepositoryPaths {
// Do not include "plain" repositories in this step of the evaluation, because they have been checked with the common check before.
if !repositoriesLookup[repositoryPath] || strings.HasSuffix(repositoryPath, RepositoryPlainName) {
if !repositoriesLookup[repositoryPath] || filepath.Base(repositoryPath) == RepositoryPlainName {
continue
}

Expand Down
135 changes: 135 additions & 0 deletions evaluate/evaluate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1393,4 +1393,139 @@ func TestEvaluate(t *testing.T) {
},
})
}
{
// Setup provider and model mocking.
languageGolang := &golang.Language{}
mockedModelID := "testing-provider/testing-model"
mockedModel := modeltesting.NewMockCapabilityWriteTestsNamed(t, mockedModelID)

repositoryPathPlain := filepath.Join("golang", "plain")
repositoryPathSomePlain := filepath.Join("golang", "some-plain")
temporaryTestdataPath := t.TempDir()
require.NoError(t, osutil.CopyTree(filepath.Join("..", "testdata", repositoryPathPlain), filepath.Join(temporaryTestdataPath, repositoryPathSomePlain)))
require.NoError(t, osutil.CopyTree(filepath.Join("..", "testdata", repositoryPathPlain), filepath.Join(temporaryTestdataPath, repositoryPathPlain)))

validate(t, &testCase{
Name: "Repository with -plain suffix",

Before: func(t *testing.T, logger *log.Logger, resultPath string) {
mockedModel.RegisterGenerateSuccess(t, testFiles["plain"].Path, testFiles["plain"].Content, metricstesting.AssessmentsWithProcessingTime)
},

Context: &Context{
Languages: []language.Language{
languageGolang,
},

Models: []evalmodel.Model{
mockedModel,
},

RepositoryPaths: []string{
repositoryPathSomePlain,
},

TestdataPath: temporaryTestdataPath,

Runs: 1,
},

ExpectedAssessments: []*metricstesting.AssessmentTuple{
&metricstesting.AssessmentTuple{
Model: mockedModel,
Language: languageGolang,
RepositoryPath: repositoryPathPlain,
Task: evaluatetask.IdentifierWriteTests,
Assessment: map[metrics.AssessmentKey]uint64{
metrics.AssessmentKeyFilesExecuted: 1,
metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
metrics.AssessmentKeyResponseNoError: 1,
},
},
&metricstesting.AssessmentTuple{
Model: mockedModel,
Language: languageGolang,
RepositoryPath: repositoryPathPlain,
Task: evaluatetask.IdentifierWriteTestsSymflowerFix,
Assessment: map[metrics.AssessmentKey]uint64{
metrics.AssessmentKeyFilesExecuted: 1,
metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
metrics.AssessmentKeyResponseNoError: 1,
},
},
&metricstesting.AssessmentTuple{
Model: mockedModel,
Language: languageGolang,
RepositoryPath: repositoryPathPlain,
Task: evaluatetask.IdentifierWriteTestsSymflowerTemplate,
Assessment: map[metrics.AssessmentKey]uint64{
metrics.AssessmentKeyFilesExecuted: 1,
metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
metrics.AssessmentKeyResponseNoError: 1,
},
},
&metricstesting.AssessmentTuple{
Model: mockedModel,
Language: languageGolang,
RepositoryPath: repositoryPathPlain,
Task: evaluatetask.IdentifierWriteTestsSymflowerTemplateSymflowerFix,
Assessment: map[metrics.AssessmentKey]uint64{
metrics.AssessmentKeyFilesExecuted: 1,
metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
metrics.AssessmentKeyResponseNoError: 1,
},
},
&metricstesting.AssessmentTuple{
Model: mockedModel,
Language: languageGolang,
RepositoryPath: repositoryPathSomePlain,
Task: evaluatetask.IdentifierWriteTests,
Assessment: map[metrics.AssessmentKey]uint64{
metrics.AssessmentKeyFilesExecuted: 1,
metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
metrics.AssessmentKeyResponseNoError: 1,
},
},
&metricstesting.AssessmentTuple{
Model: mockedModel,
Language: languageGolang,
RepositoryPath: repositoryPathSomePlain,
Task: evaluatetask.IdentifierWriteTestsSymflowerFix,
Assessment: map[metrics.AssessmentKey]uint64{
metrics.AssessmentKeyFilesExecuted: 1,
metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
metrics.AssessmentKeyResponseNoError: 1,
},
},
&metricstesting.AssessmentTuple{
Model: mockedModel,
Language: languageGolang,
RepositoryPath: repositoryPathSomePlain,
Task: evaluatetask.IdentifierWriteTestsSymflowerTemplate,
Assessment: map[metrics.AssessmentKey]uint64{
metrics.AssessmentKeyFilesExecuted: 1,
metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
metrics.AssessmentKeyResponseNoError: 1,
},
},
&metricstesting.AssessmentTuple{
Model: mockedModel,
Language: languageGolang,
RepositoryPath: repositoryPathSomePlain,
Task: evaluatetask.IdentifierWriteTestsSymflowerTemplateSymflowerFix,
Assessment: map[metrics.AssessmentKey]uint64{
metrics.AssessmentKeyFilesExecuted: 1,
metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
metrics.AssessmentKeyResponseNoError: 1,
},
},
},
ExpectedResultFiles: map[string]func(t *testing.T, filePath string, data string){
"evaluation.log": nil,
filepath.Join(string(evaluatetask.IdentifierWriteTests), log.CleanModelNameForFileSystem(mockedModelID), "golang", "golang", "plain", "evaluation.log"): nil,
filepath.Join(string(evaluatetask.IdentifierWriteTests), log.CleanModelNameForFileSystem(mockedModelID), "golang", "golang", "some-plain", "evaluation.log"): nil,
"evaluation.csv": nil,
},
})
}
}
68 changes: 9 additions & 59 deletions evaluate/task/repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package task

import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
Expand All @@ -18,55 +17,9 @@ import (
"github.com/symflower/eval-dev-quality/util"
)

// RepositoryConfiguration holds the configuration of a repository.
type RepositoryConfiguration struct {
Tasks []task.Identifier
}

// LoadRepositoryConfiguration loads a repository configuration from the given path.
func LoadRepositoryConfiguration(path string) (config *RepositoryConfiguration, err error) {
if osutil.FileExists(path) != nil { // If we don't get a valid file, assume it is a repository directory and target the default configuration file name.
path = filepath.Join(path, RepositoryConfigurationFileName)
}

data, err := os.ReadFile(path)
if errors.Is(err, os.ErrNotExist) {
// Set default configuration.
return &RepositoryConfiguration{
Tasks: AllIdentifiers,
}, nil
} else if err != nil {
return nil, pkgerrors.Wrap(err, path)
}

config = &RepositoryConfiguration{}
if err := json.Unmarshal(data, &config); err != nil {
return nil, pkgerrors.Wrap(err, path)
} else if err := config.validate(); err != nil {
return nil, err
}

return config, nil
}

// validate validates the configuration.
func (rc *RepositoryConfiguration) validate() (err error) {
if len(rc.Tasks) == 0 {
return pkgerrors.Errorf("empty list of tasks in configuration")
}

for _, taskIdentifier := range rc.Tasks {
if !LookupIdentifier[taskIdentifier] {
return pkgerrors.Errorf("task identifier %q unknown", taskIdentifier)
}
}

return nil
}

// Repository holds data about a repository.
type Repository struct {
RepositoryConfiguration
task.RepositoryConfiguration

// name holds the name of the repository.
name string
Expand All @@ -76,14 +29,11 @@ type Repository struct {

var _ task.Repository = (*Repository)(nil)

// RepositoryConfigurationFileName holds the file name for a repository configuration.
const RepositoryConfigurationFileName = "repository.json"

// loadConfiguration loads the configuration from the dedicated configuration file.
func (r *Repository) loadConfiguration() (err error) {
configurationFilePath := filepath.Join(r.dataPath, RepositoryConfigurationFileName)
configurationFilePath := filepath.Join(r.dataPath, task.RepositoryConfigurationFileName)

configuration, err := LoadRepositoryConfiguration(configurationFilePath)
configuration, err := task.LoadRepositoryConfiguration(configurationFilePath, AllIdentifiers)
if err != nil {
return err
}
Expand All @@ -103,14 +53,9 @@ func (r *Repository) DataPath() (dataPath string) {
return r.dataPath
}

// SupportedTasks returns the list of task identifiers the repository supports.
func (r *Repository) SupportedTasks() (tasks []task.Identifier) {
return r.Tasks
}

// Validate checks it the repository is well-formed.
func (r *Repository) Validate(logger *log.Logger, language language.Language) (err error) {
for _, taskIdentifier := range r.SupportedTasks() {
for _, taskIdentifier := range r.RepositoryConfiguration.Tasks {
switch taskIdentifier {
case IdentifierCodeRepair:
return validateCodeRepairRepository(logger, r.DataPath(), language)
Expand Down Expand Up @@ -163,6 +108,11 @@ func (r *Repository) Reset(logger *log.Logger) (err error) {
return nil
}

// Configuration returns the configuration of a repository.
func (r *Repository) Configuration() *task.RepositoryConfiguration {
return &r.RepositoryConfiguration
}

// TemporaryRepository creates a temporary repository and initializes a git repo in it.
func TemporaryRepository(logger *log.Logger, testDataPath string, repositoryPathRelative string) (repository *Repository, cleanup func(), err error) {
repositoryPathAbsolute := filepath.Join(testDataPath, repositoryPathRelative)
Expand Down
1 change: 1 addition & 0 deletions evaluate/task/symflower.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ func symflowerTemplate(logger *log.Logger, repositoryPath string, language langu
"--language", language.ID(),
"--workspace", repositoryPath,
"--test-style", "basic",
"--code-disable-fetch-dependencies",
filePath,
},

Expand Down
Loading

0 comments on commit 4ef6a1a

Please sign in to comment.