Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions conformance/third_party/conformance.exp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
{
"code": -2,
"column": 9,
"concise_description": "`TypeAlias[GoodTypeAlias3, type[list[int | None]]]` is not subscriptable",
"description": "`TypeAlias[GoodTypeAlias3, type[list[int | None]]]` is not subscriptable",
"concise_description": "`TypeAlias[GoodTypeAlias3, type[list[GoodTypeAlias2]]]` is not subscriptable",
"description": "`TypeAlias[GoodTypeAlias3, type[list[GoodTypeAlias2]]]` is not subscriptable",
"line": 68,
"name": "unsupported-operation",
"severity": "error",
Expand Down Expand Up @@ -270,8 +270,8 @@
{
"code": -2,
"column": 9,
"concise_description": "`type[list[int | None]]` is not subscriptable",
"description": "`type[list[int | None]]` is not subscriptable",
"concise_description": "`type[list[GoodTypeAlias2]]` is not subscriptable",
"description": "`type[list[GoodTypeAlias2]]` is not subscriptable",
"line": 77,
"name": "unsupported-operation",
"severity": "error",
Expand Down
33 changes: 23 additions & 10 deletions crates/pyrefly_types/src/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ use crate::types::SuperObj;
use crate::types::TArgs;
use crate::types::TParam;
use crate::types::Type;
use crate::types::Union;

/// Information about the qnames we have seen.
/// Set to None to indicate we have seen different values, or Some if they are all the same.
Expand Down Expand Up @@ -442,17 +443,21 @@ impl<'a> TypeDisplayContext<'a> {
self.maybe_fmt_with_module("typing", "NoReturn", output)
}
Type::Never(NeverStyle::Never) => self.maybe_fmt_with_module("typing", "Never", output),
Type::Union(types) if types.is_empty() => {
Type::Union(box Union { members: types, .. }) if types.is_empty() => {
self.maybe_fmt_with_module("typing", "Never", output)
}
Type::Union(types) => {
Type::Union(box Union {
display_name: Some(name),
..
}) if !is_toplevel => output.write_str(name),
Type::Union(box Union { members, .. }) => {
let mut literal_idx = None;
let mut literals = Vec::new();
let mut union_members: Vec<&Type> = Vec::new();
// Track seen types to deduplicate (mainly to prettify types for functions with different names but the same signature)
let mut seen_types = SmallSet::new();

for t in types.iter() {
for t in members.iter() {
match t {
Type::Literal(lit) => {
if literal_idx.is_none() {
Expand Down Expand Up @@ -1099,7 +1104,7 @@ pub mod tests {
assert_eq!(ctx.display(&int_type).to_string(), "int");
}

let union_foo_int = Type::Union(vec![foo_type, int_type]);
let union_foo_int = Type::union(vec![foo_type, int_type]);

{
let mut ctx = TypeDisplayContext::new(&[&union_foo_int]);
Expand All @@ -1115,11 +1120,11 @@ pub mod tests {
let t3 = fake_tyvar("qux", "bar", 2);

assert_eq!(
Type::Union(vec![t1.to_type(), t2.to_type()]).to_string(),
Type::union(vec![t1.to_type(), t2.to_type()]).to_string(),
"TypeVar[bar.foo@1:2] | TypeVar[bar.foo@1:3]"
);
assert_eq!(
Type::Union(vec![t1.to_type(), t3.to_type()]).to_string(),
Type::union(vec![t1.to_type(), t3.to_type()]).to_string(),
"TypeVar[foo] | TypeVar[qux]"
);
}
Expand Down Expand Up @@ -1159,13 +1164,21 @@ pub mod tests {
let nonlit2 = Type::LiteralString;

assert_eq!(
Type::Union(vec![nonlit1.clone(), nonlit2.clone()]).to_string(),
Type::union(vec![nonlit1.clone(), nonlit2.clone()]).to_string(),
"None | LiteralString"
);
assert_eq!(
Type::Union(vec![nonlit1, lit1, nonlit2, lit2]).to_string(),
Type::union(vec![nonlit1.clone(), lit1, nonlit2.clone(), lit2]).to_string(),
"None | Literal[True, 'test'] | LiteralString"
);
assert_eq!(
Type::type_form(Type::Union(Box::new(Union {
members: vec![nonlit1, nonlit2],
display_name: Some("MyUnion".to_owned())
})))
.to_string(),
"type[MyUnion]"
);
}

#[test]
Expand Down Expand Up @@ -1575,7 +1588,7 @@ def overloaded_func[T](

#[test]
fn test_union_of_intersection() {
let x = Type::Union(vec![
let x = Type::union(vec![
Type::Intersect(Box::new((
vec![Type::any_explicit(), Type::LiteralString],
Type::any_implicit(),
Expand Down Expand Up @@ -1679,7 +1692,7 @@ def overloaded_func[T](
let foo2 = fake_class("Foo", "mod.ule", 8);
let t1 = Type::ClassType(ClassType::new(foo1, TArgs::default()));
let t2 = Type::ClassType(ClassType::new(foo2, TArgs::default()));
let union = Type::Union(vec![t1.clone(), t2.clone()]);
let union = Type::union(vec![t1.clone(), t2.clone()]);
let ctx = TypeDisplayContext::new(&[&union]);

let parts1 = ctx.get_types_with_location(&t1, false).parts().to_vec();
Expand Down
10 changes: 8 additions & 2 deletions crates/pyrefly_types/src/simplify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@ use crate::literal::Lit;
use crate::stdlib::Stdlib;
use crate::tuple::Tuple;
use crate::types::Type;
use crate::types::Union;

/// Turn unions of unions into a flattened list for one union, and return the deduped list.
fn flatten_and_dedup(xs: Vec<Type>) -> Vec<Type> {
fn flatten(xs: Vec<Type>, res: &mut Vec<Type>) {
for x in xs {
match x {
Type::Union(xs) => flatten(xs, res),
Type::Union(box Union { members, .. }) => flatten(members, res),
Type::Never(_) => {}
_ => res.push(x),
}
Expand Down Expand Up @@ -76,7 +77,12 @@ fn unions_internal(
}
collapse_tuple_unions_with_empty(&mut res);
// `res` is collapsible again if `flatten_and_dedup` drops `xs` to 0 or 1 elements
try_collapse(res).unwrap_or_else(Type::Union)
try_collapse(res).unwrap_or_else(|members| {
Type::Union(Box::new(Union {
members,
display_name: None,
}))
})
})
}

Expand Down
2 changes: 1 addition & 1 deletion crates/pyrefly_types/src/type_output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ mod tests {

let int_type = Type::ClassType(ClassType::new(int_class, TArgs::default()));
let str_type = Type::ClassType(ClassType::new(str_class, TArgs::default()));
let union_type = Type::Union(vec![int_type, str_type, Type::None]);
let union_type = Type::union(vec![int_type, str_type, Type::None]);

let ctx = TypeDisplayContext::new(&[&union_type]);
let mut output = OutputWithLocations::new(&ctx);
Expand Down
81 changes: 67 additions & 14 deletions crates/pyrefly_types/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use std::borrow::Cow;
use std::cmp::Ordering;
use std::fmt;
use std::fmt::Display;
use std::hash::Hash;
use std::hash::Hasher;
use std::sync::Arc;

use dupe::Dupe;
Expand Down Expand Up @@ -611,6 +613,40 @@ pub enum SuperObj {
Class(ClassType),
}

#[derive(Debug, Clone, Eq, TypeEq, PartialOrd, Ord)]
pub struct Union {
pub members: Vec<Type>,
pub display_name: Option<String>,
}

impl PartialEq for Union {
fn eq(&self, other: &Self) -> bool {
self.members == other.members
}
}

impl Hash for Union {
fn hash<H: Hasher>(&self, state: &mut H) {
self.members.hash(state)
}
}

impl Visit<Type> for Union {
fn recurse<'a>(&'a self, f: &mut dyn FnMut(&'a Type)) {
for member in &self.members {
member.visit(f);
}
}
}

impl VisitMut<Type> for Union {
fn recurse_mut(&mut self, f: &mut dyn FnMut(&mut Type)) {
for member in &mut self.members {
member.visit_mut(f);
}
}
}

// Note: The fact that Literal and LiteralString are at the front is important for
// optimisations in `unions_with_literals`.
#[derive(Debug, Clone, PartialEq, Eq, TypeEq, PartialOrd, Ord, Hash)]
Expand All @@ -626,7 +662,8 @@ pub enum Type {
BoundMethod(Box<BoundMethod>),
/// An overloaded function.
Overload(Overload),
Union(Vec<Type>),
/// Unions will hold an optional name to use when displaying the type
Union(Box<Union>),
/// Our intersection support is partial, so we store a fallback type that we use for operations
/// that are not yet supported on intersections.
Intersect(Box<(Vec<Type>, Type)>),
Expand Down Expand Up @@ -1392,7 +1429,7 @@ impl Type {
/// type[a | b] -> type[a] | type[b]
pub fn distribute_type_over_union(self) -> Self {
self.transform(&mut |ty| {
if let Type::Type(box Type::Union(members)) = ty {
if let Type::Type(box Type::Union(box Union { members, .. })) = ty {
*ty = unions(members.drain(..).map(Type::type_form).collect());
}
})
Expand Down Expand Up @@ -1435,10 +1472,15 @@ impl Type {
})
}

pub fn sort_unions(self) -> Self {
pub fn sort_unions_and_drop_names(self) -> Self {
self.transform(&mut |ty| {
if let Type::Union(ts) = ty {
if let Type::Union(box Union {
members: ts,
display_name,
}) = ty
{
ts.sort();
*display_name = None;
}
})
}
Expand Down Expand Up @@ -1488,27 +1530,30 @@ impl Type {

pub fn into_unions(self) -> Vec<Type> {
match self {
Type::Union(types) => types,
Type::Union(box Union { members: types, .. }) => types,
_ => vec![self],
}
}

/// Create an optional type (union with None).
pub fn optional(x: Self) -> Self {
// We would like the resulting type not nested, and well sorted.
if let Type::Union(mut xs) = x {
if let Type::Union(box Union {
members: mut xs, ..
}) = x
{
match xs.binary_search(&Type::None) {
Ok(_) => Type::Union(xs),
Ok(_) => Type::union(xs),
Err(i) => {
xs.insert(i, Type::None);
Type::Union(xs)
Type::union(xs)
}
}
} else {
match x.cmp(&Type::None) {
Ordering::Equal => Type::None,
Ordering::Less => Type::Union(vec![x, Type::None]),
Ordering::Greater => Type::Union(vec![Type::None, x]),
Ordering::Less => Type::union(vec![x, Type::None]),
Ordering::Greater => Type::union(vec![Type::None, x]),
}
}
}
Expand Down Expand Up @@ -1538,9 +1583,9 @@ impl Type {
Type::Literal(Lit::Str(x)) => Some(!x.is_empty()),
Type::None => Some(false),
Type::Tuple(Tuple::Concrete(elements)) => Some(!elements.is_empty()),
Type::Union(options) => {
Type::Union(box Union { members, .. }) => {
let mut answer = None;
for option in options {
for option in members {
let option_bool = option.as_bool();
option_bool?;
if answer.is_none() {
Expand Down Expand Up @@ -1590,6 +1635,14 @@ impl Type {
})
})
}

/// Creates a union from the provided types without simplifying
pub fn union(members: Vec<Type>) -> Self {
Type::Union(Box::new(Union {
members,
display_name: None,
}))
}
}

#[cfg(test)]
Expand All @@ -1616,8 +1669,8 @@ mod tests {
let false_lit = Type::Literal(Lit::Bool(false));
let none = Type::None;

let str_opt = Type::Union(vec![s, none.clone()]);
let false_opt = Type::Union(vec![false_lit, none]);
let str_opt = Type::union(vec![s, none.clone()]);
let false_opt = Type::union(vec![false_lit, none]);

assert_eq!(str_opt.as_bool(), None);
assert_eq!(false_opt.as_bool(), Some(false));
Expand Down
9 changes: 5 additions & 4 deletions pyrefly/lib/alt/answers_solver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use dupe::IterDupedExt;
use itertools::Either;
use pyrefly_python::module_name::ModuleName;
use pyrefly_python::module_path::ModulePath;
use pyrefly_types::types::Union;
use pyrefly_util::display::DisplayWithCtx;
use pyrefly_util::display::commas_iter;
use pyrefly_util::recurser::Guard;
Expand Down Expand Up @@ -797,12 +798,12 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
fn go(&mut self, ty: &Type, in_type: bool) {
match ty {
Type::Never(_) if !in_type => (),
Type::Union(tys) => {
Type::Union(box Union { members, .. }) => {
self.seen_union = true;
tys.iter().for_each(|ty| self.go(ty, in_type))
members.iter().for_each(|ty| self.go(ty, in_type))
}
Type::Type(box Type::Union(tys)) if !in_type => {
tys.iter().for_each(|ty| self.go(ty, true))
Type::Type(box Type::Union(box Union { members, .. })) if !in_type => {
members.iter().for_each(|ty| self.go(ty, true))
}
Type::Var(v) if let Some(_guard) = self.me.recurse(*v) => {
self.go(&self.me.solver().force_var(*v), in_type)
Expand Down
5 changes: 3 additions & 2 deletions pyrefly/lib/alt/attr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use pyrefly_types::special_form::SpecialForm;
use pyrefly_types::types::Forall;
use pyrefly_types::types::Forallable;
use pyrefly_types::types::TArgs;
use pyrefly_types::types::Union;
use pyrefly_types::types::Var;
use ruff_python_ast::helpers::is_dunder;
use ruff_python_ast::name::Name;
Expand Down Expand Up @@ -1795,12 +1796,12 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
Type::SuperInstance(box (cls, obj)) => {
acc.push(AttributeBase1::SuperInstance(cls, obj))
}
Type::Union(members) => {
Type::Union(box Union { members, .. }) => {
for ty in members {
self.as_attribute_base1(ty, acc)
}
}
Type::Type(box Type::Union(members)) => {
Type::Type(box Type::Union(box Union { members, .. })) => {
for ty in members {
self.as_attribute_base1(Type::type_form(ty), acc)
}
Expand Down
5 changes: 3 additions & 2 deletions pyrefly/lib/alt/call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use pyrefly_types::quantified::Quantified;
use pyrefly_types::types::CalleeKind;
use pyrefly_types::types::TArgs;
use pyrefly_types::types::TParams;
use pyrefly_types::types::Union;
use pyrefly_util::prelude::SliceExt;
use pyrefly_util::prelude::VecExt;
use ruff_python_ast::Arguments;
Expand Down Expand Up @@ -260,7 +261,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
Type::Var(v) if let Some(_guard) = self.recurse(v) => {
self.as_call_target_impl(self.solver().force_var(v), quantified, dunder_call)
}
Type::Union(xs) => {
Type::Union(box Union { members: xs, .. }) => {
let xs_length = xs.len();
let targets = xs
.into_iter()
Expand Down Expand Up @@ -315,7 +316,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
Type::Quantified(q) if q.is_type_var() => match q.restriction() {
Restriction::Unrestricted => CallTargetLookup::Error(vec![]),
Restriction::Bound(bound) => match bound {
Type::Union(members) => {
Type::Union(box Union { members, .. }) => {
let mut targets = Vec::new();
for member in members {
if let CallTargetLookup::Ok(target) = self.as_call_target_impl(
Expand Down
Loading