diff --git a/README.md b/README.md index b452533..f3459df 100644 --- a/README.md +++ b/README.md @@ -20,14 +20,15 @@ npm install fluent-ai zod@next fluent-ai includes support for multiple AI providers and modalities. -| provider | chat completion | embedding | image generation | list models | -| --------- | ------------------ | ------------------ | ------------------ | ------------------ | -| anthropic | :white_check_mark: | | | :white_check_mark: | -| fal | | | :white_check_mark: | | -| google | :white_check_mark: | | | | -| ollama | :white_check_mark: | :white_check_mark: | | :white_check_mark: | -| openai | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | -| voyage | | :white_check_mark: | | | +| provider | chat completion | embedding | image generation | list models | text to speech | +| ---------- | ------------------ | ------------------ | ------------------ | ------------------ | ------------------ | +| anthropic | :white_check_mark: | | | :white_check_mark: | | +| elevenlabs | | | | | :white_check_mark: | +| fal | | | :white_check_mark: | | | +| google | :white_check_mark: | | | | | +| ollama | :white_check_mark: | :white_check_mark: | | :white_check_mark: | | +| openai | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | +| voyage | | :white_check_mark: | | | | By default, API keys for providers are read from environment variable (`process.env`) following the format `_API_KEY` (e.g., `OPENAI_API_KEY`, `ANTHROPIC_API_KEY`). @@ -193,6 +194,15 @@ import { openai } from "fluent-ai"; const models = await openai().models().run(); ``` +## Text to Speech + +```ts +import { openai } from "fluent-ai"; + +const job = openai().model("tts-1").text("hi"); +const result = await job.run(); +``` + ## Support Feel free to [open an issue](https://github.com/modalityml/fluent-ai/issues) or [start a discussion](https://github.com/modalityml/fluent-ai/discussions) if you have any questions. If you would like to request support for a new AI provider, please create an issue with details about the provider's API. [Join our Discord community](https://discord.gg/HzGZWbY8Fx) for help and updates. diff --git a/bun.lock b/bun.lock index a13c8b3..f764d80 100644 --- a/bun.lock +++ b/bun.lock @@ -10,6 +10,7 @@ "devDependencies": { "@types/bun": "latest", "bun-plugin-dts": "^0.3.0", + "prettier": "^3.5.3", }, "peerDependencies": { "typescript": "^5.0.0", @@ -58,6 +59,8 @@ "partial-json": ["partial-json@0.1.7", "", {}, "sha512-Njv/59hHaokb/hRUjce3Hdv12wd60MtM9Z5Olmn+nehe0QDAsRtRbJPvJ0Z91TusF0SuZRIvnM+S4l6EIP8leA=="], + "prettier": ["prettier@3.5.3", "", { "bin": { "prettier": "bin/prettier.cjs" } }, "sha512-QQtaxnoDJeAkDvDKWCLiwIXkTgRhwYDEQCghU9Z6q03iyek/rxRh/2lC3HB7P8sWT2xC/y5JDctPLBIGzHKbhw=="], + "require-directory": ["require-directory@2.1.1", "", {}, "sha512-fGxEI7+wsG9xrvdjsrlmL22OMTTiHRwAMroiEeMgq8gzoLC/PQr7RsRDSTLUg/bZAZtF+TVIkHc6/4RIKrui+Q=="], "resolve-pkg-maps": ["resolve-pkg-maps@1.0.0", "", {}, "sha512-seS2Tj26TBVOC2NIc2rOe2y2ZO7efxITtLZcGSOnHHNOQ7CkiUBfw0Iw2ck6xkIhPwLhKNLS8BO+hEpngQlqzw=="], diff --git a/examples/openai-chat-stream.ts b/examples/openai-chat-stream.ts index 290ef7a..8330698 100644 --- a/examples/openai-chat-stream.ts +++ b/examples/openai-chat-stream.ts @@ -2,9 +2,8 @@ import { text, openai } from "../src"; const job = openai() .chat("gpt-4o-mini") - .messages([{ role: "user", content: "generate a 50 words text" }]) + .prompt("generate a 50 words text") .stream(); -const stream = await job.run(); -for await (const chunk of stream) { - process.stdout.write(text(chunk)); +for await (const chunk of job) { + console.log(chunk?.message); } diff --git a/examples/openai-chat-tool-stream.ts b/examples/openai-chat-tool-stream.ts index 4ba4483..b417464 100644 --- a/examples/openai-chat-tool-stream.ts +++ b/examples/openai-chat-tool-stream.ts @@ -7,7 +7,7 @@ const weatherTool = tool("get_current_weather") z.object({ location: z.string(), unit: z.enum(["celsius", "fahrenheit"]).optional(), - }) + }), ); const job = openai() .chat("gpt-4o-mini") diff --git a/examples/openai-chat-tool.ts b/examples/openai-chat-tool.ts index 8886a73..835dcaf 100644 --- a/examples/openai-chat-tool.ts +++ b/examples/openai-chat-tool.ts @@ -7,7 +7,7 @@ const weatherTool = tool("get_current_weather") z.object({ location: z.string(), unit: z.enum(["celsius", "fahrenheit"]).optional(), - }) + }), ); const job = openai() .chat("gpt-4o-mini") diff --git a/examples/openai-chat.ts b/examples/openai-chat.ts index 5b03733..2731e4d 100644 --- a/examples/openai-chat.ts +++ b/examples/openai-chat.ts @@ -4,4 +4,4 @@ const job = openai({}) .chat("gpt-4o-mini") .messages([system("you are a helpful assistant"), user("hi")]); const result = await job.run(); -console.log(text(result)); +console.log(result?.message); diff --git a/examples/openai-embedding.ts b/examples/openai-embedding.ts index fb714ea..3ea9662 100644 --- a/examples/openai-embedding.ts +++ b/examples/openai-embedding.ts @@ -2,4 +2,4 @@ import { openai } from "../src"; const job = openai().embedding("text-embedding-3-small").value("hello"); const result = await job.run(); -console.log(result.embedding); +console.log(result!.embedding); diff --git a/examples/openai-image-edit.ts b/examples/openai-image-edit.ts new file mode 100644 index 0000000..71cf4e1 --- /dev/null +++ b/examples/openai-image-edit.ts @@ -0,0 +1,14 @@ +import { openai } from "../src"; +import { readFileSync, writeFileSync } from "node:fs"; + +const job = openai() + .image("gpt-image-1") // TODO: add support for dall-e-2 + .edit( + new File([readFileSync("./cat.jpg")], "cat.jpg", { type: "image/jpeg" }), + ) + .prompt("add a hat to the cat") + .size("1024x1024"); + +const result = await job.run(); +const buffer = Buffer.from(result!.raw.data[0].b64_json, "base64"); +writeFileSync("cat_edit.jpg", buffer); diff --git a/examples/openai-image.ts b/examples/openai-image.ts index 6e0cfab..e9b3aa3 100644 --- a/examples/openai-image.ts +++ b/examples/openai-image.ts @@ -1,9 +1,13 @@ import { openai } from "../src"; +import { writeFileSync } from "node:fs"; const job = openai() - .image("dalle-2") + .image("dall-e-2") .prompt("a cat") - .size({ width: 512, height: 512 }); -const result = await job.run(); + .size("512x512") + .outputFormat("jpeg") + .responseFormat("b64_json"); -console.log(result); +const result = await job.run(); +const buffer = Buffer.from(result!.raw.data[0].b64_json, "base64"); +writeFileSync("cat.jpg", buffer); diff --git a/examples/openai-models.ts b/examples/openai-models.ts index b64d8a7..da8aa9a 100644 --- a/examples/openai-models.ts +++ b/examples/openai-models.ts @@ -2,5 +2,6 @@ import { openai } from "../src"; const job = openai().models(); const result = await job.run(); - -console.log(result); +for (const model of result!) { + console.log(model); +} diff --git a/examples/workflow.ts b/examples/workflow.ts new file mode 100644 index 0000000..c5b3ad2 --- /dev/null +++ b/examples/workflow.ts @@ -0,0 +1,31 @@ +import { z } from "zod"; +import { openai, workflow } from "../src"; + +const flow = workflow("workflow1") + .input( + z.object({ + description: z.string(), + }), + ) + .step("step1", ({ context }) => { + return openai() + .chat("gpt-4o-mini") + .prompt( + `generate a story based on following description: ${context.input.description}`, + ) + .jsonSchema( + z.object({ + story: z.string(), + }), + ); + }) + .step("step2", ({ context }) => { + return elevenlabs() + .tts("eleven_multilingual_v2") + .text(context.steps.step1.story); + }); + +const result = await flow.run({ + input: { description: "fire engine and a cat" }, +}); +console.log(result); diff --git a/package.json b/package.json index 5f006b6..c45b918 100644 --- a/package.json +++ b/package.json @@ -33,7 +33,8 @@ ], "devDependencies": { "@types/bun": "latest", - "bun-plugin-dts": "^0.3.0" + "bun-plugin-dts": "^0.3.0", + "prettier": "^3.5.3" }, "peerDependencies": { "typescript": "^5.0.0", @@ -43,5 +44,8 @@ "repository": { "type": "git", "url": "git+https://github.com/modalityml/fluent-ai.git" + }, + "prettier": { + "trailingComma": "all" } } diff --git a/src/client.ts b/src/client.ts new file mode 100644 index 0000000..ab6d472 --- /dev/null +++ b/src/client.ts @@ -0,0 +1,70 @@ +import type { Job } from "./jobs/load"; + +export interface ClientOptions { + url: string; + apiKey: string; +} + +export class Client { + url: string; + apiKey: string; + + constructor(options: ClientOptions) { + this.url = options.url; + this.apiKey = options.apiKey; + } + + async createJob(job: Job) { + // TODO: reuse fetch error handling + const response = await fetch(`${this.url}/api/jobs`, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${this.apiKey}`, + }, + body: JSON.stringify(job), + }); + + const data = await response.json(); + return data; + } + + async streamJob(jobId: string) { + const response = await fetch(`${this.url}/api/jobs/${jobId}/stream`, { + headers: { + Authorization: `Bearer ${this.apiKey}`, + }, + }); + + const reader = response.body!.getReader(); + const decoder = new TextDecoder(); + + async function* streamGenerator() { + while (true) { + const { done, value } = await reader.read(); + if (done) break; + + const chunk = decoder.decode(value, { stream: true }); + const lines = chunk.split("\n").filter((line) => line.trim()); + + for (const line of lines) { + if (line.startsWith("data: ")) { + const jsonStr = line.slice(6); + try { + const data = JSON.parse(jsonStr); + yield data; + } catch (e) { + console.error("Error parsing SSE data:", e); + } + } + } + } + } + + return streamGenerator(); + } +} + +export function createClient(options: ClientOptions): Client { + return new Client(options); +} diff --git a/src/index.ts b/src/index.ts index bdb2a62..21a038e 100644 --- a/src/index.ts +++ b/src/index.ts @@ -6,6 +6,7 @@ export * from "./jobs/embedding"; export * from "./jobs/models"; export * from "./providers/anthropic"; +export * from "./providers/elevenlabs"; export * from "./providers/deepseek"; export * from "./providers/fal"; export * from "./providers/fireworks"; @@ -15,3 +16,7 @@ export * from "./providers/ollama"; export * from "./providers/openai"; export * from "./providers/together"; export * from "./providers/voyage"; + +export * from "./workflow"; + +export * from "./client"; diff --git a/src/jobs/builder.ts b/src/jobs/builder.ts index f331891..dda0267 100644 --- a/src/jobs/builder.ts +++ b/src/jobs/builder.ts @@ -1,37 +1,36 @@ import { version } from "../../package.json"; -import type { Job } from "./load"; -import type { - JobCost, - JobOptions, - JobPerformance, - JobProvider, - JobType, -} from "./schema"; +import type { BaseJob } from "./schema"; export class HTTPError extends Error { status: number; json?: any; constructor(message: string, status: number, json?: any) { + if (json && json.error && json.error.message) { + message = json.error.message; + } super(message); this.status = status; this.json = json; } } -export class JobBuilder { - provider!: JobProvider; - options!: JobOptions; - type!: JobType; - input?: Input; - output?: Output; - cost?: JobCost; - performance?: JobPerformance; // TODO: track job performance +export abstract class JobBuilder { + provider!: Job["provider"]; + options!: Job["options"]; + type!: Job["type"]; + input?: Job["input"]; + output?: Job["output"]; + cost?: Job["cost"]; + performance?: Job["performance"]; // TODO: track job performance + + abstract makeRequest(): Request; - makeRequest?: () => Request; - handleResponse?: (response: Response) => any; + async handleResponse(response: Response): Promise { + throw new Error("Not implemented"); + } - async run(): Promise { + async run(): Promise { const request = this.makeRequest!(); const response = await fetch(request); if (!response.ok) { @@ -41,9 +40,9 @@ export class JobBuilder { } catch (e) {} throw new HTTPError( - `Fetch error: ${response.statusText}`, + `HTTP error: ${response.statusText}`, response.status, - json + json, ); } return await this.handleResponse!(response); @@ -56,9 +55,9 @@ export class JobBuilder { options: this.options, type: this.type, input: this.input!, - output: this.output as any, + output: this.output, cost: this.cost, performance: this.performance, - } as Job; + }; } } diff --git a/src/jobs/chat/builder.ts b/src/jobs/chat/builder.ts index 8cb5369..ca40ade 100644 --- a/src/jobs/chat/builder.ts +++ b/src/jobs/chat/builder.ts @@ -1,16 +1,18 @@ import { z } from "zod"; -import { JobBuilder } from "~/jobs/builder"; +import { HTTPError, JobBuilder } from "~/jobs/builder"; import type { - ChatInput, - ChatOutput, + ChatJob, ChatStreamOptions, + ChatToolChoiceSchema, Message, ResponseFormat, } from "./schema"; import type { ChatTool } from "./tool"; -export class ChatJobBuilder extends JobBuilder { - input: ChatInput; +export abstract class ChatJobBuilder< + Job extends ChatJob, +> extends JobBuilder { + input: Job["input"]; constructor(model: string) { super(); @@ -21,6 +23,27 @@ export class ChatJobBuilder extends JobBuilder { }; } + async *handleStream(response: Response): AsyncGenerator { + throw new Error("Not implemented"); + } + + async *stream(options?: ChatStreamOptions): AsyncGenerator { + this.input.stream = true; + this.input.streamOptions = options; + if (!this.handleStream) { + throw new Error("Stream not supported"); + } + const request = this.makeRequest!(); + const response = await fetch(request); + if (!response.ok) { + throw new HTTPError( + `Fetch error: ${response.statusText}`, + response.status, + ); + } + yield* this.handleStream(response); + } + system(system: string) { this.input.system = system; return this; @@ -31,6 +54,14 @@ export class ChatJobBuilder extends JobBuilder { return this; } + prompt(prompt: string) { + this.input.messages.push({ + role: "user", + content: prompt, + }); + return this; + } + temperature(temperature: number) { this.input.temperature = temperature; return this; @@ -64,7 +95,7 @@ export class ChatJobBuilder extends JobBuilder { return this; } - toolChoice(toolChoice: string) { + toolChoice(toolChoice: z.infer) { this.input.toolChoice = toolChoice; return this; } @@ -83,12 +114,4 @@ export class ChatJobBuilder extends JobBuilder { return this; } - - stream(streamOptions?: ChatStreamOptions) { - this.input.stream = true; - if (streamOptions) { - this.input.streamOptions = streamOptions; - } - return this; - } } diff --git a/src/jobs/chat/schema.ts b/src/jobs/chat/schema.ts index 4bda766..3aa18b5 100644 --- a/src/jobs/chat/schema.ts +++ b/src/jobs/chat/schema.ts @@ -1,24 +1,73 @@ import { z } from "zod"; import { BaseJobSchema } from "~/jobs/schema"; -export const MessageSchema = z.object({ - role: z.enum(["system", "user", "assistant"]), - content: z.union([ - z.string(), - z.array( - z.union([ - z.object({ type: z.literal("text"), text: z.string() }), - z.object({ - type: z.literal("image_url"), - image_url: z.object({ - url: z.string(), - }), - }), - ]) - ), - ]), +export const MessageContentSchema = z.union([ + z.string(), + z.array( + z.object({ + type: z.literal("text"), + text: z.string(), + }), + ), + z.array( + z.object({ + type: z.literal("image"), + image_url: z.string().optional(), + source: z + .object({ + type: z.literal("base64"), + data: z.string(), + media_type: z.enum([ + "image/jpeg", + "image/png", + "image/webp", + "image/gif", + ]), + }) + .optional(), + }), + ), +]); + +export const ToolCallSchema = z.object({ + name: z.string().optional(), + type: z.string().optional(), + id: z.string().optional(), + call_id: z.string().optional(), + arguments: z.record(z.string(), z.any()), +}); + +export const BaseMessageSchema = z.object({ + role: z.enum(["assistant", "user", "system", "tool"]), + content: MessageContentSchema, +}); + +export const AIMessageSchema = BaseMessageSchema.extend({ + role: z.literal("assistant"), + tool_calls: z.array(ToolCallSchema).optional(), +}); + +export const HumanMessageSchema = BaseMessageSchema.extend({ + role: z.literal("user"), }); +export const SystemMessageSchema = BaseMessageSchema.extend({ + role: z.literal("system"), +}); + +export const ToolMessageSchema = BaseMessageSchema.extend({ + role: z.literal("tool"), + call_id: z.string().optional(), + result: z.any().optional(), +}); + +export const MessageSchema = z.discriminatedUnion("role", [ + AIMessageSchema, + HumanMessageSchema, + SystemMessageSchema, + ToolMessageSchema, +]); + export type Message = z.infer; export const ChatStreamOptionsSchema = z.object({ @@ -48,26 +97,9 @@ export const JsonSchemaDefSchema = z.object({ export const ChunkSchema = z.object({}); -export const ChatResultSchema = z.object({ - message: z.object({ - role: z.literal("assistant"), - content: z.string().nullable(), - }), - usage: z - .object({ - prompt_tokens: z.number(), - completion_tokens: z.number(), - total_tokens: z.number(), - }) - .optional(), - tool_calls: z - .array( - z.object({ - name: z.string(), - arguments: z.record(z.string(), z.any()), - }) - ) - .optional(), +export const ChatToolChoiceSchema = z.object({ + mode: z.enum(["auto", "none", "any"]), + allowed_tools: z.array(z.string()).optional(), }); export const ChatInputSchema = z.object({ @@ -78,7 +110,8 @@ export const ChatInputSchema = z.object({ maxTokens: z.number().optional(), messages: z.array(MessageSchema), tools: z.array(ChatToolSchema).optional(), - toolChoice: z.string().optional(), + //TODO: support gemini and anthropic tool choice, might need refactor + toolChoice: ChatToolChoiceSchema.optional(), responseFormat: ResponseFormatSchema.optional(), topP: z.number().optional(), topK: z.number().optional(), @@ -87,7 +120,10 @@ export const ChatInputSchema = z.object({ }); // TODO: Add a schema for the output -export const ChatOutputSchema = z.any(); +export const ChatOutputSchema = z.object({ + message: AIMessageSchema.optional(), + raw: z.any().optional(), +}); export const ChatJobSchema = BaseJobSchema.extend({ type: z.literal("chat"), @@ -95,6 +131,6 @@ export const ChatJobSchema = BaseJobSchema.extend({ output: ChatOutputSchema.optional(), }); +export type ChatJob = z.infer; export type ChatInput = z.infer; - export type ChatOutput = z.infer; diff --git a/src/jobs/chat/utils.ts b/src/jobs/chat/utils.ts index fe03d15..077b393 100644 --- a/src/jobs/chat/utils.ts +++ b/src/jobs/chat/utils.ts @@ -3,7 +3,10 @@ import { parse } from "partial-json"; import type { ChatToolSchema, Message } from "./schema"; import { ChatTool } from "./tool"; +// TODO: move to providers, different providers have different tool formats +// @deprecated export function convertTools(tools: z.infer[]) { + console.warn("convertTools is deprecated, move to providers"); return tools.map((tool) => ({ type: "function", function: { diff --git a/src/jobs/embedding/builder.ts b/src/jobs/embedding/builder.ts index ef48dca..60f9b46 100644 --- a/src/jobs/embedding/builder.ts +++ b/src/jobs/embedding/builder.ts @@ -1,11 +1,10 @@ import { JobBuilder } from "~/jobs/builder"; -import type { EmbeddingInput, EmbeddingOutput } from "./schema"; +import type { EmbeddingJob } from "./schema"; -export class EmbeddingJobBuilder extends JobBuilder< - EmbeddingInput, - EmbeddingOutput -> { - input: EmbeddingInput; +export abstract class EmbeddingJobBuilder< + Job extends EmbeddingJob, +> extends JobBuilder { + input: Job["input"]; constructor(model: string) { super(); diff --git a/src/jobs/embedding/schema.ts b/src/jobs/embedding/schema.ts index 3c878ac..d0c3f5f 100644 --- a/src/jobs/embedding/schema.ts +++ b/src/jobs/embedding/schema.ts @@ -19,6 +19,8 @@ export const EmbeddingJobSchema = BaseJobSchema.extend({ output: EmbeddingOutputSchema.optional(), }); +export type EmbeddingJob = z.infer; + export type EmbeddingInput = z.infer; export type EmbeddingOutput = z.infer; diff --git a/src/jobs/image/builder.ts b/src/jobs/image/builder.ts index a773793..3befb30 100644 --- a/src/jobs/image/builder.ts +++ b/src/jobs/image/builder.ts @@ -1,8 +1,10 @@ import { JobBuilder } from "~/jobs/builder"; -import type { ImageInput, ImageOutput, ImageSize } from "./schema"; +import type { ImageJob, ImageSize } from "./schema"; -export class ImageJobBuilder extends JobBuilder { - input: ImageInput; +export abstract class ImageJobBuilder< + Job extends ImageJob, +> extends JobBuilder { + input: Job["input"]; constructor(model: string) { super(); @@ -17,6 +19,20 @@ export class ImageJobBuilder extends JobBuilder { return this; } + edit(image: File | Array) { + if (Array.isArray(image)) { + this.input.images = image; + } else { + this.input.images = [image]; + } + return this; + } + + mask(mask: File) { + this.input.mask = mask; + return this; + } + n(numImages: number) { this.input.n = numImages; return this; @@ -72,8 +88,23 @@ export class ImageJobBuilder extends JobBuilder { return this; } - stream() { - this.input.stream = true; + moderation(moderation: string) { + this.input.moderation = moderation; + return this; + } + + outputCompression(outputCompression: string) { + this.input.outputCompression = outputCompression; + return this; + } + + outputFormat(outputFormat: string) { + this.input.outputFormat = outputFormat; + return this; + } + + background(background: string) { + this.input.background = background; return this; } } diff --git a/src/jobs/image/schema.ts b/src/jobs/image/schema.ts index 9b1b8b0..e1cc675 100644 --- a/src/jobs/image/schema.ts +++ b/src/jobs/image/schema.ts @@ -1,7 +1,8 @@ -import { z } from "zod"; +import { size, z } from "zod"; import { BaseJobSchema } from "~/jobs/schema"; export const ImageSizeSchema = z.union([ + z.string(), z.literal("square_hd"), z.literal("square"), z.literal("portrait_4_3"), @@ -19,6 +20,8 @@ export type ImageSize = z.infer; export const ImageInputSchema = z.object({ model: z.string(), prompt: z.string().optional(), + images: z.array(z.any()).optional(), // TODO: fix any + mask: z.any().optional(), // TODO: fix any n: z.number().optional(), quality: z.string().optional(), responseFormat: z.string().optional(), @@ -31,6 +34,10 @@ export const ImageInputSchema = z.object({ syncMode: z.boolean().optional(), enableSafetyChecker: z.boolean().optional(), stream: z.boolean().optional(), + moderation: z.string().optional(), + outputCompression: z.string().optional(), + outputFormat: z.string().optional(), + background: z.string().optional(), }); const ImageOuputSchema = z.object({ @@ -43,7 +50,7 @@ const ImageOuputSchema = z.object({ z.object({ base64: z.string(), }), - ]) + ]), ), metadata: z .object({ @@ -60,6 +67,6 @@ export const ImageJobSchema = BaseJobSchema.extend({ output: ImageOuputSchema.optional(), }); +export type ImageJob = z.infer; export type ImageInput = z.infer; - export type ImageOutput = z.infer; diff --git a/src/jobs/load.ts b/src/jobs/load.ts index ce99ab7..7de18c2 100644 --- a/src/jobs/load.ts +++ b/src/jobs/load.ts @@ -1,6 +1,7 @@ import { z } from "zod"; import { anthropic, AnthropicJobSchema } from "~/providers/anthropic"; import { deepseek, DeepseekJobSchema } from "~/providers/deepseek"; +import { elevenlabs, ElevenlabsJobSchema } from "~/providers/elevenlabs"; import { fal, FalJobSchema } from "~/providers/fal"; import { GoogleJobSchema } from "~/providers/google"; import { LumaJobSchema } from "~/providers/luma"; @@ -11,6 +12,7 @@ import { voyage, VoyageJobSchema } from "~/providers/voyage"; export const JobSchema = z.union([ AnthropicJobSchema, DeepseekJobSchema, + ElevenlabsJobSchema, FalJobSchema, GoogleJobSchema, LumaJobSchema, @@ -31,6 +33,8 @@ export function load(obj: Job) { provider = anthropic(obj.options); } else if (obj.provider === "deepseek") { provider = deepseek(obj.options); + } else if (obj.provider === "elevenlabs") { + provider = elevenlabs(obj.options); } else if (obj.provider === "fal") { provider = fal(obj.options); } else if (obj.provider === "ollama") { @@ -59,6 +63,9 @@ export function load(obj: Job) { if (obj.type === "models" && "models" in provider) { builder = provider.models(); } + if (obj.type === "speech" && "speech" in provider) { + builder = provider.speech(obj.input.model); + } if (!builder) { throw new Error("Failed to load job"); diff --git a/src/jobs/models/builder.ts b/src/jobs/models/builder.ts index 9b40ab8..dc1b4ff 100644 --- a/src/jobs/models/builder.ts +++ b/src/jobs/models/builder.ts @@ -1,8 +1,10 @@ import { JobBuilder } from "~/jobs/builder"; -import type { ModelsInput, ModelsOutput } from "./schema"; +import type { ModelsJob } from "./schema"; -export class ModelsJobBuilder extends JobBuilder { - input: ModelsInput; +export abstract class ModelsJobBuilder< + Job extends ModelsJob, +> extends JobBuilder { + input: Job["input"]; constructor() { super(); diff --git a/src/jobs/models/schema.ts b/src/jobs/models/schema.ts index 51f0899..ff43a5c 100644 --- a/src/jobs/models/schema.ts +++ b/src/jobs/models/schema.ts @@ -3,7 +3,16 @@ import { BaseJobSchema } from "~/jobs/schema"; export const ModelsInputSchema = z.object({}); -export const ModelsOuputSchema = z.object({}); +export const ModelsOuputSchema = z.object({ + models: z.array( + z.object({ + id: z.string(), + created: z.number(), + owned_by: z.string(), + }), + ), + raw: z.any(), +}); export const ModelsJobSchema = BaseJobSchema.extend({ type: z.literal("models"), @@ -11,6 +20,7 @@ export const ModelsJobSchema = BaseJobSchema.extend({ output: ModelsOuputSchema.optional(), }); +export type ModelsJob = z.infer; export type ModelsInput = z.infer; export type ModelsOutput = z.infer; diff --git a/src/jobs/schema.ts b/src/jobs/schema.ts index 43a92e6..2c31131 100644 --- a/src/jobs/schema.ts +++ b/src/jobs/schema.ts @@ -2,6 +2,7 @@ import { z } from "zod"; export const JobProviderSchema = z.enum([ "anthropic", + "elevenlabs", "deepseek", "fal", "google", @@ -11,7 +12,13 @@ export const JobProviderSchema = z.enum([ "voyage", ]); -export const JobTypeSchema = z.enum(["chat", "image", "models", "embedding"]); +export const JobTypeSchema = z.enum([ + "chat", + "embedding", + "image", + "models", + "speech", +]); export const JobOptionsSchema = z.object({ apiKey: z.string().optional(), @@ -26,18 +33,26 @@ export const JobCostSchema = z.object({ export const JobPerformance = z.object({}); -export type JobCost = z.infer; - -export type JobPerformance = z.infer; +export const RemoteJobSchema = z.object({ + id: z.string().optional(), + status: z.enum(["pending", "running", "completed", "failed"]), + createdAt: z.date().optional(), +}); export const BaseJobSchema = z.object({ version: z.string().optional(), + provider: JobProviderSchema, options: JobOptionsSchema.optional(), cost: JobCostSchema.optional(), + type: JobTypeSchema, + input: z.any(), + output: z.any().optional(), performance: JobPerformance.optional(), }); +export type JobCost = z.infer; +export type JobPerformance = z.infer; +export type BaseJob = z.infer; export type JobProvider = z.infer; export type JobType = z.infer; -export type BaseJob = z.infer; export type JobOptions = z.infer; diff --git a/src/jobs/speech/builder.ts b/src/jobs/speech/builder.ts new file mode 100644 index 0000000..4ca4002 --- /dev/null +++ b/src/jobs/speech/builder.ts @@ -0,0 +1,16 @@ +import { JobBuilder } from "~/jobs/builder"; +import type { SpeechJob } from "./schema"; + +export abstract class SpeechJobBuilder< + Job extends SpeechJob, +> extends JobBuilder { + input: Job["input"]; + + constructor(model: string) { + super(); + this.type = "speech"; + this.input = { + model: model, + }; + } +} diff --git a/src/jobs/speech/index.ts b/src/jobs/speech/index.ts new file mode 100644 index 0000000..e73f1db --- /dev/null +++ b/src/jobs/speech/index.ts @@ -0,0 +1,2 @@ +export * from "./schema"; +export * from "./builder"; diff --git a/src/jobs/speech/schema.ts b/src/jobs/speech/schema.ts new file mode 100644 index 0000000..d085dbb --- /dev/null +++ b/src/jobs/speech/schema.ts @@ -0,0 +1,18 @@ +import { z } from "zod"; +import { BaseJobSchema } from "~/jobs/schema"; + +export const SpeechInputSchema = z.object({ + model: z.string(), +}); + +export const SpeechOutputSchema = z.object({}); + +export const SpeechJobSchema = BaseJobSchema.extend({ + type: z.literal("speech"), + input: SpeechInputSchema, + output: SpeechOutputSchema.optional(), +}); + +export type SpeechJob = z.infer; +export type SpeechInput = z.infer; +export type SpeechOutput = z.infer; diff --git a/src/providers/anthropic/chat.ts b/src/providers/anthropic/chat.ts index 8f3aa2d..f84475a 100644 --- a/src/providers/anthropic/chat.ts +++ b/src/providers/anthropic/chat.ts @@ -1,14 +1,15 @@ import { ChatJobBuilder, convertTools } from "~/jobs/chat"; import type { JobOptions } from "~/jobs/schema"; +import type { AnthropicChatJob } from "./schema"; -export class AnthropicChatJobBuilder extends ChatJobBuilder { +export class AnthropicChatJobBuilder extends ChatJobBuilder { constructor(options: JobOptions, model: string) { super(model); this.provider = "anthropic"; this.options = options; } - makeRequest = () => { + makeRequest() { const requestParams = { model: this.input.model, max_tokens: this.input.maxTokens, @@ -22,7 +23,7 @@ export class AnthropicChatJobBuilder extends ChatJobBuilder { const headers = { "anthropic-version": "2023-06-01", - "x-api-key": this.options.apiKey!, + "x-api-key": this.options!.apiKey!, "Content-Type": "application/json", }; @@ -31,10 +32,10 @@ export class AnthropicChatJobBuilder extends ChatJobBuilder { headers: headers, body: JSON.stringify(requestParams), }); - }; + } - handleResponse = async (response: Response) => { + async handleResponse(response: Response) { const raw = await response.json(); return { raw }; - }; + } } diff --git a/src/providers/anthropic/models.ts b/src/providers/anthropic/models.ts index a6c553d..1df31e9 100644 --- a/src/providers/anthropic/models.ts +++ b/src/providers/anthropic/models.ts @@ -1,17 +1,18 @@ import { ModelsJobBuilder } from "~/jobs/models"; import type { JobOptions } from "~/jobs/schema"; +import type { AnthropicModelsJob } from "./schema"; -export class AnthropicModelsJobBuilder extends ModelsJobBuilder { +export class AnthropicModelsJobBuilder extends ModelsJobBuilder { constructor(options: JobOptions) { super(); this.provider = "anthropic"; this.options = options; } - makeRequest = () => { + makeRequest() { const headers = { "anthropic-version": "2023-06-01", - "x-api-key": this.options.apiKey!, + "x-api-key": this.options!.apiKey!, "Content-Type": "application/json", }; @@ -19,10 +20,10 @@ export class AnthropicModelsJobBuilder extends ModelsJobBuilder { method: "GET", headers: headers, }); - }; + } - handleResponse = async (response: Response) => { + async handleResponse(response: Response) { const json = await response.json(); - return json; - }; + return { raw: json, models: [] }; + } } diff --git a/src/providers/anthropic/schema.ts b/src/providers/anthropic/schema.ts index 7e43a6c..8ba56de 100644 --- a/src/providers/anthropic/schema.ts +++ b/src/providers/anthropic/schema.ts @@ -7,17 +7,18 @@ export const AnthropicBaseJobSchema = z.object({ }); export const AnthropicChatJobSchema = ChatJobSchema.extend( - AnthropicBaseJobSchema + AnthropicBaseJobSchema, ); -export type AnthropicChatJob = z.infer; export const AnthropicModelsJobSchema = ModelsJobSchema.extend( - AnthropicBaseJobSchema + AnthropicBaseJobSchema, ); -export type AnthropicModelsJob = z.infer; export const AnthropicJobSchema = z.discriminatedUnion("type", [ AnthropicChatJobSchema, AnthropicModelsJobSchema, ]); + export type AnthropicJob = z.infer; +export type AnthropicChatJob = z.infer; +export type AnthropicModelsJob = z.infer; diff --git a/src/providers/anthropic/type.ts b/src/providers/anthropic/type.ts new file mode 100644 index 0000000..e69de29 diff --git a/src/providers/deepseek/index.ts b/src/providers/deepseek/index.ts index 386e31d..3daf181 100644 --- a/src/providers/deepseek/index.ts +++ b/src/providers/deepseek/index.ts @@ -5,22 +5,19 @@ import type { JobOptions } from "~/jobs/schema"; import { OpenAIChatJobBuilder } from "~/providers/openai"; import { OpenAIModelsJobBuilder } from "~/providers/openai/models"; -export const BaseDeepseekJobSchema = z.object({ +export const DeepseekBaseJobSchema = z.object({ provider: z.literal("deepseek"), }); export const DeepseekChatJobSchema = ChatJobSchema.extend( - BaseDeepseekJobSchema + DeepseekBaseJobSchema, ); -export type DeepseekChatJob = z.infer; export const DeepseekModelsJobSchema = ModelsJobSchema.extend( - BaseDeepseekJobSchema + DeepseekBaseJobSchema, ); -export type DeepseekModelsJob = z.infer; export const DeepseekJobSchema = z.discriminatedUnion("type", [ DeepseekChatJobSchema, DeepseekModelsJobSchema, ]); -export type DeepseekJob = z.infer; export function deepseek(options?: JobOptions) { options = options || {}; @@ -33,7 +30,7 @@ export function deepseek(options?: JobOptions) { ...options, baseURL: "https://api.deepseek.com", }, - model + model, ); }, diff --git a/src/providers/elevenlabs/index.ts b/src/providers/elevenlabs/index.ts new file mode 100644 index 0000000..59e4f63 --- /dev/null +++ b/src/providers/elevenlabs/index.ts @@ -0,0 +1,46 @@ +import { z } from "zod"; +import type { JobOptions } from "~/jobs/schema"; +import { SpeechJobBuilder, SpeechJobSchema } from "~/jobs/speech"; + +export function elevenlabs(options?: JobOptions) { + options = options || {}; + options.apiKey = options.apiKey || process.env.ELEVENLABS_API_KEY; + + return { + speech(model: string) { + return new ElevenlabsSpeechJobBuilder(options, model); + }, + }; +} + +class ElevenlabsSpeechJobBuilder extends SpeechJobBuilder { + constructor(options: JobOptions, model: string) { + super(model); + this.provider = "elevenlabs"; + this.options = options; + } + + makeRequest() { + return new Request("https://api.elevenlabs.io/v1/text-to-speech", {}); + } + + async handleResponse(response: Response) { + const raw = await response.json(); + return { raw }; + } +} + +export const ElevenlabsBaseJobSchema = z.object({ + provider: z.literal("elevenlabs"), +}); + +export const ElevenlabsSpeechJobSchema = SpeechJobSchema.extend( + ElevenlabsBaseJobSchema, +); + +export const ElevenlabsJobSchema = z.discriminatedUnion("type", [ + ElevenlabsSpeechJobSchema, +]); + +export type ElevenlabsJob = z.infer; +export type ElevenlabsSpeechJob = z.infer; diff --git a/src/providers/fal/image.ts b/src/providers/fal/image.ts index eb1667b..01d1c03 100644 --- a/src/providers/fal/image.ts +++ b/src/providers/fal/image.ts @@ -1,7 +1,8 @@ -import { ImageJobBuilder } from "~/jobs/image"; import type { JobOptions } from "~/jobs/schema"; +import { ImageJobBuilder } from "~/jobs/image"; +import type { FalImageJob } from "./schema"; -export class FalImageJobBuilder extends ImageJobBuilder { +export class FalImageJobBuilder extends ImageJobBuilder { constructor(options: JobOptions, model: string) { super(model); this.provider = "fal"; @@ -12,7 +13,7 @@ export class FalImageJobBuilder extends ImageJobBuilder { return new Request(`https://queue.fal.run/${this.input.model}`, { method: "POST", headers: { - Authorization: `Key ${this.options.apiKey}`, + Authorization: `Key ${this.options!.apiKey}`, "Content-Type": "application/json", }, body: JSON.stringify({ diff --git a/src/providers/fal/schema.ts b/src/providers/fal/schema.ts index a194d31..2db8161 100644 --- a/src/providers/fal/schema.ts +++ b/src/providers/fal/schema.ts @@ -1,19 +1,13 @@ import { z } from "zod"; import { ImageJobSchema } from "~/jobs/image"; -export type FalImage = { - url: string; - width: number; - height: number; - contentType: string; -}; - export const FalBaseJobSchema = z.object({ provider: z.literal("fal"), }); export const FalImageJobSchema = ImageJobSchema.extend(FalBaseJobSchema); -export type FalImageJob = z.infer; export const FalJobSchema = z.discriminatedUnion("type", [FalImageJobSchema]); + export type FalJob = z.infer; +export type FalImageJob = z.infer; diff --git a/src/providers/fireworks/index.ts b/src/providers/fireworks/index.ts index a83d631..4d13c23 100644 --- a/src/providers/fireworks/index.ts +++ b/src/providers/fireworks/index.ts @@ -12,7 +12,7 @@ export function fireworks(options?: JobOptions) { ...options, baseURL: "https://api.fireworks.ai/inference/v1", }, - model + model, ); }, }; diff --git a/src/providers/google/chat.ts b/src/providers/google/chat.ts index 382222b..b21291a 100644 --- a/src/providers/google/chat.ts +++ b/src/providers/google/chat.ts @@ -1,7 +1,8 @@ -import { ChatJobBuilder } from "~/jobs/chat"; +import { ChatJobBuilder, type Message } from "~/jobs/chat"; import type { JobOptions } from "~/jobs/schema"; +import type { GoogleChatJob } from "./schema"; -export class GoogleChatJobBuilder extends ChatJobBuilder { +export class GoogleChatJobBuilder extends ChatJobBuilder { constructor(options: JobOptions, model: string) { super(model); this.provider = "google"; @@ -10,19 +11,19 @@ export class GoogleChatJobBuilder extends ChatJobBuilder { makeRequest = () => { return new Request( - `https://generativelanguage.googleapis.com/v1beta/models/${this.input.model}:generateContent?key=${this.options.apiKey}`, + `https://generativelanguage.googleapis.com/v1beta/models/${this.input.model}:generateContent?key=${this.options!.apiKey}`, { method: "POST", headers: { "Content-Type": "application/json", }, body: JSON.stringify({ - contents: this.input.messages.map((msg) => ({ + contents: this.input.messages.map((msg: Message) => ({ role: msg.role === "user" ? "user" : "model", parts: [{ text: msg.content }], })), }), - } + }, ); }; diff --git a/src/providers/google/schema.ts b/src/providers/google/schema.ts index 53e509b..ed2a46c 100644 --- a/src/providers/google/schema.ts +++ b/src/providers/google/schema.ts @@ -6,9 +6,10 @@ export const GoogleBaseJobSchema = z.object({ }); export const GoogleChatJobSchema = ChatJobSchema.extend(GoogleBaseJobSchema); -export type GoogleChatJob = z.infer; export const GoogleJobSchema = z.discriminatedUnion("type", [ GoogleChatJobSchema, ]); + export type GoogleJob = z.infer; +export type GoogleChatJob = z.infer; diff --git a/src/providers/luma/index.ts b/src/providers/luma/index.ts index 6aa5b2f..ccb6845 100644 --- a/src/providers/luma/index.ts +++ b/src/providers/luma/index.ts @@ -9,16 +9,14 @@ export const LumaBaseJobSchema = z.object({ export const LumaImageJobSchema = ImageJobSchema.extend(LumaBaseJobSchema); export const LumaJobSchema = z.discriminatedUnion("type", [LumaImageJobSchema]); + export type LumaJob = z.infer; +export type LumaImageJob = z.infer; export function luma(options?: JobOptions) { options = options || {}; options.apiKey = options.apiKey || process.env.LUMA_API_KEY; - if (!options.apiKey) { - throw new Error("Luma API key is required"); - } - return { image(model: string) { return new LumaImageJobBuilder(options, model); @@ -26,18 +24,18 @@ export function luma(options?: JobOptions) { }; } -export class LumaImageJobBuilder extends ImageJobBuilder { +export class LumaImageJobBuilder extends ImageJobBuilder { constructor(options: JobOptions, model: string) { super(model); this.provider = "luma"; this.options = options; } - makeRequest = () => { + makeRequest() { return new Request("https://api.lumalabs.ai/dream-machine/v1/generations", { method: "POST", headers: { - Authorization: `Bearer ${this.options.apiKey}`, + Authorization: `Bearer ${this.options!.apiKey}`, "Content-Type": "application/json", }, body: JSON.stringify({ @@ -45,5 +43,11 @@ export class LumaImageJobBuilder extends ImageJobBuilder { prompt: this.input.prompt, }), }); - }; + } + + async handleResponse(response: Response) { + const raw = await response.json(); + //TODO: handle raw.images + return { raw, images: raw.images }; + } } diff --git a/src/providers/mistral/index.ts b/src/providers/mistral/index.ts new file mode 100644 index 0000000..6c22615 --- /dev/null +++ b/src/providers/mistral/index.ts @@ -0,0 +1,8 @@ +export function mistral(options?: JobOptions) { + options = options || {}; + options.apiKey = options.apiKey || process.env.MISTRAL_API_KEY; + + return { + chat(model: string) {}, + }; +} diff --git a/src/providers/ollama/chat.ts b/src/providers/ollama/chat.ts index 9e1ac74..60180cf 100644 --- a/src/providers/ollama/chat.ts +++ b/src/providers/ollama/chat.ts @@ -1,14 +1,15 @@ import { ChatJobBuilder, convertTools } from "~/jobs/chat"; import type { JobOptions } from "~/jobs/schema"; +import type { OllamaChatJob } from "./schema"; -export class OllamaChatJobBuilder extends ChatJobBuilder { +export class OllamaChatJobBuilder extends ChatJobBuilder { constructor(options: JobOptions, model: string) { super(model); this.provider = "ollama"; this.options = options; } - makeRequest = () => { + makeRequest() { const requestBody = { model: this.input.model, messages: this.input.messages, @@ -23,10 +24,9 @@ export class OllamaChatJobBuilder extends ChatJobBuilder { method: "POST", body: JSON.stringify(requestBody), }); - }; + } - handleResponse = async (response: Response) => { - const json = await response.json(); - return json; - }; + handleResponse(response: Response) { + return response.json(); + } } diff --git a/src/providers/ollama/embedding.ts b/src/providers/ollama/embedding.ts index 4f705ad..cbaaf6c 100644 --- a/src/providers/ollama/embedding.ts +++ b/src/providers/ollama/embedding.ts @@ -1,14 +1,15 @@ import { EmbeddingJobBuilder } from "~/jobs/embedding"; import type { JobOptions } from "~/jobs/schema"; +import type { OllamaEmbeddingJob } from "./schema"; -export class OllamaEmbeddingJobBuilder extends EmbeddingJobBuilder { +export class OllamaEmbeddingJobBuilder extends EmbeddingJobBuilder { constructor(options: JobOptions, model: string) { super(model); this.provider = "ollama"; this.options = options; } - makeRequest = () => { + makeRequest() { return new Request("http://localhost:11434/api/embed", { method: "POST", body: JSON.stringify({ @@ -16,10 +17,10 @@ export class OllamaEmbeddingJobBuilder extends EmbeddingJobBuilder { input: this.input.value, }), }); - }; + } - handleResponse = async (response: Response) => { + async handleResponse(response: Response) { const raw = await response.json(); - return { raw, embeddings: raw.embeddings }; - }; + return { raw, embedding: raw.embedding }; + } } diff --git a/src/providers/ollama/models.ts b/src/providers/ollama/models.ts index 10d3ba8..0e2f3e1 100644 --- a/src/providers/ollama/models.ts +++ b/src/providers/ollama/models.ts @@ -1,19 +1,20 @@ import type { JobOptions } from "~/jobs/schema"; import { ModelsJobBuilder } from "~/jobs/models"; +import type { OllamaModelsJob } from "./schema"; -export class OllamaModelsJobBuilder extends ModelsJobBuilder { +export class OllamaModelsJobBuilder extends ModelsJobBuilder { constructor(options: JobOptions) { super(); this.provider = "ollama"; this.options = options; } - makeRequest = () => { + makeRequest() { return new Request("http://localhost:11434/api/tags", { method: "GET" }); - }; + } - handleResponse = async (response: Response) => { + async handleResponse(response: Response) { const json = await response.json(); - return json; - }; + return { raw: json, models: [] }; + } } diff --git a/src/providers/ollama/schema.ts b/src/providers/ollama/schema.ts index 5b7ca83..d6e5913 100644 --- a/src/providers/ollama/schema.ts +++ b/src/providers/ollama/schema.ts @@ -8,19 +8,20 @@ export const OllamaBaseJobSchema = z.object({ }); export const OllamaChatJobSchema = ChatJobSchema.extend(OllamaBaseJobSchema); -export type OllamaChatJob = z.infer; export const OllamaEmbeddingJobSchema = EmbeddingJobSchema.extend(OllamaBaseJobSchema); -export type OllamaEmbeddingJob = z.infer; export const OllamaModelsJobSchema = ModelsJobSchema.extend(OllamaBaseJobSchema); -export type OllamaModelsJob = z.infer; export const OllamaJobSchema = z.discriminatedUnion("type", [ OllamaChatJobSchema, OllamaEmbeddingJobSchema, OllamaModelsJobSchema, ]); + export type OllamaJob = z.infer; +export type OllamaChatJob = z.infer; +export type OllamaEmbeddingJob = z.infer; +export type OllamaModelsJob = z.infer; diff --git a/src/providers/openai/chat.ts b/src/providers/openai/chat.ts index eb3fd0e..014fd41 100644 --- a/src/providers/openai/chat.ts +++ b/src/providers/openai/chat.ts @@ -1,10 +1,23 @@ import { z } from "zod"; -import { ChatJobBuilder, convertTools } from "~/jobs/chat"; +import { ChatJobBuilder, type ChatOutput } from "~/jobs/chat"; import type { JobOptions } from "~/jobs/schema"; -import { jobStream } from "~/jobs/stream"; -import { OPENAI_BASE_URL } from "./schema"; +import { OPENAI_BASE_URL, type OpenAIChatJob } from "./schema"; +import { EventSourceParserStream } from "eventsource-parser/stream"; +import type { + OpenAIResponse, + OpenAIStreamResponse, + OpenAIToolCall, + OpenAIDelta, + FormattedToolCall, + OpenAIChatCompletionRequest, + OpenAIMessageRequest, + OpenAIResponseFormat, + OpenAIJSONSchemaResponseFormat, + OpenAIToolDefinition, +} from "./types"; +import { convertToolChoice, convertTools } from "./utils"; -export class OpenAIChatJobBuilder extends ChatJobBuilder { +export class OpenAIChatJobBuilder extends ChatJobBuilder { constructor(options: JobOptions, model: string) { super(model); this.provider = "openai"; @@ -12,8 +25,8 @@ export class OpenAIChatJobBuilder extends ChatJobBuilder { } makeRequest = () => { - const baseURL = this.options.baseURL || OPENAI_BASE_URL; - const messages = this.input.messages; + const baseURL = this.options!.baseURL || OPENAI_BASE_URL; + const messages = this.input.messages as unknown as OpenAIMessageRequest[]; if (this.input.system) { messages.unshift({ @@ -21,34 +34,53 @@ export class OpenAIChatJobBuilder extends ChatJobBuilder { content: this.input.system, }); } - const requestBody = { - messages: messages, + + const requestBody: OpenAIChatCompletionRequest = { + messages, model: this.input.model, - temperature: this.input.temperature, stream: this.input.stream, - response_format: this.input.responseFormat, - } as any; + }; + + if (this.input.temperature !== undefined) { + requestBody.temperature = this.input.temperature; + } + + if (this.input.maxTokens !== undefined) { + requestBody.max_tokens = this.input.maxTokens; + } + + if (this.input.topP !== undefined) { + requestBody.top_p = this.input.topP; + } if (this.input.tools && this.input.tools.length) { requestBody.tools = convertTools(this.input.tools); - requestBody.tool_choice = this.input.toolChoice; + + if (this.input.toolChoice) { + requestBody.tool_choice = convertToolChoice(this.input.toolChoice); + } + } + + if (this.input.responseFormat) { + requestBody.response_format = this.input + .responseFormat as OpenAIResponseFormat; } if (this.input.jsonSchema) { const schema = z.toJSONSchema(this.input.jsonSchema.schema); requestBody.response_format = { type: "json_schema", - json_schema: { + schema: { name: this.input.jsonSchema.name, description: this.input.jsonSchema.description, - schema: schema, + ...schema, }, - }; + } as OpenAIJSONSchemaResponseFormat; } return new Request(`${baseURL}/chat/completions`, { headers: { - Authorization: `Bearer ${this.options.apiKey}`, + Authorization: `Bearer ${this.options!.apiKey}`, "Content-Type": "application/json", }, method: "POST", @@ -56,17 +88,118 @@ export class OpenAIChatJobBuilder extends ChatJobBuilder { }); }; - handleResponse = async (response: Response) => { - if (this.input.stream) { - return jobStream(response); + async handleResponse(response: Response) { + // Handle non-streaming response + const raw = (await response.json()) as OpenAIResponse; + + // Record cost/usage information + if (raw.usage) { + this.cost = { + promptTokens: raw.usage.prompt_tokens, + completionTokens: raw.usage.completion_tokens, + totalTokens: raw.usage.total_tokens, + }; } - const raw = await response.json(); - this.cost = { - promptTokens: raw.usage.prompt_tokens, - completionTokens: raw.usage.completion_tokens, - totalTokens: raw.usage.total_tokens, + // Format the response to match ChatOutput schema + const choice = raw.choices?.[0]; + const message = choice?.message; + + const output: ChatOutput = { + raw, }; - return { raw }; - }; + + if (message) { + output.message = { + role: "assistant", + content: message.content ?? "", + }; + + // Handle tool calls if present + if (message.tool_calls && message.tool_calls.length > 0) { + output.message.tool_calls = message.tool_calls.map( + (call: OpenAIToolCall): FormattedToolCall => { + const result: FormattedToolCall = { + id: call.id, + type: call.type, + name: call.function?.name || "", + arguments: {} as Record, + }; + + if (call.function) { + // Safely handle arguments + if (call.function.arguments) { + try { + result.arguments = JSON.parse(call.function.arguments); + } catch (e) { + // If parsing fails, return as string + result.arguments = { _raw: call.function.arguments }; + } + } + } + + return result; + }, + ); + } + } + + return output; + } + + async *handleStream(response: Response) { + const eventStream = response + .body!.pipeThrough(new TextDecoderStream()) + .pipeThrough(new EventSourceParserStream()); + const reader = eventStream.getReader(); + for (;;) { + const { done, value } = await reader.read(); + if (done || value.data === "[DONE]") { + break; + } + + const chunk = JSON.parse(value.data) as OpenAIStreamResponse; + const delta = chunk.choices?.[0]?.delta; + + // Create a properly formatted ChatOutput object + const output: ChatOutput = { raw: chunk }; + + if (delta) { + output.message = { + role: "assistant", + content: delta.content || "", + }; + + // Handle tool calls in stream + if (delta.tool_calls && delta.tool_calls.length > 0) { + output.message.tool_calls = delta.tool_calls.map( + (call): FormattedToolCall => { + const result: FormattedToolCall = { + id: call.id || "", + type: call.type || "", + name: call.function?.name || "", + arguments: {} as Record, + }; + + if (call.function) { + // Safely handle arguments which might be partial JSON in streaming + if (call.function.arguments) { + try { + result.arguments = JSON.parse(call.function.arguments); + } catch (e) { + // If parsing fails (for partial JSON), return as string + result.arguments = { _raw: call.function.arguments }; + } + } + } + + return result; + }, + ); + } + } + + yield output; + } + } } diff --git a/src/providers/openai/embedding.ts b/src/providers/openai/embedding.ts index 59a24b8..cfe783f 100644 --- a/src/providers/openai/embedding.ts +++ b/src/providers/openai/embedding.ts @@ -1,8 +1,9 @@ import { EmbeddingJobBuilder } from "~/jobs/embedding"; import type { JobOptions } from "~/jobs/schema"; import { OPENAI_BASE_URL } from "./schema"; +import type { OpenAIEmbeddingJob } from "./schema"; -export class OpenAIEmbeddingJobBuilder extends EmbeddingJobBuilder { +export class OpenAIEmbeddingJobBuilder extends EmbeddingJobBuilder { constructor(options: JobOptions, model: string) { super(model); this.provider = "openai"; @@ -10,10 +11,10 @@ export class OpenAIEmbeddingJobBuilder extends EmbeddingJobBuilder { } makeRequest = () => { - const baseURL = this.options.baseURL || OPENAI_BASE_URL; + const baseURL = this.options!.baseURL || OPENAI_BASE_URL; return new Request(`${baseURL}/embeddings`, { headers: { - Authorization: `Bearer ${this.options.apiKey}`, + Authorization: `Bearer ${this.options!.apiKey}`, "Content-Type": "application/json", }, method: "POST", diff --git a/src/providers/openai/image.ts b/src/providers/openai/image.ts index 358321e..88f8bd0 100644 --- a/src/providers/openai/image.ts +++ b/src/providers/openai/image.ts @@ -1,32 +1,92 @@ -import { ImageJobBuilder } from "~/jobs/image"; import type { JobOptions } from "~/jobs/schema"; -import { OPENAI_BASE_URL } from "./schema"; +import { ImageJobBuilder } from "~/jobs/image"; +import { OPENAI_BASE_URL, type OpenAIImageJob } from "./schema"; -export class OpenAIImageJobBuilder extends ImageJobBuilder { +export class OpenAIImageJobBuilder extends ImageJobBuilder { constructor(options: JobOptions, model: string) { super(model); this.provider = "openai"; this.options = options; } - makeRequest = () => { - const baseURL = this.options.baseURL || OPENAI_BASE_URL; - return new Request(`${baseURL}/image/generations`, { + makeRequest() { + const baseURL = this.options!.baseURL || OPENAI_BASE_URL; + + if (this.input.images && this.input.images.length > 0) { + return this.makeEditRequest(baseURL); + } + + const url = `${baseURL}/images/generations`; + const body = { + prompt: this.input.prompt, + model: this.input.model, + n: this.input.n, + quality: this.input.quality, + output_format: this.input.outputFormat, + size: this.input.size, + style: this.input.style, + user: this.input.user, + }; + + return new Request(url, { headers: { - Authorization: `Bearer ${this.options.apiKey}`, + Authorization: `Bearer ${this.options!.apiKey}`, "Content-Type": "application/json", }, method: "POST", - body: JSON.stringify({ - prompt: this.input.prompt, - model: this.input.model, - n: this.input.n, - quality: this.input.quality, - response_format: this.input.responseFormat, - size: this.input.size, - style: this.input.style, - user: this.input.user, - }), + body: JSON.stringify(body), + }); + } + + makeEditRequest(baseURL: string) { + const url = `${baseURL}/images/edits`; + + const formData = new FormData(); + formData.append("prompt", this.input.prompt || ""); + formData.append("model", this.input.model); + + for (const image of this.input.images!) { + formData.append("image[]", image, image.name); + } + + if (this.input.mask) { + formData.append("mask", this.input.mask, "mask.png"); + } + + if (this.input.quality) { + formData.append("quality", String(this.input.quality)); + } + + if (this.input.n) { + formData.append("n", String(this.input.n)); + } + if (this.input.size) { + formData.append("size", String(this.input.size)); + } + if (this.input.responseFormat) { + formData.append("response_format", this.input.responseFormat); + } + if (this.input.user) { + formData.append("user", this.input.user); + } + + return new Request(url, { + headers: { + Authorization: `Bearer ${this.options!.apiKey}`, + }, + method: "POST", + body: formData, }); - }; + } + + async handleResponse(response: Response) { + const raw = await response.json(); + // TODO: handle raw.images + return { + raw, + images: raw.data.map((image: any) => + image.url ? { url: image.url } : { base64: image.b64_json }, + ), + }; + } } diff --git a/src/providers/openai/index.ts b/src/providers/openai/index.ts index b566148..fdebc0f 100644 --- a/src/providers/openai/index.ts +++ b/src/providers/openai/index.ts @@ -3,15 +3,12 @@ import { OpenAIChatJobBuilder } from "~/providers/openai/chat"; import { OpenAIImageJobBuilder } from "~/providers/openai/image"; import { OpenAIEmbeddingJobBuilder } from "~/providers/openai/embedding"; import { OpenAIModelsJobBuilder } from "~/providers/openai/models"; +import { OpenAISpeechJobBuilder } from "~/providers/openai/speech"; export function openai(options?: JobOptions) { options = options || {}; options.apiKey = options.apiKey || process.env.OPENAI_API_KEY; - if (!options.apiKey) { - throw new Error("OpenAI API key is required"); - } - return { chat(model: string) { return new OpenAIChatJobBuilder(options, model); @@ -25,6 +22,9 @@ export function openai(options?: JobOptions) { models() { return new OpenAIModelsJobBuilder(options); }, + speech(model: string) { + return new OpenAISpeechJobBuilder(options, model); + }, }; } diff --git a/src/providers/openai/models.ts b/src/providers/openai/models.ts index 13e3fd7..707f5a2 100644 --- a/src/providers/openai/models.ts +++ b/src/providers/openai/models.ts @@ -1,26 +1,41 @@ import type { JobOptions } from "~/jobs/schema"; import { ModelsJobBuilder } from "~/jobs/models"; -import { OPENAI_BASE_URL } from "./schema"; +import { OPENAI_BASE_URL, type OpenAIModelsJob } from "./schema"; -export class OpenAIModelsJobBuilder extends ModelsJobBuilder { +export class OpenAIModelsJobBuilder extends ModelsJobBuilder { constructor(options: JobOptions) { super(); this.provider = "openai"; this.options = options; } - makeRequest = () => { - const baseURL = this.options.baseURL || OPENAI_BASE_URL; + makeRequest() { + const baseURL = this.options!.baseURL || OPENAI_BASE_URL; return new Request(`${baseURL}/models`, { headers: { - Authorization: `Bearer ${this.options.apiKey}`, + Authorization: `Bearer ${this.options!.apiKey}`, "Content-Type": "application/json", }, method: "GET", }); - }; + } - handleResponse = async (response: Response) => { - return await response.json(); - }; + async handleResponse(response: Response) { + const raw: { + data: { + id: string; + object: string; + created: number; + owned_by: string; + }[]; + } = await response.json(); + return { + raw: raw, + models: raw.data.map((model) => ({ + id: model.id, + created: model.created, + owned_by: model.owned_by, + })), + }; + } } diff --git a/src/providers/openai/schema.ts b/src/providers/openai/schema.ts index 957abf8..204f3e9 100644 --- a/src/providers/openai/schema.ts +++ b/src/providers/openai/schema.ts @@ -3,31 +3,38 @@ import { ChatJobSchema } from "~/jobs/chat"; import { EmbeddingJobSchema } from "~/jobs/embedding"; import { ImageJobSchema } from "~/jobs/image"; import { ModelsJobSchema } from "~/jobs/models"; +import { SpeechJobSchema } from "~/jobs/speech"; export const OPENAI_BASE_URL = "https://api.openai.com/v1"; -export const BaseOpenAIJobSchema = z.object({ +export const OpenAIBaseJobSchema = z.object({ provider: z.literal("openai"), }); -export const OpenAIChatJobSchema = ChatJobSchema.extend(BaseOpenAIJobSchema); -export type OpenAIChatJob = z.infer; +export const OpenAIChatJobSchema = ChatJobSchema.extend(OpenAIBaseJobSchema); export const OpenAIEmbeddingJobSchema = - EmbeddingJobSchema.extend(BaseOpenAIJobSchema); -export type OpenAIEmbeddingJob = z.infer; + EmbeddingJobSchema.extend(OpenAIBaseJobSchema); -export const OpenAIImageJobSchema = ImageJobSchema.extend(BaseOpenAIJobSchema); -export type OpenAIImageJob = z.infer; +export const OpenAIImageJobSchema = ImageJobSchema.extend(OpenAIBaseJobSchema); export const OpenAIModelsJobSchema = - ModelsJobSchema.extend(BaseOpenAIJobSchema); -export type OpenAIModelsJob = z.infer; + ModelsJobSchema.extend(OpenAIBaseJobSchema); + +export const OpenAISpeechJobSchema = + SpeechJobSchema.extend(OpenAIBaseJobSchema); export const OpenAIJobSchema = z.discriminatedUnion("type", [ OpenAIChatJobSchema, OpenAIEmbeddingJobSchema, OpenAIImageJobSchema, OpenAIModelsJobSchema, + OpenAISpeechJobSchema, ]); + export type OpenAIJob = z.infer; +export type OpenAIChatJob = z.infer; +export type OpenAIEmbeddingJob = z.infer; +export type OpenAIImageJob = z.infer; +export type OpenAIModelsJob = z.infer; +export type OpenAISpeechJob = z.infer; diff --git a/src/providers/openai/speech.ts b/src/providers/openai/speech.ts new file mode 100644 index 0000000..fa6670d --- /dev/null +++ b/src/providers/openai/speech.ts @@ -0,0 +1,20 @@ +import { type JobOptions } from "~/jobs/schema"; +import { SpeechJobBuilder } from "~/jobs/speech"; +import type { OpenAISpeechJob } from "./schema"; + +export class OpenAISpeechJobBuilder extends SpeechJobBuilder { + constructor(options: JobOptions, model: string) { + super(model); + this.provider = "openai"; + this.options = options; + } + + makeRequest(): Request { + return new Request("https://api.openai.com/v1", {}); + } + + async handleResponse(response: Response) { + const raw = await response.json(); + return { raw }; + } +} diff --git a/src/providers/openai/types.ts b/src/providers/openai/types.ts new file mode 100644 index 0000000..a1f0bb8 --- /dev/null +++ b/src/providers/openai/types.ts @@ -0,0 +1,184 @@ +/** + * Type definitions for OpenAI API responses + * Based on https://platform.openai.com/docs/api-reference/chat/create + */ + +// Common types +export interface OpenAIFunctionCall { + name: string; + arguments: string; +} + +export interface OpenAIToolCall { + id: string; + type: string; + function: OpenAIFunctionCall; +} + +// Adding our internal format for tool calls that matches what we send to clients +export interface FormattedToolCall { + id: string; + type: string; + name: string; + arguments: Record; +} + +export interface OpenAIMessage { + role: string; + content: string | null; + tool_calls?: OpenAIToolCall[]; +} + +export interface OpenAIChoice { + index: number; + message: OpenAIMessage; + logprobs: null | { + content: Array<{ + token: string; + logprob: number; + top_logprobs: Record; + }>; + }; + finish_reason: string; +} + +export interface OpenAIUsage { + prompt_tokens: number; + completion_tokens: number; + total_tokens: number; +} + +// Non-streaming response +export interface OpenAIResponse { + id: string; + object: string; + created: number; + model: string; + choices: OpenAIChoice[]; + usage: OpenAIUsage; + system_fingerprint: string; +} + +// Streaming response types +export interface OpenAIDelta { + role?: string; + content?: string; + tool_calls?: Array<{ + id?: string; + type?: string; + index?: number; + function?: { + name?: string; + arguments?: string; + }; + }>; +} + +export interface OpenAIStreamChoice { + index: number; + delta: OpenAIDelta; + logprobs: null | { + content: Array<{ + token: string; + logprob: number; + top_logprobs: Record; + }>; + }; + finish_reason: string | null; +} + +export interface OpenAIStreamResponse { + id: string; + object: string; + created: number; + model: string; + choices: OpenAIStreamChoice[]; + system_fingerprint: string; +} + +// Request types + +// Text or image content +export type OpenAIMessageContent = + | { type: "text"; text: string } + | { + type: "image_url"; + image_url: { + url: string; + detail?: "auto" | "low" | "high"; + } + } + | { + type: "image"; + image_url?: string; + source?: { + type: "base64"; + data: string; + media_type: string; + }; + }; + +// Message request with flexible content type +export interface OpenAIMessageRequest { + role: "system" | "user" | "assistant" | "tool"; + content: string | null | Array; + name?: string; + tool_calls?: Array<{ + type?: string; + name?: string; + id?: string; + call_id?: string; + arguments: Record; + }>; + tool_call_id?: string; +} + +export interface OpenAIFunctionDefinition { + name: string; + description?: string; + parameters: Record; +} + +export interface OpenAIToolDefinition { + type: "function"; + function: OpenAIFunctionDefinition; +} + +export type OpenAIToolChoice = + | "none" + | "auto" + | { type: "function"; function: { name: string } }; + +export interface OpenAIResponseFormat { + type: "text" | "json_object"; +} + +export interface OpenAIJSONSchemaResponseFormat { + type: "json_schema"; + schema: Record; +} + +export interface OpenAILogprobs { + top_logprobs: number; +} + +export interface OpenAIChatCompletionRequest { + model: string; + messages: OpenAIMessageRequest[]; + frequency_penalty?: number; + logit_bias?: Record; + logprobs?: boolean; + top_logprobs?: number; + max_tokens?: number; + n?: number; + presence_penalty?: number; + response_format?: OpenAIResponseFormat | OpenAIJSONSchemaResponseFormat; + seed?: number; + stop?: string | string[]; + stream?: boolean; + temperature?: number; + top_p?: number; + tools?: OpenAIToolDefinition[]; + tool_choice?: OpenAIToolChoice; + user?: string; +} \ No newline at end of file diff --git a/src/providers/openai/utils.ts b/src/providers/openai/utils.ts new file mode 100644 index 0000000..0972ca3 --- /dev/null +++ b/src/providers/openai/utils.ts @@ -0,0 +1,37 @@ +import z from "zod"; +import type { ChatToolChoiceSchema, ChatToolSchema } from "~/jobs/chat"; +import type { OpenAIToolDefinition, OpenAIToolChoice } from "./types"; +export function convertTools(tools: z.infer[]): OpenAIToolDefinition[] { + return tools.map((tool) => ({ + type: "function", + function: { + name: tool.name, + description: tool.description, + parameters: z.toJSONSchema(tool.parameters), + }, + })); +} + + +export function convertToolChoice(toolChoice: z.infer): OpenAIToolChoice { + if (!toolChoice) return "auto"; + + switch (toolChoice.mode) { + case "none": + return "none"; + case "auto": + return "auto"; + case "any": + // If allowed_tools is specified and has exactly one tool, specify that tool + if (toolChoice.allowed_tools && toolChoice.allowed_tools.length === 1) { + return { + type: "function", + function: { name: toolChoice.allowed_tools[0] } + }; + } + // Otherwise, use auto + return "auto"; + default: + return "auto"; + } +} diff --git a/src/providers/together/index.ts b/src/providers/together/index.ts index 3f1def0..aa6139c 100644 --- a/src/providers/together/index.ts +++ b/src/providers/together/index.ts @@ -12,7 +12,7 @@ export function together(options?: JobOptions) { ...options, baseURL: "https://api.together.xyz/v1", }, - model + model, ); }, }; diff --git a/src/providers/voyage/index.ts b/src/providers/voyage/index.ts index ec04faa..d59565d 100644 --- a/src/providers/voyage/index.ts +++ b/src/providers/voyage/index.ts @@ -14,6 +14,7 @@ export const VoyageJobSchema = z.discriminatedUnion("type", [ ]); export type VoyageJob = z.infer; +export type VoyageEmbeddingJob = z.infer; export function voyage(options?: JobOptions) { options = options || {}; @@ -26,7 +27,7 @@ export function voyage(options?: JobOptions) { }; } -export class VoyageEmbeddingJobBuilder extends EmbeddingJobBuilder { +export class VoyageEmbeddingJobBuilder extends EmbeddingJobBuilder { constructor(options: JobOptions, model: string) { super(model); this.provider = "voyage"; @@ -37,7 +38,7 @@ export class VoyageEmbeddingJobBuilder extends EmbeddingJobBuilder { return new Request("https://api.voyageai.com/v1/embeddings", { method: "POST", headers: { - Authorization: `Bearer ${this.options.apiKey}`, + Authorization: `Bearer ${this.options!.apiKey}`, "Content-Type": "application/json", }, body: JSON.stringify({ @@ -47,8 +48,9 @@ export class VoyageEmbeddingJobBuilder extends EmbeddingJobBuilder { }); }; - handleResponse = async (response: Response) => { + async handleResponse(response: Response) { const json = await response.json(); - return json; - }; + //TODO: handle raw.embedding + return { raw: json, embedding: json.embedding }; + } } diff --git a/src/workflow/index.ts b/src/workflow/index.ts new file mode 100644 index 0000000..c0c8c54 --- /dev/null +++ b/src/workflow/index.ts @@ -0,0 +1,28 @@ +import { z } from "zod"; + +export function workflow(name: string) { + return new Workflow(name); +} + +class Workflow { + name: string; + _input?: z.ZodType; + steps: any[] = []; + + constructor(name: string) { + this.name = name; + } + + input(inputSchema: z.ZodType) { + this._input = inputSchema; + return this; + } + + async run() {} + + step(name: string, fn: any) { + return this; + } + + dump() {} +} diff --git a/test/__snapshots__/speech.test.ts.snap b/test/__snapshots__/speech.test.ts.snap new file mode 100644 index 0000000..c24d6e4 --- /dev/null +++ b/test/__snapshots__/speech.test.ts.snap @@ -0,0 +1,53 @@ +// Bun Snapshot v1, https://goo.gl/fbAQLP + +exports[`dump 1`] = ` +{ + "cost": undefined, + "input": { + "model": "tts-1", + }, + "options": { + "apiKey": undefined, + }, + "output": undefined, + "performance": undefined, + "provider": "openai", + "type": "speech", + "version": "0.3.0", +} +`; + +exports[`dump 2`] = ` +{ + "cost": undefined, + "input": { + "model": "model", + }, + "options": { + "apiKey": undefined, + }, + "output": undefined, + "performance": undefined, + "provider": "elevenlabs", + "type": "speech", + "version": "0.3.0", +} +`; + +exports[`speech 1`] = ` +{ + "body": "", + "headers": {}, + "method": "GET", + "url": "https://api.openai.com/v1", +} +`; + +exports[`speech 2`] = ` +{ + "body": "", + "headers": {}, + "method": "GET", + "url": "https://api.elevenlabs.io/v1/text-to-speech", +} +`; diff --git a/test/__snapshots__/workflow.test.ts.snap b/test/__snapshots__/workflow.test.ts.snap new file mode 100644 index 0000000..f91134b --- /dev/null +++ b/test/__snapshots__/workflow.test.ts.snap @@ -0,0 +1,3 @@ +// Bun Snapshot v1, https://goo.gl/fbAQLP + +exports[`workflow 1`] = `undefined`; diff --git a/test/chat.test.ts b/test/chat.test.ts index 2719be1..1609570 100644 --- a/test/chat.test.ts +++ b/test/chat.test.ts @@ -18,8 +18,8 @@ test("chat", async () => { job .messages([system("you are a helpful assistant"), user("hi")]) .temperature(0.5) - .makeRequest() - ) + .makeRequest(), + ), ).toMatchSnapshot(); } }); @@ -30,9 +30,8 @@ test("stream", async () => { await requestObject( job .messages([system("you are a helpful assistant"), user("hi")]) - .stream() - .makeRequest() - ) + .makeRequest(), + ), ).toMatchSnapshot(); } }); @@ -58,8 +57,8 @@ test("json_object", async () => { .chat("gpt-4o-mini") .messages([user("hi")]) .responseFormat({ type: "json_object" }) - .makeRequest() - ) + .makeRequest(), + ), ).toMatchSnapshot(); }); @@ -70,7 +69,7 @@ test("tool", async () => { z.object({ location: z.string(), unit: z.enum(["celsius", "fahrenheit"]).optional(), - }) + }), ); for (const job of createJobs()) { @@ -79,8 +78,8 @@ test("tool", async () => { job .tool(weatherTool) .messages([user("What's the weather like in Boston today?")]) - .makeRequest() - ) + .makeRequest(), + ), ).toMatchSnapshot(); } }); @@ -99,8 +98,8 @@ test("jsonSchema", async () => { user("generate a person with name and age in json format"), ]) .jsonSchema(personSchema, "person") - .makeRequest() - ) + .makeRequest(), + ), ).toMatchSnapshot(); } }); diff --git a/test/speech.test.ts b/test/speech.test.ts new file mode 100644 index 0000000..77eea66 --- /dev/null +++ b/test/speech.test.ts @@ -0,0 +1,31 @@ +import { test, expect } from "bun:test"; +import { load, elevenlabs, openai } from "../src"; +import { requestObject } from "./utils"; + +function createJobs() { + // prettier-ignore + return [ + openai().speech('tts-1'), + elevenlabs().speech("model"), + ] +} + +test("speech", async () => { + for (const job of createJobs()) { + expect(await requestObject(job.makeRequest())).toMatchSnapshot(); + } +}); + +test("dump", () => { + for (const job of createJobs()) { + expect(job.dump()).toMatchSnapshot(); + } +}); + +test("load", async () => { + for (const job of createJobs()) { + const req1 = await requestObject(load(job.dump()).makeRequest!()); + const req2 = await requestObject(job.makeRequest()); + expect(req1).toEqual(req2); + } +}); diff --git a/test/workflow.test.ts b/test/workflow.test.ts new file mode 100644 index 0000000..b41f357 --- /dev/null +++ b/test/workflow.test.ts @@ -0,0 +1,33 @@ +import { test, expect } from "bun:test"; +import { z } from "zod"; +import { elevenlabs, load, openai, workflow } from "../src"; + +test("workflow", () => { + const flow1 = workflow("workflow1") + .input( + z.object({ + description: z.string(), + }), + ) + .step("step1", ({ context }) => { + return openai() + .chat("gpt-4o-mini") + .prompt( + `generate a story based on following description: ${context.input.description}`, + ) + .jsonSchema( + z.object({ + story: z.string(), + }), + ); + }) + .step("step2", ({ context }) => { + return elevenlabs() + .tts("eleven_multilingual_v2") + .text(context.steps.step1.story); + }); + + const payload = flow1.dump(); + expect(payload).toMatchSnapshot(); + // expect(load(payload)).toEqual(flow1); +});