Skip to content

Commit

Permalink
Pull ollama models during evaluation
Browse files Browse the repository at this point in the history
Closes #283
  • Loading branch information
Munsio committed Jul 25, 2024
1 parent 0a52c91 commit 9bec019
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 23 deletions.
35 changes: 28 additions & 7 deletions cmd/eval-dev-quality/cmd/evaluate.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func (command *Evaluate) SetArguments(args []string) {
}

// Initialize initializes the command according to the arguments.
func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate.Context) {
func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate.Context, cleanup func()) {
evaluationContext = &evaluate.Context{}

// Check and validate common options.
Expand Down Expand Up @@ -313,6 +313,7 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate.
}

// Gather models.
serviceShutdown := []func() (err error){}
{
// Check which providers are needed for the evaluation.
providersSelected := map[string]provider.Provider{}
Expand All @@ -321,6 +322,11 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate.
} else {
for _, model := range command.Models {
p := strings.SplitN(model, provider.ProviderModelSeparator, 2)[0]

if _, ok := providersSelected[p]; ok {
continue
}

if provider, ok := provider.Providers[p]; !ok {
command.logger.Panicf("Provider %q does not exist", p)
} else {
Expand Down Expand Up @@ -354,11 +360,19 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate.
if err != nil {
command.logger.Panicf("ERROR: could not start services for provider %q: %s", p, err)
}
defer func() {
if err := shutdown(); err != nil {
command.logger.Panicf("ERROR: could not shutdown services of provider %q: %s", p, err)
serviceShutdown = append(serviceShutdown, shutdown)
}

// Check if a provider has the ability to pull models and do so if necessary.
if puller, ok := p.(provider.Puller); ok {
command.logger.Printf("Pulling available models for provider %q", p.ID())
for _, modelID := range command.Models {
if strings.HasPrefix(modelID, p.ID()) {
if err := puller.Pull(command.logger, modelID); err != nil {
command.logger.Panicf("ERROR: could not pull model %q: %s", modelID, err)
}
}
}()
}
}

ms, err := p.Models()
Expand Down Expand Up @@ -394,14 +408,21 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate.
}
}

return evaluationContext
return evaluationContext, func() {
for _, shutdown := range serviceShutdown {
if err := shutdown(); err != nil {
command.logger.Error(err.Error())
}
}
}
}

// Execute executes the command.
func (command *Evaluate) Execute(args []string) (err error) {
command.timestamp = time.Now()

evaluationContext := command.Initialize(args)
evaluationContext, cleanup := command.Initialize(args)
defer cleanup()
if evaluationContext == nil {
command.logger.Panic("ERROR: empty evaluation context")
}
Expand Down
21 changes: 5 additions & 16 deletions cmd/eval-dev-quality/cmd/evaluate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -491,23 +491,9 @@ func TestEvaluateExecute(t *testing.T) {
}

{
var shutdown func() (err error)
defer func() { // Defer the shutdown in case there is a panic.
if shutdown != nil {
require.NoError(t, shutdown())
}
}()
validate(t, &testCase{
Name: "Pulled Model",

Before: func(t *testing.T, logger *log.Logger, resultPath string) {
var err error
shutdown, err = tools.OllamaStart(logger, tools.OllamaPath, tools.OllamaURL)
require.NoError(t, err)

require.NoError(t, tools.OllamaPull(logger, tools.OllamaPath, tools.OllamaURL, providertesting.OllamaTestModel))
},

Arguments: []string{
"--language", "golang",
"--model", "ollama/" + providertesting.OllamaTestModel,
Expand Down Expand Up @@ -1050,17 +1036,20 @@ func TestEvaluateInitialize(t *testing.T) {
tc.Command.ResultPath = strings.ReplaceAll(tc.Command.ResultPath, "$TEMP_PATH", temporaryDirectory)

var actualEvaluationContext *evaluate.Context
var cleanup func()

if tc.ValidatePanic != "" {
assert.PanicsWithValue(t, tc.ValidatePanic, func() {
actualEvaluationContext = tc.Command.Initialize([]string{})
actualEvaluationContext, cleanup = tc.Command.Initialize([]string{})
defer cleanup()
})

return
}

assert.NotPanics(t, func() {
actualEvaluationContext = tc.Command.Initialize([]string{})
actualEvaluationContext, cleanup = tc.Command.Initialize([]string{})
defer cleanup()
})

if tc.ValidateCommand != nil {
Expand Down

0 comments on commit 9bec019

Please sign in to comment.