Skip to content

Commit bc2481a

Browse files
refactor visit_conditional_expr to fix ternary behavior (#19563)
- Fixes #18817 - added unit test `testTernaryOperatorWithDefault` - Fixes #19561 - added unit tests `testLambdaTernary{, IndirectAttribute, DoubleIndirectAttribute}` - Fixes #19534 - added unit tests `test{List, Set, Dict}ComprehensionTernary` - Fixes #19998 Option ① from #19561 (comment) which does not allow `MemberExpr` elements to the nested binder, hence ternaries like `f(x.attr) if x.attr else g(x.attr)` will not consider the narrowed type of `x.attr` Option ②: #19562 The current implementation of `visit_conditional_expr` seemed to some rather complicated things, I found that if there is no context, we can simply use the union of the types produced when considering the branches context-free as an artificial context that leads to the desired behavior in the unification test cases.
1 parent 64d953d commit bc2481a

File tree

5 files changed

+176
-55
lines changed

5 files changed

+176
-55
lines changed

mypy/checkexpr.py

Lines changed: 15 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,6 @@
154154
get_type_vars,
155155
is_literal_type_like,
156156
make_simplified_union,
157-
simple_literal_type,
158157
true_only,
159158
try_expanding_sum_type_to_union,
160159
try_getting_str_literals,
@@ -5899,7 +5898,7 @@ def check_for_comp(self, e: GeneratorExpr | DictionaryComprehension) -> None:
58995898

59005899
def visit_conditional_expr(self, e: ConditionalExpr, allow_none_return: bool = False) -> Type:
59015900
self.accept(e.cond)
5902-
ctx = self.type_context[-1]
5901+
ctx: Type | None = self.type_context[-1]
59035902

59045903
# Gain type information from isinstance if it is there
59055904
# but only for the current expression
@@ -5910,63 +5909,26 @@ def visit_conditional_expr(self, e: ConditionalExpr, allow_none_return: bool = F
59105909
elif else_map is None:
59115910
self.msg.redundant_condition_in_if(True, e.cond)
59125911

5912+
if ctx is None:
5913+
# When no context is provided, compute each branch individually, and
5914+
# use the union of the results as artificial context. Important for:
5915+
# - testUnificationDict
5916+
# - testConditionalExpressionWithEmpty
5917+
ctx_if_type = self.analyze_cond_branch(
5918+
if_map, e.if_expr, context=ctx, allow_none_return=allow_none_return
5919+
)
5920+
ctx_else_type = self.analyze_cond_branch(
5921+
else_map, e.else_expr, context=ctx, allow_none_return=allow_none_return
5922+
)
5923+
ctx = make_simplified_union([ctx_if_type, ctx_else_type])
5924+
59135925
if_type = self.analyze_cond_branch(
59145926
if_map, e.if_expr, context=ctx, allow_none_return=allow_none_return
59155927
)
5916-
5917-
# we want to keep the narrowest value of if_type for union'ing the branches
5918-
# however, it would be silly to pass a literal as a type context. Pass the
5919-
# underlying fallback type instead.
5920-
if_type_fallback = simple_literal_type(get_proper_type(if_type)) or if_type
5921-
5922-
# Analyze the right branch using full type context and store the type
5923-
full_context_else_type = self.analyze_cond_branch(
5928+
else_type = self.analyze_cond_branch(
59245929
else_map, e.else_expr, context=ctx, allow_none_return=allow_none_return
59255930
)
59265931

5927-
if not mypy.checker.is_valid_inferred_type(if_type, self.chk.options):
5928-
# Analyze the right branch disregarding the left branch.
5929-
else_type = full_context_else_type
5930-
# we want to keep the narrowest value of else_type for union'ing the branches
5931-
# however, it would be silly to pass a literal as a type context. Pass the
5932-
# underlying fallback type instead.
5933-
else_type_fallback = simple_literal_type(get_proper_type(else_type)) or else_type
5934-
5935-
# If it would make a difference, re-analyze the left
5936-
# branch using the right branch's type as context.
5937-
if ctx is None or not is_equivalent(else_type_fallback, ctx):
5938-
# TODO: If it's possible that the previous analysis of
5939-
# the left branch produced errors that are avoided
5940-
# using this context, suppress those errors.
5941-
if_type = self.analyze_cond_branch(
5942-
if_map,
5943-
e.if_expr,
5944-
context=else_type_fallback,
5945-
allow_none_return=allow_none_return,
5946-
)
5947-
5948-
elif if_type_fallback == ctx:
5949-
# There is no point re-running the analysis if if_type is equal to ctx.
5950-
# That would be an exact duplicate of the work we just did.
5951-
# This optimization is particularly important to avoid exponential blowup with nested
5952-
# if/else expressions: https://github.com/python/mypy/issues/9591
5953-
# TODO: would checking for is_proper_subtype also work and cover more cases?
5954-
else_type = full_context_else_type
5955-
else:
5956-
# Analyze the right branch in the context of the left
5957-
# branch's type.
5958-
else_type = self.analyze_cond_branch(
5959-
else_map,
5960-
e.else_expr,
5961-
context=if_type_fallback,
5962-
allow_none_return=allow_none_return,
5963-
)
5964-
5965-
# In most cases using if_type as a context for right branch gives better inferred types.
5966-
# This is however not the case for literal types, so use the full context instead.
5967-
if is_literal_type_like(full_context_else_type) and not is_literal_type_like(else_type):
5968-
else_type = full_context_else_type
5969-
59705932
res: Type = make_simplified_union([if_type, else_type])
59715933
if has_uninhabited_component(res) and not isinstance(
59725934
get_proper_type(self.type_context[-1]), UnionType

mypyc/irbuild/statement.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,9 @@ def make_entry(type: Expression) -> tuple[ValueGenFunc, int]:
598598
(make_entry(type) if type else None, var, make_handler(body))
599599
for type, var, body in zip(t.types, t.vars, t.handlers)
600600
]
601-
else_body = (lambda: builder.accept(t.else_body)) if t.else_body else None
601+
602+
_else_body = t.else_body
603+
else_body = (lambda: builder.accept(_else_body)) if _else_body else None
602604
transform_try_except(builder, body, handlers, else_body, t.line)
603605

604606

test-data/unit/check-literal.test

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2949,6 +2949,143 @@ reveal_type(C().collection) # N: Revealed type is "builtins.list[Literal['word'
29492949
reveal_type(C().word) # N: Revealed type is "Literal['word']"
29502950
[builtins fixtures/tuple.pyi]
29512951

2952+
[case testStringLiteralTernary]
2953+
# https://github.com/python/mypy/issues/19534
2954+
def test(b: bool) -> None:
2955+
l = "foo" if b else "bar"
2956+
reveal_type(l) # N: Revealed type is "builtins.str"
2957+
[builtins fixtures/tuple.pyi]
2958+
2959+
[case testintLiteralTernary]
2960+
# https://github.com/python/mypy/issues/19534
2961+
def test(b: bool) -> None:
2962+
l = 0 if b else 1
2963+
reveal_type(l) # N: Revealed type is "builtins.int"
2964+
[builtins fixtures/tuple.pyi]
2965+
2966+
[case testStringIntUnionTernary]
2967+
# https://github.com/python/mypy/issues/19534
2968+
def test(b: bool) -> None:
2969+
l = 1 if b else "a"
2970+
reveal_type(l) # N: Revealed type is "Union[builtins.int, builtins.str]"
2971+
[builtins fixtures/tuple.pyi]
2972+
2973+
[case testListComprehensionTernary]
2974+
# https://github.com/python/mypy/issues/19534
2975+
def test(b: bool) -> None:
2976+
l = [1] if b else ["a"]
2977+
reveal_type(l) # N: Revealed type is "Union[builtins.list[builtins.int], builtins.list[builtins.str]]"
2978+
[builtins fixtures/list.pyi]
2979+
2980+
[case testSetComprehensionTernary]
2981+
# https://github.com/python/mypy/issues/19534
2982+
def test(b: bool) -> None:
2983+
s = {1} if b else {"a"}
2984+
reveal_type(s) # N: Revealed type is "Union[builtins.set[builtins.int], builtins.set[builtins.str]]"
2985+
[builtins fixtures/set.pyi]
2986+
2987+
[case testDictComprehensionTernary]
2988+
# https://github.com/python/mypy/issues/19534
2989+
def test(b: bool) -> None:
2990+
d = {1:1} if "" else {"a": "a"}
2991+
reveal_type(d) # N: Revealed type is "Union[builtins.dict[builtins.int, builtins.int], builtins.dict[builtins.str, builtins.str]]"
2992+
[builtins fixtures/dict.pyi]
2993+
2994+
[case testLambdaTernary]
2995+
from typing import TypeVar, Union, Callable, reveal_type
2996+
2997+
NOOP = lambda: None
2998+
class A: pass
2999+
class B:
3000+
attr: Union[A, None]
3001+
3002+
def test_static(x: Union[A, None]) -> None:
3003+
def foo(t: A) -> None: ...
3004+
3005+
l1: Callable[[], object] = (lambda: foo(x)) if x is not None else NOOP
3006+
r1: Callable[[], object] = NOOP if x is None else (lambda: foo(x))
3007+
l2 = (lambda: foo(x)) if x is not None else NOOP
3008+
r2 = NOOP if x is None else (lambda: foo(x))
3009+
reveal_type(l2) # N: Revealed type is "def ()"
3010+
reveal_type(r2) # N: Revealed type is "def ()"
3011+
3012+
def test_generic(x: Union[A, None]) -> None:
3013+
T = TypeVar("T")
3014+
def bar(t: T) -> T: return t
3015+
3016+
l1: Callable[[], None] = (lambda: bar(x)) if x is None else NOOP
3017+
r1: Callable[[], None] = NOOP if x is not None else (lambda: bar(x))
3018+
l2 = (lambda: bar(x)) if x is None else NOOP
3019+
r2 = NOOP if x is not None else (lambda: bar(x))
3020+
reveal_type(l2) # N: Revealed type is "def ()"
3021+
reveal_type(r2) # N: Revealed type is "def ()"
3022+
3023+
3024+
[case testLambdaTernaryIndirectAttribute]
3025+
# fails due to binder issue inside `check_func_def`
3026+
# https://github.com/python/mypy/issues/19561
3027+
from typing import TypeVar, Union, Callable, reveal_type
3028+
3029+
NOOP = lambda: None
3030+
class A: pass
3031+
class B:
3032+
attr: Union[A, None]
3033+
3034+
def test_static_with_attr(x: B) -> None:
3035+
def foo(t: A) -> None: ...
3036+
3037+
l1: Callable[[], None] = (lambda: foo(x.attr)) if x.attr is not None else NOOP # E: Argument 1 to "foo" has incompatible type "Optional[A]"; expected "A"
3038+
r1: Callable[[], None] = NOOP if x.attr is None else (lambda: foo(x.attr)) # E: Argument 1 to "foo" has incompatible type "Optional[A]"; expected "A"
3039+
l2 = (lambda: foo(x.attr)) if x.attr is not None else NOOP # E: Argument 1 to "foo" has incompatible type "Optional[A]"; expected "A"
3040+
r2 = NOOP if x.attr is None else (lambda: foo(x.attr)) # E: Argument 1 to "foo" has incompatible type "Optional[A]"; expected "A"
3041+
reveal_type(l2) # N: Revealed type is "def ()"
3042+
reveal_type(r2) # N: Revealed type is "def ()"
3043+
3044+
def test_generic_with_attr(x: B) -> None:
3045+
T = TypeVar("T")
3046+
def bar(t: T) -> T: return t
3047+
3048+
l1: Callable[[], None] = (lambda: bar(x.attr)) if x.attr is None else NOOP # E: Incompatible types in assignment (expression has type "Callable[[], Optional[A]]", variable has type "Callable[[], None]")
3049+
r1: Callable[[], None] = NOOP if x.attr is not None else (lambda: bar(x.attr)) # E: Incompatible types in assignment (expression has type "Callable[[], Optional[A]]", variable has type "Callable[[], None]")
3050+
l2 = (lambda: bar(x.attr)) if x.attr is None else NOOP
3051+
r2 = NOOP if x.attr is not None else (lambda: bar(x.attr))
3052+
reveal_type(l2) # N: Revealed type is "def () -> Union[__main__.A, None]"
3053+
reveal_type(r2) # N: Revealed type is "def () -> Union[__main__.A, None]"
3054+
3055+
[case testLambdaTernaryDoubleIndirectAttribute]
3056+
# fails due to binder issue inside `check_func_def`
3057+
# https://github.com/python/mypy/issues/19561
3058+
from typing import TypeVar, Union, Callable, reveal_type
3059+
3060+
NOOP = lambda: None
3061+
class A: pass
3062+
class B:
3063+
attr: Union[A, None]
3064+
class C:
3065+
attr: B
3066+
3067+
def test_static_with_attr(x: C) -> None:
3068+
def foo(t: A) -> None: ...
3069+
3070+
l1: Callable[[], None] = (lambda: foo(x.attr.attr)) if x.attr.attr is not None else NOOP # E: Argument 1 to "foo" has incompatible type "Optional[A]"; expected "A"
3071+
r1: Callable[[], None] = NOOP if x.attr.attr is None else (lambda: foo(x.attr.attr)) # E: Argument 1 to "foo" has incompatible type "Optional[A]"; expected "A"
3072+
l2 = (lambda: foo(x.attr.attr)) if x.attr.attr is not None else NOOP # E: Argument 1 to "foo" has incompatible type "Optional[A]"; expected "A"
3073+
r2 = NOOP if x.attr.attr is None else (lambda: foo(x.attr.attr)) # E: Argument 1 to "foo" has incompatible type "Optional[A]"; expected "A"
3074+
reveal_type(l2) # N: Revealed type is "def ()"
3075+
reveal_type(r2) # N: Revealed type is "def ()"
3076+
3077+
def test_generic_with_attr(x: C) -> None:
3078+
T = TypeVar("T")
3079+
def bar(t: T) -> T: return t
3080+
3081+
l1: Callable[[], None] = (lambda: bar(x.attr.attr)) if x.attr.attr is None else NOOP # E: Incompatible types in assignment (expression has type "Callable[[], Optional[A]]", variable has type "Callable[[], None]")
3082+
r1: Callable[[], None] = NOOP if x.attr.attr is not None else (lambda: bar(x.attr.attr)) # E: Incompatible types in assignment (expression has type "Callable[[], Optional[A]]", variable has type "Callable[[], None]")
3083+
l2 = (lambda: bar(x.attr.attr)) if x.attr.attr is None else NOOP
3084+
r2 = NOOP if x.attr.attr is not None else (lambda: bar(x.attr.attr))
3085+
reveal_type(l2) # N: Revealed type is "def () -> Union[__main__.A, None]"
3086+
reveal_type(r2) # N: Revealed type is "def () -> Union[__main__.A, None]"
3087+
3088+
29523089
[case testLiteralTernaryUnionNarrowing]
29533090
from typing import Literal, Optional
29543091

test-data/unit/check-optional.test

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,9 @@ reveal_type(l) # N: Revealed type is "builtins.list[typing.Generator[builtins.s
428428
[builtins fixtures/list.pyi]
429429

430430
[case testNoneListTernary]
431-
x = [None] if "" else [1] # E: List item 0 has incompatible type "int"; expected "None"
431+
# https://github.com/python/mypy/issues/19534
432+
x = [None] if "" else [1]
433+
reveal_type(x) # N: Revealed type is "Union[builtins.list[None], builtins.list[builtins.int]]"
432434
[builtins fixtures/list.pyi]
433435

434436
[case testListIncompatibleErrorMessage]

test-data/unit/check-python313.test

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,3 +319,21 @@ class A[Y = X, X = int]: ... # E: Name "X" is not defined
319319

320320
class B[Y = X]: ... # E: Name "X" is not defined
321321
[builtins fixtures/tuple.pyi]
322+
323+
324+
[case testTernaryOperatorWithTypeVarDefault]
325+
# https://github.com/python/mypy/issues/18817
326+
327+
class Ok[T, E = None]:
328+
def __init__(self, value: T) -> None:
329+
self._value = value
330+
331+
class Err[E, T = None]:
332+
def __init__(self, value: E) -> None:
333+
self._value = value
334+
335+
type Result[T, E] = Ok[T, E] | Err[E, T]
336+
337+
class Bar[U]:
338+
def foo(data: U, cond: bool) -> Result[U, str]:
339+
return Ok(data) if cond else Err("Error")

0 commit comments

Comments
 (0)