diff --git a/Package.resolved b/Package.resolved index e65252d..4e67108 100644 --- a/Package.resolved +++ b/Package.resolved @@ -1,4 +1,5 @@ { + "originHash" : "1b286c8c0a077892b2e712977f467e5f49595367ac433daf2de7e17cc43f93a6", "pins" : [ { "identity" : "swift-docc-plugin", @@ -17,7 +18,16 @@ "revision" : "b45d1f2ed151d057b54504d653e0da5552844e34", "version" : "1.0.0" } + }, + { + "identity" : "swift-json-schema", + "kind" : "remoteSourceControl", + "location" : "https://github.com/1amageek/swift-json-schema.git", + "state" : { + "branch" : "main", + "revision" : "ef7c71dcae944c18792a2164501394d111501004" + } } ], - "version" : 2 + "version" : 3 } diff --git a/Package.swift b/Package.swift index 9a0a5f8..9a019ce 100644 --- a/Package.swift +++ b/Package.swift @@ -16,12 +16,15 @@ let package = Package( targets: ["OllamaKit"]), ], dependencies: [ - .package(url: "https://github.com/apple/swift-docc-plugin.git", .upToNextMajor(from: "1.3.0")) + .package(url: "https://github.com/apple/swift-docc-plugin.git", .upToNextMajor(from: "1.3.0")), + .package(url: "https://github.com/1amageek/swift-json-schema.git", branch: "main") ], targets: [ .target( name: "OllamaKit", - dependencies: []), + dependencies: [ + .product(name: "JSONSchema", package: "swift-json-schema") + ]), .testTarget( name: "OllamaKitTests", dependencies: ["OllamaKit"]), diff --git a/Package@swift-5.9.swift b/Package@swift-5.9.swift index 1177966..84bc4d4 100644 --- a/Package@swift-5.9.swift +++ b/Package@swift-5.9.swift @@ -16,12 +16,15 @@ let package = Package( targets: ["OllamaKit"]), ], dependencies: [ - .package(url: "https://github.com/apple/swift-docc-plugin.git", .upToNextMajor(from: "1.3.0")) + .package(url: "https://github.com/apple/swift-docc-plugin.git", .upToNextMajor(from: "1.3.0")), + .package(url: "https://github.com/kevinhermawan/swift-json-schema.git", .upToNextMajor(from: "2.0.1")) ], targets: [ .target( name: "OllamaKit", - dependencies: []), + dependencies: [ + .product(name: "JSONSchema", package: "swift-json-schema") + ]), .testTarget( name: "OllamaKitTests", dependencies: ["OllamaKit"]), diff --git a/Playground/OKPlayground/Views/ChatView.swift b/Playground/OKPlayground/Views/ChatView.swift index e6a91f1..09e73b5 100644 --- a/Playground/OKPlayground/Views/ChatView.swift +++ b/Playground/OKPlayground/Views/ChatView.swift @@ -13,7 +13,7 @@ struct ChatView: View { @Environment(ViewModel.self) private var viewModel @State private var model: String? = nil - @State private var temperature: Double = 0.5 + @State private var temperature: Float = 0.5 @State private var prompt = "" @State private var response = "" @State private var cancellables = Set() diff --git a/Playground/OKPlayground/Views/ChatWithFormatView.swift b/Playground/OKPlayground/Views/ChatWithFormatView.swift index fd33391..51c1807 100644 --- a/Playground/OKPlayground/Views/ChatWithFormatView.swift +++ b/Playground/OKPlayground/Views/ChatWithFormatView.swift @@ -8,6 +8,7 @@ import Combine import OllamaKit import SwiftUI +import JSONSchema struct ChatWithFormatView: View { @@ -133,19 +134,17 @@ struct ChatWithFormatView: View { .store(in: &cancellables) } - private func getFormat() -> OKJSONValue { - return - .object(["type": .string("array"), - "items": .object([ - "type" : .string("object"), - "properties": .object([ - "id": .object(["type" : .string("string")]), - "country": .object(["type" : .string("string")]), - "capital": .object(["type" : .string("string")]), - ]), - "required": .array([.string("id"), .string("country"), .string("capital")]) - ]) - ]) + private func getFormat() -> JSONSchema { + return .array( + items:.object( + properties: [ + "id": .string(), + "country": .string(), + "capital": .string() + ], + required: ["id", "country", "capital"] + ) + ) } private func decodeResponse(_ content: String) { diff --git a/Playground/OKPlayground/Views/ChatWithToolsView.swift b/Playground/OKPlayground/Views/ChatWithToolsView.swift index 3717901..0de3805 100644 --- a/Playground/OKPlayground/Views/ChatWithToolsView.swift +++ b/Playground/OKPlayground/Views/ChatWithToolsView.swift @@ -105,37 +105,32 @@ struct ChatWithToolsView: View { .store(in: &cancellables) } - private func getTools() -> [OKJSONValue] { + private func getTools() -> [OKTool] { return [ - .object([ - "type": .string("function"), - "function": .object([ - "name": .string("get_current_weather"), - "description": .string("Get the current weather for a location"), - "parameters": .object([ - "type": .string("object"), - "properties": .object([ - "location": .object([ - "type": .string("string"), - "description": .string("The location to get the weather for, e.g. San Francisco, CA") - ]), - "format": .object([ - "type": .string("string"), - "description": .string("The format to return the weather in, e.g. 'celsius' or 'fahrenheit'"), - "enum": .array([.string("celsius"), .string("fahrenheit")]) - ]) - ]), - "required": .array([.string("location"), .string("format")]) - ]) - ]) - ]) + .function( + .init( + name: "get_current_weather", + description: "Get the current weather for a location", + parameters: + .object( + properties: [ + "location": .string( + description: "The location to get the weather for, e.g. San Francisco, CA" + ), + "format": .enum(description: "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'", values: [.string("celsius"), .string("fahrenheit")]) + ], + required: ["location", "format"] + ) + ) + ) + ] } private func setResponses(_ function: OKChatResponse.Message.ToolCall.Function) { self.toolCalledResponse = function.name ?? "" self.argumentsResponse = "\(function.arguments ?? .string("No arguments"))" - + if let arguments = function.arguments { switch arguments { case .object(let argDict): diff --git a/Playground/OKPlayground/Views/EmbeddingsView.swift b/Playground/OKPlayground/Views/EmbeddingsView.swift index 27a3160..27503cf 100644 --- a/Playground/OKPlayground/Views/EmbeddingsView.swift +++ b/Playground/OKPlayground/Views/EmbeddingsView.swift @@ -14,7 +14,7 @@ struct EmbeddingsView: View { @State private var model: String? = nil @State private var prompt = "" - @State private var embedding = [Double]() + @State private var embedding = [Float]() @State private var cancellables = Set() var body: some View { diff --git a/Playground/OKPlayground/Views/GenerateView.swift b/Playground/OKPlayground/Views/GenerateView.swift index 3923d8f..81ccabc 100644 --- a/Playground/OKPlayground/Views/GenerateView.swift +++ b/Playground/OKPlayground/Views/GenerateView.swift @@ -13,7 +13,7 @@ struct GenerateView: View { @Environment(ViewModel.self) private var viewModel @State private var model: String? = nil - @State private var temperature: Double = 0.5 + @State private var temperature: Float = 0.5 @State private var prompt = "" @State private var response = "" @State private var cancellables = Set() diff --git a/Sources/OllamaKit/OllamaKit+Chat.swift b/Sources/OllamaKit/OllamaKit+Chat.swift index 71dd0bd..fbaac31 100644 --- a/Sources/OllamaKit/OllamaKit+Chat.swift +++ b/Sources/OllamaKit/OllamaKit+Chat.swift @@ -9,99 +9,60 @@ import Combine import Foundation extension OllamaKit { - /// Establishes an asynchronous stream for chat responses from the Ollama API, based on the provided data. + /// Starts a stream for chat responses from the Ollama API. /// - /// This method sets up a streaming connection using Swift's concurrency features, allowing for real-time data handling as chat responses are generated by the Ollama API. - /// - /// Example usage - /// - /// ```swift - /// let ollamaKit = OllamaKit() - /// let chatData = OKChatRequestData(/* parameters */) - /// - /// Task { - /// do { - /// for try await response in ollamaKit.chat(data: chatData) { - /// // Handle each chat response - /// } - /// } catch { - /// // Handle error - /// } - /// } - /// ``` - /// - /// Example usage with tools + /// This method allows real-time handling of chat responses using Swift's concurrency. /// + /// Example usage: /// ```swift /// let ollamaKit = OllamaKit() /// let chatData = OKChatRequestData( - /// /* parameters */, + /// model: "example-model", + /// messages: [ + /// .user("What's the weather like in Tokyo?") + /// ], /// tools: [ - /// .object([ - /// "type": .string("function"), - /// "function": .object([ - /// "name": .string("get_current_weather"), - /// "description": .string("Get the current weather for a location"), - /// "parameters": .object([ - /// "type": .string("object"), - /// "properties": .object([ - /// "location": .object([ - /// "type": .string("string"), - /// "description": .string("The location to get the weather for, e.g. San Francisco, CA") - /// ]), - /// "format": .object([ - /// "type": .string("string"), - /// "description": .string("The format to return the weather in, e.g. 'celsius' or 'fahrenheit'"), - /// "enum": .array([.string("celsius"), .string("fahrenheit")]) - /// ]) - /// ]), - /// "required": .array([.string("location"), .string("format")]) - /// ]) - /// ]) - /// ]) + /// .function( + /// OKFunction( + /// name: "get_current_weather", + /// description: "Fetch current weather information.", + /// parameters: .object( + /// description: "Parameters for fetching weather", + /// properties: [ + /// "location": .string( + /// description: "The location to get the weather for, e.g., Tokyo" + /// ), + /// "format": .enum( + /// description: "The format for the weather, e.g., 'celsius'.", + /// values: [ + /// .string("celsius"), + /// .string("fahrenheit") + /// ] + /// ) + /// ], + /// required: ["location", "format"] + /// ) + /// ) + /// ) /// ] /// ) /// /// Task { /// do { /// for try await response in ollamaKit.chat(data: chatData) { - /// if let toolCalls = response.message?.toolCalls { - /// for toolCall in toolCalls { - /// if let function = toolCall.function { - /// print("Tool called: \(function.name ?? "")") - /// - /// if let arguments = function.arguments { - /// switch arguments { - /// case .object(let argDict): - /// if let location = argDict["location"], case .string(let locationValue) = location { - /// print("Location: \(locationValue)") - /// } - /// - /// if let format = argDict["format"], case .string(let formatValue) = format { - /// print("Format: \(formatValue)") - /// } - /// default: - /// print("Unexpected arguments format") - /// } - /// } else { - /// print("No arguments provided") - /// } - /// } - /// } - /// } + /// // Handle each response here + /// print(response) /// } /// } catch { - /// // Handle error + /// print("Error: \(error)") /// } /// } /// ``` - /// - /// - Parameter data: The ``OKChatRequestData`` used to initiate the chat streaming from the Ollama API. - /// - Returns: An `AsyncThrowingStream` emitting the live stream of chat responses from the Ollama API. + /// - Parameter data: The ``OKChatRequestData`` containing chat request details. + /// - Returns: An `AsyncThrowingStream` emitting chat responses from the Ollama API. public func chat(data: OKChatRequestData) -> AsyncThrowingStream { do { let request = try OKRouter.chat(data: data).asURLRequest(with: baseURL) - return OKHTTPClient.shared.stream(request: request, with: OKChatResponse.self) } catch { return AsyncThrowingStream { continuation in @@ -110,95 +71,63 @@ extension OllamaKit { } } - /// Establishes a Combine publisher for streaming chat responses from the Ollama API, based on the provided data. - /// - /// This method sets up a streaming connection using the Combine framework, facilitating real-time data handling as chat responses are generated by the Ollama API. - /// - /// Example usage - /// - /// ```swift - /// let ollamaKit = OllamaKit() - /// let chatData = OKChatRequestData(/* parameters */) - /// - /// ollamaKit.chat(data: chatData) - /// .sink(receiveCompletion: { completion in - /// // Handle completion or error - /// }, receiveValue: { chatResponse in - /// // Handle each chat response - /// }) - /// .store(in: &cancellables) - /// ``` + /// Publishes a stream of chat responses from the Ollama API using Combine. /// - /// Example usage with tools + /// Enables real-time data handling through Combine's reactive streams. /// + /// Example usage: /// ```swift /// let ollamaKit = OllamaKit() /// let chatData = OKChatRequestData( - /// /* parameters */, + /// model: "example-model", + /// messages: [ + /// .user("What's the weather like in Tokyo?") + /// ], /// tools: [ - /// .object([ - /// "type": .string("function"), - /// "function": .object([ - /// "name": .string("get_current_weather"), - /// "description": .string("Get the current weather for a location"), - /// "parameters": .object([ - /// "type": .string("object"), - /// "properties": .object([ - /// "location": .object([ - /// "type": .string("string"), - /// "description": .string("The location to get the weather for, e.g. San Francisco, CA") - /// ]), - /// "format": .object([ - /// "type": .string("string"), - /// "description": .string("The format to return the weather in, e.g. 'celsius' or 'fahrenheit'"), - /// "enum": .array([.string("celsius"), .string("fahrenheit")]) - /// ]) - /// ]), - /// "required": .array([.string("location"), .string("format")]) - /// ]) - /// ]) - /// ]) + /// .function( + /// OKFunction( + /// name: "get_current_weather", + /// description: "Fetch current weather information.", + /// parameters: .object( + /// description: "Parameters for fetching weather", + /// properties: [ + /// "location": .string( + /// description: "The location to get the weather for, e.g., Tokyo" + /// ), + /// "format": .enum( + /// description: "The format for the weather, e.g., 'celsius'.", + /// values: [ + /// .string("celsius"), + /// .string("fahrenheit") + /// ] + /// ) + /// ], + /// required: ["location", "format"] + /// ) + /// ) + /// ) /// ] /// ) /// /// ollamaKit.chat(data: chatData) /// .sink(receiveCompletion: { completion in - /// // Handle completion or error - /// }, receiveValue: { chatResponse in - /// if let toolCalls = chatResponse.message?.toolCalls { - /// for toolCall in toolCalls { - /// if let function = toolCall.function { - /// print("Tool called: \(function.name ?? "")") - /// - /// if let arguments = function.arguments { - /// switch arguments { - /// case .object(let argDict): - /// if let location = argDict["location"], case .string(let locationValue) = location { - /// print("Location: \(locationValue)") - /// } - /// - /// if let format = argDict["format"], case .string(let formatValue) = format { - /// print("Format: \(formatValue)") - /// } - /// default: - /// print("Unexpected arguments format") - /// } - /// } else { - /// print("No arguments provided") - /// } - /// } - /// } + /// switch completion { + /// case .finished: + /// print("Stream finished") + /// case .failure(let error): + /// print("Error: \(error)") /// } + /// }, receiveValue: { response in + /// // Handle each response here + /// print(response) /// }) /// .store(in: &cancellables) /// ``` - /// - /// - Parameter data: The ``OKChatRequestData`` used to initiate the chat streaming from the Ollama API. - /// - Returns: An `AnyPublisher` emitting the live stream of chat responses from the Ollama API. + /// - Parameter data: The ``OKChatRequestData`` containing chat request details. + /// - Returns: An `AnyPublisher` emitting chat responses from the Ollama API. public func chat(data: OKChatRequestData) -> AnyPublisher { do { let request = try OKRouter.chat(data: data).asURLRequest(with: baseURL) - return OKHTTPClient.shared.stream(request: request, with: OKChatResponse.self) } catch { return Fail(error: error).eraseToAnyPublisher() diff --git a/Sources/OllamaKit/OllamaKit+CopyModel.swift b/Sources/OllamaKit/OllamaKit+CopyModel.swift index af63a6d..91e0caf 100644 --- a/Sources/OllamaKit/OllamaKit+CopyModel.swift +++ b/Sources/OllamaKit/OllamaKit+CopyModel.swift @@ -23,7 +23,6 @@ extension OllamaKit { /// - Throws: An error if the request to copy the model fails. public func copyModel(data: OKCopyModelRequestData) async throws -> Void { let request = try OKRouter.copyModel(data: data).asURLRequest(with: baseURL) - try await OKHTTPClient.shared.send(request: request) } @@ -49,7 +48,6 @@ extension OllamaKit { public func copyModel(data: OKCopyModelRequestData) -> AnyPublisher { do { let request = try OKRouter.copyModel(data: data).asURLRequest(with: baseURL) - return OKHTTPClient.shared.send(request: request) } catch { return Fail(error: error).eraseToAnyPublisher() diff --git a/Sources/OllamaKit/OllamaKit+DeleteModel.swift b/Sources/OllamaKit/OllamaKit+DeleteModel.swift index 82a6916..795ff8d 100644 --- a/Sources/OllamaKit/OllamaKit+DeleteModel.swift +++ b/Sources/OllamaKit/OllamaKit+DeleteModel.swift @@ -23,7 +23,6 @@ extension OllamaKit { /// - Throws: An error if the request to delete the model fails. public func deleteModel(data: OKDeleteModelRequestData) async throws -> Void { let request = try OKRouter.deleteModel(data: data).asURLRequest(with: baseURL) - try await OKHTTPClient.shared.send(request: request) } @@ -49,7 +48,6 @@ extension OllamaKit { public func deleteModel(data: OKDeleteModelRequestData) -> AnyPublisher { do { let request = try OKRouter.deleteModel(data: data).asURLRequest(with: baseURL) - return OKHTTPClient.shared.send(request: request) } catch { return Fail(error: error).eraseToAnyPublisher() diff --git a/Sources/OllamaKit/OllamaKit+Embeddings.swift b/Sources/OllamaKit/OllamaKit+Embeddings.swift index 85d457f..1afe411 100644 --- a/Sources/OllamaKit/OllamaKit+Embeddings.swift +++ b/Sources/OllamaKit/OllamaKit+Embeddings.swift @@ -24,7 +24,6 @@ extension OllamaKit { /// - Throws: An error if the request fails or the response can't be decoded. public func embeddings(data: OKEmbeddingsRequestData) async throws -> OKEmbeddingsResponse { let request = try OKRouter.embeddings(data: data).asURLRequest(with: baseURL) - return try await OKHTTPClient.shared.send(request: request, with: OKEmbeddingsResponse.self) } @@ -50,7 +49,6 @@ extension OllamaKit { public func embeddings(data: OKEmbeddingsRequestData) -> AnyPublisher { do { let request = try OKRouter.embeddings(data: data).asURLRequest(with: baseURL) - return OKHTTPClient.shared.send(request: request, with: OKEmbeddingsResponse.self) } catch { return Fail(error: error).eraseToAnyPublisher() diff --git a/Sources/OllamaKit/OllamaKit+Generate.swift b/Sources/OllamaKit/OllamaKit+Generate.swift index 6c8d9b8..9344d4f 100644 --- a/Sources/OllamaKit/OllamaKit+Generate.swift +++ b/Sources/OllamaKit/OllamaKit+Generate.swift @@ -33,7 +33,6 @@ extension OllamaKit { public func generate(data: OKGenerateRequestData) -> AsyncThrowingStream { do { let request = try OKRouter.generate(data: data).asURLRequest(with: baseURL) - return OKHTTPClient.shared.stream(request: request, with: OKGenerateResponse.self) } catch { return AsyncThrowingStream { continuation in @@ -64,7 +63,6 @@ extension OllamaKit { public func generate(data: OKGenerateRequestData) -> AnyPublisher { do { let request = try OKRouter.generate(data: data).asURLRequest(with: baseURL) - return OKHTTPClient.shared.stream(request: request, with: OKGenerateResponse.self) } catch { return Fail(error: error).eraseToAnyPublisher() diff --git a/Sources/OllamaKit/OllamaKit+ModelInfo.swift b/Sources/OllamaKit/OllamaKit+ModelInfo.swift index cabb515..db01560 100644 --- a/Sources/OllamaKit/OllamaKit+ModelInfo.swift +++ b/Sources/OllamaKit/OllamaKit+ModelInfo.swift @@ -24,7 +24,6 @@ extension OllamaKit { /// - Throws: An error if the request fails or the response can't be decoded. public func modelInfo(data: OKModelInfoRequestData) async throws -> OKModelInfoResponse { let request = try OKRouter.modelInfo(data: data).asURLRequest(with: baseURL) - return try await OKHTTPClient.shared.send(request: request, with: OKModelInfoResponse.self) } @@ -50,7 +49,6 @@ extension OllamaKit { public func modelInfo(data: OKModelInfoRequestData) -> AnyPublisher { do { let request = try OKRouter.modelInfo(data: data).asURLRequest(with: baseURL) - return OKHTTPClient.shared.send(request: request, with: OKModelInfoResponse.self) } catch { return Fail(error: error).eraseToAnyPublisher() diff --git a/Sources/OllamaKit/OllamaKit+Models.swift b/Sources/OllamaKit/OllamaKit+Models.swift index 6447301..520b0db 100644 --- a/Sources/OllamaKit/OllamaKit+Models.swift +++ b/Sources/OllamaKit/OllamaKit+Models.swift @@ -22,7 +22,6 @@ extension OllamaKit { /// - Throws: An error if the request fails or the response can't be decoded. public func models() async throws -> OKModelResponse { let request = try OKRouter.models.asURLRequest(with: baseURL) - return try await OKHTTPClient.shared.send(request: request, with: OKModelResponse.self) } @@ -46,7 +45,6 @@ extension OllamaKit { public func models() -> AnyPublisher { do { let request = try OKRouter.models.asURLRequest(with: baseURL) - return OKHTTPClient.shared.send(request: request, with: OKModelResponse.self) } catch { return Fail(error: error).eraseToAnyPublisher() diff --git a/Sources/OllamaKit/OllamaKit+Reachable.swift b/Sources/OllamaKit/OllamaKit+Reachable.swift index 20b9b4d..6cba3ed 100644 --- a/Sources/OllamaKit/OllamaKit+Reachable.swift +++ b/Sources/OllamaKit/OllamaKit+Reachable.swift @@ -48,7 +48,6 @@ extension OllamaKit { public func reachable() -> AnyPublisher { do { let request = try OKRouter.root.asURLRequest(with: baseURL) - return OKHTTPClient.shared.send(request: request) .map { _ in true } .replaceError(with: false) diff --git a/Sources/OllamaKit/RequestData/Completion/OKCompletionOptions.swift b/Sources/OllamaKit/RequestData/Completion/OKCompletionOptions.swift index 7dbdad0..665bbbe 100644 --- a/Sources/OllamaKit/RequestData/Completion/OKCompletionOptions.swift +++ b/Sources/OllamaKit/RequestData/Completion/OKCompletionOptions.swift @@ -19,13 +19,13 @@ public struct OKCompletionOptions: Encodable, Sendable { /// (Lower values result in slower adjustments, higher values increase responsiveness.) /// This parameter, `mirostatEta`, adjusts how quickly the algorithm reacts to feedback /// from the generated text. A default value of 0.1 provides a moderate adjustment speed. - public var mirostatEta: Double? + public var mirostatEta: Float? /// Optional double controlling the balance between coherence and diversity. /// (Lower values lead to more focused and coherent text) /// The `mirostatTau` parameter sets the target perplexity level, influencing how /// creative or constrained the text generation should be. Default is 5.0. - public var mirostatTau: Double? + public var mirostatTau: Float? /// Optional integer setting the size of the context window for token generation. /// This defines the number of previous tokens the model considers when generating new tokens. @@ -40,13 +40,13 @@ public struct OKCompletionOptions: Encodable, Sendable { /// Optional double setting the penalty strength for repetitions. /// A higher value increases the penalty for repeated tokens, discouraging repetition. /// The default value is 1.1, providing moderate repetition control. - public var repeatPenalty: Double? + public var repeatPenalty: Float? /// Optional double to control the model's creativity. /// (Higher values increase creativity and randomness) /// The `temperature` parameter adjusts the randomness of predictions; higher values /// like 0.8 make outputs more creative and diverse. The default is 0.7. - public var temperature: Double? + public var temperature: Float? /// Optional integer for setting a random number seed for generation consistency. /// Specifying a seed ensures the same output for the same prompt and parameters, @@ -61,7 +61,7 @@ public struct OKCompletionOptions: Encodable, Sendable { /// Optional double for tail free sampling, reducing impact of less probable tokens. /// `tfsZ` adjusts how much the model avoids unlikely tokens, with higher values /// reducing their influence. A value of 1.0 disables this feature. - public var tfsZ: Double? + public var tfsZ: Float? /// Optional integer for the maximum number of tokens to predict. /// `numPredict` sets the upper limit for the number of tokens to generate. @@ -76,14 +76,14 @@ public struct OKCompletionOptions: Encodable, Sendable { /// Optional double working with top-k to balance text diversity and focus. /// `topP` (nucleus sampling) retains tokens that cumulatively account for a certain /// probability mass, adding flexibility beyond `topK`. A value like 0.9 increases diversity. - public var topP: Double? + public var topP: Float? /// Optional double for the minimum probability threshold for token inclusion. /// `minP` ensures that tokens below a certain probability threshold are excluded, /// focusing the model's output on more probable sequences. Default is 0.0, meaning no filtering. - public var minP: Double? + public var minP: Float? - public init(mirostat: Int? = nil, mirostatEta: Double? = nil, mirostatTau: Double? = nil, numCtx: Int? = nil, repeatLastN: Int? = nil, repeatPenalty: Double? = nil, temperature: Double? = nil, seed: Int? = nil, stop: String? = nil, tfsZ: Double? = nil, numPredict: Int? = nil, topK: Int? = nil, topP: Double? = nil, minP: Double? = nil) { + public init(mirostat: Int? = nil, mirostatEta: Float? = nil, mirostatTau: Float? = nil, numCtx: Int? = nil, repeatLastN: Int? = nil, repeatPenalty: Float? = nil, temperature: Float? = nil, seed: Int? = nil, stop: String? = nil, tfsZ: Float? = nil, numPredict: Int? = nil, topK: Int? = nil, topP: Float? = nil, minP: Float? = nil) { self.mirostat = mirostat self.mirostatEta = mirostatEta self.mirostatTau = mirostatTau diff --git a/Sources/OllamaKit/RequestData/OKChatRequestData.swift b/Sources/OllamaKit/RequestData/OKChatRequestData.swift index c515132..9ffdcfb 100644 --- a/Sources/OllamaKit/RequestData/OKChatRequestData.swift +++ b/Sources/OllamaKit/RequestData/OKChatRequestData.swift @@ -6,6 +6,7 @@ // import Foundation +import JSONSchema /// A structure that encapsulates data for chat requests to the Ollama API. public struct OKChatRequestData: Sendable { @@ -17,22 +18,47 @@ public struct OKChatRequestData: Sendable { /// An array of ``Message`` instances representing the content to be sent to the Ollama API. public let messages: [Message] - /// An optional array of ``OKJSONValue`` representing the tools available for tool calling in the chat. - public let tools: [OKJSONValue]? + /// An optional array of ``OKTool`` representing the tools available for tool calling in the chat. + public let tools: [OKTool]? - /// Optional ``OKJSONValue`` representing the JSON schema for the response. + /// Optional ``JSONSchema`` representing the JSON schema for the response. /// Be sure to also include "return as JSON" in your prompt - public let format: OKJSONValue? + public let format: JSONSchema? /// Optional ``OKCompletionOptions`` providing additional configuration for the chat request. public var options: OKCompletionOptions? - public init(model: String, messages: [Message], tools: [OKJSONValue]? = nil, format: OKJSONValue? = nil) { + + public init( + model: String, + messages: [Message], + tools: [OKTool]? = nil, + format: JSONSchema? = nil, + options: OKCompletionOptions? = nil + ) { + self.stream = tools == nil + self.model = model + self.messages = messages + self.tools = tools + self.format = format + self.options = options + } + + public init( + model: String, + messages: [Message], + tools: [OKTool]? = nil, + format: JSONSchema? = nil, + with configureOptions: @Sendable (inout OKCompletionOptions) -> Void + ) { self.stream = tools == nil self.model = model self.messages = messages self.tools = tools + var options = OKCompletionOptions() + configureOptions(&options) self.format = format + self.options = options } /// A structure that represents a single message in the chat request. @@ -52,16 +78,48 @@ public struct OKChatRequestData: Sendable { self.images = images } - /// An enumeration that represents the role of the message sender. - public enum Role: String, Encodable, Sendable { - /// Indicates the message is from the system. + /// An enumeration representing the role of the message sender. + public enum Role: RawRepresentable, Encodable, Sendable { + + /// The message is from the system. case system - /// Indicates the message is from the assistant. + /// The message is from the assistant. case assistant - /// Indicates the message is from the user. + /// The message is from the user. case user + + /// A custom role with a specified name. + case custom(String) + + // Initializer for RawRepresentable conformance + public init?(rawValue: String) { + switch rawValue { + case "system": + self = .system + case "assistant": + self = .assistant + case "user": + self = .user + default: + self = .custom(rawValue) + } + } + + // Computed property to get the raw value as a string. + public var rawValue: String { + switch self { + case .system: + return "system" + case .assistant: + return "assistant" + case .user: + return "user" + case .custom(let value): + return value + } + } } } } @@ -74,7 +132,6 @@ extension OKChatRequestData: Encodable { try container.encode(messages, forKey: .messages) try container.encodeIfPresent(tools, forKey: .tools) try container.encodeIfPresent(format, forKey: .format) - if let options { try options.encode(to: encoder) } @@ -84,3 +141,22 @@ extension OKChatRequestData: Encodable { case stream, model, messages, tools, format } } + +extension OKChatRequestData.Message { + + public static func system(_ content: String, images: [String]? = nil) -> OKChatRequestData.Message { + .init(role: .system, content: content, images: images) + } + + public static func user(_ content: String, images: [String]? = nil) -> OKChatRequestData.Message { + .init(role: .user, content: content, images: images) + } + + public static func assistant(_ content: String, images: [String]? = nil) -> OKChatRequestData.Message { + .init(role: .assistant, content: content, images: images) + } + + public static func custom(name: String, _ content: String, images: [String]? = nil) -> OKChatRequestData.Message { + .init(role: .custom(name), content: content, images: images) + } +} diff --git a/Sources/OllamaKit/RequestData/OKEmbeddingsRequestData.swift b/Sources/OllamaKit/RequestData/OKEmbeddingsRequestData.swift index 5c3101e..e80bc95 100644 --- a/Sources/OllamaKit/RequestData/OKEmbeddingsRequestData.swift +++ b/Sources/OllamaKit/RequestData/OKEmbeddingsRequestData.swift @@ -21,8 +21,22 @@ public struct OKEmbeddingsRequestData: Encodable, Sendable { /// Optionally control how long the model will stay loaded into memory following the request (default: 5m) public var keepAlive: String? - public init(model: String, prompt: String) { + public init(model: String, prompt: String, options: OKCompletionOptions? = nil, keepAlive: String? = nil) { self.model = model self.prompt = prompt + self.options = options + self.keepAlive = keepAlive + } + + public init( + model: String, + prompt: String, + with configureOptions: @Sendable (inout OKCompletionOptions) -> Void + ) { + self.model = model + self.prompt = prompt + var options = OKCompletionOptions() + configureOptions(&options) + self.options = options } } diff --git a/Sources/OllamaKit/RequestData/OKGenerateRequestData.swift b/Sources/OllamaKit/RequestData/OKGenerateRequestData.swift index 3ba5625..059eafb 100644 --- a/Sources/OllamaKit/RequestData/OKGenerateRequestData.swift +++ b/Sources/OllamaKit/RequestData/OKGenerateRequestData.swift @@ -6,6 +6,7 @@ // import Foundation +import JSONSchema /// A structure that encapsulates the data required for generating responses using the Ollama API. public struct OKGenerateRequestData: Sendable { @@ -20,9 +21,9 @@ public struct OKGenerateRequestData: Sendable { /// An optional array of base64-encoded images. public let images: [String]? - /// Optional ``OKJSONValue`` representing the JSON schema for the response. + /// Optional ``JSONSchema`` representing the JSON schema for the response. /// Be sure to also include "return as JSON" in your prompt - public let format: OKJSONValue? + public let format: JSONSchema? /// An optional string specifying the system message. public var system: String? @@ -33,11 +34,39 @@ public struct OKGenerateRequestData: Sendable { /// Optional ``OKCompletionOptions`` providing additional configuration for the generation request. public var options: OKCompletionOptions? - public init(model: String, prompt: String, images: [String]? = nil, format: OKJSONValue? = nil) { + public init( + model: String, + prompt: String, + images: [String]? = nil, + system: String? = nil, + context: [Int]? = nil, + format: JSONSchema? = nil, + options: OKCompletionOptions? = nil + ) { self.stream = true self.model = model self.prompt = prompt self.images = images + self.system = system + self.context = context + self.format = format + self.options = options + } + + public init( + model: String, + prompt: String, + images: [String]? = nil, + format: JSONSchema? = nil, + with configureOptions: @Sendable (inout OKCompletionOptions) -> Void + ) { + self.stream = true + self.model = model + self.prompt = prompt + self.images = images + var options = OKCompletionOptions() + configureOptions(&options) + self.options = options self.format = format } } diff --git a/Sources/OllamaKit/Responses/OKChatResponse.swift b/Sources/OllamaKit/Responses/OKChatResponse.swift index f732f26..82c73b6 100644 --- a/Sources/OllamaKit/Responses/OKChatResponse.swift +++ b/Sources/OllamaKit/Responses/OKChatResponse.swift @@ -6,6 +6,7 @@ // import Foundation +import JSONSchema /// A structure that represents the response to a chat request from the Ollama API. public struct OKChatResponse: OKCompletionResponse, Decodable, Sendable { @@ -55,7 +56,7 @@ public struct OKChatResponse: OKCompletionResponse, Decodable, Sendable { public var toolCalls: [ToolCall]? /// An enumeration representing the role of the message sender. - public enum Role: String, Decodable, Sendable { + public enum Role: RawRepresentable, Decodable, Sendable { /// The message is from the system. case system @@ -64,6 +65,37 @@ public struct OKChatResponse: OKCompletionResponse, Decodable, Sendable { /// The message is from the user. case user + + /// A custom role with a specified name. + case custom(String) + + // Initializer for RawRepresentable conformance + public init?(rawValue: String) { + switch rawValue { + case "system": + self = .system + case "assistant": + self = .assistant + case "user": + self = .user + default: + self = .custom(rawValue) + } + } + + // Computed property to get the raw value as a string. + public var rawValue: String { + switch self { + case .system: + return "system" + case .assistant: + return "assistant" + case .user: + return "user" + case .custom(let value): + return value + } + } } /// A structure that represents a tool call in the response. diff --git a/Sources/OllamaKit/Responses/OKEmbeddingsResponse.swift b/Sources/OllamaKit/Responses/OKEmbeddingsResponse.swift index fe7f6b0..413558c 100644 --- a/Sources/OllamaKit/Responses/OKEmbeddingsResponse.swift +++ b/Sources/OllamaKit/Responses/OKEmbeddingsResponse.swift @@ -11,5 +11,5 @@ import Foundation public struct OKEmbeddingsResponse: Decodable, Sendable { /// An array of doubles representing the embeddings of the input prompt. - public let embedding: [Double]? + public let embedding: [Float]? } diff --git a/Sources/OllamaKit/Utils/OKHTTPClient.swift b/Sources/OllamaKit/Utils/OKHTTPClient.swift index 1cc9ac9..c767147 100644 --- a/Sources/OllamaKit/Utils/OKHTTPClient.swift +++ b/Sources/OllamaKit/Utils/OKHTTPClient.swift @@ -8,7 +8,7 @@ import Combine import Foundation -internal struct OKHTTPClient { +internal struct OKHTTPClient: Sendable { private let decoder: JSONDecoder = .default static let shared = OKHTTPClient() } diff --git a/Sources/OllamaKit/Utils/OKJSONValue.swift b/Sources/OllamaKit/Utils/OKJSONValue.swift index 32032d5..9e6a879 100644 --- a/Sources/OllamaKit/Utils/OKJSONValue.swift +++ b/Sources/OllamaKit/Utils/OKJSONValue.swift @@ -9,7 +9,7 @@ import Foundation public enum OKJSONValue: Codable, Sendable { case string(String) - case number(Double) + case number(Float) case integer(Int) case boolean(Bool) case array([OKJSONValue]) @@ -20,7 +20,7 @@ public enum OKJSONValue: Codable, Sendable { if let value = try? container.decode(String.self) { self = .string(value) - } else if let value = try? container.decode(Double.self) { + } else if let value = try? container.decode(Float.self) { self = .number(value) } else if let value = try? container.decode(Int.self) { self = .integer(value) diff --git a/Sources/OllamaKit/Utils/OKRouter.swift b/Sources/OllamaKit/Utils/OKRouter.swift index 0e3d280..55744c6 100644 --- a/Sources/OllamaKit/Utils/OKRouter.swift +++ b/Sources/OllamaKit/Utils/OKRouter.swift @@ -8,6 +8,7 @@ import Foundation internal enum OKRouter { + case root case models case modelInfo(data: OKModelInfoRequestData) diff --git a/Sources/OllamaKit/Utils/OKTool.swift b/Sources/OllamaKit/Utils/OKTool.swift new file mode 100644 index 0000000..1f5afd6 --- /dev/null +++ b/Sources/OllamaKit/Utils/OKTool.swift @@ -0,0 +1,47 @@ +// +// OKTool.swift +// OllamaKit +// +// Created by Norikazu Muramoto on 2025/01/11. +// + +import Foundation +import JSONSchema + +/// Represents a tool that can be used in the Ollama API chat. +public struct OKTool: Encodable, Sendable { + /// The type of the tool (e.g., "function"). + public let type: String + + /// The function details associated with the tool. + public let function: OKFunction + + public init(type: String, function: OKFunction) { + self.type = type + self.function = function + } + + /// Convenience method for creating a tool with type "function". + public static func function(_ function: OKFunction) -> OKTool { + return OKTool(type: "function", function: function) + } +} + +/// Represents a function used as a tool in the Ollama API chat. +public struct OKFunction: Encodable, Sendable { + /// The name of the function. + public let name: String + + /// A description of what the function does. + public let description: String + + /// Parameters required by the function, defined as a JSON schema. + public let parameters: JSONSchema + + public init(name: String, description: String, parameters: JSONSchema) { + self.name = name + self.description = description + self.parameters = parameters + } +} + diff --git a/Tests/OllamaKitTests/ModelTests.swift b/Tests/OllamaKitTests/ModelTests.swift new file mode 100644 index 0000000..7f67f6c --- /dev/null +++ b/Tests/OllamaKitTests/ModelTests.swift @@ -0,0 +1,188 @@ +// +// ModelTests.swift +// OllamaKit +// +// Created by Norikazu Muramoto on 2025/01/22. +// + + +import Testing +import Foundation +import OllamaKit + +@Test("Basic chat functionality for all models") +func testAllModelsBasicChat() async throws { + let models = ["deepseek-r1:8b", "phi4", "llama3.2"] + let ollamaKit = OllamaKit() + + for model in models { + print("Testing model: \(model)") + + // Basic chat test + let chatData = OKChatRequestData( + model: model, + messages: [.user("What is 2+2?")] + ) + + var response = "" + for try await chatResponse in ollamaKit.chat(data: chatData) { + if let content = chatResponse.message?.content { + response += content + } + } + + #expect(!response.isEmpty, "Model \(model) should return a response") + #expect(response.contains("4"), "Model \(model) should correctly answer 2+2=4") + } +} + +@Test("Tool calling functionality for all models") +func testAllModelsToolCalling() async throws { + let models = ["deepseek-r1:8b", "phi4", "llama3.2"] + let ollamaKit = OllamaKit() + + // Weather function definition + let weatherFunction = OKFunction( + name: "get_current_weather", + description: "Get the current weather in a given location", + parameters: .object( + description: "Parameters for the weather function", + properties: [ + "location": .string(description: "The city and state, e.g. San Francisco, CA"), + "unit": .string(description: "The temperature unit to use: 'celsius' or 'fahrenheit'") + ], + required: ["location", "unit"] + ) + ) + + for model in models { + print("Testing model: \(model) with tool calling") + + let chatData = OKChatRequestData( + model: model, + messages: [ + .system("You are a helpful assistant that uses the provided weather function when asked about weather."), + .user("What's the weather like in Tokyo?") + ], + tools: [.function(weatherFunction)] + ) + + var hasToolCall = false + var functionName: String? + + for try await chatResponse in ollamaKit.chat(data: chatData) { + if let toolCalls = chatResponse.message?.toolCalls { + hasToolCall = true + functionName = toolCalls.first?.function?.name + } + } + + #expect(hasToolCall, "Model \(model) should attempt to use the weather tool") + #expect(functionName == "get_current_weather", "Model \(model) should call the correct function") + } +} + +@Test("Response format validation for all models") +func testAllModelsResponseFormat() async throws { + let models = ["deepseek-r1:8b", "phi4", "llama3.2"] + let ollamaKit = OllamaKit() + + for model in models { + print("Testing model: \(model) response format") + + let chatData = OKChatRequestData( + model: model, + messages: [ + .system("You should respond in complete sentences."), + .user("List three colors.") + ] + ) + + var response = "" + var responseComplete = false + + for try await chatResponse in ollamaKit.chat(data: chatData) { + if let content = chatResponse.message?.content { + response += content + } + if chatResponse.done { + responseComplete = true + } + } + + #expect(responseComplete, "Model \(model) should complete its response") + #expect(response.contains("."), "Model \(model) should respond in complete sentences") + #expect(response.components(separatedBy: .whitespaces).count > 5, "Model \(model) should provide a substantial response") + } +} + +@Test("Error handling for all models") +func testAllModelsErrorHandling() async throws { + let models = ["deepseek-r1:8b", "phi4", "llama3.2"] + let ollamaKit = OllamaKit() + + for model in models { + print("Testing model: \(model) error handling") + + // Test with empty message + do { + let chatData = OKChatRequestData( + model: model, + messages: [] + ) + + var receivedResponse = false + for try await _ in ollamaKit.chat(data: chatData) { + receivedResponse = true + } + #expect(!receivedResponse, "Model \(model) should not process empty messages") + } catch { + // Error is expected + } + + // Test with invalid model name + do { + let chatData = OKChatRequestData( + model: model + "_invalid", + messages: [.user("Test")] + ) + + var receivedResponse = false + for try await _ in ollamaKit.chat(data: chatData) { + receivedResponse = true + } + #expect(!receivedResponse, "Invalid model name should not return response") + } catch { + // Error is expected + } + } +} + +@Test("Context handling for all models") +func testAllModelsContextHandling() async throws { + let models = ["deepseek-r1:8b", "phi4", "llama3.2"] + let ollamaKit = OllamaKit() + + for model in models { + print("Testing model: \(model) context handling") + + let chatData = OKChatRequestData( + model: model, + messages: [ + .system("You are a helpful assistant."), + .user("My name is Alice."), + .assistant("Nice to meet you, Alice!"), + .user("What's my name?") + ] + ) + + var response = "" + for try await chatResponse in ollamaKit.chat(data: chatData) { + if let content = chatResponse.message?.content { + response += content + } + } + + #expect(response.contains("Alice"), "Model \(model) should remember context") + } +} diff --git a/Tests/OllamaKitTests/OKChatRequestTests.swift b/Tests/OllamaKitTests/OKChatRequestTests.swift new file mode 100644 index 0000000..c84d077 --- /dev/null +++ b/Tests/OllamaKitTests/OKChatRequestTests.swift @@ -0,0 +1,203 @@ +import Testing +import Foundation +import JSONSchema +@testable import OllamaKit + +@Test("Basic chat request initialization") +func testBasicChatRequestInit() { + let messages = [ + OKChatRequestData.Message.user("Hello, how are you?") + ] + + let chatRequest = OKChatRequestData( + model: "llama2", + messages: messages + ) + + #expect(chatRequest.model == "llama2") + #expect(chatRequest.messages.count == 1) + #expect(chatRequest.messages[0].role.rawValue == "user") + #expect(chatRequest.messages[0].content == "Hello, how are you?") + #expect(chatRequest.tools == nil) + #expect(chatRequest.format == nil) +} + +@Test("Chat request with all message types") +func testChatRequestWithAllMessageTypes() { + let messages: [OKChatRequestData.Message] = [ + .system("You are a helpful assistant."), + .user("What's the weather?"), + .assistant("The weather is sunny."), + .custom(name: "weather_bot", "Temperature is 25°C") + ] + + let chatRequest = OKChatRequestData( + model: "llama2", + messages: messages + ) + + #expect(chatRequest.messages.count == 4) + #expect(chatRequest.messages[0].role.rawValue == "system") + #expect(chatRequest.messages[1].role.rawValue == "user") + #expect(chatRequest.messages[2].role.rawValue == "assistant") + #expect(chatRequest.messages[3].role.rawValue == "weather_bot") +} + +@Test("Chat request with tools and JSON schema") +func testChatRequestWithToolsAndSchema() { + let weatherFunction = OKFunction( + name: "get_weather", + description: "Get current weather", + parameters: .object( + description: "Weather parameters", + properties: [ + "location": .string(description: "City name"), + "unit": .string(description: "Temperature unit") + ], + required: ["location"] + ) + ) + + let responseSchema = JSONSchema.object( + description: "Weather response", + properties: [ + "temperature": .number(description: "Current temperature"), + "condition": .string(description: "Weather condition") + ], + required: ["temperature", "condition"] + ) + + let chatRequest = OKChatRequestData( + model: "llama2", + messages: [.user("What's the weather in Tokyo?")], + tools: [.function(weatherFunction)], + format: responseSchema + ) + + #expect(chatRequest.tools?.count == 1) + #expect(chatRequest.tools?[0].type == "function") + #expect(chatRequest.tools?[0].function.name == "get_weather") + #expect(chatRequest.format != nil) +} + +@Test("Chat request with options configuration") +func testChatRequestWithOptions() { + let chatRequest = OKChatRequestData( + model: "llama2", + messages: [.user("Hello")], + with: { options in + options.temperature = 0.7 + options.topP = 0.9 + options.seed = 42 + } + ) + + #expect(chatRequest.options?.temperature == 0.7) + #expect(chatRequest.options?.topP == 0.9) + #expect(chatRequest.options?.seed == 42) +} + +@Test("Chat request with images") +func testChatRequestWithImages() { + let imageBase64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=" + + let messages = [ + OKChatRequestData.Message.user("What's in this image?", images: [imageBase64]) + ] + + let chatRequest = OKChatRequestData( + model: "llama2", + messages: messages + ) + + #expect(chatRequest.messages[0].images?.count == 1) + #expect(chatRequest.messages[0].images?[0] == imageBase64) +} + +@Test("Chat request encoding") +func testChatRequestEncoding() throws { + let messages = [ + OKChatRequestData.Message.user("Hello") + ] + + let chatRequest = OKChatRequestData( + model: "llama2", + messages: messages + ) + + let encoder = JSONEncoder() + let data = try encoder.encode(chatRequest) + let json = try JSONSerialization.jsonObject(with: data) as? [String: Any] + + #expect(json?["model"] as? String == "llama2") + #expect(json?["stream"] as? Bool == true) + #expect((json?["messages"] as? [[String: Any]])?.count == 1) +} + +@Test("Message role raw value conversion") +func testMessageRoleRawValue() { + let systemRole = OKChatRequestData.Message.Role.system + let customRole = OKChatRequestData.Message.Role.custom("test_bot") + + #expect(systemRole.rawValue == "system") + #expect(customRole.rawValue == "test_bot") + + let fromRawSystem = OKChatRequestData.Message.Role(rawValue: "system")! + let fromRawCustom = OKChatRequestData.Message.Role(rawValue: "test_bot")! + + #expect(fromRawSystem == .system) + #expect(fromRawCustom == .custom("test_bot")) +} + +@Test("Complex tool configuration") +func testComplexTools() { + let complexFunction = OKFunction( + name: "analyze_data", + description: "Analyze complex data structure", + parameters: .object( + description: "Analysis parameters", + properties: [ + "data": .array( + description: "Input data points", + items: .object( + properties: [ + "value": .number(), + "label": .string(), + "metadata": .object( + properties: [ + "timestamp": .string(description: "ISO8601 formatted timestamp"), + "source": .string() + ] + ) + ], + required: ["value", "label"] + ) + ), + "options": .object( + properties: [ + "algorithm": .enum( + values: [ + .string("mean"), + .string("median"), + .string("mode") + ] + ), + "precision": .integer(minimum: 0, maximum: 10) + ] + ) + ], + required: ["data"] + ) + ) + + let chatRequest = OKChatRequestData( + model: "llama2", + messages: [.user("Analyze this data")], + tools: [.function(complexFunction)] + ) + + let encoder = JSONEncoder() + #expect(throws: Never.self) { + _ = try encoder.encode(chatRequest) + } +}