Skip to content
Open
20 changes: 20 additions & 0 deletions litellm/litellm_core_utils/prompt_templates/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3186,8 +3186,18 @@ async def _bedrock_converse_messages_pt_async( # noqa: PLR0915
tool_content: List[BedrockContentBlock] = []
while msg_i < len(messages) and messages[msg_i]["role"] == "tool":
tool_call_result = _convert_to_bedrock_tool_call_result(messages[msg_i])
_cache_point_block = (
litellm.AmazonConverseConfig()._get_cache_point_block(
message_block=cast(
OpenAIMessageContentListBlock, messages[msg_i]
),
block_type="content_block",
)
)

tool_content.append(tool_call_result)
if _cache_point_block is not None:
tool_content.append(_cache_point_block)
msg_i += 1
if tool_content:
# if last message was a 'user' message, then add a blank assistant message (bedrock requires alternating roles)
Expand Down Expand Up @@ -3516,8 +3526,18 @@ def _bedrock_converse_messages_pt( # noqa: PLR0915
tool_content: List[BedrockContentBlock] = []
while msg_i < len(messages) and messages[msg_i]["role"] == "tool":
tool_call_result = _convert_to_bedrock_tool_call_result(messages[msg_i])
_cache_point_block = (
litellm.AmazonConverseConfig()._get_cache_point_block(
message_block=cast(
OpenAIMessageContentListBlock, messages[msg_i]
),
block_type="content_block",
)
)

tool_content.append(tool_call_result)
if _cache_point_block is not None:
tool_content.append(_cache_point_block)
msg_i += 1
if tool_content:
# if last message was a 'user' message, then add a blank assistant message (bedrock requires alternating roles)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,81 @@ def test_bedrock_get_document_format_mimetypes_success():
)
assert result == "docx", f"Expected 'docx', got '{result}'"

def test_bedrock_converse_messages_pt():
"""
Test that the _bedrock_converse_messages_pt method transforms the tool result prompt cache correctly.
"""
from litellm.litellm_core_utils.prompt_templates.factory import _bedrock_converse_messages_pt
messages = [
{
'role': 'user',
'content': 'query the data of user_id:10000'
},
{
'role': 'assistant',
'content': '<thinking> I need to call the function to query the PV power generation data for the user with ID 10000. </thinking> ',
'tool_calls': [
{
'id': 'tooluse_Tx3Zhu0RTLawyVv7yfsAYg',
'type': 'function',
'function': {'name': 'query_pv_power', 'arguments': '{"user_id":"10000"}'}
}
]
},
{
'role': 'tool',
'tool_call_id': 'tooluse_Tx3Zhu0RTLawyVv7yfsAYg',
'content': '1000kw',
'cache_control': {'type': 'ephemeral'}
}
]
bedrock_messages = _bedrock_converse_messages_pt(
messages=messages,
model='bedrock/eu.amazon.nova-pro-v1:0',
llm_provider='bedrock_converse')

cache_point = bedrock_messages[-1]['content'][-1].get('cachePoint')
assert cache_point == {'type': 'default'}, f"Expected {{'type': 'default'}}, got {cache_point}"

def test_bedrock_converse_messages_pt_async():
"""
Test that the _bedrock_converse_messages_pt_async method in BedrockConverseMessagesProcessor class transforms the tool result prompt cache correctly.
"""
from litellm.litellm_core_utils.prompt_templates.factory import BedrockConverseMessagesProcessor
import asyncio
messages = [
{
'role': 'user',
'content': 'query the data of user_id:10000'
},
{
'role': 'assistant',
'content': '<thinking> I need to call the function to query the PV power generation data for the user with ID 10000. </thinking> ',
'tool_calls': [
{
'id': 'tooluse_Tx3Zhu0RTLawyVv7yfsAYg',
'type': 'function',
'function': {'name': 'query_pv_power', 'arguments': '{"user_id":"10000"}'}
}
]
},
{
'role': 'tool',
'tool_call_id': 'tooluse_Tx3Zhu0RTLawyVv7yfsAYg',
'content': '1000kw',
'cache_control': {'type': 'ephemeral'}
}
]
processor = BedrockConverseMessagesProcessor()

async_func = processor._bedrock_converse_messages_pt_async(
messages=messages,
model='bedrock/eu.amazon.nova-pro-v1:0',
llm_provider='bedrock_converse')
bedrock_messages = asyncio.run(async_func)
cache_point = bedrock_messages[-1]['content'][-1].get('cachePoint')
assert cache_point == {'type': 'default'}, f"Expected {{'type': 'default'}}, got {cache_point}"



# def test_ollama_pt_consecutive_system_messages():
Expand Down
2 changes: 1 addition & 1 deletion tests/test_litellm/test_cost_calculation_log_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,4 @@ def test_batch_cost_calculation_uses_debug_level(caplog):
if batch_cost_records: # May not always log depending on the code path
for record in batch_cost_records:
assert record.levelno == logging.DEBUG, \
f"Batch cost calculation log should be DEBUG level, but was {record.levelname}"
f"Batch cost calculation log should be DEBUG level, but was {record.levelname}"
Loading