Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unified generation outputs #6

Merged
merged 1 commit into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion v2/cmd/describeimage/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,14 @@ func main() {
promptAndImages := append([]string{prompt}, images...)

logVerbose("[%s] Generating... ", oc.ModelName)
output, err := oc.GetOutput(promptAndImages...)
response, err := oc.GetOutput(promptAndImages...)
if err != nil {
fmt.Printf("error: %s\n", err)
os.Exit(1)
}

output := response.Response

logVerbose("OK\n")

if output == "" {
Expand Down
4 changes: 2 additions & 2 deletions v2/cmd/fortune/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ func main() {
oc.SetRandom()

generatedOutput := oc.MustOutput(prompt)
if generatedOutput == "" {
if generatedOutput.Response == "" {
log.Println("Could not generate output.")
}

fmt.Println(ollamaclient.Massage(generatedOutput))
fmt.Println(ollamaclient.Massage(generatedOutput.Response))
}
4 changes: 3 additions & 1 deletion v2/cmd/summarize/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,14 @@ func main() {
}

logVerbose("[%s] Generating... ", oc.ModelName)
output, err := oc.GetOutput(prompt)
response, err := oc.GetOutput(prompt)
if err != nil {
fmt.Printf("error: %s\n", err)
os.Exit(1)
}

output := response.Response

logVerbose("OK\n")

if wrapWidth > 0 {
Expand Down
4 changes: 2 additions & 2 deletions v2/describeimage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ func TestDescribeImage(t *testing.T) {

prompt := "Describe this image:"
generatedOutput := oc.MustOutput(prompt, base64image)
if generatedOutput == "" {
if generatedOutput.Response == "" {
t.Fatalf("Generated output for the prompt %s is empty.\n", prompt)
}
fmt.Println(Massage(generatedOutput))
fmt.Println(Massage(generatedOutput.Response))
}
63 changes: 63 additions & 0 deletions v2/generation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package ollamaclient

// GenerateRequest represents the request payload for generating output
type GenerateRequest struct {
Model string `json:"model"`
System string `json:"system,omitempty"`
Prompt string `json:"prompt,omitempty"`
Images []string `json:"images,omitempty"` // base64 encoded images
Stream bool `json:"stream,omitempty"`
Options RequestOptions `json:"options,omitempty"`
}

// GenerateResponse represents the response data from the generate API call
type GenerateResponse struct {
Model string `json:"model"`
CreatedAt string `json:"created_at"`
Response string `json:"response"`
Context []int `json:"context,omitempty"`
TotalDuration int64 `json:"total_duration,omitempty"`
LoadDuration int64 `json:"load_duration,omitempty"`
SampleCount int `json:"sample_count,omitempty"`
SampleDuration int64 `json:"sample_duration,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
PromptEvalDuration int64 `json:"prompt_eval_duration,omitempty"`
EvalCount int `json:"eval_count,omitempty"`
EvalDuration int64 `json:"eval_duration,omitempty"`
Done bool `json:"done"`
}

// GenerateChatRequest represents the request payload for generating chat output
type GenerateChatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages,omitempty"`
Images []string `json:"images,omitempty"` // base64 encoded images
Stream bool `json:"stream"`
Tools []Tool `json:"tools,omitempty"`
Options RequestOptions `json:"options,omitempty"`
}

// GenerateChatResponse represents the response data from the generate chat API call
type GenerateChatResponse struct {
Model string `json:"model"`
CreatedAt string `json:"created_at"`
Message MessageResponse `json:"message"`
DoneReason string `json:"done_reason"`
Done bool `json:"done"`
TotalDuration int64 `json:"total_duration,omitempty"`
LoadDuration int64 `json:"load_duration,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
PromptEvalDuration int64 `json:"prompt_eval_duration,omitempty"`
EvalCount int `json:"eval_count,omitempty"`
EvalDuration int64 `json:"eval_duration,omitempty"`
}

// OutputResponse represents the output from Ollama
type OutputResponse struct {
Role string `json:"role"`
Response string `json:"response"`
ToolCalls []ToolCall `json:"tool_calls"`
PromptTokens int `json:"prompt_tokens"`
ResponseTokens int `json:"response_tokens"`
Error string `json:"error"`
}
87 changes: 37 additions & 50 deletions v2/ollamaclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,33 +30,6 @@ type RequestOptions struct {
ContextLength int64 `json:"num_ctx,omitempty"`
}

// GenerateRequest represents the request payload for generating output
type GenerateRequest struct {
Model string `json:"model"`
System string `json:"system,omitempty"`
Prompt string `json:"prompt,omitempty"`
Images []string `json:"images,omitempty"` // base64 encoded images
Stream bool `json:"stream,omitempty"`
Options RequestOptions `json:"options,omitempty"`
}

// GenerateResponse represents the response data from the generate API call
type GenerateResponse struct {
Model string `json:"model"`
CreatedAt string `json:"created_at"`
Response string `json:"response"`
Context []int `json:"context,omitempty"`
TotalDuration int64 `json:"total_duration,omitempty"`
LoadDuration int64 `json:"load_duration,omitempty"`
SampleCount int `json:"sample_count,omitempty"`
SampleDuration int64 `json:"sample_duration,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
PromptEvalDuration int64 `json:"prompt_eval_duration,omitempty"`
EvalCount int `json:"eval_count,omitempty"`
EvalDuration int64 `json:"eval_duration,omitempty"`
Done bool `json:"done"`
}

// Model represents a downloaded model
type Model struct {
Modified time.Time `json:"modified_at"`
Expand Down Expand Up @@ -184,13 +157,13 @@ func (oc *Config) SetTool(tool Tool) {
}

// GetOutputChat sends a request to the Ollama API and returns the generated output.
func (oc *Config) GetOutputChat(promptAndOptionalImages ...string) (OutputChat, error) {
func (oc *Config) GetOutputChat(promptAndOptionalImages ...string) (OutputResponse, error) {
var (
temperature float64
seed = oc.SeedOrNegative
)
if len(promptAndOptionalImages) == 0 {
return OutputChat{}, errors.New("at least one prompt must be given (and then optionally, base64 encoded JPG or PNG image strings)")
return OutputResponse{}, errors.New("at least one prompt must be given (and then optionally, base64 encoded JPG or PNG image strings)")
}
prompt := promptAndOptionalImages[0]
var images []string
Expand Down Expand Up @@ -239,7 +212,7 @@ func (oc *Config) GetOutputChat(promptAndOptionalImages ...string) (OutputChat,
}
reqBytes, err := json.Marshal(reqBody)
if err != nil {
return OutputChat{}, err
return OutputResponse{}, err
}
if oc.Verbose {
fmt.Printf("Sending request to %s/api/chat: %s\n", oc.ServerAddr, string(reqBytes))
Expand All @@ -249,10 +222,10 @@ func (oc *Config) GetOutputChat(promptAndOptionalImages ...string) (OutputChat,
}
resp, err := HTTPClient.Post(oc.ServerAddr+"/api/chat", mimeJSON, bytes.NewBuffer(reqBytes))
if err != nil {
return OutputChat{}, err
return OutputResponse{}, err
}
defer resp.Body.Close()
var res = OutputChat{}
var res = OutputResponse{}
var sb strings.Builder
decoder := json.NewDecoder(resp.Body)
for {
Expand All @@ -264,25 +237,27 @@ func (oc *Config) GetOutputChat(promptAndOptionalImages ...string) (OutputChat,
if genResp.Done {
res.Role = genResp.Message.Role
res.ToolCalls = genResp.Message.ToolCalls
res.PromptTokens = genResp.PromptEvalCount
res.ResponseTokens = genResp.EvalCount
break
}
}
res.Content = strings.TrimPrefix(sb.String(), "\n")
res.Response = strings.TrimPrefix(sb.String(), "\n")
if oc.TrimSpace {
res.Content = strings.TrimSpace(res.Content)
res.Response = strings.TrimSpace(res.Response)
}
return res, nil
}

// GetOutput sends a request to the Ollama API and returns the generated output.
func (oc *Config) GetOutput(promptAndOptionalImages ...string) (string, error) {
func (oc *Config) GetOutput(promptAndOptionalImages ...string) (OutputResponse, error) {
var (
temperature float64
cacheKey string
seed = oc.SeedOrNegative
)
if len(promptAndOptionalImages) == 0 {
return "", errors.New("at least one prompt must be given (and then optionally, base64 encoded JPG or PNG image strings)")
return OutputResponse{}, errors.New("at least one prompt must be given (and then optionally, base64 encoded JPG or PNG image strings)")
}
prompt := promptAndOptionalImages[0]
var images []string
Expand All @@ -296,11 +271,13 @@ func (oc *Config) GetOutput(promptAndOptionalImages ...string) (string, error) {
cacheKey = prompt + "-" + oc.ModelName
if Cache == nil {
if err := InitCache(); err != nil {
return "", err
return OutputResponse{}, err
}
}
if entry, err := Cache.Get(cacheKey); err == nil {
return string(entry), nil
var res OutputResponse
json.Unmarshal(entry, &res)
return res, nil
}
}
var reqBody GenerateRequest
Expand Down Expand Up @@ -329,7 +306,7 @@ func (oc *Config) GetOutput(promptAndOptionalImages ...string) (string, error) {
}
reqBytes, err := json.Marshal(reqBody)
if err != nil {
return "", err
return OutputResponse{}, err
}
if oc.Verbose {
fmt.Printf("Sending request to %s/api/generate: %s\n", oc.ServerAddr, string(reqBytes))
Expand All @@ -339,9 +316,12 @@ func (oc *Config) GetOutput(promptAndOptionalImages ...string) (string, error) {
}
resp, err := HTTPClient.Post(oc.ServerAddr+"/api/generate", mimeJSON, bytes.NewBuffer(reqBytes))
if err != nil {
return "", err
return OutputResponse{}, err
}
defer resp.Body.Close()
response := OutputResponse{
Role: "assistant",
}
var sb strings.Builder
decoder := json.NewDecoder(resp.Body)
for {
Expand All @@ -351,33 +331,40 @@ func (oc *Config) GetOutput(promptAndOptionalImages ...string) (string, error) {
}
sb.WriteString(genResp.Response)
if genResp.Done {
response.PromptTokens = genResp.PromptEvalCount
response.ResponseTokens = genResp.EvalCount
break
}
}
outputString := strings.TrimPrefix(sb.String(), "\n")
if oc.TrimSpace {
outputString = strings.TrimSpace(outputString)
}
response.Response = outputString

if cacheKey != "" {
Cache.Set(cacheKey, []byte(outputString))
var data []byte
json.Unmarshal([]byte(data), &response)
Cache.Set(cacheKey, []byte(data))
}
return outputString, nil

return response, nil
}

// MustOutput returns the output from Ollama, or the error as a string if not
func (oc *Config) MustOutput(promptAndOptionalImages ...string) string {
func (oc *Config) MustOutput(promptAndOptionalImages ...string) OutputResponse {
output, err := oc.GetOutput(promptAndOptionalImages...)
if err != nil {
return err.Error()
return OutputResponse{Error: err.Error()}
}
return output
}

// MustOutputChat returns the output from Ollama, or the error as a string if not
func (oc *Config) MustOutputChat(promptAndOptionalImages ...string) OutputChat {
func (oc *Config) MustOutputChat(promptAndOptionalImages ...string) OutputResponse {
output, err := oc.GetOutputChat(promptAndOptionalImages...)
if err != nil {
return OutputChat{Error: err.Error()}
return OutputResponse{Error: err.Error()}
}
return output
}
Expand Down Expand Up @@ -481,18 +468,18 @@ func ClearCache() {
// DescribeImages can load a slice of image filenames into base64 encoded strings
// and build a prompt that starts with "Describe this/these image(s):" followed
// by the encoded images, and return a result. Typically used together with the "llava" model.
func (oc *Config) DescribeImages(imageFilenames []string, desiredWordCount int) (string, error) {
func (oc *Config) DescribeImages(imageFilenames []string, desiredWordCount int) (OutputResponse, error) {
var errNoImages = errors.New("must be given at least one image file to describe")

if len(imageFilenames) == 0 {
return "", errNoImages
return OutputResponse{}, errNoImages
}

var images []string
for _, imageFilename := range imageFilenames {
base64image, err := Base64EncodeFile(imageFilename)
if err != nil {
return "", fmt.Errorf("could not base64 encode %s: %v", imageFilename, err)
return OutputResponse{}, fmt.Errorf("could not base64 encode %s: %v", imageFilename, err)
}
// append the base64 encoded image to the "images" string slice
images = append(images, base64image)
Expand All @@ -501,7 +488,7 @@ func (oc *Config) DescribeImages(imageFilenames []string, desiredWordCount int)
var prompt string
switch len(images) {
case 0:
return "", errNoImages
return OutputResponse{}, errNoImages
case 1:
if desiredWordCount > 0 {
prompt = fmt.Sprintf("Describe this image using a maximum of %d words:", desiredWordCount)
Expand Down
4 changes: 2 additions & 2 deletions v2/pull_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ func TestPullGemmaIntegration(t *testing.T) {

prompt := "Generate an imperative sentence. Keep it brief. Only output the sentence itself. Skip explanations, introductions or preamble."
generatedOutput := oc.MustOutput(prompt)
if generatedOutput == "" {
if generatedOutput.Response == "" {
t.Fatalf("Generated output for the prompt %s is empty.\n", prompt)
}
fmt.Println(Massage(generatedOutput))
fmt.Println(Massage(generatedOutput.Response))
}
25 changes: 0 additions & 25 deletions v2/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,6 @@ import (
"net/http"
)

// GenerateChatRequest represents the request payload for generating chat output
type GenerateChatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages,omitempty"`
Images []string `json:"images,omitempty"` // base64 encoded images
Stream bool `json:"stream"`
Tools []Tool `json:"tools,omitempty"`
Options RequestOptions `json:"options,omitempty"`
}

// GenerateChatResponse represents the response data from the generate chat API call
type GenerateChatResponse struct {
Model string `json:"model"`
CreatedAt string `json:"created_at"`
Message MessageResponse `json:"message"`
DoneReason string `json:"done_reason"`
Done bool `json:"done"`
TotalDuration int64 `json:"total_duration,omitempty"`
LoadDuration int64 `json:"load_duration,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
PromptEvalDuration int64 `json:"prompt_eval_duration,omitempty"`
EvalCount int `json:"eval_count,omitempty"`
EvalDuration int64 `json:"eval_duration,omitempty"`
}

// Message is a chat message
type Message struct {
Role string `json:"role"`
Expand Down
8 changes: 0 additions & 8 deletions v2/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,3 @@ type ToolCallFunction struct {
type ToolCall struct {
Function ToolCallFunction `json:"function"`
}

// OutputChat represents the output from a chat request, including the role, content, tool calls, and any errors
type OutputChat struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls"`
Error string `json:"error"`
}
Loading