Skip to content

Commit

Permalink
refactor, Separate LLM prompt template for "write test" task so we ca…
Browse files Browse the repository at this point in the history
…n add a template

Part of #350
  • Loading branch information
bauersimon committed Oct 3, 2024
1 parent 3c368f8 commit 6c30b2b
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 28 deletions.
30 changes: 19 additions & 11 deletions model/llm/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func (m *Model) MetaInformation() (metaInformation *model.MetaInformation) {
return m.metaInformation
}

// llmSourceFilePromptContext is the context for template for generating an LLM test generation prompt.
// llmSourceFilePromptContext is the base template context for an LLM generation prompt.
type llmSourceFilePromptContext struct {
// Language holds the programming language name.
Language language.Language
Expand All @@ -76,8 +76,14 @@ type llmSourceFilePromptContext struct {
ImportPath string
}

// llmGenerateTestForFilePromptTemplate is the template for generating an LLM test generation prompt.
var llmGenerateTestForFilePromptTemplate = template.Must(template.New("model-llm-generate-test-for-file-prompt").Parse(bytesutil.StringTrimIndentations(`
// llmWriteTestSourceFilePromptContext is the template context for a write test LLM prompt.
type llmWriteTestSourceFilePromptContext struct {
// llmSourceFilePromptContext holds the context for a source file prompt.
llmSourceFilePromptContext
}

// llmWriteTestForFilePromptTemplate is the template for generating an LLM test generation prompt.
var llmWriteTestForFilePromptTemplate = template.Must(template.New("model-llm-write-test-for-file-prompt").Parse(bytesutil.StringTrimIndentations(`
Given the following {{ .Language.Name }} code file "{{ .FilePath }}" with package "{{ .ImportPath }}", provide a test file for this code{{ with $testFramework := .Language.TestFramework }} with {{ $testFramework }} as a test framework{{ end }}.
The tests should produce 100 percent code coverage and must compile.
The response must contain only the test code in a fenced code block and nothing else.
Expand All @@ -87,14 +93,14 @@ var llmGenerateTestForFilePromptTemplate = template.Must(template.New("model-llm
` + "```" + `
`)))

// llmGenerateTestForFilePrompt returns the prompt for generating an LLM test generation.
func llmGenerateTestForFilePrompt(data *llmSourceFilePromptContext) (message string, err error) {
// llmWriteTestForFilePrompt returns the prompt for generating an LLM test generation.
func llmWriteTestForFilePrompt(data *llmWriteTestSourceFilePromptContext) (message string, err error) {
// Use Linux paths even when running the evaluation on Windows to ensure consistency in prompting.
data.FilePath = filepath.ToSlash(data.FilePath)
data.Code = strings.TrimSpace(data.Code)

var b strings.Builder
if err := llmGenerateTestForFilePromptTemplate.Execute(&b, data); err != nil {
if err := llmWriteTestForFilePromptTemplate.Execute(&b, data); err != nil {
return "", pkgerrors.WithStack(err)
}

Expand Down Expand Up @@ -198,12 +204,14 @@ func (m *Model) WriteTests(ctx model.Context) (assessment metrics.Assessments, e

importPath := ctx.Language.ImportPath(ctx.RepositoryPath, ctx.FilePath)

request, err := llmGenerateTestForFilePrompt(&llmSourceFilePromptContext{
Language: ctx.Language,
request, err := llmWriteTestForFilePrompt(&llmWriteTestSourceFilePromptContext{
llmSourceFilePromptContext: llmSourceFilePromptContext{
Language: ctx.Language,

Code: fileContent,
FilePath: ctx.FilePath,
ImportPath: importPath,
Code: fileContent,
FilePath: ctx.FilePath,
ImportPath: importPath,
},
})
if err != nil {
return nil, err
Expand Down
38 changes: 21 additions & 17 deletions model/llm/llm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,14 @@ func TestModelGenerateTestsForFile(t *testing.T) {
func main() {}
`
sourceFilePath := "simple.go"
promptMessage, err := llmGenerateTestForFilePrompt(&llmSourceFilePromptContext{
Language: &golang.Language{},
promptMessage, err := llmWriteTestForFilePrompt(&llmWriteTestSourceFilePromptContext{
llmSourceFilePromptContext: llmSourceFilePromptContext{
Language: &golang.Language{},

Code: bytesutil.StringTrimIndentations(sourceFileContent),
FilePath: sourceFilePath,
ImportPath: "native",
Code: bytesutil.StringTrimIndentations(sourceFileContent),
FilePath: sourceFilePath,
ImportPath: "native",
},
})
require.NoError(t, err)
validate(t, &testCase{
Expand Down Expand Up @@ -291,14 +293,14 @@ func TestLLMGenerateTestForFilePrompt(t *testing.T) {
type testCase struct {
Name string

Data *llmSourceFilePromptContext
Data *llmWriteTestSourceFilePromptContext

ExpectedMessage string
}

validate := func(t *testing.T, tc *testCase) {
t.Run(tc.Name, func(t *testing.T) {
actualMessage, actualErr := llmGenerateTestForFilePrompt(tc.Data)
actualMessage, actualErr := llmWriteTestForFilePrompt(tc.Data)
require.NoError(t, actualErr)

assert.Equal(t, tc.ExpectedMessage, actualMessage)
Expand All @@ -308,18 +310,20 @@ func TestLLMGenerateTestForFilePrompt(t *testing.T) {
validate(t, &testCase{
Name: "Plain",

Data: &llmSourceFilePromptContext{
Language: &golang.Language{},
Data: &llmWriteTestSourceFilePromptContext{
llmSourceFilePromptContext: llmSourceFilePromptContext{
Language: &golang.Language{},

Code: bytesutil.StringTrimIndentations(`
package increment
Code: bytesutil.StringTrimIndentations(`
package increment
func increment(i int) int
return i + 1
}
`),
FilePath: filepath.Join("path", "to", "increment.go"),
ImportPath: "increment",
func increment(i int) int
return i + 1
}
`),
FilePath: filepath.Join("path", "to", "increment.go"),
ImportPath: "increment",
},
},

ExpectedMessage: bytesutil.StringTrimIndentations(`
Expand Down

0 comments on commit 6c30b2b

Please sign in to comment.