Skip to content

Commit

Permalink
Add simple code completion functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
xyproto committed Oct 4, 2024
1 parent 6000d98 commit 54dd20e
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 0 deletions.
127 changes: 127 additions & 0 deletions v2/code.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
package ollamaclient

import (
"bytes"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
)

// GetBetweenResponse is given the start of code and end of code and will try to complete what is in between
// This function will ignore oc.TrimSpace and not trim blanks.
func (oc *Config) GetBetweenResponse(prompt, suffix string) (OutputResponse, error) {
var (
temperature float64
cacheKey string
seed = oc.SeedOrNegative
)
if prompt == "" {
return OutputResponse{}, errors.New("the prompt can not be empty")
}
if seed < 0 {
temperature = oc.TemperatureIfNegativeSeed
} else {
temperature = 0 // Since temperature is set to 0 when seed >=0
// The cache is only used for fixed seeds and a temperature of 0
keyData := struct {
Prompts []string
ModelName string
Seed int
Temperature float64
}{
Prompts: []string{prompt, suffix},
ModelName: oc.ModelName,
Seed: seed,
Temperature: temperature,
}
keyDataBytes, err := json.Marshal(keyData)
if err != nil {
return OutputResponse{}, err
}
hash := sha256.Sum256(keyDataBytes)
cacheKey = hex.EncodeToString(hash[:])
if Cache == nil {
if err := InitCache(); err != nil {
return OutputResponse{}, err
}
}
if entry, err := Cache.Get(cacheKey); err == nil {
var res OutputResponse
err = json.Unmarshal(entry, &res)
if err != nil {
return OutputResponse{}, err
}
return res, nil
}
}
var reqBody GenerateRequest
reqBody = GenerateRequest{
Model: oc.ModelName,
Prompt: prompt,
Suffix: suffix,
Options: RequestOptions{
Seed: seed, // set to -1 to make it random
Temperature: temperature, // set to 0 together with a specific seed to make output reproducible
},
}
if oc.ContextLength != 0 {
reqBody.Options.ContextLength = oc.ContextLength
}
reqBytes, err := json.Marshal(reqBody)
if err != nil {
return OutputResponse{}, err
}
if oc.Verbose {
fmt.Printf("Sending request to %s/api/generate: %s\n", oc.ServerAddr, string(reqBytes))
}
HTTPClient := &http.Client{
Timeout: oc.HTTPTimeout,
}
resp, err := HTTPClient.Post(oc.ServerAddr+"/api/generate", mimeJSON, bytes.NewBuffer(reqBytes))
if err != nil {
return OutputResponse{}, err
}
defer resp.Body.Close()
response := OutputResponse{
Role: "assistant",
}
var sb strings.Builder
decoder := json.NewDecoder(resp.Body)
for {
var genResp GenerateResponse
if err := decoder.Decode(&genResp); err != nil {
break
}
sb.WriteString(genResp.Response)
if genResp.Done {
response.PromptTokens = genResp.PromptEvalCount
response.ResponseTokens = genResp.EvalCount
break
}
}
response.Response = strings.TrimPrefix(sb.String(), "\n")
if cacheKey != "" {
data, err := json.Marshal(response)
if err != nil {
return OutputResponse{}, err
}
Cache.Set(cacheKey, data)
}
return response, nil
}

// Complete is a convenience function for completing code between two given strings of code
func (oc *Config) Complete(codeStart, codeEnd string) (string, error) {
if err := oc.PullIfNeeded(true); err != nil {
return "", err
}
response, err := oc.GetBetweenResponse(codeStart, codeEnd)
if err != nil {
return "", err
}
return response.Response, nil
}
40 changes: 40 additions & 0 deletions v2/code_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package ollamaclient

import (
"fmt"
"testing"
)

const codeCompleteModel = "deepseek-coder-v2:latest"

func TestBetween(t *testing.T) {
const (
codeStart = "def compute_gcd(a, b):"
codeEnd = " return result"
)
oc := New(codeCompleteModel)
oc.Verbose = true
if err := oc.PullIfNeeded(true); err != nil {
t.Fatalf("Failed to pull model: %v", err)
}
response, err := oc.GetBetweenResponse(codeStart, codeEnd)
if err != nil {
t.Fatalf("Failed to get code completion: %v", err)
}
fmt.Printf("%s\n%s\n%s\n", codeStart, response.Response, codeEnd)
}

func TestCodeCompletion(t *testing.T) {
const (
codeStart = "def compute_gcd(a, b):"
codeEnd = " return result"
verbose = true
)
oc := New(codeCompleteModel)
oc.Verbose = true
codeBetween, err := oc.Complete(codeStart, codeEnd)
if err != nil {
t.Fatal(err)
}
fmt.Printf("%s\n%s\n%s\n", codeStart, codeBetween, codeEnd)
}
2 changes: 2 additions & 0 deletions v2/generation.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ type GenerateRequest struct {
Images []string `json:"images,omitempty"` // base64 encoded images
Stream bool `json:"stream,omitempty"`
Options RequestOptions `json:"options,omitempty"`
Suffix string `json:"suffix,omitempty"`
}

// GenerateResponse represents the response data from the generate API call
Expand Down Expand Up @@ -35,6 +36,7 @@ type GenerateChatRequest struct {
Stream bool `json:"stream"`
Tools []Tool `json:"tools,omitempty"`
Options RequestOptions `json:"options,omitempty"`
Suffix string `json:"suffix,omitempty"`
}

// GenerateChatResponse represents the response data from the generate chat API call
Expand Down

0 comments on commit 54dd20e

Please sign in to comment.