Skip to content

Commit ec854c7

Browse files
authored
[ty] Fix subtyping with type[T] and unions (#21740)
## Summary Resolves #21685 (comment).
1 parent edc6ed5 commit ec854c7

File tree

2 files changed

+81
-54
lines changed

2 files changed

+81
-54
lines changed

crates/ty_python_semantic/resources/mdtest/type_of/generics.md

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -123,11 +123,11 @@ class A:
123123
A class `A` is a subtype of `type[T]` if any instance of `A` is a subtype of `T`.
124124

125125
```py
126-
from typing import Callable, Protocol
126+
from typing import Any, Callable, Protocol
127127
from ty_extensions import is_assignable_to, is_subtype_of, is_disjoint_from, static_assert
128128

129-
class IntCallback(Protocol):
130-
def __call__(self, *args, **kwargs) -> int: ...
129+
class Callback[T](Protocol):
130+
def __call__(self, *args, **kwargs) -> T: ...
131131

132132
def _[T](_: T):
133133
static_assert(not is_subtype_of(type[T], T))
@@ -141,8 +141,11 @@ def _[T](_: T):
141141
static_assert(is_assignable_to(type[T], Callable[..., T]))
142142
static_assert(not is_disjoint_from(type[T], Callable[..., T]))
143143

144-
static_assert(not is_assignable_to(type[T], IntCallback))
145-
static_assert(not is_disjoint_from(type[T], IntCallback))
144+
static_assert(is_assignable_to(type[T], Callable[..., T] | Callable[..., Any]))
145+
static_assert(not is_disjoint_from(type[T], Callable[..., T] | Callable[..., Any]))
146+
147+
static_assert(not is_assignable_to(type[T], Callback[int]))
148+
static_assert(not is_disjoint_from(type[T], Callback[int]))
146149

147150
def _[T: int](_: T):
148151
static_assert(not is_subtype_of(type[T], T))
@@ -157,14 +160,23 @@ def _[T: int](_: T):
157160
static_assert(is_subtype_of(type[T], type[int]))
158161
static_assert(not is_disjoint_from(type[T], type[int]))
159162

163+
static_assert(is_subtype_of(type[T], type[int] | None))
164+
static_assert(not is_disjoint_from(type[T], type[int] | None))
165+
160166
static_assert(is_subtype_of(type[T], type[T]))
161167
static_assert(not is_disjoint_from(type[T], type[T]))
162168

163169
static_assert(is_assignable_to(type[T], Callable[..., T]))
164170
static_assert(not is_disjoint_from(type[T], Callable[..., T]))
165171

166-
static_assert(is_assignable_to(type[T], IntCallback))
167-
static_assert(not is_disjoint_from(type[T], IntCallback))
172+
static_assert(is_assignable_to(type[T], Callable[..., T] | Callable[..., Any]))
173+
static_assert(not is_disjoint_from(type[T], Callable[..., T] | Callable[..., Any]))
174+
175+
static_assert(is_assignable_to(type[T], Callback[int]))
176+
static_assert(not is_disjoint_from(type[T], Callback[int]))
177+
178+
static_assert(is_assignable_to(type[T], Callback[int] | Callback[Any]))
179+
static_assert(not is_disjoint_from(type[T], Callback[int] | Callback[Any]))
168180

169181
static_assert(is_subtype_of(type[T], type[T] | None))
170182
static_assert(not is_disjoint_from(type[T], type[T] | None))
@@ -183,8 +195,14 @@ def _[T: (int, str)](_: T):
183195
static_assert(is_assignable_to(type[T], Callable[..., T]))
184196
static_assert(not is_disjoint_from(type[T], Callable[..., T]))
185197

186-
static_assert(not is_assignable_to(type[T], IntCallback))
187-
static_assert(not is_disjoint_from(type[T], IntCallback))
198+
static_assert(is_assignable_to(type[T], Callable[..., T] | Callable[..., Any]))
199+
static_assert(not is_disjoint_from(type[T], Callable[..., T] | Callable[..., Any]))
200+
201+
static_assert(not is_assignable_to(type[T], Callback[int]))
202+
static_assert(not is_disjoint_from(type[T], Callback[int]))
203+
204+
static_assert(is_assignable_to(type[T], Callback[int | str]))
205+
static_assert(not is_disjoint_from(type[T], Callback[int] | Callback[str]))
188206

189207
static_assert(is_subtype_of(type[T], type[T] | None))
190208
static_assert(not is_disjoint_from(type[T], type[T] | None))

crates/ty_python_semantic/src/types.rs

Lines changed: 54 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2089,18 +2089,25 @@ impl<'db> Type<'db> {
20892089
// `type[T]` is a subtype of the class object `A` if every instance of `T` is a subtype of an instance
20902090
// of `A`, and vice versa.
20912091
(Type::SubclassOf(subclass_of), _)
2092-
if subclass_of.is_type_var()
2093-
&& !matches!(target, Type::Callable(_) | Type::ProtocolInstance(_)) =>
2092+
if !subclass_of
2093+
.into_type_var()
2094+
.zip(target.to_instance(db))
2095+
.when_some_and(|(this_instance, other_instance)| {
2096+
Type::TypeVar(this_instance).has_relation_to_impl(
2097+
db,
2098+
other_instance,
2099+
inferable,
2100+
relation,
2101+
relation_visitor,
2102+
disjointness_visitor,
2103+
)
2104+
})
2105+
.is_never_satisfied(db) =>
20942106
{
2107+
// TODO: The repetition here isn't great, but we really need the fallthrough logic,
2108+
// where this arm only engages if it returns true.
20952109
let this_instance = Type::TypeVar(subclass_of.into_type_var().unwrap());
2096-
let other_instance = match target {
2097-
Type::Union(union) => Some(
2098-
union.map(db, |element| element.to_instance(db).unwrap_or(Type::Never)),
2099-
),
2100-
_ => target.to_instance(db),
2101-
};
2102-
2103-
other_instance.when_some_and(|other_instance| {
2110+
target.to_instance(db).when_some_and(|other_instance| {
21042111
this_instance.has_relation_to_impl(
21052112
db,
21062113
other_instance,
@@ -2111,6 +2118,7 @@ impl<'db> Type<'db> {
21112118
)
21122119
})
21132120
}
2121+
21142122
(_, Type::SubclassOf(subclass_of)) if subclass_of.is_type_var() => {
21152123
let other_instance = Type::TypeVar(subclass_of.into_type_var().unwrap());
21162124
self.to_instance(db).when_some_and(|this_instance| {
@@ -2647,6 +2655,10 @@ impl<'db> Type<'db> {
26472655
disjointness_visitor,
26482656
),
26492657

2658+
(Type::SubclassOf(subclass_of), _) if subclass_of.is_type_var() => {
2659+
ConstraintSet::from(false)
2660+
}
2661+
26502662
// `Literal[<class 'C'>]` is a subtype of `type[B]` if `C` is a subclass of `B`,
26512663
// since `type[B]` describes all possible runtime subclasses of the class object `B`.
26522664
(Type::ClassLiteral(class), Type::SubclassOf(target_subclass_ty)) => target_subclass_ty
@@ -3081,8 +3093,7 @@ impl<'db> Type<'db> {
30813093
ConstraintSet::from(false)
30823094
}
30833095

3084-
// `type[T]` is disjoint from a callable or protocol instance if its upper bound or
3085-
// constraints are.
3096+
// `type[T]` is disjoint from a callable or protocol instance if its upper bound or constraints are.
30863097
(Type::SubclassOf(subclass_of), Type::Callable(_) | Type::ProtocolInstance(_))
30873098
| (Type::Callable(_) | Type::ProtocolInstance(_), Type::SubclassOf(subclass_of))
30883099
if subclass_of.is_type_var() =>
@@ -3104,13 +3115,14 @@ impl<'db> Type<'db> {
31043115

31053116
// `type[T]` is disjoint from a class object `A` if every instance of `T` is disjoint from an instance of `A`.
31063117
(Type::SubclassOf(subclass_of), other) | (other, Type::SubclassOf(subclass_of))
3107-
if subclass_of.is_type_var() =>
3118+
if subclass_of.is_type_var()
3119+
&& (other.to_instance(db).is_some()
3120+
|| other.as_typevar().is_some_and(|type_var| {
3121+
type_var.typevar(db).bound_or_constraints(db).is_none()
3122+
})) =>
31083123
{
31093124
let this_instance = Type::TypeVar(subclass_of.into_type_var().unwrap());
31103125
let other_instance = match other {
3111-
Type::Union(union) => Some(
3112-
union.map(db, |element| element.to_instance(db).unwrap_or(Type::Never)),
3113-
),
31143126
// An unbounded typevar `U` may have instances of type `object` if specialized to
31153127
// an instance of `type`.
31163128
Type::TypeVar(typevar)
@@ -3464,6 +3476,12 @@ impl<'db> Type<'db> {
34643476
})
34653477
}
34663478

3479+
(Type::SubclassOf(subclass_of_ty), _) | (_, Type::SubclassOf(subclass_of_ty))
3480+
if subclass_of_ty.is_type_var() =>
3481+
{
3482+
ConstraintSet::from(true)
3483+
}
3484+
34673485
(Type::SubclassOf(subclass_of_ty), Type::ClassLiteral(class_b))
34683486
| (Type::ClassLiteral(class_b), Type::SubclassOf(subclass_of_ty)) => {
34693487
match subclass_of_ty.subclass_of() {
@@ -3493,31 +3511,27 @@ impl<'db> Type<'db> {
34933511
// for `type[Any]`/`type[Unknown]`/`type[Todo]`, we know the type cannot be any larger than `type`,
34943512
// so although the type is dynamic we can still determine disjointedness in some situations
34953513
(Type::SubclassOf(subclass_of_ty), other)
3496-
| (other, Type::SubclassOf(subclass_of_ty))
3497-
if !subclass_of_ty.is_type_var() =>
3498-
{
3499-
match subclass_of_ty.subclass_of() {
3500-
SubclassOfInner::Dynamic(_) => {
3501-
KnownClass::Type.to_instance(db).is_disjoint_from_impl(
3502-
db,
3503-
other,
3504-
inferable,
3505-
disjointness_visitor,
3506-
relation_visitor,
3507-
)
3508-
}
3509-
SubclassOfInner::Class(class) => {
3510-
class.metaclass_instance_type(db).is_disjoint_from_impl(
3511-
db,
3512-
other,
3513-
inferable,
3514-
disjointness_visitor,
3515-
relation_visitor,
3516-
)
3517-
}
3518-
SubclassOfInner::TypeVar(_) => unreachable!(),
3514+
| (other, Type::SubclassOf(subclass_of_ty)) => match subclass_of_ty.subclass_of() {
3515+
SubclassOfInner::Dynamic(_) => {
3516+
KnownClass::Type.to_instance(db).is_disjoint_from_impl(
3517+
db,
3518+
other,
3519+
inferable,
3520+
disjointness_visitor,
3521+
relation_visitor,
3522+
)
35193523
}
3520-
}
3524+
SubclassOfInner::Class(class) => {
3525+
class.metaclass_instance_type(db).is_disjoint_from_impl(
3526+
db,
3527+
other,
3528+
inferable,
3529+
disjointness_visitor,
3530+
relation_visitor,
3531+
)
3532+
}
3533+
SubclassOfInner::TypeVar(_) => unreachable!(),
3534+
},
35213535

35223536
(Type::SpecialForm(special_form), Type::NominalInstance(instance))
35233537
| (Type::NominalInstance(instance), Type::SpecialForm(special_form)) => {
@@ -3779,11 +3793,6 @@ impl<'db> Type<'db> {
37793793
relation_visitor,
37803794
)
37813795
}
3782-
3783-
(Type::SubclassOf(_), _) | (_, Type::SubclassOf(_)) => {
3784-
// All cases should have been handled above.
3785-
unreachable!()
3786-
}
37873796
}
37883797
}
37893798

0 commit comments

Comments
 (0)