Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Categorization #48

Merged
merged 7 commits into from
Apr 19, 2024
Merged
27 changes: 14 additions & 13 deletions cmd/eval-dev-quality/cmd/evaluate.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,21 @@ func (command *Evaluate) Execute(args []string) (err error) {
defer logClose()

// Gather languages.
languages := map[string]language.Language{}
if len(command.Languages) == 0 {
command.Languages = maps.Keys(language.Languages)
languages = language.Languages
} else {
for _, languageID := range command.Languages {
if _, ok := language.Languages[languageID]; !ok {
l, ok := language.Languages[languageID]
if !ok {
ls := maps.Keys(language.Languages)
sort.Strings(ls)

log.Fatalf("ERROR: language %s does not exist. Valid languages are: %s", languageID, strings.Join(ls, ", "))
}

languages[languageID] = l
}
}
sort.Strings(command.Languages)
Expand Down Expand Up @@ -113,21 +118,17 @@ func (command *Evaluate) Execute(args []string) (err error) {

// Check that models and languages can be evaluated by executing the "plain" repositories.
log.Printf("Checking that models and languages can be used for evaluation")
assessmentsPerModel := map[string]metrics.Assessments{}
// Ensure we report metrics for every model even if they are excluded.
assessments := report.NewAssessmentPerModelPerLanguagePerRepository(maps.Values(models), maps.Values(languages), append(command.Repositories, repositoryPlainName))
problemsPerModel := map[string][]error{}
{
// Ensure we report metrics for every model even if they are excluded.
for _, modelID := range command.Models {
assessmentsPerModel[modelID] = metrics.NewAssessments()
}

for _, languageID := range command.Languages {
for _, modelID := range command.Models {
model := models[modelID]
language := language.Languages[languageID]
language := languages[languageID]

assessment, ps, err := evaluate.EvaluateRepository(command.ResultPath, model, language, command.TestdataPath, filepath.Join(language.ID(), repositoryPlainName))
assessmentsPerModel[modelID].Add(assessment)
assessments[model][language][repositoryPlainName].Add(assessment)
if err != nil {
ps = append(ps, err)
}
Expand Down Expand Up @@ -164,10 +165,10 @@ func (command *Evaluate) Execute(args []string) (err error) {
}

model := models[modelID]
language := language.Languages[languageID]
language := languages[languageID]

assessment, ps, err := evaluate.EvaluateRepository(command.ResultPath, model, language, command.TestdataPath, filepath.Join(languageID, repository.Name()))
assessmentsPerModel[model.ID()].Add(assessment)
assessments[model][language][repository.Name()].Add(assessment)
problemsPerModel[modelID] = append(problemsPerModel[modelID], ps...)
if err != nil {
log.Printf("ERROR: Model %q encountered a hard error for language %q, repository %q: %+v", modelID, languageID, repository.Name(), err)
Expand All @@ -176,13 +177,13 @@ func (command *Evaluate) Execute(args []string) (err error) {
}
}

_ = metrics.WalkByScore(assessmentsPerModel, func(model string, assessment metrics.Assessments, score uint) error {
_ = metrics.WalkByScore(assessments.Collapse(), func(model string, assessment metrics.Assessments, score uint) error {
log.Printf("Evaluation score for %q: %s", model, assessment)

return nil
})

csv, err := report.FormatCSV(assessmentsPerModel)
csv, err := report.FormatCSV(assessments)
if err != nil {
log.Fatalf("ERROR: could not create result summary: %s", err)
}
Expand Down
76 changes: 76 additions & 0 deletions evaluate/report/collection.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package report

import (
"cmp"
"slices"
"sort"

"golang.org/x/exp/maps"

"github.com/symflower/eval-dev-quality/evaluate/metrics"
"github.com/symflower/eval-dev-quality/language"
"github.com/symflower/eval-dev-quality/model"
)

// AssessmentPerModelPerLanguagePerRepository holds a collection of assessments per model per language and per repository.
type AssessmentPerModelPerLanguagePerRepository map[model.Model]map[language.Language]map[string]metrics.Assessments

// NewAssessmentPerModelPerLanguagePerRepository returns a new AssessmentPerModelPerLanguagePerRepository initialized with an empty assessment for each combination.
func NewAssessmentPerModelPerLanguagePerRepository(models []model.Model, languages []language.Language, repositories []string) AssessmentPerModelPerLanguagePerRepository {
a := AssessmentPerModelPerLanguagePerRepository{}
for _, m := range models {
if _, ok := a[m]; !ok {
a[m] = map[language.Language]map[string]metrics.Assessments{}
}
for _, l := range languages {
if _, ok := a[m][l]; !ok {
a[m][l] = map[string]metrics.Assessments{}
}
for _, r := range repositories {
a[m][l][r] = metrics.NewAssessments()
}
}
}

return a
}

// Walk walks over all entries.
func (a AssessmentPerModelPerLanguagePerRepository) Walk(function func(m model.Model, l language.Language, r string, a metrics.Assessments) error) error {
models := maps.Keys(a)
slices.SortStableFunc(models, func(a, b model.Model) int {
return cmp.Compare(a.ID(), b.ID())
})
for _, m := range models {
languages := maps.Keys(a[m])
slices.SortStableFunc(languages, func(a, b language.Language) int {
return cmp.Compare(a.ID(), b.ID())
})
for _, l := range languages {
repositories := maps.Keys(a[m][l])
sort.Strings(repositories)
for _, r := range repositories {
if err := function(m, l, r, a[m][l][r]); err != nil {
return err
}
}
}
}

return nil
}

// Collapse returns all assessments aggregated per model ID.
func (a AssessmentPerModelPerLanguagePerRepository) Collapse() map[string]metrics.Assessments {
perModel := make(map[string]metrics.Assessments, len(a))
for _, m := range maps.Keys(a) {
perModel[m.ID()] = metrics.NewAssessments()
}
_ = a.Walk(func(m model.Model, l language.Language, r string, a metrics.Assessments) error {
perModel[m.ID()].Add(a)

return nil
})

return perModel
}
132 changes: 132 additions & 0 deletions evaluate/report/collection_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
package report

import (
"testing"

"github.com/stretchr/testify/assert"

"github.com/symflower/eval-dev-quality/evaluate/metrics"
metricstesting "github.com/symflower/eval-dev-quality/evaluate/metrics/testing"
"github.com/symflower/eval-dev-quality/language"
languagetesting "github.com/symflower/eval-dev-quality/language/testing"
"github.com/symflower/eval-dev-quality/model"
modeltesting "github.com/symflower/eval-dev-quality/model/testing"
)

func TestAssessmentPerModelPerLanguagePerRepositoryWalk(t *testing.T) {
type testCase struct {
Name string

Assessments AssessmentPerModelPerLanguagePerRepository

ExpectedOrder []metrics.Assessments
}

validate := func(t *testing.T, tc *testCase) {
t.Run(tc.Name, func(t *testing.T) {
actualOrder := []metrics.Assessments{}
assert.NoError(t, tc.Assessments.Walk(func(m model.Model, l language.Language, r string, a metrics.Assessments) error {
actualOrder = append(actualOrder, a)
metricstesting.AssertAssessmentsEqual(t, tc.Assessments[m][l][r], a)

return nil
}))

if assert.Equal(t, len(tc.ExpectedOrder), len(actualOrder)) {
for i := range tc.ExpectedOrder {
metricstesting.AssertAssessmentsEqual(t, tc.ExpectedOrder[i], actualOrder[i])
}
}
})
}

validate(t, &testCase{
Name: "Single Group",

Assessments: AssessmentPerModelPerLanguagePerRepository{
modeltesting.NewMockModelNamed("some-model"): {
languagetesting.NewMockLanguageNamed("some-language"): {
"some-repository": metrics.Assessments{
metrics.AssessmentKeyResponseNoExcess: 1,
},
},
},
},

ExpectedOrder: []metrics.Assessments{
metrics.Assessments{
metrics.AssessmentKeyResponseNoExcess: 1,
},
},
})

validate(t, &testCase{
Name: "Multiple Groups",

Assessments: AssessmentPerModelPerLanguagePerRepository{
modeltesting.NewMockModelNamed("some-model-a"): {
languagetesting.NewMockLanguageNamed("some-language-a"): {
"some-repository-a": metrics.Assessments{
metrics.AssessmentKeyResponseNoExcess: 1,
},
"some-repository-b": metrics.Assessments{
metrics.AssessmentKeyResponseNoExcess: 2,
},
},
languagetesting.NewMockLanguageNamed("some-language-b"): {
"some-repository-a": metrics.Assessments{
metrics.AssessmentKeyResponseNoExcess: 3,
},
"some-repository-b": metrics.Assessments{
metrics.AssessmentKeyResponseNoExcess: 4,
},
},
},
modeltesting.NewMockModelNamed("some-model-b"): {
languagetesting.NewMockLanguageNamed("some-language-a"): {
"some-repository-a": metrics.Assessments{
metrics.AssessmentKeyResponseNoExcess: 5,
},
"some-repository-b": metrics.Assessments{
metrics.AssessmentKeyResponseNoExcess: 6,
},
},
languagetesting.NewMockLanguageNamed("some-language-b"): {
"some-repository-a": metrics.Assessments{
metrics.AssessmentKeyResponseNoExcess: 7,
},
"some-repository-b": metrics.Assessments{
metrics.AssessmentKeyResponseNoExcess: 8,
},
},
},
},

ExpectedOrder: []metrics.Assessments{
metrics.Assessments{
metrics.AssessmentKeyResponseNoExcess: 1,
},
metrics.Assessments{
metrics.AssessmentKeyResponseNoExcess: 2,
},
metrics.Assessments{
metrics.AssessmentKeyResponseNoExcess: 3,
},
metrics.Assessments{
metrics.AssessmentKeyResponseNoExcess: 4,
},
metrics.Assessments{
metrics.AssessmentKeyResponseNoExcess: 5,
},
metrics.Assessments{
metrics.AssessmentKeyResponseNoExcess: 6,
},
metrics.Assessments{
metrics.AssessmentKeyResponseNoExcess: 7,
},
metrics.Assessments{
metrics.AssessmentKeyResponseNoExcess: 8,
},
},
})
}
13 changes: 8 additions & 5 deletions evaluate/report/csv.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,29 @@ import (
pkgerrors "github.com/pkg/errors"

"github.com/symflower/eval-dev-quality/evaluate/metrics"
"github.com/symflower/eval-dev-quality/language"
"github.com/symflower/eval-dev-quality/model"
)

// csvHeader returns the header description as a CSV row.
func csvHeader() []string {
return append([]string{"model", "score"}, metrics.AllAssessmentKeysStrings...)
return append([]string{"model", "language", "repository", "score"}, metrics.AllAssessmentKeysStrings...)
}

// FormatCSV formats the given assessment metrics as CSV.
func FormatCSV(assessmentsPerModel map[string]metrics.Assessments) (string, error) {
func FormatCSV(assessments AssessmentPerModelPerLanguagePerRepository) (string, error) {
var out strings.Builder
csv := csv.NewWriter(&out)

if err := csv.Write(csvHeader()); err != nil {
return "", pkgerrors.WithStack(err)
}

if err := metrics.WalkByScore(assessmentsPerModel, func(model string, assessment metrics.Assessments, score uint) error {
row := assessment.StringCSV()
if err := assessments.Walk(func(m model.Model, l language.Language, r string, a metrics.Assessments) error {
row := a.StringCSV()
score := a.Score()

if err := csv.Write(append([]string{model, strconv.FormatUint(uint64(score), 10)}, row...)); err != nil {
if err := csv.Write(append([]string{m.ID(), l.ID(), r, strconv.FormatUint(uint64(score), 10)}, row...)); err != nil {
return pkgerrors.WithStack(err)
}

Expand Down
Loading