Skip to content

Commit

Permalink
feat: adds format support
Browse files Browse the repository at this point in the history
  • Loading branch information
macistador committed Dec 30, 2024
1 parent 95d4249 commit eec9a25
Show file tree
Hide file tree
Showing 6 changed files with 209 additions and 34 deletions.
15 changes: 13 additions & 2 deletions Playground/OKPlayground.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
objects = {

/* Begin PBXBuildFile section */
04A2EC422D22A513009C9AED /* ChatWithFormatView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 04A2EC412D22A513009C9AED /* ChatWithFormatView.swift */; };
04C1CC092D21E8B8005362B0 /* OllamaKit in Frameworks */ = {isa = PBXBuildFile; productRef = 04C1CC082D21E8B8005362B0 /* OllamaKit */; };
0A5648C42C1468E0008FB5F6 /* OKPlaygroundApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0A5648C32C1468E0008FB5F6 /* OKPlaygroundApp.swift */; };
0A5648C82C1468E1008FB5F6 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 0A5648C72C1468E1008FB5F6 /* Assets.xcassets */; };
0A5648CB2C1468E1008FB5F6 /* Preview Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 0A5648CA2C1468E1008FB5F6 /* Preview Assets.xcassets */; };
Expand All @@ -25,6 +27,7 @@
/* End PBXBuildFile section */

/* Begin PBXFileReference section */
04A2EC412D22A513009C9AED /* ChatWithFormatView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatWithFormatView.swift; sourceTree = "<group>"; };
0A5648C02C1468E0008FB5F6 /* OKPlayground.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = OKPlayground.app; sourceTree = BUILT_PRODUCTS_DIR; };
0A5648C32C1468E0008FB5F6 /* OKPlaygroundApp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OKPlaygroundApp.swift; sourceTree = "<group>"; };
0A5648C72C1468E1008FB5F6 /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = "<group>"; };
Expand All @@ -49,6 +52,7 @@
files = (
0A5648EE2C150C66008FB5F6 /* OllamaKit in Frameworks */,
0AEA9B902CFCFD0100227D01 /* OllamaKit in Frameworks */,
04C1CC092D21E8B8005362B0 /* OllamaKit in Frameworks */,
);
runOnlyForDeploymentPostprocessing = 0;
};
Expand Down Expand Up @@ -104,6 +108,7 @@
0A5648E72C14E00E008FB5F6 /* ChatView.swift */,
0A5648E92C1504CD008FB5F6 /* GenerateView.swift */,
0A7B752C2C55E79400624336 /* ChatWithToolsView.swift */,
04A2EC412D22A513009C9AED /* ChatWithFormatView.swift */,
);
path = Views;
sourceTree = "<group>";
Expand Down Expand Up @@ -151,6 +156,7 @@
packageProductDependencies = (
0A5648ED2C150C66008FB5F6 /* OllamaKit */,
0AEA9B8F2CFCFD0100227D01 /* OllamaKit */,
04C1CC082D21E8B8005362B0 /* OllamaKit */,
);
productName = OKPlayground;
productReference = 0A5648C02C1468E0008FB5F6 /* OKPlayground.app */;
Expand Down Expand Up @@ -181,7 +187,7 @@
);
mainGroup = 0A5648B72C1468E0008FB5F6;
packageReferences = (
0AEA9B8E2CFCFD0100227D01 /* XCLocalSwiftPackageReference "../../OllamaKit" */,
04C1CC072D21E8B8005362B0 /* XCLocalSwiftPackageReference "../../OllamaKit" */,
);
productRefGroup = 0A5648C12C1468E0008FB5F6 /* Products */;
projectDirPath = "";
Expand Down Expand Up @@ -220,6 +226,7 @@
0A5648E42C14D342008FB5F6 /* ModelInfoView.swift in Sources */,
0A5648E22C14C7E1008FB5F6 /* ViewModel.swift in Sources */,
0A5648DF2C14C583008FB5F6 /* EmbeddingsView.swift in Sources */,
04A2EC422D22A513009C9AED /* ChatWithFormatView.swift in Sources */,
);
runOnlyForDeploymentPostprocessing = 0;
};
Expand Down Expand Up @@ -429,13 +436,17 @@
/* End XCConfigurationList section */

/* Begin XCLocalSwiftPackageReference section */
0AEA9B8E2CFCFD0100227D01 /* XCLocalSwiftPackageReference "../../OllamaKit" */ = {
04C1CC072D21E8B8005362B0 /* XCLocalSwiftPackageReference "../../OllamaKit" */ = {
isa = XCLocalSwiftPackageReference;
relativePath = ../../OllamaKit;
};
/* End XCLocalSwiftPackageReference section */

/* Begin XCSwiftPackageProductDependency section */
04C1CC082D21E8B8005362B0 /* OllamaKit */ = {
isa = XCSwiftPackageProductDependency;
productName = OllamaKit;
};
0A5648ED2C150C66008FB5F6 /* OllamaKit */ = {
isa = XCSwiftPackageProductDependency;
productName = OllamaKit;
Expand Down

This file was deleted.

6 changes: 5 additions & 1 deletion Playground/OKPlayground/Views/AppView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ struct AppView: View {
NavigationLink("Chat with Tools") {
ChatWithToolsView()
}


NavigationLink("Chat with Format") {
ChatWithFormatView()
}

NavigationLink("Generate") {
GenerateView()
}
Expand Down
171 changes: 171 additions & 0 deletions Playground/OKPlayground/Views/ChatWithFormatView.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
//
// ChatWithFormatView.swift
// OKPlayground
//
// Created by Michel-Andre Chirita on 30/12/2024.
//

import Combine
import OllamaKit
import SwiftUI

struct ChatWithFormatView: View {

enum ViewState {
case idle
case loading
case error(String)
}

@Environment(ViewModel.self) private var viewModel

@State private var model: String? = nil
/// TIP: be sure to include "return as JSON" in your prompt
@State private var prompt = "Lists of the 10 biggest countries in the world with their iso code as id, name and capital, return as JSON"
@State private var cancellables = Set<AnyCancellable>()
@State private var viewState: ViewState = .idle

@State private var responseItems: [ResponseItem] = []

var body: some View {
NavigationStack {
Form {
Section {
Picker("Model", selection: $model) {
ForEach(viewModel.models, id: \.self) { model in
Text(model)
.tag(model as String?)
}
}

TextField("Prompt", text: $prompt, axis: .vertical)
.lineLimit(5)
}

Section {
Button("Chat Async", action: actionAsync)
Button("Chat Combine", action: actionCombine)
}

switch viewState {
case .idle:
EmptyView()

case .loading:
ProgressView()
.id(UUID())

case .error(let error):
Text(error)
.foregroundStyle(.red)
}

Section("Response") {
ForEach(responseItems) { item in
Text("Country: " + item.country + ", capital: " + item.capital)
}
}
}
.navigationTitle("Chat with Format")
.navigationBarTitleDisplayMode(.inline)
.onAppear {
model = viewModel.models.first
}
}
}

func actionAsync() {
clearResponse()

guard let model = model else { return }
let messages = [OKChatRequestData.Message(role: .user, content: prompt)]
var data = OKChatRequestData(model: model, messages: messages, format: getFormat())
data.options = OKCompletionOptions(temperature: 0) /// TIP: better results with temperature = 0
self.viewState = .loading

Task {
do {
var message: String = ""
for try await chunk in viewModel.ollamaKit.chat(data: data) {
if let content = chunk.message?.content {
message.append(content)
}
if chunk.done {
self.viewState = .idle
decodeResponse(message)
}
}
} catch {
print("Error:", error.localizedDescription)
self.viewState = .error(error.localizedDescription)
}
}
}

func actionCombine() {
clearResponse()

guard let model = model else { return }
let messages = [OKChatRequestData.Message(role: .user, content: prompt)]
var data = OKChatRequestData(model: model, messages: messages, format: getFormat())
data.options = OKCompletionOptions(temperature: 0) /// TIP: better results with temperature = 0
self.viewState = .loading

var message: String = ""
viewModel.ollamaKit.chat(data: data)
.compactMap { $0.message?.content }
.scan("", { result, nextChunk in
result + nextChunk
})
.sink { completion in
switch completion {
case .finished:
print("Finished")
decodeResponse(message)
self.viewState = .idle
case .failure(let error):
print("Error:", error.localizedDescription)
self.viewState = .error(error.localizedDescription)
}
} receiveValue: { value in
message = value
}
.store(in: &cancellables)
}

private func getFormat() -> OKJSONValue {
return
.object(["type": .string("array"),
"items": .object([
"type" : .string("object"),
"properties": .object([
"id": .object(["type" : .string("string")]),
"country": .object(["type" : .string("string")]),
"capital": .object(["type" : .string("string")]),
]),
"required": .array([.string("id"), .string("country"), .string("capital")])
])
])
}

private func decodeResponse(_ content: String) {
do {
guard let data = content.data(using: .utf8) else { return }
let response = try JSONDecoder().decode([ResponseItem].self, from: data)
self.responseItems = response
} catch {
print("Error message: \(error)")
self.viewState = .error(error.localizedDescription)
}
}

private func clearResponse() {
self.responseItems = []
}
}

struct ResponseItem: Identifiable, Codable {
let id: String
let country: String
let capital: String
}
14 changes: 10 additions & 4 deletions Sources/OllamaKit/RequestData/OKChatRequestData.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,20 @@ public struct OKChatRequestData: Sendable {

/// An optional array of ``OKJSONValue`` representing the tools available for tool calling in the chat.
public let tools: [OKJSONValue]?


/// Optional ``OKJSONValue`` representing the JSON schema for the response.
/// Be sure to also include "return as JSON" in your prompt
public let format: OKJSONValue?

/// Optional ``OKCompletionOptions`` providing additional configuration for the chat request.
public var options: OKCompletionOptions?

public init(model: String, messages: [Message], tools: [OKJSONValue]? = nil) {
public init(model: String, messages: [Message], tools: [OKJSONValue]? = nil, format: OKJSONValue? = nil) {
self.stream = tools == nil
self.model = model
self.messages = messages
self.tools = tools
self.format = format
}

/// A structure that represents a single message in the chat request.
Expand Down Expand Up @@ -68,13 +73,14 @@ extension OKChatRequestData: Encodable {
try container.encode(model, forKey: .model)
try container.encode(messages, forKey: .messages)
try container.encodeIfPresent(tools, forKey: .tools)

try container.encodeIfPresent(format, forKey: .format)

if let options {
try options.encode(to: encoder)
}
}

private enum CodingKeys: String, CodingKey {
case stream, model, messages, tools
case stream, model, messages, tools, format
}
}
14 changes: 10 additions & 4 deletions Sources/OllamaKit/RequestData/OKGenerateRequestData.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,13 @@ public struct OKGenerateRequestData: Sendable {
/// A string containing the initial input or prompt.
public let prompt: String

/// /// An optional array of base64-encoded images.
/// An optional array of base64-encoded images.
public let images: [String]?


/// Optional ``OKJSONValue`` representing the JSON schema for the response.
/// Be sure to also include "return as JSON" in your prompt
public let format: OKJSONValue?

/// An optional string specifying the system message.
public var system: String?

Expand All @@ -29,11 +33,12 @@ public struct OKGenerateRequestData: Sendable {
/// Optional ``OKCompletionOptions`` providing additional configuration for the generation request.
public var options: OKCompletionOptions?

public init(model: String, prompt: String, images: [String]? = nil) {
public init(model: String, prompt: String, images: [String]? = nil, format: OKJSONValue? = nil) {
self.stream = true
self.model = model
self.prompt = prompt
self.images = images
self.format = format
}
}

Expand All @@ -44,6 +49,7 @@ extension OKGenerateRequestData: Encodable {
try container.encode(model, forKey: .model)
try container.encode(prompt, forKey: .prompt)
try container.encodeIfPresent(images, forKey: .images)
try container.encodeIfPresent(format, forKey: .format)
try container.encodeIfPresent(system, forKey: .system)
try container.encodeIfPresent(context, forKey: .context)

Expand All @@ -53,6 +59,6 @@ extension OKGenerateRequestData: Encodable {
}

private enum CodingKeys: String, CodingKey {
case stream, model, prompt, images, system, context
case stream, model, prompt, images, format, system, context
}
}

0 comments on commit eec9a25

Please sign in to comment.