diff --git a/model/llm/llm.go b/model/llm/llm.go index 4bd52fee..7cca4704 100644 --- a/model/llm/llm.go +++ b/model/llm/llm.go @@ -63,7 +63,7 @@ func (m *Model) MetaInformation() (metaInformation *model.MetaInformation) { return m.metaInformation } -// llmSourceFilePromptContext is the context for template for generating an LLM test generation prompt. +// llmSourceFilePromptContext is the base template context for an LLM generation prompt. type llmSourceFilePromptContext struct { // Language holds the programming language name. Language language.Language @@ -76,8 +76,14 @@ type llmSourceFilePromptContext struct { ImportPath string } -// llmGenerateTestForFilePromptTemplate is the template for generating an LLM test generation prompt. -var llmGenerateTestForFilePromptTemplate = template.Must(template.New("model-llm-generate-test-for-file-prompt").Parse(bytesutil.StringTrimIndentations(` +// llmWriteTestSourceFilePromptContext is the template context for a write test LLM prompt. +type llmWriteTestSourceFilePromptContext struct { + // llmSourceFilePromptContext holds the context for a source file prompt. + llmSourceFilePromptContext +} + +// llmWriteTestForFilePromptTemplate is the template for generating an LLM test generation prompt. +var llmWriteTestForFilePromptTemplate = template.Must(template.New("model-llm-write-test-for-file-prompt").Parse(bytesutil.StringTrimIndentations(` Given the following {{ .Language.Name }} code file "{{ .FilePath }}" with package "{{ .ImportPath }}", provide a test file for this code{{ with $testFramework := .Language.TestFramework }} with {{ $testFramework }} as a test framework{{ end }}. The tests should produce 100 percent code coverage and must compile. The response must contain only the test code in a fenced code block and nothing else. @@ -87,14 +93,14 @@ var llmGenerateTestForFilePromptTemplate = template.Must(template.New("model-llm ` + "```" + ` `))) -// llmGenerateTestForFilePrompt returns the prompt for generating an LLM test generation. -func llmGenerateTestForFilePrompt(data *llmSourceFilePromptContext) (message string, err error) { +// llmWriteTestForFilePrompt returns the prompt for generating an LLM test generation. +func llmWriteTestForFilePrompt(data *llmWriteTestSourceFilePromptContext) (message string, err error) { // Use Linux paths even when running the evaluation on Windows to ensure consistency in prompting. data.FilePath = filepath.ToSlash(data.FilePath) data.Code = strings.TrimSpace(data.Code) var b strings.Builder - if err := llmGenerateTestForFilePromptTemplate.Execute(&b, data); err != nil { + if err := llmWriteTestForFilePromptTemplate.Execute(&b, data); err != nil { return "", pkgerrors.WithStack(err) } @@ -198,12 +204,14 @@ func (m *Model) WriteTests(ctx model.Context) (assessment metrics.Assessments, e importPath := ctx.Language.ImportPath(ctx.RepositoryPath, ctx.FilePath) - request, err := llmGenerateTestForFilePrompt(&llmSourceFilePromptContext{ - Language: ctx.Language, + request, err := llmWriteTestForFilePrompt(&llmWriteTestSourceFilePromptContext{ + llmSourceFilePromptContext: llmSourceFilePromptContext{ + Language: ctx.Language, - Code: fileContent, - FilePath: ctx.FilePath, - ImportPath: importPath, + Code: fileContent, + FilePath: ctx.FilePath, + ImportPath: importPath, + }, }) if err != nil { return nil, err diff --git a/model/llm/llm_test.go b/model/llm/llm_test.go index 6963282b..c2f9c808 100644 --- a/model/llm/llm_test.go +++ b/model/llm/llm_test.go @@ -84,12 +84,14 @@ func TestModelGenerateTestsForFile(t *testing.T) { func main() {} ` sourceFilePath := "simple.go" - promptMessage, err := llmGenerateTestForFilePrompt(&llmSourceFilePromptContext{ - Language: &golang.Language{}, + promptMessage, err := llmWriteTestForFilePrompt(&llmWriteTestSourceFilePromptContext{ + llmSourceFilePromptContext: llmSourceFilePromptContext{ + Language: &golang.Language{}, - Code: bytesutil.StringTrimIndentations(sourceFileContent), - FilePath: sourceFilePath, - ImportPath: "native", + Code: bytesutil.StringTrimIndentations(sourceFileContent), + FilePath: sourceFilePath, + ImportPath: "native", + }, }) require.NoError(t, err) validate(t, &testCase{ @@ -291,14 +293,14 @@ func TestLLMGenerateTestForFilePrompt(t *testing.T) { type testCase struct { Name string - Data *llmSourceFilePromptContext + Data *llmWriteTestSourceFilePromptContext ExpectedMessage string } validate := func(t *testing.T, tc *testCase) { t.Run(tc.Name, func(t *testing.T) { - actualMessage, actualErr := llmGenerateTestForFilePrompt(tc.Data) + actualMessage, actualErr := llmWriteTestForFilePrompt(tc.Data) require.NoError(t, actualErr) assert.Equal(t, tc.ExpectedMessage, actualMessage) @@ -308,18 +310,20 @@ func TestLLMGenerateTestForFilePrompt(t *testing.T) { validate(t, &testCase{ Name: "Plain", - Data: &llmSourceFilePromptContext{ - Language: &golang.Language{}, + Data: &llmWriteTestSourceFilePromptContext{ + llmSourceFilePromptContext: llmSourceFilePromptContext{ + Language: &golang.Language{}, - Code: bytesutil.StringTrimIndentations(` - package increment + Code: bytesutil.StringTrimIndentations(` + package increment - func increment(i int) int - return i + 1 - } - `), - FilePath: filepath.Join("path", "to", "increment.go"), - ImportPath: "increment", + func increment(i int) int + return i + 1 + } + `), + FilePath: filepath.Join("path", "to", "increment.go"), + ImportPath: "increment", + }, }, ExpectedMessage: bytesutil.StringTrimIndentations(`