diff --git a/lib/llm/providers/openai.rb b/lib/llm/providers/openai.rb index 29f7a9d8..79b0381a 100644 --- a/lib/llm/providers/openai.rb +++ b/lib/llm/providers/openai.rb @@ -35,6 +35,16 @@ def complete(message, **params) Response::Completion.new(res.body, self).extend(response_parser) end + %w[generation edit variation].each do |action| + define_method :"vision_#{action}" do |prompt, **params| + req = Net::HTTP::Post.new ["/v1", "images", "#{action}s"].join("/") + body = {prompt:, model: "dall-e-3", n: 1}.merge!(params) + req = preflight(req, body) + res = request @http, req + Response::Vision.new(res.body, self).extend(response_parser) + end + end + ## # @param prompt (see LLM::Provider#transform_prompt) # @return (see LLM::Provider#transform_prompt) diff --git a/lib/llm/providers/openai/response_parser.rb b/lib/llm/providers/openai/response_parser.rb index 79be36bb..364b76d1 100644 --- a/lib/llm/providers/openai/response_parser.rb +++ b/lib/llm/providers/openai/response_parser.rb @@ -26,5 +26,13 @@ def parse_completion(raw) total_tokens: raw.dig("usage", "total_tokens") } end + + def parse_vision(raw) + { + images: raw["data"].map do + URI(_1["url"]) + end + } + end end end diff --git a/lib/llm/response.rb b/lib/llm/response.rb index e97ad0fe..1b66e1d5 100644 --- a/lib/llm/response.rb +++ b/lib/llm/response.rb @@ -5,6 +5,7 @@ class Response require "json" require_relative "response/completion" require_relative "response/embedding" + require_relative "response/vision" ## # @return [Hash] diff --git a/lib/llm/response/vision.rb b/lib/llm/response/vision.rb new file mode 100644 index 00000000..7ac4e215 --- /dev/null +++ b/lib/llm/response/vision.rb @@ -0,0 +1,22 @@ +# frozen_string_literal: true + +module LLM + class Response::Vision < Response + ## + # @return [Array] + # Returns an array of image URIs + def images + parsed[:images] + end + + private + + ## + # @private + # @return [Hash] + # Returns the parsed vision response from the provider + def parsed + @parsed ||= parse_vision(raw) + end + end +end diff --git a/spec/openai/vision_spec.rb b/spec/openai/vision_spec.rb new file mode 100644 index 00000000..4f851a40 --- /dev/null +++ b/spec/openai/vision_spec.rb @@ -0,0 +1,54 @@ +# frozen_string_literal: true + +require "webmock/rspec" + +RSpec.describe "LLM::OpenAI" do + subject(:openai) { LLM.openai("") } + + before(:each, :success) do + stub_request(:post, "https://api.openai.com/v1/images/generations") + .with(headers: {"Content-Type" => "application/json"}) + .to_return( + status: 200, + body: { + created: 1731499418, + data: [ + { + revised_prompt: "Create a detailed image showing a white Siamese cat. The cat has pierce blue eyes and slightly elongated ears. It should be sitting gracefully with its tail wrapped around its legs. The Siamese cat's unique color points on its ears, face, paws and tail are in a contrast with its creamy white fur. The background is peaceful and comforting, perhaps a softly lit quieter corner of a home, with tantalizing shadows and welcoming warm colors.", + url: "https://oaidalleapiprodscus.blob.core.windows.net/private/org-onsUXMUK28Zzsh9Vv8iWj80q/user-VcliHUdhkKDdohyDGnVsJzYg/img-C5OCBxw69p4vKtcLLIlL9xCz.png?st=2024-11-13T11%3A03%3A37Z&se=2024-11-13T13%3A03%3A37Z&sp=r&sv=2024-08-04&sr=b&rscd=inline&rsct=image/png&skoid=d505667d-d6c1-4a0a-bac7-5c84a87759f8&sktid=a48cca56-e6da-484e-a814-9c849652bcb3&skt=2024-11-12T19%3A49%3A57Z&ske=2024-11-13T19%3A49%3A57Z&sks=b&skv=2024-08-04&sig=9Bp9muevzDLdjymf%2BsnVuorprp6iCol/wI8Ih95xjhE%3D" + } + ] + }.to_json + ) + end + + before(:each, :unauthorized) do + stub_request(:post, "https://api.openai.com/v1/images/generations") + .with(headers: {"Content-Type" => "application/json"}) + .to_return( + status: 401, + body: '{ + "error": { + "code": null, + "message": "Invalid authorization header", + "param": null, + "type": "server_error" + } + }' + ) + end + + context "with successful vision", :success do + let(:vision) { openai.vision_generation("a white siamese cat") } + + it "returns a vision" do + expect(vision).to be_a(LLM::Response::Vision) + end + + it "has images" do + expect(vision.images.first).to be_a(URI).and have_attributes( + to_s: "https://oaidalleapiprodscus.blob.core.windows.net/private/org-onsUXMUK28Zzsh9Vv8iWj80q/user-VcliHUdhkKDdohyDGnVsJzYg/img-C5OCBxw69p4vKtcLLIlL9xCz.png?st=2024-11-13T11%3A03%3A37Z&se=2024-11-13T13%3A03%3A37Z&sp=r&sv=2024-08-04&sr=b&rscd=inline&rsct=image/png&skoid=d505667d-d6c1-4a0a-bac7-5c84a87759f8&sktid=a48cca56-e6da-484e-a814-9c849652bcb3&skt=2024-11-12T19%3A49%3A57Z&ske=2024-11-13T19%3A49%3A57Z&sks=b&skv=2024-08-04&sig=9Bp9muevzDLdjymf%2BsnVuorprp6iCol/wI8Ih95xjhE%3D" + ) + end + end +end