Skip to content

Commit b16dda7

Browse files
committed
core: Fix some missing generic types
1 parent bc1b5ff commit b16dda7

30 files changed

+227
-226
lines changed

libs/core/langchain_core/document_loaders/langsmith.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def lazy_load(self) -> Iterator[Document]:
125125
yield Document(content_str, metadata=metadata)
126126

127127

128-
def _stringify(x: Union[str, dict]) -> str:
128+
def _stringify(x: Union[str, dict[str, Any]]) -> str:
129129
if isinstance(x, str):
130130
return x
131131
try:

libs/core/langchain_core/messages/utils.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -202,13 +202,17 @@ def message_chunk_to_message(chunk: BaseMessageChunk) -> BaseMessage:
202202

203203

204204
MessageLikeRepresentation = Union[
205-
BaseMessage, list[str], tuple[str, str], str, dict[str, Any]
205+
BaseMessage,
206+
list[str],
207+
tuple[str, Union[str, list[Union[str, dict[str, Any]]]]],
208+
str,
209+
dict[str, Any],
206210
]
207211

208212

209213
def _create_message_from_message_type(
210214
message_type: str,
211-
content: str,
215+
content: Union[str, list[Union[str, dict[str, Any]]]],
212216
name: Optional[str] = None,
213217
tool_call_id: Optional[str] = None,
214218
tool_calls: Optional[list[dict[str, Any]]] = None,
@@ -218,13 +222,13 @@ def _create_message_from_message_type(
218222
"""Create a message from a message type and content string.
219223
220224
Args:
221-
message_type: (str) the type of the message (e.g., "human", "ai", etc.).
222-
content: (str) the content string.
223-
name: (str) the name of the message. Default is None.
224-
tool_call_id: (str) the tool call id. Default is None.
225-
tool_calls: (list[dict[str, Any]]) the tool calls. Default is None.
226-
id: (str) the id of the message. Default is None.
227-
additional_kwargs: (dict[str, Any]) additional keyword arguments.
225+
message_type: the type of the message (e.g., "human", "ai", etc.).
226+
content: the content string.
227+
name: the name of the message. Default is None.
228+
tool_call_id: the tool call id. Default is None.
229+
tool_calls: the tool calls. Default is None.
230+
id: the id of the message. Default is None.
231+
**additional_kwargs: additional keyword arguments.
228232
229233
Returns:
230234
a message of the appropriate type.

libs/core/langchain_core/output_parsers/openai_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ class Dog(BaseModel):
224224

225225
@model_validator(mode="before")
226226
@classmethod
227-
def validate_schema(cls, values: dict) -> Any:
227+
def validate_schema(cls, values: dict[str, Any]) -> Any:
228228
"""Validate the pydantic schema.
229229
230230
Args:

libs/core/langchain_core/prompts/chat.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
from __future__ import annotations
44

55
from abc import ABC, abstractmethod
6+
from collections.abc import Sequence
67
from pathlib import Path
78
from typing import (
8-
TYPE_CHECKING,
99
Annotated,
1010
Any,
1111
Optional,
@@ -51,9 +51,6 @@
5151
from langchain_core.utils import get_colored_text
5252
from langchain_core.utils.interactive_env import is_interactive_env
5353

54-
if TYPE_CHECKING:
55-
from collections.abc import Sequence
56-
5754

5855
class MessagesPlaceholder(BaseMessagePromptTemplate):
5956
"""Prompt template that assumes variable is already list of messages.
@@ -771,7 +768,7 @@ def pretty_print(self) -> None:
771768
MessageLike,
772769
tuple[
773770
Union[str, type],
774-
Union[str, list[dict], list[object]],
771+
Union[str, Sequence[dict], Sequence[object]],
775772
],
776773
str,
777774
dict[str, Any],

libs/core/langchain_core/tracers/evaluation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def __init__(
100100
)
101101
else:
102102
self.executor = None
103-
self.futures: weakref.WeakSet[Future] = weakref.WeakSet()
103+
self.futures: weakref.WeakSet[Future[None]] = weakref.WeakSet()
104104
self.skip_unfinished = skip_unfinished
105105
self.project_name = project_name
106106
self.logged_eval_results: dict[tuple[str, str], list[EvaluationResult]] = {}

libs/core/tests/unit_tests/example_selectors/test_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def __init__(self) -> None:
1010
def add_example(self, example: dict[str, str]) -> None:
1111
self.example = example
1212

13-
def select_examples(self, input_variables: dict[str, str]) -> list[dict]:
13+
def select_examples(self, input_variables: dict[str, str]) -> list[dict[str, str]]:
1414
return [input_variables]
1515

1616

libs/core/tests/unit_tests/fake/callbacks.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,9 @@ def on_retriever_error(
276276
self.on_retriever_error_common()
277277

278278
# Overriding since BaseModel has __deepcopy__ method as well
279-
def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler": # type: ignore[override]
279+
def __deepcopy__(
280+
self, memo: Union[dict[int, Any], None] = None
281+
) -> "FakeCallbackHandler":
280282
return self
281283

282284

@@ -426,5 +428,7 @@ async def on_text(
426428
self.on_text_common()
427429

428430
# Overriding since BaseModel has __deepcopy__ method as well
429-
def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler": # type: ignore[override]
431+
def __deepcopy__(
432+
self, memo: Union[dict[int, Any], None] = None
433+
) -> "FakeAsyncCallbackHandler":
430434
return self

libs/core/tests/unit_tests/language_models/chat_models/test_base.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,22 +41,22 @@
4141

4242

4343
@pytest.fixture
44-
def messages() -> list:
44+
def messages() -> list[BaseMessage]:
4545
return [
4646
SystemMessage(content="You are a test user."),
4747
HumanMessage(content="Hello, I am a test user."),
4848
]
4949

5050

5151
@pytest.fixture
52-
def messages_2() -> list:
52+
def messages_2() -> list[BaseMessage]:
5353
return [
5454
SystemMessage(content="You are a test user."),
5555
HumanMessage(content="Hello, I not a test user."),
5656
]
5757

5858

59-
def test_batch_size(messages: list, messages_2: list) -> None:
59+
def test_batch_size(messages: list[BaseMessage], messages_2: list[BaseMessage]) -> None:
6060
# The base endpoint doesn't support native batching,
6161
# so we expect batch_size to always be 1
6262
llm = FakeListChatModel(responses=[str(i) for i in range(100)])
@@ -80,7 +80,9 @@ def test_batch_size(messages: list, messages_2: list) -> None:
8080
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
8181

8282

83-
async def test_async_batch_size(messages: list, messages_2: list) -> None:
83+
async def test_async_batch_size(
84+
messages: list[BaseMessage], messages_2: list[BaseMessage]
85+
) -> None:
8486
llm = FakeListChatModel(responses=[str(i) for i in range(100)])
8587
# The base endpoint doesn't support native batching,
8688
# so we expect batch_size to always be 1
@@ -262,7 +264,7 @@ def _llm_type(self) -> str:
262264
class FakeTracer(BaseTracer):
263265
def __init__(self) -> None:
264266
super().__init__()
265-
self.traced_run_ids: list = []
267+
self.traced_run_ids: list[uuid.UUID] = []
266268

267269
def _persist_run(self, run: Run) -> None:
268270
"""Persist a run."""
@@ -411,7 +413,7 @@ async def test_disable_streaming_no_streaming_model_async(
411413
class FakeChatModelStartTracer(FakeTracer):
412414
def __init__(self) -> None:
413415
super().__init__()
414-
self.messages: list = []
416+
self.messages: list[list[list[BaseMessage]]] = []
415417

416418
def on_chat_model_start(self, *args: Any, **kwargs: Any) -> Run:
417419
_, messages = args

libs/core/tests/unit_tests/messages/test_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
ToolMessage,
1818
)
1919
from langchain_core.messages.utils import (
20+
MessageLikeRepresentation,
2021
convert_to_messages,
2122
convert_to_openai_messages,
2223
count_tokens_approximately,
@@ -152,7 +153,7 @@ def test_merge_messages_tool_messages() -> None:
152153
{"include_names": ["blah", "blur"], "exclude_types": [SystemMessage]},
153154
],
154155
)
155-
def test_filter_message(filters: dict) -> None:
156+
def test_filter_message(filters: dict[str, Any]) -> None:
156157
messages = [
157158
SystemMessage("foo", name="blah", id="1"),
158159
HumanMessage("bar", name="blur", id="2"),
@@ -672,7 +673,7 @@ def get_num_tokens_from_messages(
672673

673674

674675
def test_convert_to_messages() -> None:
675-
message_like: list = [
676+
message_like: list[MessageLikeRepresentation] = [
676677
# BaseMessage
677678
SystemMessage("1"),
678679
SystemMessage("1.1", additional_kwargs={"__openai_role__": "developer"}),
@@ -1178,7 +1179,7 @@ def test_convert_to_openai_messages_mixed_content_types() -> None:
11781179

11791180

11801181
def test_convert_to_openai_messages_developer() -> None:
1181-
messages: list = [
1182+
messages: list[MessageLikeRepresentation] = [
11821183
SystemMessage("a", additional_kwargs={"__openai_role__": "developer"}),
11831184
{"role": "developer", "content": "a"},
11841185
]

libs/core/tests/unit_tests/output_parsers/test_openai_tools.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
IS_PYDANTIC_V2,
2222
)
2323

24-
STREAMED_MESSAGES: list = [
24+
STREAMED_MESSAGES: list[AIMessageChunk] = [
2525
AIMessageChunk(content=""),
2626
AIMessageChunk(
2727
content="",
@@ -335,7 +335,7 @@
335335
STREAMED_MESSAGES_WITH_TOOL_CALLS.append(message)
336336

337337

338-
EXPECTED_STREAMED_JSON = [
338+
EXPECTED_STREAMED_JSON: list[dict[str, Any]] = [
339339
{},
340340
{"names": ["suz"]},
341341
{"names": ["suzy"]},
@@ -396,7 +396,7 @@ def test_partial_json_output_parser() -> None:
396396
chain = input_iter | JsonOutputToolsParser()
397397

398398
actual = list(chain.stream(None))
399-
expected: list = [[]] + [
399+
expected: list[list[dict[str, Any]]] = [[]] + [
400400
[{"type": "NameCollector", "args": chunk}]
401401
for chunk in EXPECTED_STREAMED_JSON
402402
]
@@ -409,7 +409,7 @@ async def test_partial_json_output_parser_async() -> None:
409409
chain = input_iter | JsonOutputToolsParser()
410410

411411
actual = [p async for p in chain.astream(None)]
412-
expected: list = [[]] + [
412+
expected: list[list[dict[str, Any]]] = [[]] + [
413413
[{"type": "NameCollector", "args": chunk}]
414414
for chunk in EXPECTED_STREAMED_JSON
415415
]
@@ -422,7 +422,7 @@ def test_partial_json_output_parser_return_id() -> None:
422422
chain = input_iter | JsonOutputToolsParser(return_id=True)
423423

424424
actual = list(chain.stream(None))
425-
expected: list = [[]] + [
425+
expected: list[list[dict[str, Any]]] = [[]] + [
426426
[
427427
{
428428
"type": "NameCollector",
@@ -441,7 +441,9 @@ def test_partial_json_output_key_parser() -> None:
441441
chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector")
442442

443443
actual = list(chain.stream(None))
444-
expected: list = [[]] + [[chunk] for chunk in EXPECTED_STREAMED_JSON]
444+
expected: list[list[dict[str, Any]]] = [[]] + [
445+
[chunk] for chunk in EXPECTED_STREAMED_JSON
446+
]
445447
assert actual == expected
446448

447449

@@ -452,7 +454,9 @@ async def test_partial_json_output_parser_key_async() -> None:
452454
chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector")
453455

454456
actual = [p async for p in chain.astream(None)]
455-
expected: list = [[]] + [[chunk] for chunk in EXPECTED_STREAMED_JSON]
457+
expected: list[list[dict[str, Any]]] = [
458+
[chunk] for chunk in EXPECTED_STREAMED_JSON
459+
]
456460
assert actual == expected
457461

458462

libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,7 @@ class TestModel(BaseModel):
141141

142142
def test_pydantic_output_parser() -> None:
143143
"""Test PydanticOutputParser."""
144-
pydantic_parser: PydanticOutputParser = PydanticOutputParser(
145-
pydantic_object=TestModel
146-
)
144+
pydantic_parser = PydanticOutputParser(pydantic_object=TestModel)
147145

148146
result = pydantic_parser.parse(DEF_RESULT)
149147
assert result == DEF_EXPECTED_RESULT
@@ -152,9 +150,7 @@ def test_pydantic_output_parser() -> None:
152150

153151
def test_pydantic_output_parser_fail() -> None:
154152
"""Test PydanticOutputParser where completion result fails schema validation."""
155-
pydantic_parser: PydanticOutputParser = PydanticOutputParser(
156-
pydantic_object=TestModel
157-
)
153+
pydantic_parser = PydanticOutputParser(pydantic_object=TestModel)
158154

159155
with pytest.raises(
160156
OutputParserException, match="Failed to parse TestModel from completion"

libs/core/tests/unit_tests/outputs/test_chat_generation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Union
1+
from typing import Any, Union
22

33
import pytest
44

@@ -19,14 +19,14 @@
1919
],
2020
],
2121
)
22-
def test_msg_with_text(content: Union[str, list]) -> None:
22+
def test_msg_with_text(content: Union[str, list[Union[str, dict[str, Any]]]]) -> None:
2323
expected = "foo"
2424
actual = ChatGeneration(message=AIMessage(content=content)).text
2525
assert actual == expected
2626

2727

2828
@pytest.mark.parametrize("content", [[], [{"tool_use": {}, "type": "tool_use"}]])
29-
def test_msg_no_text(content: Union[str, list]) -> None:
29+
def test_msg_no_text(content: Union[str, list[Union[str, dict[str, Any]]]]) -> None:
3030
expected = ""
3131
actual = ChatGeneration(message=AIMessage(content=content)).text
3232
assert actual == expected

0 commit comments

Comments
 (0)