diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 24f0c8c85d61..416e8ada9f61 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -18,6 +18,12 @@ from mypy.checker_shared import ExpressionCheckerSharedApi from mypy.checkmember import analyze_member_access, has_operator from mypy.checkstrformat import StringFormatterChecker +from mypy.constraints import ( + SUBTYPE_OF, + Constraint, + infer_constraints, + infer_constraints_for_callable, +) from mypy.erasetype import erase_type, remove_instance_last_known_values, replace_meta_vars from mypy.errors import ErrorWatcher, report_internal_error from mypy.expandtype import ( @@ -26,7 +32,7 @@ freshen_all_functions_type_vars, freshen_function_type_vars, ) -from mypy.infer import ArgumentInferContext, infer_function_type_arguments, infer_type_arguments +from mypy.infer import ArgumentInferContext, infer_function_type_arguments from mypy.literals import literal from mypy.maptype import map_instance_to_supertype from mypy.meet import is_overlapping_types, narrow_declared_type @@ -110,10 +116,12 @@ Plugin, ) from mypy.semanal_enum import ENUM_BASES +from mypy.solve import solve_constraints from mypy.state import state from mypy.subtypes import ( find_member, is_equivalent, + is_proper_subtype, is_same_type, is_subtype, non_method_protocol_members, @@ -191,12 +199,7 @@ is_named_instance, split_with_prefix_and_suffix, ) -from mypy.types_utils import ( - is_generic_instance, - is_overlapping_none, - is_self_type_like, - remove_optional, -) +from mypy.types_utils import is_generic_instance, is_self_type_like, remove_optional from mypy.typestate import type_state from mypy.typevars import fill_typevars from mypy.util import split_module_names @@ -1778,18 +1781,6 @@ def check_callable_call( isinstance(v, (ParamSpecType, TypeVarTupleType)) for v in callee.variables ) callee = freshen_function_type_vars(callee) - callee = self.infer_function_type_arguments_using_context(callee, context) - if need_refresh: - # Argument kinds etc. may have changed due to - # ParamSpec or TypeVarTuple variables being replaced with an arbitrary - # number of arguments; recalculate actual-to-formal map - formal_to_actual = map_actuals_to_formals( - arg_kinds, - arg_names, - callee.arg_kinds, - callee.arg_names, - lambda i: self.accept(args[i]), - ) callee = self.infer_function_type_arguments( callee, args, arg_kinds, arg_names, formal_to_actual, need_refresh, context ) @@ -2004,9 +1995,9 @@ def infer_arg_types_in_context( assert all(tp is not None for tp in res) return cast(list[Type], res) - def infer_function_type_arguments_using_context( - self, callable: CallableType, error_context: Context - ) -> CallableType: + def infer_constraints_from_context( + self, callee: CallableType, error_context: Context + ) -> list[Constraint]: """Unify callable return type to type context to infer type vars. For example, if the return type is set[t] where 't' is a type variable @@ -2015,23 +2006,23 @@ def infer_function_type_arguments_using_context( """ ctx = self.type_context[-1] if not ctx: - return callable + return [] # The return type may have references to type metavariables that # we are inferring right now. We must consider them as indeterminate # and they are not potential results; thus we replace them with the # special ErasedType type. On the other hand, class type variables are # valid results. - erased_ctx = replace_meta_vars(ctx, ErasedType()) - ret_type = callable.ret_type - if is_overlapping_none(ret_type) and is_overlapping_none(ctx): - # If both the context and the return type are optional, unwrap the optional, - # since in 99% cases this is what a user expects. In other words, we replace - # Optional[T] <: Optional[int] - # with - # T <: int - # while the former would infer T <: Optional[int]. - ret_type = remove_optional(ret_type) - erased_ctx = remove_optional(erased_ctx) + erased_ctx = get_proper_type(replace_meta_vars(ctx, ErasedType())) + proper_ret = get_proper_type(callee.ret_type) + if isinstance(proper_ret, UnionType) and isinstance(erased_ctx, UnionType): + # If both the context and the return type are unions, we simplify shared items + # e.g. T | None <: int | None => T <: int + # since the former would infer T <: int | None. + # whereas the latter would infer the more precise T <: int. + new_ret = [val for val in proper_ret.items if val not in erased_ctx.items] + new_ctx = [val for val in erased_ctx.items if val not in proper_ret.items] + proper_ret = make_simplified_union(new_ret) + erased_ctx = make_simplified_union(new_ctx) # # TODO: Instead of this hack and the one below, we need to use outer and # inner contexts at the same time. This is however not easy because of two @@ -2042,7 +2033,6 @@ def infer_function_type_arguments_using_context( # variables in an expression are inferred at the same time. # (And this is hard, also we need to be careful with lambdas that require # two passes.) - proper_ret = get_proper_type(ret_type) if ( isinstance(proper_ret, TypeVarType) or isinstance(proper_ret, UnionType) @@ -2072,22 +2062,9 @@ def infer_function_type_arguments_using_context( # TODO: we may want to add similar exception if all arguments are lambdas, since # in this case external context is almost everything we have. if not is_generic_instance(ctx) and not is_literal_type_like(ctx): - return callable.copy_modified() - args = infer_type_arguments( - callable.variables, ret_type, erased_ctx, skip_unsatisfied=True - ) - # Only substitute non-Uninhabited and non-erased types. - new_args: list[Type | None] = [] - for arg in args: - if has_uninhabited_component(arg) or has_erased_component(arg): - new_args.append(None) - else: - new_args.append(arg) - # Don't show errors after we have only used the outer context for inference. - # We will use argument context to infer more variables. - return self.apply_generic_arguments( - callable, new_args, error_context, skip_unsatisfied=True - ) + return [] + constraints = infer_constraints(proper_ret, erased_ctx, SUBTYPE_OF) + return constraints def infer_function_type_arguments( self, @@ -2126,15 +2103,131 @@ def infer_function_type_arguments( else: pass1_args.append(arg) - inferred_args, _ = infer_function_type_arguments( - callee_type, - pass1_args, - arg_kinds, - arg_names, - formal_to_actual, - context=self.argument_infer_context(), - strict=self.chk.in_checked_function(), - ) + if True: # NEW CODE + # compute the inner constraints + _inner_constraints = infer_constraints_for_callable( + callee_type, + pass1_args, + arg_kinds, + arg_names, + formal_to_actual, + context=self.argument_infer_context(), + ) + # HACK: convert "Literal?" constraints to their non-literal versions. + inner_constraints: list[Constraint] = [] + for constraint in _inner_constraints: + target = get_proper_type(constraint.target) + inner_constraints.append( + Constraint( + constraint.original_type_var, + constraint.op, + ( + target.copy_modified(last_known_value=None) + if isinstance(target, Instance) + else target + ), + ) + ) + + # compute the outer solution + outer_constraints = self.infer_constraints_from_context(callee_type, context) + outer_solution = solve_constraints( + callee_type.variables, + outer_constraints, + strict=self.chk.in_checked_function(), + allow_polymorphic=False, + ) + outer_args = [ + None if has_uninhabited_component(arg) or has_erased_component(arg) else arg + for arg in outer_solution[0] + ] + outer_solution = (outer_args, outer_solution[1]) + outer_callee = self.apply_generic_arguments( + callee_type, outer_solution[0], context, skip_unsatisfied=True + ) + outer_ret_type = get_proper_type(outer_callee.ret_type) + + # compute the joint solution using both inner and outer constraints. + # NOTE: The order of constraints is important here! + # solve(outer + inner) and solve(inner + outer) may yield different results. + # we need to use outer first. + joint_constraints = outer_constraints + inner_constraints + joint_solution = solve_constraints( + callee_type.variables, + joint_constraints, + strict=self.chk.in_checked_function(), + allow_polymorphic=False, + ) + joint_args = [ + None if has_uninhabited_component(arg) or has_erased_component(arg) else arg + for arg in joint_solution[0] + ] + joint_solution = (joint_args, joint_solution[1]) + joint_callee = self.apply_generic_arguments( + callee_type, joint_solution[0], context, skip_unsatisfied=True + ) + joint_ret_type = get_proper_type(joint_callee.ret_type) + + if ( # determine which solution to take + # no inner constraints + not inner_constraints + # no outer constraints + # or not (outer_upper + outer_lower) + # no outer_constraints + or not joint_solution[0] + # joint constraints failed to produce a complete solution + or None in joint_solution[0] + # If the outer solution is more concrete than the joint solution, prefer the outer solution. + or ( + is_subtype(outer_ret_type, joint_ret_type) + and not is_proper_subtype(joint_ret_type, outer_ret_type) + ) + ): + use_joint = False + else: + use_joint = True + + if use_joint: + inferred_args = joint_solution[0] + else: + # If we cannot use the joint solution, fallback to outer_solution + inferred_args = outer_solution[0] + # Don't show errors after we have only used the outer context for inference. + # We will use argument context to infer more variables. + callee_type = self.apply_generic_arguments( + callee_type, inferred_args, context, skip_unsatisfied=True + ) + if need_refresh: + # Argument kinds etc. may have changed due to + # ParamSpec or TypeVarTuple variables being replaced with an arbitrary + # number of arguments; recalculate actual-to-formal map + formal_to_actual = map_actuals_to_formals( + arg_kinds, + arg_names, + callee_type.arg_kinds, + callee_type.arg_names, + lambda i: self.accept(args[i]), + ) + + # ??? QUESTION: Do we need to recompute arg_types and pass1_args here??? + # recompute and apply inner solution. + inner_constraints = infer_constraints_for_callable( + callee_type, + pass1_args, + arg_kinds, + arg_names, + formal_to_actual, + context=self.argument_infer_context(), + ) + inner_solution = solve_constraints( + callee_type.variables, + inner_constraints, + strict=self.chk.in_checked_function(), + allow_polymorphic=False, + ) + inferred_args = inner_solution[0] + else: # END NEW CODE + pass if 2 in arg_pass_nums: # Second pass of type inference. diff --git a/mypy/constraints.py b/mypy/constraints.py index 6416791fa74a..89b4da91b365 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -77,11 +77,13 @@ class Constraint: """ type_var: TypeVarId + original_type_var: TypeVarLikeType op = 0 # SUBTYPE_OF or SUPERTYPE_OF target: Type def __init__(self, type_var: TypeVarLikeType, op: int, target: Type) -> None: self.type_var = type_var.id + self.original_type_var = type_var self.op = op # TODO: should we add "assert not isinstance(target, UnpackType)"? # UnpackType is a synthetic type, and is never valid as a constraint target. @@ -1356,7 +1358,10 @@ def visit_typeddict_type(self, template: TypedDictType) -> list[Constraint]: # NOTE: Non-matching keys are ignored. Compatibility is checked # elsewhere so this shouldn't be unsafe. for item_name, template_item_type, actual_item_type in template.zip(actual): - res.extend(infer_constraints(template_item_type, actual_item_type, self.direction)) + # Value type is invariant, so irrespective of the direction, we constraint + # both above and below. + res.extend(infer_constraints(template_item_type, actual_item_type, SUBTYPE_OF)) + res.extend(infer_constraints(template_item_type, actual_item_type, SUPERTYPE_OF)) return res elif isinstance(actual, AnyType): return self.infer_against_any(template.items.values(), actual) diff --git a/test-data/unit/check-expressions.test b/test-data/unit/check-expressions.test index 33271a3cc04c..b1e16c17ada9 100644 --- a/test-data/unit/check-expressions.test +++ b/test-data/unit/check-expressions.test @@ -308,6 +308,28 @@ main:5: error: Unsupported operand types for ^ ("A" and "A") main:6: error: Unsupported operand types for << ("A" and "B") main:7: error: Unsupported operand types for >> ("A" and "A") +[case testBinaryOperatorContext] +from typing import TypeVar, Generic, Iterable, Iterator, Union + +T = TypeVar("T") +S = TypeVar("S") +IntOrStr = TypeVar("IntOrStr", bound=Union[int, str]) + +class Vec(Generic[T]): + def __init__(self, iterable: Iterable[T], /) -> None: ... + def __iter__(self) -> Iterator[T]: yield from [] + def __add__(self, value: "Vec[S]", /) -> "Vec[Union[S, T]]": return Vec([]) + +def fmt(arg: Iterable[Union[int, str]]) -> None: ... +def first(arg: Iterable[IntOrStr]) -> IntOrStr: ... + +def test_fmt(l1: Vec[int], l2: Vec[int], /) -> None: + fmt(l1 + l2) + +def test_first(l1: Vec[int], l2: Vec[int], /) -> None: + first(l1 + l2) +[builtins fixtures/list.pyi] + [case testBooleanAndOr] a: A b: bool diff --git a/test-data/unit/check-generics.test b/test-data/unit/check-generics.test index 78680684f69b..cbda383afdab 100644 --- a/test-data/unit/check-generics.test +++ b/test-data/unit/check-generics.test @@ -2998,7 +2998,7 @@ def lift(f: F[T]) -> F[Optional[T]]: ... def g(x: T) -> T: return x -reveal_type(lift(g)) # N: Revealed type is "def [T] (Union[T`1, None]) -> Union[T`1, None]" +reveal_type(lift(g)) # N: Revealed type is "__main__.F[Union[T`-1, None]]" [builtins fixtures/list.pyi] [case testInferenceAgainstGenericSplitOrder] @@ -3198,7 +3198,7 @@ def dec(f: Callable[P, Callable[[T], S]]) -> Callable[Concatenate[T, P], S]: ... def id() -> Callable[[U], U]: ... def either(x: U) -> Callable[[U], U]: ... def pair(x: U) -> Callable[[V], Tuple[V, U]]: ... -reveal_type(dec(id)) # N: Revealed type is "def [T] (T`3) -> T`3" +reveal_type(dec(id)) # N: Revealed type is "def (U`-1) -> U`-1" reveal_type(dec(either)) # N: Revealed type is "def [T] (T`6, x: T`6) -> T`6" reveal_type(dec(pair)) # N: Revealed type is "def [T, U] (T`9, x: U`-1) -> tuple[T`9, U`-1]" # This is counter-intuitive but looks correct, dec matches itself only if P can be empty @@ -3657,3 +3657,43 @@ t2.foo = [B()] t2.foo = [C()] t2.foo = [1] # E: Value of type variable "T" of "foo" of "Test" cannot be "int" [builtins fixtures/property.pyi] + +[case testContextFreeConcatInvariantType] +from typing import Iterable, Iterator, TypeVar, Generic, Union + +T = TypeVar("T") +S = TypeVar("S") + +class Vec(Generic[T]): + def getitem(self, i: int) -> T: ... # ensure invariance of T + def setitem(self, i: int, v: T) -> None: ... # ensure invariance of T + def __iter__(self) -> Iterator[T]: ... + def __add__(self, other: "Vec[S]") -> "Vec[Union[T, S]]": ... + +mix: Vec[Union[int, str]] +strings: Vec[str] +mix = mix + strings +mix = strings + mix +reveal_type(mix + strings) # N: Revealed type is "__main__.Vec[Union[builtins.int, builtins.str]]" +reveal_type(strings + mix) # N: Revealed type is "__main__.Vec[Union[builtins.str, builtins.int]]" +[builtins fixtures/list.pyi] + + +[case testInContextConcatInvariantType] +# https://github.com/python/mypy/issues/3933#issuecomment-2272804302 +from typing import Iterable, Iterator, TypeVar, Generic, Union + +T = TypeVar("T") +S = TypeVar("S") + +class Vec(Generic[T]): + def getitem(self, i: int) -> T: ... # ensure invariance of T + def setitem(self, i: int, v: T) -> None: ... # ensure invariance of T + def __iter__(self) -> Iterator[T]: ... + def __add__(self, other: "Vec[S]") -> "Vec[Union[T, S]]": ... + +def identity_on_iterable(arg: Iterable[T]) -> Iterable[T]: return arg +x: Vec[str] +y: Vec[None] +reveal_type( identity_on_iterable(y + x) ) # N: Revealed type is "typing.Iterable[Union[None, builtins.str]]" +reveal_type( identity_on_iterable(x + y) ) # N: Revealed type is "typing.Iterable[Union[builtins.str, None]]" diff --git a/test-data/unit/check-inference-context.test b/test-data/unit/check-inference-context.test index ff726530cf9f..c44a3c60f0e4 100644 --- a/test-data/unit/check-inference-context.test +++ b/test-data/unit/check-inference-context.test @@ -1372,6 +1372,22 @@ x: Tuple[str, ...] = f(tuple) [builtins fixtures/tuple.pyi] [out] +[case testTypedDictWideContext] +from typing_extensions import TypedDict +from typing import TypeVar, Generic + +T = TypeVar('T') + +class A: ... +class B(A): ... + +class OverridesItem(TypedDict, Generic[T]): + tp: type[T] + +d1: dict[str, dict[str, type[A]]] = {"foo": {"bar": B}} +d2: dict[str, OverridesItem[A]] = {"foo": OverridesItem(tp=B)} +[builtins fixtures/dict.pyi] + [case testUseCovariantGenericOuterContextUserDefined] from typing import TypeVar, Callable, Generic diff --git a/test-data/unit/check-literal.test b/test-data/unit/check-literal.test index 3c9290b8dbbb..b30a56bb5fb6 100644 --- a/test-data/unit/check-literal.test +++ b/test-data/unit/check-literal.test @@ -1206,6 +1206,13 @@ reveal_type(a) # N: Revealed type is "builtins.dict[builtins.str, builtins.int] [builtins fixtures/dict.pyi] [out] +[case testLiteralMappingContext] +from typing import Mapping, Literal + +x: Mapping[str, Literal["sum", "mean", "max", "min"]] = {"x": "sum"} + +[builtins fixtures/dict.pyi] + [case testLiteralInferredInOverloadContextBasic] from typing import Literal, overload diff --git a/test-data/unit/check-recursive-types.test b/test-data/unit/check-recursive-types.test index 86e9f02b5263..b94bd120932e 100644 --- a/test-data/unit/check-recursive-types.test +++ b/test-data/unit/check-recursive-types.test @@ -285,21 +285,33 @@ if isinstance(b[0], Sequence): [case testRecursiveAliasWithRecursiveInstance] from typing import Sequence, Union, TypeVar -class A: ... T = TypeVar("T") Nested = Sequence[Union[T, Nested[T]]] +def join(a: T, b: T) -> T: ... + +class A: ... class B(Sequence[B]): ... a: Nested[A] aa: Nested[A] b: B + a = b # OK +reveal_type(a) # N: Revealed type is "__main__.B" + a = [[b]] # OK +reveal_type(a) # N: Revealed type is "builtins.list[builtins.list[__main__.B]]" + b = aa # E: Incompatible types in assignment (expression has type "Nested[A]", variable has type "B") +reveal_type(b) # N: Revealed type is "__main__.B" + +reveal_type(join(a, b)) # N: Revealed type is "typing.Sequence[typing.Sequence[__main__.B]]" +reveal_type(join(b, a)) # N: Revealed type is "typing.Sequence[typing.Sequence[__main__.B]]" + +def test(a: Nested[A], b: B) -> None: + reveal_type(join(a, b)) # N: Revealed type is "typing.Sequence[Union[__main__.A, typing.Sequence[Union[__main__.A, ...]]]]" + reveal_type(join(b, a)) # N: Revealed type is "typing.Sequence[Union[__main__.A, typing.Sequence[Union[__main__.A, ...]]]]" -def join(a: T, b: T) -> T: ... -reveal_type(join(a, b)) # N: Revealed type is "typing.Sequence[Union[__main__.A, typing.Sequence[Union[__main__.A, ...]]]]" -reveal_type(join(b, a)) # N: Revealed type is "typing.Sequence[Union[__main__.A, typing.Sequence[Union[__main__.A, ...]]]]" [builtins fixtures/isinstancelist.pyi] [case testRecursiveAliasWithRecursiveInstanceInference] diff --git a/test-data/unit/check-varargs.test b/test-data/unit/check-varargs.test index 680021a166f2..def03f5f3ec1 100644 --- a/test-data/unit/check-varargs.test +++ b/test-data/unit/check-varargs.test @@ -629,9 +629,9 @@ from typing import TypeVar T = TypeVar('T') def f(*args: T) -> T: ... -reveal_type(f(*(1, None))) # N: Revealed type is "Union[Literal[1]?, None]" -reveal_type(f(1, *(None, 1))) # N: Revealed type is "Union[Literal[1]?, None]" -reveal_type(f(1, *(1, None))) # N: Revealed type is "Union[Literal[1]?, None]" +reveal_type(f(*(1, None))) # N: Revealed type is "Union[builtins.int, None]" +reveal_type(f(1, *(None, 1))) # N: Revealed type is "Union[builtins.int, None]" +reveal_type(f(1, *(1, None))) # N: Revealed type is "Union[builtins.int, None]" [builtins fixtures/tuple.pyi]