Skip to content

[match-case] Fix narrowing of class pattern with union-argument. #19473

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 40 additions & 13 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7996,20 +7996,47 @@ def conditional_types(
) -> tuple[Type | None, Type | None]:
"""Takes in the current type and a proposed type of an expression.

Returns a 2-tuple: The first element is the proposed type, if the expression
can be the proposed type. The second element is the type it would hold
if it was not the proposed type, if any. UninhabitedType means unreachable.
None means no new information can be inferred. If default is set it is returned
instead."""
Returns a 2-tuple:
The first element is the proposed type, if the expression can be the proposed type.
The second element is the type it would hold if it was not the proposed type, if any.
UninhabitedType means unreachable.
None means no new information can be inferred.
If default is set it is returned instead.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question as in #19471, but this will be conflicted anyway:)

"""
if proposed_type_ranges and len(proposed_type_ranges) == 1:
# expand e.g. bool -> Literal[True] | Literal[False]
target = proposed_type_ranges[0].item
target = get_proper_type(target)
if isinstance(target, LiteralType) and (
target.is_enum_literal() or isinstance(target.value, bool)
):
enum_name = target.fallback.type.fullname
current_type = try_expanding_sum_type_to_union(current_type, enum_name)

current_type = get_proper_type(current_type)
if isinstance(current_type, UnionType) and (default == current_type):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default == current_type was necessary, otherwise quite a few tests break.
Adding or (default is None) and passing default=None in the recursive call also seems to break stuff. Not entirely sure what's going on there.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm not mistaken, this only works by coincidence and is incorrect in general case.

The problem is that you try to union and restrict multiple results here. This is only reasonable when none of the ranges are upper bounds, otherwise this is unsound:

class Base: pass

class A1(Base): ...
class A2(Base): ...
class A3(Base): ...

ref: type[Base] = A1
x: Base | int

if isinstance(x, ref):
    reveal_type(x)  # Should be Base
else:
    reveal_type(x)  # Should be Base | int (e.g. `x = A2()`)

Applying the logic from this PR, you get Never | Base == Base as yes_type (correct), but then you get (Base | int) \ Base == int as no_type, which is wrong. Since this has nothing to do with default, likely this piece just happens to work on our testcases?

I can suggest only performing this expansion when not any(tr.is_upper_bound for tr in proposed_type_ranges), but maybe there's a better general solution similar to the logic below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran your test example and it gives the expected results.

$ mypy tmp.py 
tmp.py:11: note: Revealed type is "tmp.Base"
tmp.py:13: note: Revealed type is "tmp.Base | builtins.int"

I believe it depends on from where this is called. conditional_types is called in a few places:

  • refine_identity_comparison_expression without a default
  • conditional_types_with_intersection with default, which itself is called in 7 places:
    1. TypeChecker.find_type_equals_check with current_type=self.lookup_type(expr) and default=None
    2. TypeChecker.find_isinstance_check_helper with current_type=self.lookup_type(expr) and default=None
    3. TypeChecker.infer_issubclass_maps with current_type=vartype and default=None
    4. PatternChecker.visit_as_pattern with current_type=current_typeanddefault=current_type`
    5. PatternChecker.visit_value_pattern with current_type=current_typeanddefault=get_proper_type(typ)`
    6. PatternChecker.visit_singleton_pattern with current_type=current_typeanddefault=current_type`
    7. PatternChecker.visit_sequence_pattern with current_type=inner_typeanddefault=inner_type`
    8. PatternChecker.visit_sequence_pattern with current_type=new_tuple_typeanddefault=new_tuple_type`
    9. PatternChecker.narrow_sequence_child with current_type=outer_typeanddefault=outer_type`
    10. PatternChecker.visist_class_pattern with current_type=current_typeanddefault=current_type`

So, we can see for the isinstance calls, default is None, so this branch is actually never taken currently, but may be taken for all the call-sites within PatternChecker except possibly PatternChecker.visit_value_pattern

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I try to replicate your example with match-case we get these results:

class Base: pass

class A1(Base): ...
class A2(Base): ...
class A3(Base): ...

ref: type[Base] = A1
x: Base | int

match x:
    case ref() as y: # E: expected type in class pattern; found "type[tmp.Base]"
        reveal_type(y)  # E: Statement is unreachable
    case other:
        reveal_type(other)  # Base | int

match x:
    case Base() as y:
        reveal_type(y)  # Base
    case other:
        reveal_type(other)  # int

match x:
    case A1() as y:
        reveal_type(y)  # A1
    case other:
        reveal_type(other)  # Base | int

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sterliakov I integrated your suggestion into a separate branch #19517, and was able to relax the condition to

    if (
        isinstance(current_type, UnionType)
        and not any(tr.is_upper_bound for tr in proposed_type_ranges)
        and (default in (current_type, None))
    ):

so that it also is applied with regular isinstance branches when default=None

# factorize over union types
# if we try to narrow A|B to C, we instead narrow A to C and B to C, and
# return the union of the results
result: list[tuple[Type | None, Type | None]] = [
conditional_types(
union_item,
proposed_type_ranges,
default=union_item,
consider_runtime_isinstance=consider_runtime_isinstance,
)
for union_item in get_proper_types(current_type.items)
]
# separate list of tuples into two lists
yes_types, no_types = zip(*result)
yes_type = make_simplified_union([t for t in yes_types if t is not None])
no_type = restrict_subtype_away(
current_type, yes_type, consider_runtime_isinstance=consider_runtime_isinstance
)

return yes_type, no_type

if proposed_type_ranges:
if len(proposed_type_ranges) == 1:
target = proposed_type_ranges[0].item
target = get_proper_type(target)
if isinstance(target, LiteralType) and (
target.is_enum_literal() or isinstance(target.value, bool)
):
enum_name = target.fallback.type.fullname
current_type = try_expanding_sum_type_to_union(current_type, enum_name)
proposed_items = [type_range.item for type_range in proposed_type_ranges]
proposed_type = make_simplified_union(proposed_items)
if isinstance(proposed_type, AnyType):
Expand Down
16 changes: 16 additions & 0 deletions test-data/unit/check-python310.test
Original file line number Diff line number Diff line change
Expand Up @@ -1760,6 +1760,22 @@ def union(x: str | bool) -> None:
reveal_type(x) # N: Revealed type is "Union[builtins.str, Literal[False]]"
[builtins fixtures/tuple.pyi]

[case testMatchNarrowDownUnionUsingClassPattern]

class Foo: ...
class Bar(Foo): ...

def test_1(bar: Bar) -> None:
match bar:
case Foo() as foo:
reveal_type(foo) # N: Revealed type is "__main__.Bar"

def test_2(bar: Bar | str) -> None:
match bar:
case Foo() as foo:
reveal_type(foo) # N: Revealed type is "__main__.Bar"


[case testMatchAssertFalseToSilenceFalsePositives]
class C:
a: int | str
Expand Down