Skip to content

Commit a0ee716

Browse files
committed
llms: Add more sophisticated json marshaling
1 parent 4ceae99 commit a0ee716

File tree

2 files changed

+110
-1
lines changed

2 files changed

+110
-1
lines changed

llms/generatecontent.go

+29
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,23 @@ type ToolCall struct {
123123
FunctionCall *FunctionCall `json:"function,omitempty"`
124124
}
125125

126+
func (bc ToolCall) MarshalJSON() ([]byte, error) {
127+
fc, err := json.Marshal(bc.FunctionCall)
128+
if err != nil {
129+
return nil, err
130+
}
131+
132+
m := map[string]any{
133+
"type": "tool_call",
134+
"tool_call": map[string]any{
135+
"id": bc.ID,
136+
"type": bc.Type,
137+
"fc": json.RawMessage(fc),
138+
},
139+
}
140+
return json.Marshal(m)
141+
}
142+
126143
func (ToolCall) isPart() {}
127144

128145
// ToolCallResponse is the response returned by a tool call.
@@ -135,6 +152,18 @@ type ToolCallResponse struct {
135152
Content string `json:"content"`
136153
}
137154

155+
func (tc ToolCallResponse) MarshalJSON() ([]byte, error) {
156+
m := map[string]any{
157+
"type": "tool_response",
158+
"tool_response": map[string]string{
159+
"tool_call_id": tc.ToolCallID,
160+
"name": tc.Name,
161+
"content": tc.Content,
162+
},
163+
}
164+
return json.Marshal(m)
165+
}
166+
138167
func (ToolCallResponse) isPart() {}
139168

140169
// ContentResponse is the response returned by a GenerateContent call.

llms/marshaling.go

+81-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package llms
22

33
import (
44
"encoding/base64"
5+
"encoding/json"
56
"fmt"
67

78
"gopkg.in/yaml.v3"
@@ -62,6 +63,65 @@ func (mc *MessageContent) UnmarshalYAML(unmarshal func(interface{}) error) error
6263
return nil
6364
}
6465

66+
func (mc *MessageContent) UnmarshalJSON(data []byte) error {
67+
var m struct {
68+
Role ChatMessageType `json:"role"`
69+
Text string `json:"text"`
70+
Parts []struct {
71+
Type string `json:"type"`
72+
Text string `json:"text"`
73+
ImageURL struct {
74+
URL string `json:"url"`
75+
} `json:"image_url"`
76+
Binary struct {
77+
MIMEType string `json:"mime_type"`
78+
Data []byte `json:"data"`
79+
} `json:"binary"`
80+
ToolCall struct {
81+
ID string `json:"id"`
82+
Type string `json:"type"`
83+
FuncCall *FunctionCall `json:"fc"`
84+
} `json:"tool_call"`
85+
ToolResp json.RawMessage `json:"tool_response"`
86+
} `json:"parts"`
87+
}
88+
if err := json.Unmarshal(data, &m); err != nil {
89+
return err
90+
}
91+
mc.Role = m.Role
92+
mc.Parts = make([]ContentPart, len(m.Parts))
93+
94+
for i, part := range m.Parts {
95+
switch part.Type {
96+
case "text":
97+
mc.Parts[i] = TextContent{Text: part.Text}
98+
case "image_url":
99+
mc.Parts[i] = ImageURLContent{URL: part.ImageURL.URL}
100+
case "binary":
101+
mc.Parts[i] = BinaryContent{MIMEType: part.Binary.MIMEType, Data: part.Binary.Data}
102+
case "tool_call":
103+
mc.Parts[i] = ToolCall{
104+
ID: part.ToolCall.ID,
105+
Type: part.ToolCall.Type,
106+
FunctionCall: part.ToolCall.FuncCall,
107+
}
108+
case "tool_response":
109+
var tr ToolCallResponse
110+
if err := json.Unmarshal(part.ToolResp, &tr); err != nil {
111+
return err
112+
}
113+
mc.Parts[i] = tr
114+
default:
115+
return fmt.Errorf("unknown content type: %s", part.Type)
116+
}
117+
}
118+
// Special case: handle single text part directly:
119+
if len(mc.Parts) == 0 && m.Text != "" {
120+
mc.Parts = []ContentPart{TextContent{Text: m.Text}}
121+
}
122+
return nil
123+
}
124+
65125
// MarshalYAML custom marshaling logic for MessageContent.
66126
func (mc MessageContent) MarshalYAML() (interface{}, error) {
67127
// Special case: handle single text part directly
@@ -101,10 +161,30 @@ func (mc MessageContent) MarshalYAML() (interface{}, error) {
101161
raw := make(map[string]interface{})
102162
raw["role"] = mc.Role
103163
raw["parts"] = parts
104-
105164
return raw, nil
106165
}
107166

167+
func (mc MessageContent) MarshalJSON() ([]byte, error) {
168+
hasSingleTextPart := false
169+
if len(mc.Parts) == 1 {
170+
_, hasSingleTextPart = mc.Parts[0].(TextContent)
171+
}
172+
if hasSingleTextPart {
173+
tp, _ := mc.Parts[0].(TextContent)
174+
return json.Marshal(struct {
175+
Role ChatMessageType `json:"role"`
176+
Text string `json:"text"`
177+
}{Role: mc.Role, Text: tp.Text})
178+
}
179+
return json.Marshal(struct {
180+
Role ChatMessageType `json:"role"`
181+
Parts []ContentPart `json:"parts"`
182+
}{
183+
Role: mc.Role,
184+
Parts: mc.Parts,
185+
})
186+
}
187+
108188
// Helper function to map raw data to struct.
109189
func mapToStruct(data map[string]interface{}, target interface{}) error {
110190
bytes, err := yaml.Marshal(data)

0 commit comments

Comments
 (0)