Skip to content

Commit 9e0da9e

Browse files
committed
Generalize with variance
1 parent f95e2c2 commit 9e0da9e

File tree

3 files changed

+191
-56
lines changed

3 files changed

+191
-56
lines changed

chalk-solve/src/infer.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,22 @@ impl<I: Interner> InferenceTable<I> {
146146
.map(|p| p.assert_const_ref(interner).clone())
147147
}
148148

149+
pub fn ty_root(&mut self, interner: &I, leaf: &Ty<I>) -> Option<Ty<I>> {
150+
Some(
151+
self.unify
152+
.find(leaf.inference_var(interner)?)
153+
.to_ty(interner),
154+
)
155+
}
156+
157+
pub fn lifetime_root(&mut self, interner: &I, leaf: &Lifetime<I>) -> Option<Lifetime<I>> {
158+
Some(
159+
self.unify
160+
.find(leaf.inference_var(interner)?)
161+
.to_lifetime(interner),
162+
)
163+
}
164+
149165
/// Finds the root inference var for the given variable.
150166
///
151167
/// The returned variable will be exactly equivalent to the given

chalk-solve/src/infer/unify.rs

Lines changed: 151 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,24 @@ impl<'t, I: Interner> Unifier<'t, I> {
7474
T: ?Sized + Zip<I>,
7575
{
7676
Zip::zip_with(&mut self, variance, a, b)?;
77-
Ok(RelationResult { goals: self.goals })
77+
let interner = self.interner();
78+
let mut goals = self.goals;
79+
let table = self.table;
80+
// Sometimes we'll produce a lifetime outlives goal which we later solve by unification
81+
// Technically, these *will* get canonicalized to the same bound var and so that will end up
82+
// as a goal like `^0.0 <: ^0.0`, which is trivially true. But, we remove those *here*, which
83+
// might help caching.
84+
goals.retain(|g| match g.goal.data(interner) {
85+
GoalData::SubtypeGoal(SubtypeGoal { a, b }) => {
86+
let n_a = table.ty_root(interner, a);
87+
let n_b = table.ty_root(interner, b);
88+
let a = n_a.as_ref().unwrap_or(a);
89+
let b = n_b.as_ref().unwrap_or(b);
90+
a != b
91+
}
92+
_ => true,
93+
});
94+
Ok(RelationResult { goals })
7895
}
7996

8097
/// Relate `a`, `b` with the variance such that if `variance = Covariant`, `a` is
@@ -473,85 +490,128 @@ impl<'t, I: Interner> Unifier<'t, I> {
473490
}
474491

475492
#[instrument(level = "debug", skip(self))]
476-
fn generalize_ty(&mut self, ty: &Ty<I>, universe_index: UniverseIndex) -> Ty<I> {
493+
fn generalize_ty(
494+
&mut self,
495+
ty: &Ty<I>,
496+
universe_index: UniverseIndex,
497+
variance: Variance,
498+
) -> Ty<I> {
477499
let interner = self.interner;
478500
match ty.kind(interner) {
479-
TyKind::Adt(id, substitution) => TyKind::Adt(
480-
*id,
481-
self.generalize_substitution(substitution, universe_index),
482-
)
483-
.intern(interner),
501+
TyKind::Adt(id, substitution) => {
502+
let variances = if matches!(variance, Variance::Invariant) {
503+
None
504+
} else {
505+
Some(self.unification_database().adt_variance(*id))
506+
};
507+
let get_variance = |i| {
508+
variances
509+
.as_ref()
510+
.map(|v| v.as_slice(interner)[i])
511+
.unwrap_or(Variance::Invariant)
512+
};
513+
TyKind::Adt(
514+
*id,
515+
self.generalize_substitution(substitution, universe_index, get_variance),
516+
)
517+
.intern(interner)
518+
}
484519
TyKind::AssociatedType(id, substitution) => TyKind::AssociatedType(
485520
*id,
486-
self.generalize_substitution(substitution, universe_index),
521+
self.generalize_substitution(substitution, universe_index, |_| variance),
487522
)
488523
.intern(interner),
489524
TyKind::Scalar(scalar) => TyKind::Scalar(*scalar).intern(interner),
490525
TyKind::Str => TyKind::Str.intern(interner),
491526
TyKind::Tuple(arity, substitution) => TyKind::Tuple(
492527
*arity,
493-
self.generalize_substitution(substitution, universe_index),
528+
self.generalize_substitution(substitution, universe_index, |_| variance),
494529
)
495530
.intern(interner),
496531
TyKind::OpaqueType(id, substitution) => TyKind::OpaqueType(
497532
*id,
498-
self.generalize_substitution(substitution, universe_index),
533+
self.generalize_substitution(substitution, universe_index, |_| variance),
499534
)
500535
.intern(interner),
501536
TyKind::Slice(ty) => {
502-
TyKind::Slice(self.generalize_ty(ty, universe_index)).intern(interner)
537+
TyKind::Slice(self.generalize_ty(ty, universe_index, variance)).intern(interner)
538+
}
539+
TyKind::FnDef(id, substitution) => {
540+
let variances = if matches!(variance, Variance::Invariant) {
541+
None
542+
} else {
543+
Some(self.unification_database().fn_def_variance(*id))
544+
};
545+
let get_variance = |i| {
546+
variances
547+
.as_ref()
548+
.map(|v| v.as_slice(interner)[i])
549+
.unwrap_or(Variance::Invariant)
550+
};
551+
TyKind::FnDef(
552+
*id,
553+
self.generalize_substitution(substitution, universe_index, get_variance),
554+
)
555+
.intern(interner)
556+
}
557+
TyKind::Ref(mutability, lifetime, ty) => {
558+
let lifetime_variance = variance.xform(Variance::Contravariant);
559+
let ty_variance = match mutability {
560+
Mutability::Not => Variance::Covariant,
561+
Mutability::Mut => Variance::Invariant,
562+
};
563+
TyKind::Ref(
564+
*mutability,
565+
self.generalize_lifetime(lifetime, universe_index, lifetime_variance),
566+
self.generalize_ty(ty, universe_index, ty_variance),
567+
)
568+
.intern(interner)
503569
}
504-
TyKind::FnDef(id, substitution) => TyKind::FnDef(
505-
*id,
506-
self.generalize_substitution(substitution, universe_index),
507-
)
508-
.intern(interner),
509-
TyKind::Ref(mutability, lifetime, ty) => TyKind::Ref(
510-
*mutability,
511-
self.generalize_lifetime(lifetime, universe_index),
512-
self.generalize_ty(ty, universe_index),
513-
)
514-
.intern(interner),
515570
TyKind::Raw(mutability, ty) => {
516-
TyKind::Raw(*mutability, self.generalize_ty(ty, universe_index)).intern(interner)
571+
let ty_variance = match mutability {
572+
Mutability::Not => Variance::Covariant,
573+
Mutability::Mut => Variance::Invariant,
574+
};
575+
TyKind::Raw(
576+
*mutability,
577+
self.generalize_ty(ty, universe_index, ty_variance),
578+
)
579+
.intern(interner)
517580
}
518581
TyKind::Never => TyKind::Never.intern(interner),
519582
TyKind::Array(ty, const_) => TyKind::Array(
520-
self.generalize_ty(ty, universe_index),
583+
self.generalize_ty(ty, universe_index, variance),
521584
self.generalize_const(const_, universe_index),
522585
)
523586
.intern(interner),
524587
TyKind::Closure(id, substitution) => TyKind::Closure(
525588
*id,
526-
self.generalize_substitution(substitution, universe_index),
589+
self.generalize_substitution(substitution, universe_index, |_| variance),
527590
)
528591
.intern(interner),
529592
TyKind::Generator(id, substitution) => TyKind::Generator(
530593
*id,
531-
self.generalize_substitution(substitution, universe_index),
594+
self.generalize_substitution(substitution, universe_index, |_| variance),
532595
)
533596
.intern(interner),
534597
TyKind::GeneratorWitness(id, substitution) => TyKind::GeneratorWitness(
535598
*id,
536-
self.generalize_substitution(substitution, universe_index),
599+
self.generalize_substitution(substitution, universe_index, |_| variance),
537600
)
538601
.intern(interner),
539602
TyKind::Foreign(id) => TyKind::Foreign(*id).intern(interner),
540603
TyKind::Error => TyKind::Error.intern(interner),
541604
TyKind::Dyn(dyn_ty) => {
542-
let DynTy {
543-
bounds,
544-
lifetime: _,
545-
} = dyn_ty;
546-
let lifetime_var = self.table.new_variable(universe_index);
547-
let lifetime = lifetime_var.to_lifetime(interner);
605+
let DynTy { bounds, lifetime } = dyn_ty;
606+
let lifetime = self.generalize_lifetime(
607+
lifetime,
608+
universe_index,
609+
variance.xform(Variance::Contravariant),
610+
);
548611

549612
let bounds = bounds.map_ref(|value| {
550-
//let universe_index = universe_index.next();
551613
let iter = value.iter(interner).map(|sub_var| {
552614
sub_var.map_ref(|clause| {
553-
//let universe_index = universe_index.next();
554-
// let universe_index = self.table.new_universe();
555615
match clause {
556616
WhereClause::Implemented(trait_ref) => {
557617
let TraitRef {
@@ -561,6 +621,7 @@ impl<'t, I: Interner> Unifier<'t, I> {
561621
let substitution = self.generalize_substitution_skip_self(
562622
substitution,
563623
universe_index,
624+
|_| Some(variance),
564625
);
565626
WhereClause::Implemented(TraitRef {
566627
substitution,
@@ -578,6 +639,7 @@ impl<'t, I: Interner> Unifier<'t, I> {
578639
let substitution = self.generalize_substitution(
579640
substitution,
580641
universe_index,
642+
|_| variance,
581643
);
582644
AliasTy::Opaque(OpaqueTy {
583645
substitution,
@@ -598,6 +660,7 @@ impl<'t, I: Interner> Unifier<'t, I> {
598660
let substitution = self.generalize_substitution(
599661
substitution,
600662
universe_index,
663+
|_| variance,
601664
);
602665
AliasTy::Projection(ProjectionTy {
603666
substitution,
@@ -637,8 +700,25 @@ impl<'t, I: Interner> Unifier<'t, I> {
637700
ref substitution,
638701
} = *fn_ptr;
639702

640-
let substitution =
641-
FnSubst(self.generalize_substitution(&substitution.0, universe_index));
703+
let len = substitution.0.len(interner);
704+
let vars = substitution.0.iter(interner).enumerate().map(|(i, var)| {
705+
if i < len - 1 {
706+
self.generalize_generic_var(
707+
var,
708+
universe_index,
709+
variance.xform(Variance::Contravariant),
710+
)
711+
} else {
712+
self.generalize_generic_var(
713+
substitution.0.as_slice(interner).last().unwrap(),
714+
universe_index,
715+
variance,
716+
)
717+
}
718+
});
719+
720+
let substitution = FnSubst(Substitution::from_iter(interner, vars));
721+
642722
TyKind::Function(FnPointer {
643723
num_binders,
644724
sig,
@@ -660,7 +740,9 @@ impl<'t, I: Interner> Unifier<'t, I> {
660740
if matches!(kind, TyVariableKind::Integer | TyVariableKind::Float) {
661741
ty.clone()
662742
} else if let Some(ty) = self.table.normalize_ty_shallow(interner, ty) {
663-
self.generalize_ty(&ty, universe_index)
743+
self.generalize_ty(&ty, universe_index, variance)
744+
} else if matches!(variance, Variance::Invariant) {
745+
ty.clone()
664746
} else {
665747
let ena_var = self.table.new_variable(universe_index);
666748
ena_var.to_ty(interner)
@@ -674,6 +756,7 @@ impl<'t, I: Interner> Unifier<'t, I> {
674756
&mut self,
675757
lifetime: &Lifetime<I>,
676758
universe_index: UniverseIndex,
759+
variance: Variance,
677760
) -> Lifetime<I> {
678761
let interner = self.interner;
679762
match lifetime.data(&interner) {
@@ -706,13 +789,16 @@ impl<'t, I: Interner> Unifier<'t, I> {
706789
&mut self,
707790
sub_var: &GenericArg<I>,
708791
universe_index: UniverseIndex,
792+
variance: Variance,
709793
) -> GenericArg<I> {
710794
let interner = self.interner;
711795
(match sub_var.data(interner) {
712-
GenericArgData::Ty(ty) => GenericArgData::Ty(self.generalize_ty(ty, universe_index)),
713-
GenericArgData::Lifetime(lifetime) => {
714-
GenericArgData::Lifetime(self.generalize_lifetime(lifetime, universe_index))
796+
GenericArgData::Ty(ty) => {
797+
GenericArgData::Ty(self.generalize_ty(ty, universe_index, variance))
715798
}
799+
GenericArgData::Lifetime(lifetime) => GenericArgData::Lifetime(
800+
self.generalize_lifetime(lifetime, universe_index, variance),
801+
),
716802
GenericArgData::Const(const_value) => {
717803
GenericArgData::Const(self.generalize_const(const_value, universe_index))
718804
}
@@ -721,32 +807,37 @@ impl<'t, I: Interner> Unifier<'t, I> {
721807
}
722808

723809
/// Generalizes all but the first
724-
#[instrument(level = "debug", skip(self))]
725-
fn generalize_substitution_skip_self(
810+
#[instrument(level = "debug", skip(self, get_variance))]
811+
fn generalize_substitution_skip_self<F: Fn(usize) -> Option<Variance>>(
726812
&mut self,
727813
substitution: &Substitution<I>,
728814
universe_index: UniverseIndex,
815+
get_variance: F,
729816
) -> Substitution<I> {
730817
let interner = self.interner;
731-
let vars = substitution.iter(interner).take(1).cloned().chain(
732-
substitution
733-
.iter(interner)
734-
.skip(1)
735-
.map(|sub_var| self.generalize_generic_var(sub_var, universe_index)),
736-
);
818+
let vars = substitution.iter(interner).enumerate().map(|(i, sub_var)| {
819+
if i == 0 {
820+
sub_var.clone()
821+
} else {
822+
let variance = get_variance(i).unwrap_or(Variance::Invariant);
823+
self.generalize_generic_var(sub_var, universe_index, variance)
824+
}
825+
});
737826
Substitution::from_iter(interner, vars)
738827
}
739828

740-
#[instrument(level = "debug", skip(self))]
741-
fn generalize_substitution(
829+
#[instrument(level = "debug", skip(self, get_variance))]
830+
fn generalize_substitution<F: Fn(usize) -> Variance>(
742831
&mut self,
743832
substitution: &Substitution<I>,
744833
universe_index: UniverseIndex,
834+
get_variance: F,
745835
) -> Substitution<I> {
746836
let interner = self.interner;
747-
let vars = substitution
748-
.iter(interner)
749-
.map(|sub_var| self.generalize_generic_var(sub_var, universe_index));
837+
let vars = substitution.iter(interner).enumerate().map(|(i, sub_var)| {
838+
let variance = get_variance(i);
839+
self.generalize_generic_var(sub_var, universe_index, variance)
840+
});
750841

751842
Substitution::from_iter(interner, vars)
752843
}
@@ -822,7 +913,7 @@ impl<'t, I: Interner> Unifier<'t, I> {
822913
// this, we create two new vars `'0` and `1`. Then we relate `var` with
823914
// `&'0 1` and `&'0 1` with `&'x SomeType`. The second relation will
824915
// recurse, and we'll end up relating `'0` with `'x` and `1` with `SomeType`.
825-
let generalized_val = self.generalize_ty(&ty1, universe_index);
916+
let generalized_val = self.generalize_ty(&ty1, universe_index, variance);
826917

827918
debug!("var {:?} generalized to {:?}", var, generalized_val);
828919

@@ -1259,6 +1350,10 @@ where
12591350
// become the value of).
12601351
InferenceValue::Unbound(ui) => {
12611352
if self.unifier.table.unify.unioned(var, self.var) {
1353+
debug!(
1354+
"OccursCheck aborting because {:?} unioned with {:?}",
1355+
var, self.var,
1356+
);
12621357
return Err(NoSolution);
12631358
}
12641359

0 commit comments

Comments
 (0)