Skip to content

Prompt Caching #234

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
2e84006
13: Failing specs
tpaulshippy Jun 9, 2025
be61e48
13: Get caching specs passing for Bedrock
tpaulshippy Jun 9, 2025
edec138
13: Remove comments in specs
tpaulshippy Jun 9, 2025
971f176
13: Add unused param on other providers
tpaulshippy Jun 9, 2025
557a5ee
13: Rubocop -A
tpaulshippy Jun 9, 2025
9673b13
13: Add cassettes for bedrock cache specs
tpaulshippy Jun 9, 2025
c47d270
13: Resolve Rubocop aside from Metrics/ParameterLists
tpaulshippy Jun 9, 2025
eaf0876
13: Use large enough prompt to hit cache meaningfully
tpaulshippy Jun 9, 2025
160d9ab
13: Ensure cache tokens are being used
tpaulshippy Jun 9, 2025
d1698bf
13: Refactor completion parameters
tpaulshippy Jun 9, 2025
344729f
16: Add guide for prompt caching
tpaulshippy Jun 9, 2025
7b98277
Add real anthropic cassettes ($0.03)
tpaulshippy Jun 12, 2025
fd30f14
Merge branch 'main' into prompt-caching
tpaulshippy Jun 12, 2025
a91d07e
Switch from large_prompt.txt to 10,000 of the letter a
tpaulshippy Jul 19, 2025
f40f37d
Make that 2048 * 4 (2048 tokens for Haiku)
tpaulshippy Jul 19, 2025
109bb51
Rename properties on message class
tpaulshippy Jul 19, 2025
1c6cbf7
Revert "13: Refactor completion parameters"
tpaulshippy Jul 19, 2025
4d78a09
Address rubocop
tpaulshippy Jul 19, 2025
25b3660
Merge remote-tracking branch 'origin/main' into prompt-caching
tpaulshippy Jul 19, 2025
8e80f08
Update docs
tpaulshippy Jul 19, 2025
d42d074
Actually return the payload
tpaulshippy Jul 19, 2025
97b1ace
Add support for cache token counts in gemini and openai
tpaulshippy Jul 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions docs/guides/chat.md
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,30 @@ puts "Total Conversation Tokens: #{total_conversation_tokens}"

Refer to the [Working with Models Guide]({% link guides/models.md %}) for details on accessing model-specific pricing.

## Prompt Caching

### Enabling
For Anthropic models, you can opt-in to prompt caching which is documented more fully in the [Anthropic API docs](https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching).

Enable prompt caching using the `cache_prompts` method on your chat instance:

```ruby
chat = RubyLLM.chat(model: 'claude-3-5-haiku-20241022')

# Enable caching for different types of content
chat.cache_prompts(
system: true, # Cache system instructions
user: true, # Cache user messages
tools: true # Cache tool definitions
)
```

### Checking cached token counts
For Anthropic, OpenAI, and Gemini, you can see the number of tokens read from cache by looking at the `cached_tokens` property on the output messages.

For Anthropic, you can see the tokens written to cache by looking at the `cache_creation_tokens` property.


## Chat Event Handlers

You can register blocks to be called when certain events occur during the chat lifecycle, useful for UI updates or logging.
Expand Down
7 changes: 7 additions & 0 deletions lib/ruby_llm/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def initialize(model: nil, provider: nil, assume_model_exists: false, context: n
@temperature = 0.7
@messages = []
@tools = {}
@cache_prompts = { system: false, user: false, tools: false }
@on = {
new_message: nil,
end_message: nil
Expand Down Expand Up @@ -92,12 +93,18 @@ def each(&)
messages.each(&)
end

def cache_prompts(system: false, user: false, tools: false)
@cache_prompts = { system: system, user: user, tools: tools }
self
end

def complete(&)
response = @provider.complete(
messages,
tools: @tools,
temperature: @temperature,
model: @model.id,
cache_prompts: @cache_prompts.dup,
connection: @connection,
&wrap_streaming_block(&)
)
Expand Down
9 changes: 7 additions & 2 deletions lib/ruby_llm/message.rb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ module RubyLLM
class Message
ROLES = %i[system user assistant tool].freeze

attr_reader :role, :tool_calls, :tool_call_id, :input_tokens, :output_tokens, :model_id
attr_reader :role, :tool_calls, :tool_call_id, :input_tokens, :output_tokens, :model_id,
:cached_tokens, :cache_creation_tokens

def initialize(options = {})
@role = options.fetch(:role).to_sym
Expand All @@ -17,6 +18,8 @@ def initialize(options = {})
@output_tokens = options[:output_tokens]
@model_id = options[:model_id]
@tool_call_id = options[:tool_call_id]
@cached_tokens = options[:cached_tokens]
@cache_creation_tokens = options[:cache_creation_tokens]

ensure_valid_role
end
Expand Down Expand Up @@ -49,7 +52,9 @@ def to_h
tool_call_id: tool_call_id,
input_tokens: input_tokens,
output_tokens: output_tokens,
model_id: model_id
model_id: model_id,
cache_creation_tokens: cache_creation_tokens,
cached_tokens: cached_tokens
}.compact
end

Expand Down
4 changes: 3 additions & 1 deletion lib/ruby_llm/provider.rb
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@ module Provider
module Methods
extend Streaming

def complete(messages, tools:, temperature:, model:, connection:, &)
def complete(messages, tools:, temperature:, model:, connection:, # rubocop:disable Metrics/ParameterLists
cache_prompts: { system: false, user: false, tools: false }, &)
normalized_temperature = maybe_normalize_temperature(temperature, model)

payload = render_payload(messages,
tools: tools,
temperature: normalized_temperature,
model: model,
cache_prompts: cache_prompts,
stream: block_given?)

if block_given?
Expand Down
56 changes: 35 additions & 21 deletions lib/ruby_llm/providers/anthropic/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -11,42 +11,50 @@ def completion_url
'/v1/messages'
end

def render_payload(messages, tools:, temperature:, model:, stream: false)
def render_payload(messages, tools:, temperature:, model:, stream: false, # rubocop:disable Metrics/ParameterLists
cache_prompts: { system: false, user: false, tools: false })
system_messages, chat_messages = separate_messages(messages)
system_content = build_system_content(system_messages)
system_content = build_system_content(system_messages, cache: cache_prompts[:system])

build_base_payload(chat_messages, temperature, model, stream).tap do |payload|
add_optional_fields(payload, system_content:, tools:)
build_base_payload(chat_messages, temperature, model, stream, cache: cache_prompts[:user]).tap do |payload|
add_optional_fields(payload, system_content: system_content, tools: tools,
cache_tools: cache_prompts[:tools])
end
end

def separate_messages(messages)
messages.partition { |msg| msg.role == :system }
end

def build_system_content(system_messages)
if system_messages.length > 1
RubyLLM.logger.warn(
"Anthropic's Claude implementation only supports a single system message. " \
'Multiple system messages will be combined into one.'
)
def build_system_content(system_messages, cache: false)
system_messages.flat_map.with_index do |msg, idx|
cache = false unless idx == system_messages.size - 1
format_system_message(msg, cache:)
end

system_messages.map { |msg| format_message(msg)[:content] }.join("\n\n")
end

def build_base_payload(chat_messages, temperature, model, stream)
def build_base_payload(chat_messages, temperature, model, stream, cache: false)
messages = chat_messages.map.with_index do |msg, idx|
cache = false unless idx == chat_messages.size - 1
format_message(msg, cache:)
end

{
model: model,
messages: chat_messages.map { |msg| format_message(msg) },
messages: messages,
temperature: temperature,
stream: stream,
max_tokens: RubyLLM.models.find(model)&.max_tokens || 4096
}
end

def add_optional_fields(payload, system_content:, tools:)
payload[:tools] = tools.values.map { |t| Tools.function_for(t) } if tools.any?
def add_optional_fields(payload, system_content:, tools:, cache_tools: false)
if tools.any?
tool_definitions = tools.values.map { |t| Tools.function_for(t) }
tool_definitions[-1][:cache_control] = { type: 'ephemeral' } if cache_tools
payload[:tools] = tool_definitions
end

payload[:system] = system_content unless system_content.empty?
end

Expand All @@ -72,24 +80,30 @@ def build_message(data, content, tool_use)
tool_calls: Tools.parse_tool_calls(tool_use),
input_tokens: data.dig('usage', 'input_tokens'),
output_tokens: data.dig('usage', 'output_tokens'),
model_id: data['model']
model_id: data['model'],
cache_creation_tokens: data.dig('usage', 'cache_creation_input_tokens'),
cached_tokens: data.dig('usage', 'cache_read_input_tokens')
)
end

def format_message(msg)
def format_message(msg, cache: false)
if msg.tool_call?
Tools.format_tool_call(msg)
elsif msg.tool_result?
Tools.format_tool_result(msg)
else
format_basic_message(msg)
format_basic_message(msg, cache:)
end
end

def format_basic_message(msg)
def format_system_message(msg, cache: false)
Media.format_content(msg.content, cache:)
end

def format_basic_message(msg, cache: false)
{
role: convert_role(msg.role),
content: Media.format_content(msg.content)
content: Media.format_content(msg.content, cache:)
}
end

Expand Down
114 changes: 69 additions & 45 deletions lib/ruby_llm/providers/anthropic/media.rb
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ module Anthropic
module Media
module_function

def format_content(content)
return [format_text(content)] unless content.is_a?(Content)
def format_content(content, cache: false)
return [format_text(content, cache:)] unless content.is_a?(Content)

parts = []
parts << format_text(content.text) if content.text
parts << format_text(content.text, cache:) if content.text

content.attachments.each do |attachment|
case attachment.type
Expand All @@ -29,60 +29,84 @@ def format_content(content)
parts
end

def format_text(text)
{
type: 'text',
text: text
}
def format_text(text, cache: false)
with_cache_control(
{
type: 'text',
text: text
},
cache:
)
end

def format_image(image)
def format_image(image, cache: false)
if image.url?
{
type: 'image',
source: {
type: 'url',
url: image.source
}
}
with_cache_control(
{
type: 'image',
source: {
type: 'url',
url: image.source
}
},
cache:
)
else
{
type: 'image',
source: {
type: 'base64',
media_type: image.mime_type,
data: image.encoded
}
}
with_cache_control(
{
type: 'image',
source: {
type: 'base64',
media_type: image.mime_type,
data: image.encoded
}
},
cache:
)
end
end

def format_pdf(pdf)
def format_pdf(pdf, cache: false)
if pdf.url?
{
type: 'document',
source: {
type: 'url',
url: pdf.source
}
}
with_cache_control(
{
type: 'document',
source: {
type: 'url',
url: pdf.source
}
},
cache:
)
else
{
type: 'document',
source: {
type: 'base64',
media_type: pdf.mime_type,
data: pdf.encoded
}
}
with_cache_control(
{
type: 'document',
source: {
type: 'base64',
media_type: pdf.mime_type,
data: pdf.encoded
}
},
cache:
)
end
end

def format_text_file(text_file)
{
type: 'text',
text: Utils.format_text_file_for_llm(text_file)
}
def format_text_file(text_file, cache: false)
with_cache_control(
{
type: 'text',
text: Utils.format_text_file_for_llm(text_file)
},
cache:
)
end

def with_cache_control(hash, cache: false)
return hash unless cache

hash.merge(cache_control: { type: 'ephemeral' })
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Realizing this might cause errors on older models that do not support caching. If it does, we could raise here, or just let the API validation handle it. I'm torn on whether the capabilities check complexity is worth it as these models are probably so rarely used.

end
end
end
Expand Down
Loading