diff --git a/libs/partners/anthropic/langchain_anthropic/chat_models.py b/libs/partners/anthropic/langchain_anthropic/chat_models.py index 581008d3f1b9d..c8f4520272393 100644 --- a/libs/partners/anthropic/langchain_anthropic/chat_models.py +++ b/libs/partners/anthropic/langchain_anthropic/chat_models.py @@ -1291,6 +1291,26 @@ def _get_request_payload( ) -> dict: messages = self._convert_input(input_).to_messages() system, formatted_messages = _format_messages(messages) + + # If cache_control is provided in kwargs, add it to last message + # and content block. + if "cache_control" in kwargs and formatted_messages: + if isinstance(formatted_messages[-1]["content"], list): + formatted_messages[-1]["content"][-1]["cache_control"] = kwargs.pop( + "cache_control" + ) + elif isinstance(formatted_messages[-1]["content"], str): + formatted_messages[-1]["content"] = [ + { + "type": "text", + "text": formatted_messages[-1]["content"], + "cache_control": kwargs.pop("cache_control"), + } + ] + else: + pass + _ = kwargs.pop("cache_control", None) + payload = { "model": self.model, "max_tokens": self.max_tokens, diff --git a/libs/partners/anthropic/tests/unit_tests/test_chat_models.py b/libs/partners/anthropic/tests/unit_tests/test_chat_models.py index 2d418b6bddda6..963935272bd91 100644 --- a/libs/partners/anthropic/tests/unit_tests/test_chat_models.py +++ b/libs/partners/anthropic/tests/unit_tests/test_chat_models.py @@ -1056,3 +1056,50 @@ def mock_create(*args: Any, **kwargs: Any) -> Message: # Test headers are correctly propagated to request payload = llm._get_request_payload([input_message]) assert payload["mcp_servers"][0]["authorization_token"] == "PLACEHOLDER" + + +def test_cache_control_kwarg() -> None: + llm = ChatAnthropic(model="claude-3-5-haiku-latest") + + messages = [HumanMessage("foo"), AIMessage("bar"), HumanMessage("baz")] + payload = llm._get_request_payload(messages) + assert payload["messages"] == [ + {"role": "user", "content": "foo"}, + {"role": "assistant", "content": "bar"}, + {"role": "user", "content": "baz"}, + ] + + payload = llm._get_request_payload(messages, cache_control={"type": "ephemeral"}) + assert payload["messages"] == [ + {"role": "user", "content": "foo"}, + {"role": "assistant", "content": "bar"}, + { + "role": "user", + "content": [ + {"type": "text", "text": "baz", "cache_control": {"type": "ephemeral"}} + ], + }, + ] + + messages = [ + HumanMessage("foo"), + AIMessage("bar"), + HumanMessage( + content=[ + {"type": "text", "text": "baz"}, + {"type": "text", "text": "qux"}, + ] + ), + ] + payload = llm._get_request_payload(messages, cache_control={"type": "ephemeral"}) + assert payload["messages"] == [ + {"role": "user", "content": "foo"}, + {"role": "assistant", "content": "bar"}, + { + "role": "user", + "content": [ + {"type": "text", "text": "baz"}, + {"type": "text", "text": "qux", "cache_control": {"type": "ephemeral"}}, + ], + }, + ]