Skip to content

Commit

Permalink
Unified generation outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathanhecl authored and xyproto committed Aug 6, 2024
1 parent e6a6b2d commit abd7b76
Show file tree
Hide file tree
Showing 9 changed files with 113 additions and 91 deletions.
5 changes: 4 additions & 1 deletion v2/cmd/describeimage/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,14 @@ func main() {
promptAndImages := append([]string{prompt}, images...)

logVerbose("[%s] Generating... ", oc.ModelName)
output, err := oc.GetOutput(promptAndImages...)
response, err := oc.GetOutput(promptAndImages...)
if err != nil {
fmt.Printf("error: %s\n", err)
os.Exit(1)
}

output := response.Response

logVerbose("OK\n")

if output == "" {
Expand Down
4 changes: 2 additions & 2 deletions v2/cmd/fortune/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ func main() {
oc.SetRandom()

generatedOutput := oc.MustOutput(prompt)
if generatedOutput == "" {
if generatedOutput.Response == "" {
log.Println("Could not generate output.")
}

fmt.Println(ollamaclient.Massage(generatedOutput))
fmt.Println(ollamaclient.Massage(generatedOutput.Response))
}
4 changes: 3 additions & 1 deletion v2/cmd/summarize/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,14 @@ func main() {
}

logVerbose("[%s] Generating... ", oc.ModelName)
output, err := oc.GetOutput(prompt)
response, err := oc.GetOutput(prompt)
if err != nil {
fmt.Printf("error: %s\n", err)
os.Exit(1)
}

output := response.Response

logVerbose("OK\n")

if wrapWidth > 0 {
Expand Down
4 changes: 2 additions & 2 deletions v2/describeimage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ func TestDescribeImage(t *testing.T) {

prompt := "Describe this image:"
generatedOutput := oc.MustOutput(prompt, base64image)
if generatedOutput == "" {
if generatedOutput.Response == "" {
t.Fatalf("Generated output for the prompt %s is empty.\n", prompt)
}
fmt.Println(Massage(generatedOutput))
fmt.Println(Massage(generatedOutput.Response))
}
63 changes: 63 additions & 0 deletions v2/generation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package ollamaclient

// GenerateRequest represents the request payload for generating output
type GenerateRequest struct {
Model string `json:"model"`
System string `json:"system,omitempty"`
Prompt string `json:"prompt,omitempty"`
Images []string `json:"images,omitempty"` // base64 encoded images
Stream bool `json:"stream,omitempty"`
Options RequestOptions `json:"options,omitempty"`
}

// GenerateResponse represents the response data from the generate API call
type GenerateResponse struct {
Model string `json:"model"`
CreatedAt string `json:"created_at"`
Response string `json:"response"`
Context []int `json:"context,omitempty"`
TotalDuration int64 `json:"total_duration,omitempty"`
LoadDuration int64 `json:"load_duration,omitempty"`
SampleCount int `json:"sample_count,omitempty"`
SampleDuration int64 `json:"sample_duration,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
PromptEvalDuration int64 `json:"prompt_eval_duration,omitempty"`
EvalCount int `json:"eval_count,omitempty"`
EvalDuration int64 `json:"eval_duration,omitempty"`
Done bool `json:"done"`
}

// GenerateChatRequest represents the request payload for generating chat output
type GenerateChatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages,omitempty"`
Images []string `json:"images,omitempty"` // base64 encoded images
Stream bool `json:"stream"`
Tools []Tool `json:"tools,omitempty"`
Options RequestOptions `json:"options,omitempty"`
}

// GenerateChatResponse represents the response data from the generate chat API call
type GenerateChatResponse struct {
Model string `json:"model"`
CreatedAt string `json:"created_at"`
Message MessageResponse `json:"message"`
DoneReason string `json:"done_reason"`
Done bool `json:"done"`
TotalDuration int64 `json:"total_duration,omitempty"`
LoadDuration int64 `json:"load_duration,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
PromptEvalDuration int64 `json:"prompt_eval_duration,omitempty"`
EvalCount int `json:"eval_count,omitempty"`
EvalDuration int64 `json:"eval_duration,omitempty"`
}

// OutputResponse represents the output from Ollama
type OutputResponse struct {
Role string `json:"role"`
Response string `json:"response"`
ToolCalls []ToolCall `json:"tool_calls"`
PromptTokens int `json:"prompt_tokens"`
ResponseTokens int `json:"response_tokens"`
Error string `json:"error"`
}
87 changes: 37 additions & 50 deletions v2/ollamaclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,33 +30,6 @@ type RequestOptions struct {
ContextLength int64 `json:"num_ctx,omitempty"`
}

// GenerateRequest represents the request payload for generating output
type GenerateRequest struct {
Model string `json:"model"`
System string `json:"system,omitempty"`
Prompt string `json:"prompt,omitempty"`
Images []string `json:"images,omitempty"` // base64 encoded images
Stream bool `json:"stream,omitempty"`
Options RequestOptions `json:"options,omitempty"`
}

// GenerateResponse represents the response data from the generate API call
type GenerateResponse struct {
Model string `json:"model"`
CreatedAt string `json:"created_at"`
Response string `json:"response"`
Context []int `json:"context,omitempty"`
TotalDuration int64 `json:"total_duration,omitempty"`
LoadDuration int64 `json:"load_duration,omitempty"`
SampleCount int `json:"sample_count,omitempty"`
SampleDuration int64 `json:"sample_duration,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
PromptEvalDuration int64 `json:"prompt_eval_duration,omitempty"`
EvalCount int `json:"eval_count,omitempty"`
EvalDuration int64 `json:"eval_duration,omitempty"`
Done bool `json:"done"`
}

// Model represents a downloaded model
type Model struct {
Modified time.Time `json:"modified_at"`
Expand Down Expand Up @@ -184,13 +157,13 @@ func (oc *Config) SetTool(tool Tool) {
}

// GetOutputChat sends a request to the Ollama API and returns the generated output.
func (oc *Config) GetOutputChat(promptAndOptionalImages ...string) (OutputChat, error) {
func (oc *Config) GetOutputChat(promptAndOptionalImages ...string) (OutputResponse, error) {
var (
temperature float64
seed = oc.SeedOrNegative
)
if len(promptAndOptionalImages) == 0 {
return OutputChat{}, errors.New("at least one prompt must be given (and then optionally, base64 encoded JPG or PNG image strings)")
return OutputResponse{}, errors.New("at least one prompt must be given (and then optionally, base64 encoded JPG or PNG image strings)")
}
prompt := promptAndOptionalImages[0]
var images []string
Expand Down Expand Up @@ -239,7 +212,7 @@ func (oc *Config) GetOutputChat(promptAndOptionalImages ...string) (OutputChat,
}
reqBytes, err := json.Marshal(reqBody)
if err != nil {
return OutputChat{}, err
return OutputResponse{}, err
}
if oc.Verbose {
fmt.Printf("Sending request to %s/api/chat: %s\n", oc.ServerAddr, string(reqBytes))
Expand All @@ -249,10 +222,10 @@ func (oc *Config) GetOutputChat(promptAndOptionalImages ...string) (OutputChat,
}
resp, err := HTTPClient.Post(oc.ServerAddr+"/api/chat", mimeJSON, bytes.NewBuffer(reqBytes))
if err != nil {
return OutputChat{}, err
return OutputResponse{}, err
}
defer resp.Body.Close()
var res = OutputChat{}
var res = OutputResponse{}
var sb strings.Builder
decoder := json.NewDecoder(resp.Body)
for {
Expand All @@ -264,25 +237,27 @@ func (oc *Config) GetOutputChat(promptAndOptionalImages ...string) (OutputChat,
if genResp.Done {
res.Role = genResp.Message.Role
res.ToolCalls = genResp.Message.ToolCalls
res.PromptTokens = genResp.PromptEvalCount
res.ResponseTokens = genResp.EvalCount
break
}
}
res.Content = strings.TrimPrefix(sb.String(), "\n")
res.Response = strings.TrimPrefix(sb.String(), "\n")
if oc.TrimSpace {
res.Content = strings.TrimSpace(res.Content)
res.Response = strings.TrimSpace(res.Response)
}
return res, nil
}

// GetOutput sends a request to the Ollama API and returns the generated output.
func (oc *Config) GetOutput(promptAndOptionalImages ...string) (string, error) {
func (oc *Config) GetOutput(promptAndOptionalImages ...string) (OutputResponse, error) {
var (
temperature float64
cacheKey string
seed = oc.SeedOrNegative
)
if len(promptAndOptionalImages) == 0 {
return "", errors.New("at least one prompt must be given (and then optionally, base64 encoded JPG or PNG image strings)")
return OutputResponse{}, errors.New("at least one prompt must be given (and then optionally, base64 encoded JPG or PNG image strings)")
}
prompt := promptAndOptionalImages[0]
var images []string
Expand All @@ -296,11 +271,13 @@ func (oc *Config) GetOutput(promptAndOptionalImages ...string) (string, error) {
cacheKey = prompt + "-" + oc.ModelName
if Cache == nil {
if err := InitCache(); err != nil {
return "", err
return OutputResponse{}, err
}
}
if entry, err := Cache.Get(cacheKey); err == nil {
return string(entry), nil
var res OutputResponse
json.Unmarshal(entry, &res)
return res, nil
}
}
var reqBody GenerateRequest
Expand Down Expand Up @@ -329,7 +306,7 @@ func (oc *Config) GetOutput(promptAndOptionalImages ...string) (string, error) {
}
reqBytes, err := json.Marshal(reqBody)
if err != nil {
return "", err
return OutputResponse{}, err
}
if oc.Verbose {
fmt.Printf("Sending request to %s/api/generate: %s\n", oc.ServerAddr, string(reqBytes))
Expand All @@ -339,9 +316,12 @@ func (oc *Config) GetOutput(promptAndOptionalImages ...string) (string, error) {
}
resp, err := HTTPClient.Post(oc.ServerAddr+"/api/generate", mimeJSON, bytes.NewBuffer(reqBytes))
if err != nil {
return "", err
return OutputResponse{}, err
}
defer resp.Body.Close()
response := OutputResponse{
Role: "assistant",
}
var sb strings.Builder
decoder := json.NewDecoder(resp.Body)
for {
Expand All @@ -351,33 +331,40 @@ func (oc *Config) GetOutput(promptAndOptionalImages ...string) (string, error) {
}
sb.WriteString(genResp.Response)
if genResp.Done {
response.PromptTokens = genResp.PromptEvalCount
response.ResponseTokens = genResp.EvalCount
break
}
}
outputString := strings.TrimPrefix(sb.String(), "\n")
if oc.TrimSpace {
outputString = strings.TrimSpace(outputString)
}
response.Response = outputString

if cacheKey != "" {
Cache.Set(cacheKey, []byte(outputString))
var data []byte
json.Unmarshal([]byte(data), &response)
Cache.Set(cacheKey, []byte(data))
}
return outputString, nil

return response, nil
}

// MustOutput returns the output from Ollama, or the error as a string if not
func (oc *Config) MustOutput(promptAndOptionalImages ...string) string {
func (oc *Config) MustOutput(promptAndOptionalImages ...string) OutputResponse {
output, err := oc.GetOutput(promptAndOptionalImages...)
if err != nil {
return err.Error()
return OutputResponse{Error: err.Error()}
}
return output
}

// MustOutputChat returns the output from Ollama, or the error as a string if not
func (oc *Config) MustOutputChat(promptAndOptionalImages ...string) OutputChat {
func (oc *Config) MustOutputChat(promptAndOptionalImages ...string) OutputResponse {
output, err := oc.GetOutputChat(promptAndOptionalImages...)
if err != nil {
return OutputChat{Error: err.Error()}
return OutputResponse{Error: err.Error()}
}
return output
}
Expand Down Expand Up @@ -481,18 +468,18 @@ func ClearCache() {
// DescribeImages can load a slice of image filenames into base64 encoded strings
// and build a prompt that starts with "Describe this/these image(s):" followed
// by the encoded images, and return a result. Typically used together with the "llava" model.
func (oc *Config) DescribeImages(imageFilenames []string, desiredWordCount int) (string, error) {
func (oc *Config) DescribeImages(imageFilenames []string, desiredWordCount int) (OutputResponse, error) {
var errNoImages = errors.New("must be given at least one image file to describe")

if len(imageFilenames) == 0 {
return "", errNoImages
return OutputResponse{}, errNoImages
}

var images []string
for _, imageFilename := range imageFilenames {
base64image, err := Base64EncodeFile(imageFilename)
if err != nil {
return "", fmt.Errorf("could not base64 encode %s: %v", imageFilename, err)
return OutputResponse{}, fmt.Errorf("could not base64 encode %s: %v", imageFilename, err)
}
// append the base64 encoded image to the "images" string slice
images = append(images, base64image)
Expand All @@ -501,7 +488,7 @@ func (oc *Config) DescribeImages(imageFilenames []string, desiredWordCount int)
var prompt string
switch len(images) {
case 0:
return "", errNoImages
return OutputResponse{}, errNoImages
case 1:
if desiredWordCount > 0 {
prompt = fmt.Sprintf("Describe this image using a maximum of %d words:", desiredWordCount)
Expand Down
4 changes: 2 additions & 2 deletions v2/pull_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ func TestPullGemmaIntegration(t *testing.T) {

prompt := "Generate an imperative sentence. Keep it brief. Only output the sentence itself. Skip explanations, introductions or preamble."
generatedOutput := oc.MustOutput(prompt)
if generatedOutput == "" {
if generatedOutput.Response == "" {
t.Fatalf("Generated output for the prompt %s is empty.\n", prompt)
}
fmt.Println(Massage(generatedOutput))
fmt.Println(Massage(generatedOutput.Response))
}
25 changes: 0 additions & 25 deletions v2/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,6 @@ import (
"net/http"
)

// GenerateChatRequest represents the request payload for generating chat output
type GenerateChatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages,omitempty"`
Images []string `json:"images,omitempty"` // base64 encoded images
Stream bool `json:"stream"`
Tools []Tool `json:"tools,omitempty"`
Options RequestOptions `json:"options,omitempty"`
}

// GenerateChatResponse represents the response data from the generate chat API call
type GenerateChatResponse struct {
Model string `json:"model"`
CreatedAt string `json:"created_at"`
Message MessageResponse `json:"message"`
DoneReason string `json:"done_reason"`
Done bool `json:"done"`
TotalDuration int64 `json:"total_duration,omitempty"`
LoadDuration int64 `json:"load_duration,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
PromptEvalDuration int64 `json:"prompt_eval_duration,omitempty"`
EvalCount int `json:"eval_count,omitempty"`
EvalDuration int64 `json:"eval_duration,omitempty"`
}

// Message is a chat message
type Message struct {
Role string `json:"role"`
Expand Down
8 changes: 0 additions & 8 deletions v2/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,3 @@ type ToolCallFunction struct {
type ToolCall struct {
Function ToolCallFunction `json:"function"`
}

// OutputChat represents the output from a chat request, including the role, content, tool calls, and any errors
type OutputChat struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls"`
Error string `json:"error"`
}

0 comments on commit abd7b76

Please sign in to comment.