diff --git a/plugins/tactics/hls-tactics-plugin.cabal b/plugins/tactics/hls-tactics-plugin.cabal index 7fecc860c4..b95140fe49 100644 --- a/plugins/tactics/hls-tactics-plugin.cabal +++ b/plugins/tactics/hls-tactics-plugin.cabal @@ -84,6 +84,7 @@ test-suite tests main-is: Main.hs other-modules: AutoTupleSpec + UnificationSpec hs-source-dirs: test ghc-options: -Wall -Wredundant-constraints -threaded -rtsopts -with-rtsopts=-N diff --git a/plugins/tactics/src/Ide/Plugin/Tactic.hs b/plugins/tactics/src/Ide/Plugin/Tactic.hs index 5750835ef7..25874bf242 100644 --- a/plugins/tactics/src/Ide/Plugin/Tactic.hs +++ b/plugins/tactics/src/Ide/Plugin/Tactic.hs @@ -265,9 +265,11 @@ judgementForHole state nfp range = do $ getDefiningBindings binds rss) tcg hyps = hypothesisFromBindings rss binds + ambient = M.fromList $ contextMethodHypothesis ctx pure ( resulting_range , mkFirstJudgement hyps + ambient (isRhsHole rss tcs) (maybe mempty diff --git a/plugins/tactics/src/Ide/Plugin/Tactic/CodeGen.hs b/plugins/tactics/src/Ide/Plugin/Tactic/CodeGen.hs index 89947e1443..28a3bf8274 100644 --- a/plugins/tactics/src/Ide/Plugin/Tactic/CodeGen.hs +++ b/plugins/tactics/src/Ide/Plugin/Tactic/CodeGen.hs @@ -29,7 +29,8 @@ import Type hiding (Var) useOccName :: MonadState TacticState m => Judgement -> OccName -> m () useOccName jdg name = - case M.lookup name $ jHypothesis jdg of + -- Only score points if this is in the local hypothesis + case M.lookup name $ jLocalHypothesis jdg of Just{} -> modify $ withUsedVals $ S.insert name Nothing -> pure () diff --git a/plugins/tactics/src/Ide/Plugin/Tactic/Context.hs b/plugins/tactics/src/Ide/Plugin/Tactic/Context.hs index 2c8b48227a..a9a56d290d 100644 --- a/plugins/tactics/src/Ide/Plugin/Tactic/Context.hs +++ b/plugins/tactics/src/Ide/Plugin/Tactic/Context.hs @@ -10,16 +10,50 @@ import Development.IDE.GHC.Compat import Ide.Plugin.Tactic.Types import OccName import TcRnTypes +import Ide.Plugin.Tactic.GHC (tacticsThetaTy) +import Ide.Plugin.Tactic.Machinery (methodHypothesis) +import Data.Maybe (mapMaybe) +import Data.List +import TcType (substTy, tcSplitSigmaTy) +import Unify (tcUnifyTy) mkContext :: [(OccName, CType)] -> TcGblEnv -> Context -mkContext locals - = Context locals - . fmap splitId - . (getFunBindId =<<) - . fmap unLoc - . bagToList - . tcg_binds +mkContext locals tcg = Context + { ctxDefiningFuncs = locals + , ctxModuleFuncs = fmap splitId + . (getFunBindId =<<) + . fmap unLoc + . bagToList + $ tcg_binds tcg + } + + +------------------------------------------------------------------------------ +-- | Find all of the class methods that exist from the givens in the context. +contextMethodHypothesis :: Context -> [(OccName, CType)] +contextMethodHypothesis ctx + = join + . concatMap + ( mapMaybe methodHypothesis + . tacticsThetaTy + . unCType + ) + . mapMaybe (definedThetaType ctx) + . fmap fst + $ ctxDefiningFuncs ctx + + +------------------------------------------------------------------------------ +-- | Given the name of a function that exists in 'ctxDefiningFuncs', get its +-- theta type. +definedThetaType :: Context -> OccName -> Maybe CType +definedThetaType ctx name = do + (_, CType mono) <- find ((== name) . fst) $ ctxDefiningFuncs ctx + (_, CType poly) <- find ((== name) . fst) $ ctxModuleFuncs ctx + let (_, _, poly') = tcSplitSigmaTy poly + subst <- tcUnifyTy poly' mono + pure $ CType $ substTy subst $ snd $ splitForAllTys poly splitId :: Id -> (OccName, CType) diff --git a/plugins/tactics/src/Ide/Plugin/Tactic/GHC.hs b/plugins/tactics/src/Ide/Plugin/Tactic/GHC.hs index 7f89e4c0c9..3b66956257 100644 --- a/plugins/tactics/src/Ide/Plugin/Tactic/GHC.hs +++ b/plugins/tactics/src/Ide/Plugin/Tactic/GHC.hs @@ -1,18 +1,24 @@ -{-# LANGUAGE CPP #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE CPP #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE ViewPatterns #-} module Ide.Plugin.Tactic.GHC where -import Data.Maybe (isJust) -import Development.IDE.GHC.Compat -import OccName -import TcType -import TyCoRep -import Type -import TysWiredIn (intTyCon, floatTyCon, doubleTyCon, charTyCon) -import Unique -import Var +import Control.Monad.State +import qualified Data.Map as M +import Data.Maybe (isJust) +import Data.Traversable +import Development.IDE.GHC.Compat +import Generics.SYB (mkT, everywhere) +import Ide.Plugin.Tactic.Types +import OccName +import TcType +import TyCoRep +import Type +import TysWiredIn (intTyCon, floatTyCon, doubleTyCon, charTyCon) +import Unique +import Var tcTyVar_maybe :: Type -> Maybe Var @@ -43,8 +49,44 @@ cloneTyVar t = ------------------------------------------------------------------------------ -- | Is this a function type? isFunction :: Type -> Bool -isFunction (tcSplitFunTys -> ((_:_), _)) = True -isFunction _ = False +isFunction (tacticsSplitFunTy -> (_, _, [], _)) = False +isFunction _ = True + + +------------------------------------------------------------------------------ +-- | Split a function, also splitting out its quantified variables and theta +-- context. +tacticsSplitFunTy :: Type -> ([TyVar], ThetaType, [Type], Type) +tacticsSplitFunTy t + = let (vars, theta, t') = tcSplitSigmaTy t + (args, res) = tcSplitFunTys t' + in (vars, theta, args, res) + + +------------------------------------------------------------------------------ +-- | Rip the theta context out of a regular type. +tacticsThetaTy :: Type -> ThetaType +tacticsThetaTy (tcSplitSigmaTy -> (_, theta, _)) = theta + + +------------------------------------------------------------------------------ +-- | Instantiate all of the quantified type variables in a type with fresh +-- skolems. +freshTyvars :: MonadState TacticState m => Type -> m Type +freshTyvars t = do + let (tvs, _, _, _) = tacticsSplitFunTy t + reps <- fmap M.fromList + $ for tvs $ \tv -> do + uniq <- freshUnique + pure $ (tv, setTyVarUnique tv uniq) + pure $ + everywhere + (mkT $ \tv -> + case M.lookup tv reps of + Just tv' -> tv' + Nothing -> tv + ) t + ------------------------------------------------------------------------------ -- | Is this an algebraic type? diff --git a/plugins/tactics/src/Ide/Plugin/Tactic/Judgements.hs b/plugins/tactics/src/Ide/Plugin/Tactic/Judgements.hs index 32ad70bc2e..743448dc64 100644 --- a/plugins/tactics/src/Ide/Plugin/Tactic/Judgements.hs +++ b/plugins/tactics/src/Ide/Plugin/Tactic/Judgements.hs @@ -162,8 +162,17 @@ disallowing ns = field @"_jHypothesis" %~ flip M.withoutKeys (S.fromList ns) +------------------------------------------------------------------------------ +-- | The hypothesis, consisting of local terms and the ambient environment +-- (includes and class methods.) jHypothesis :: Judgement' a -> Map OccName a -jHypothesis = _jHypothesis +jHypothesis = _jHypothesis <> _jAmbientHypothesis + + +------------------------------------------------------------------------------ +-- | Just the local hypothesis. +jLocalHypothesis :: Judgement' a -> Map OccName a +jLocalHypothesis = _jHypothesis isPatVal :: Judgement' a -> OccName -> Bool @@ -191,13 +200,15 @@ substJdg :: TCvSubst -> Judgement -> Judgement substJdg subst = fmap $ coerce . substTy subst . coerce mkFirstJudgement - :: M.Map OccName CType + :: M.Map OccName CType -- ^ local hypothesis + -> M.Map OccName CType -- ^ ambient hypothesis -> Bool -- ^ are we in the top level rhs hole? -> M.Map OccName [[OccName]] -- ^ existing pos vals -> Type -> Judgement' CType -mkFirstJudgement hy top posvals goal = Judgement +mkFirstJudgement hy ambient top posvals goal = Judgement { _jHypothesis = hy + , _jAmbientHypothesis = ambient , _jDestructed = mempty , _jPatternVals = mempty , _jBlacklistDestruct = False diff --git a/plugins/tactics/src/Ide/Plugin/Tactic/Machinery.hs b/plugins/tactics/src/Ide/Plugin/Tactic/Machinery.hs index f34aff5abd..94850fa4e0 100644 --- a/plugins/tactics/src/Ide/Plugin/Tactic/Machinery.hs +++ b/plugins/tactics/src/Ide/Plugin/Tactic/Machinery.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DerivingStrategies #-} @@ -17,6 +18,7 @@ module Ide.Plugin.Tactic.Machinery ( module Ide.Plugin.Tactic.Machinery ) where +import Class (Class(classTyVars)) import Control.Arrow import Control.Monad.Error.Class import Control.Monad.Reader @@ -25,12 +27,15 @@ import Control.Monad.State.Class (gets, modify) import Control.Monad.State.Strict (StateT (..)) import Data.Coerce import Data.Either -import Data.List (intercalate, sortBy) +import Data.Functor ((<&>)) +import Data.Generics (mkQ, everything, gcount) +import Data.List (sortBy) import Data.Ord (comparing, Down(..)) import qualified Data.Set as S import Development.IDE.GHC.Compat import Ide.Plugin.Tactic.Judgements import Ide.Plugin.Tactic.Types +import OccName (HasOccName(occName)) import Refinery.ProofState import Refinery.Tactic import Refinery.Tactic.Internal @@ -74,7 +79,8 @@ runTactic ctx jdg t = (errs, []) -> Left $ take 50 $ errs (_, fmap assoc23 -> solns) -> do let sorted = - sortBy (comparing $ Down . uncurry scoreSolution . snd) solns + flip sortBy solns $ comparing $ \((_, ext), (jdg, holes)) -> + Down $ scoreSolution ext jdg holes case sorted of (((tr, ext), _) : _) -> Right @@ -121,21 +127,32 @@ setRecursionFrameData b = do scoreSolution - :: TacticState + :: LHsExpr GhcPs + -> TacticState -> [Judgement] -> ( Penalize Int -- number of holes , Reward Bool -- all bindings used , Penalize Int -- number of introduced bindings , Reward Int -- number used bindings + , Penalize Int -- size of extract ) -scoreSolution TacticState{..} holes +scoreSolution ext TacticState{..} holes = ( Penalize $ length holes - , Reward $ S.null $ ts_intro_vals S.\\ ts_used_vals + , Reward $ S.null $ ts_intro_vals S.\\ ts_used_vals , Penalize $ S.size ts_intro_vals - , Reward $ S.size ts_used_vals + , Reward $ S.size ts_used_vals + , Penalize $ solutionSize ext ) +------------------------------------------------------------------------------ +-- | Compute the number of 'LHsExpr' nodes; used as a rough metric for code +-- size. +solutionSize :: LHsExpr GhcPs -> Int +solutionSize = everything (+) $ gcount $ mkQ False $ \case + (_ :: LHsExpr GhcPs) -> True + + newtype Penalize a = Penalize a deriving (Eq, Ord, Show) via (Down a) @@ -143,23 +160,22 @@ newtype Reward a = Reward a deriving (Eq, Ord, Show) via a +------------------------------------------------------------------------------ +-- | Like 'tcUnifyTy', but takes a list of skolems to prevent unification of. +tryUnifyUnivarsButNotSkolems :: [TyVar] -> CType -> CType -> Maybe TCvSubst +tryUnifyUnivarsButNotSkolems skolems goal inst = + case tcUnifyTysFG (skolemsOf skolems) [unCType inst] [unCType goal] of + Unifiable subst -> pure subst + _ -> Nothing + ------------------------------------------------------------------------------ --- | We need to make sure that we don't try to unify any skolems. --- To see why, consider the case: --- --- uhh :: (Int -> Int) -> a --- uhh f = _ --- --- If we were to apply 'f', then we would try to unify 'Int' and 'a'. --- This is fine from the perspective of 'tcUnifyTy', but will cause obvious --- type errors in our use case. Therefore, we need to ensure that our --- 'TCvSubst' doesn't try to unify skolems. -checkSkolemUnification :: CType -> CType -> TCvSubst -> RuleM () -checkSkolemUnification t1 t2 subst = do - skolems <- gets ts_skolems - unless (all (flip notElemTCvSubst subst) skolems) $ - throwError (UnificationError t1 t2) +-- | Helper method for 'tryUnifyUnivarsButNotSkolems' +skolemsOf :: [TyVar] -> TyVar -> BindFlag +skolemsOf tvs tv = + case elem tv tvs of + True -> Skolem + False -> BindMe ------------------------------------------------------------------------------ @@ -167,10 +183,41 @@ checkSkolemUnification t1 t2 subst = do unify :: CType -- ^ The goal type -> CType -- ^ The type we are trying unify the goal type with -> RuleM () -unify goal inst = - case tcUnifyTy (unCType inst) (unCType goal) of - Just subst -> do - checkSkolemUnification inst goal subst - modify (\s -> s { ts_unifier = unionTCvSubst subst (ts_unifier s) }) - Nothing -> throwError (UnificationError inst goal) +unify goal inst = do + skolems <- gets ts_skolems + case tryUnifyUnivarsButNotSkolems skolems goal inst of + Just subst -> + modify (\s -> s { ts_unifier = unionTCvSubst subst (ts_unifier s) }) + Nothing -> throwError (UnificationError inst goal) + + +------------------------------------------------------------------------------ +-- | Get the class methods of a 'PredType', correctly dealing with +-- instantiation of quantified class types. +methodHypothesis :: PredType -> Maybe [(OccName, CType)] +methodHypothesis ty = do + (tc, apps) <- splitTyConApp_maybe ty + cls <- tyConClass_maybe tc + let methods = classMethods cls + tvs = classTyVars cls + subst = zipTvSubst tvs apps + sc_methods <- fmap join + $ traverse (methodHypothesis . substTy subst) + $ classSCTheta cls + pure $ mappend sc_methods $ methods <&> \method -> + let (_, _, ty) = tcSplitSigmaTy $ idType method + in (occName method, CType $ substTy subst ty) + + +------------------------------------------------------------------------------ +-- | Run the given tactic iff the current hole contains no univars. Skolems and +-- already decided univars are OK though. +requireConcreteHole :: TacticsM a -> TacticsM a +requireConcreteHole m = do + jdg <- goal + skolems <- gets $ S.fromList . ts_skolems + let vars = S.fromList $ tyCoVarsOfTypeWellScoped $ unCType $ jGoal jdg + case S.size $ vars S.\\ skolems of + 0 -> m + _ -> throwError TooPolymorphic diff --git a/plugins/tactics/src/Ide/Plugin/Tactic/Tactics.hs b/plugins/tactics/src/Ide/Plugin/Tactic/Tactics.hs index f00a1087cb..4a6389ec9f 100644 --- a/plugins/tactics/src/Ide/Plugin/Tactic/Tactics.hs +++ b/plugins/tactics/src/Ide/Plugin/Tactic/Tactics.hs @@ -14,6 +14,7 @@ module Ide.Plugin.Tactic.Tactics , runTactic ) where +import Control.Monad (when) import Control.Monad.Except (throwError) import Control.Monad.Reader.Class (MonadReader(ask)) import Control.Monad.State.Class @@ -35,7 +36,7 @@ import Ide.Plugin.Tactic.Judgements import Ide.Plugin.Tactic.Machinery import Ide.Plugin.Tactic.Naming import Ide.Plugin.Tactic.Types -import Name (nameOccName, occNameString) +import Name (occNameString) import Refinery.Tactic import Refinery.Tactic.Internal import TcType @@ -54,26 +55,22 @@ assume :: OccName -> TacticsM () assume name = rule $ \jdg -> do let g = jGoal jdg case M.lookup name $ jHypothesis jdg of - Just ty -> - case ty == jGoal jdg of - True -> do - case M.member name (jPatHypothesis jdg) of - True -> setRecursionFrameData True - False -> pure () - useOccName jdg name - pure $ (tracePrim $ "assume " <> occNameString name, ) $ noLoc $ var' name - False -> throwError $ GoalMismatch "assume" g + Just ty -> do + unify ty $ jGoal jdg + when (M.member name $ jPatHypothesis jdg) $ + setRecursionFrameData True + useOccName jdg name + pure $ (tracePrim $ "assume " <> occNameString name, ) $ noLoc $ var' name Nothing -> throwError $ UndefinedHypothesis name - recursion :: TacticsM () -recursion = tracing "recursion" $ do +recursion = requireConcreteHole $ tracing "recursion" $ do defs <- getCurrentDefinitions attemptOn (const $ fmap fst defs) $ \name -> do modify $ withRecursionStack (False :) ensure recursiveCleanup (withRecursionStack tail) $ do - (localTactic (apply' (const id) name) $ introducing defs) + (localTactic (apply name) $ introducing defs) <@> fmap (localTactic assumption . filterPosition name) [0..] @@ -109,7 +106,7 @@ intros = rule $ \jdg -> do ------------------------------------------------------------------------------ -- | Case split, and leave holes in the matches. destructAuto :: OccName -> TacticsM () -destructAuto name = tracing "destruct(auto)" $ do +destructAuto name = requireConcreteHole $ tracing "destruct(auto)" $ do jdg <- goal case hasDestructed jdg name of True -> throwError $ AlreadyDestructed name @@ -129,7 +126,7 @@ destructAuto name = tracing "destruct(auto)" $ do ------------------------------------------------------------------------------ -- | Case split, and leave holes in the matches. destruct :: OccName -> TacticsM () -destruct name = tracing "destruct(user)" $ do +destruct name = requireConcreteHole $ tracing "destruct(user)" $ do jdg <- goal case hasDestructed jdg name of True -> throwError $ AlreadyDestructed name @@ -139,7 +136,7 @@ destruct name = tracing "destruct(user)" $ do ------------------------------------------------------------------------------ -- | Case split, using the same data constructor in the matches. homo :: OccName -> TacticsM () -homo = tracing "homo" . rule . destruct' (\dc jdg -> +homo = requireConcreteHole . tracing "homo" . rule . destruct' (\dc jdg -> buildDataCon jdg dc $ snd $ splitAppTys $ unCType $ jGoal jdg) @@ -152,40 +149,42 @@ destructLambdaCase = tracing "destructLambdaCase" $ rule $ destructLambdaCase' ( ------------------------------------------------------------------------------ -- | LambdaCase split, using the same data constructor in the matches. homoLambdaCase :: TacticsM () -homoLambdaCase = tracing "homoLambdaCase" $ rule $ destructLambdaCase' (\dc jdg -> - buildDataCon jdg dc $ snd $ splitAppTys $ unCType $ jGoal jdg) +homoLambdaCase = + tracing "homoLambdaCase" $ + rule $ destructLambdaCase' $ \dc jdg -> + buildDataCon jdg dc + . snd + . splitAppTys + . unCType + $ jGoal jdg apply :: OccName -> TacticsM () -apply = apply' (const id) - - -apply' :: (Int -> Judgement -> Judgement) -> OccName -> TacticsM () -apply' f func = tracing ("apply' " <> show func) $ do - rule $ \jdg -> do - let hy = jHypothesis jdg - g = jGoal jdg - case M.lookup func hy of - Just (CType ty) -> do - let (args, ret) = splitFunTys ty - unify g (CType ret) - useOccName jdg func - (tr, sgs) - <- fmap unzipTrace - $ traverse ( \(i, t) -> - newSubgoal - . f i - . blacklistingDestruct - . flip withNewGoal jdg - $ CType t - ) $ zip [0..] args - pure - . (tr, ) - . noLoc - . foldl' (@@) (var' func) - $ fmap unLoc sgs - Nothing -> do - throwError $ GoalMismatch "apply" g +apply func = requireConcreteHole $ tracing ("apply' " <> show func) $ do + jdg <- goal + let hy = jHypothesis jdg + g = jGoal jdg + case M.lookup func hy of + Just (CType ty) -> do + ty' <- freshTyvars ty + let (_, _, args, ret) = tacticsSplitFunTy ty' + requireNewHoles $ rule $ \jdg -> do + unify g (CType ret) + useOccName jdg func + (tr, sgs) + <- fmap unzipTrace + $ traverse ( newSubgoal + . blacklistingDestruct + . flip withNewGoal jdg + . CType + ) args + pure + . (tr, ) + . noLoc + . foldl' (@@) (var' func) + $ fmap unLoc sgs + Nothing -> do + throwError $ GoalMismatch "apply" g ------------------------------------------------------------------------------ @@ -206,7 +205,7 @@ split = tracing "split(user)" $ do -- 'split' because it won't split a data con if it doesn't result in any new -- goals. splitAuto :: TacticsM () -splitAuto = tracing "split(auto)" $ do +splitAuto = requireConcreteHole $ tracing "split(auto)" $ do jdg <- goal let g = jGoal jdg case splitTyConApp_maybe $ unCType g of @@ -216,23 +215,34 @@ splitAuto = tracing "split(auto)" $ do case isSplitWhitelisted jdg of True -> choice $ fmap splitDataCon dcs False -> do - choice $ flip fmap dcs $ \dc -> pruning (splitDataCon dc) $ \jdgs -> - case null jdgs || any (/= jGoal jdg) (fmap jGoal jdgs) of - True -> Nothing - False -> Just $ UnhelpfulSplit $ nameOccName $ dataConName dc + choice $ flip fmap dcs $ \dc -> requireNewHoles $ + splitDataCon dc + + +------------------------------------------------------------------------------ +-- | Allow the given tactic to proceed if and only if it introduces holes that +-- have a different goal than current goal. +requireNewHoles :: TacticsM () -> TacticsM () +requireNewHoles m = do + jdg <- goal + pruning m $ \jdgs -> + case null jdgs || any (/= jGoal jdg) (fmap jGoal jdgs) of + True -> Nothing + False -> Just NoProgress ------------------------------------------------------------------------------ -- | Attempt to instantiate the given data constructor to solve the goal. splitDataCon :: DataCon -> TacticsM () -splitDataCon dc = tracing ("splitDataCon:" <> show dc) $ rule $ \jdg -> do - let g = jGoal jdg - case splitTyConApp_maybe $ unCType g of - Just (tc, apps) -> do - case elem dc $ tyConDataCons tc of - True -> buildDataCon (unwhitelistingSplit jdg) dc apps - False -> throwError $ IncorrectDataCon dc - Nothing -> throwError $ GoalMismatch "splitDataCon" g +splitDataCon dc = + requireConcreteHole $ tracing ("splitDataCon:" <> show dc) $ rule $ \jdg -> do + let g = jGoal jdg + case splitTyConApp_maybe $ unCType g of + Just (tc, apps) -> do + case elem dc $ tyConDataCons tc of + True -> buildDataCon (unwhitelistingSplit jdg) dc apps + False -> throwError $ IncorrectDataCon dc + Nothing -> throwError $ GoalMismatch "splitDataCon" g ------------------------------------------------------------------------------ diff --git a/plugins/tactics/src/Ide/Plugin/Tactic/Types.hs b/plugins/tactics/src/Ide/Plugin/Tactic/Types.hs index 4d1b802697..2d7299a380 100644 --- a/plugins/tactics/src/Ide/Plugin/Tactic/Types.hs +++ b/plugins/tactics/src/Ide/Plugin/Tactic/Types.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE DeriveFunctor #-} @@ -21,20 +22,24 @@ module Ide.Plugin.Tactic.Types ) where import Control.Lens hiding (Context) -import Data.Generics.Product (field) import Control.Monad.Reader +import Control.Monad.State +import Data.Coerce import Data.Function +import Data.Generics.Product (field) import Data.Map (Map) import Data.Set (Set) +import Data.Tree import Development.IDE.GHC.Compat hiding (Node) import Development.IDE.Types.Location import GHC.Generics import Ide.Plugin.Tactic.Debug import OccName import Refinery.Tactic +import System.IO.Unsafe (unsafePerformIO) import Type -import Data.Tree -import Data.Coerce +import UniqSupply (takeUniqFromSupply, mkSplitUniqSupply, UniqSupply) +import Unique (Unique) ------------------------------------------------------------------------------ @@ -73,12 +78,40 @@ data TacticState = TacticState , ts_used_vals :: !(Set OccName) , ts_intro_vals :: !(Set OccName) , ts_recursion_stack :: ![Bool] + , ts_unique_gen :: !UniqSupply } deriving stock (Show, Generic) +instance Show UniqSupply where + show _ = "" + + +------------------------------------------------------------------------------ +-- | A 'UniqSupply' to use in 'defaultTacticState' +unsafeDefaultUniqueSupply :: UniqSupply +unsafeDefaultUniqueSupply = + unsafePerformIO $ mkSplitUniqSupply '🚒' +{-# NOINLINE unsafeDefaultUniqueSupply #-} + defaultTacticState :: TacticState defaultTacticState = - TacticState mempty emptyTCvSubst mempty mempty mempty + TacticState + { ts_skolems = mempty + , ts_unifier = emptyTCvSubst + , ts_used_vals = mempty + , ts_intro_vals = mempty + , ts_recursion_stack = mempty + , ts_unique_gen = unsafeDefaultUniqueSupply + } + + +------------------------------------------------------------------------------ +-- | Generate a new 'Unique' +freshUnique :: MonadState TacticState m => m Unique +freshUnique = do + (uniq, supply) <- gets $ takeUniqFromSupply . ts_unique_gen + modify' $! field @"ts_unique_gen" .~ supply + pure uniq withRecursionStack @@ -102,6 +135,9 @@ withIntroducedVals f = -- | The current bindings and goal for a hole to be filled by refinery. data Judgement' a = Judgement { _jHypothesis :: !(Map OccName a) + , _jAmbientHypothesis :: !(Map OccName a) + -- ^ Things in the hypothesis that were imported. Solutions don't get + -- points for using the ambient hypothesis. , _jDestructed :: !(Set OccName) -- ^ These should align with keys of _jHypothesis , _jPatternVals :: !(Set OccName) @@ -141,6 +177,7 @@ data TacticError | RecursionOnWrongParam OccName Int OccName | UnhelpfulDestruct OccName | UnhelpfulSplit OccName + | TooPolymorphic deriving stock (Eq) instance Show TacticError where @@ -177,6 +214,8 @@ instance Show TacticError where "Destructing patval " <> show n <> " leads to no new types" show (UnhelpfulSplit n) = "Splitting constructor " <> show n <> " leads to no new goals" + show TooPolymorphic = + "The tactic isn't applicable because the goal is too polymorphic" ------------------------------------------------------------------------------ @@ -198,6 +237,12 @@ data Context = Context deriving stock (Eq, Ord) +------------------------------------------------------------------------------ +-- | An empty context +emptyContext :: Context +emptyContext = Context mempty mempty + + newtype Rose a = Rose (Tree a) deriving stock (Eq, Functor, Generic) diff --git a/plugins/tactics/test/AutoTupleSpec.hs b/plugins/tactics/test/AutoTupleSpec.hs index efe37bf09a..9b73c7c2f9 100644 --- a/plugins/tactics/test/AutoTupleSpec.hs +++ b/plugins/tactics/test/AutoTupleSpec.hs @@ -43,6 +43,7 @@ spec = describe "auto for tuple" $ do (Context [] []) (mkFirstJudgement (M.singleton (mkVarOcc "x") $ CType in_type) + mempty True mempty out_type) diff --git a/plugins/tactics/test/UnificationSpec.hs b/plugins/tactics/test/UnificationSpec.hs new file mode 100644 index 0000000000..9351725036 --- /dev/null +++ b/plugins/tactics/test/UnificationSpec.hs @@ -0,0 +1,64 @@ +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fno-warn-orphans #-} + +module UnificationSpec where + +import Control.Arrow +import Data.Bool (bool) +import Data.Functor ((<&>)) +import Data.Maybe (mapMaybe) +import Data.Traversable +import Data.Tuple (swap) +import Ide.Plugin.Tactic.Debug +import Ide.Plugin.Tactic.Machinery +import Ide.Plugin.Tactic.Types +import TcType (tcGetTyVar_maybe, substTy) +import Test.Hspec +import Test.QuickCheck +import Type (mkTyVarTy) +import TysPrim (alphaTyVars) +import TysWiredIn (mkBoxedTupleTy) + + +instance Show Type where + show = unsafeRender + + +spec :: Spec +spec = describe "unification" $ do + it "should be able to unify univars with skolems on either side of the equality" $ do + property $ do + -- Pick some number of unification vars and skolem + n <- choose (1, 1) + let (skolems, take n -> univars) = splitAt n $ fmap mkTyVarTy alphaTyVars + -- Randomly pair them + skolem_uni_pairs <- + for (zip skolems univars) randomSwap + let (lhs, rhs) + = mkBoxedTupleTy *** mkBoxedTupleTy + $ unzip skolem_uni_pairs + pure $ + counterexample (show skolems) $ + counterexample (show lhs) $ + counterexample (show rhs) $ + case tryUnifyUnivarsButNotSkolems + (mapMaybe tcGetTyVar_maybe skolems) + (CType lhs) + (CType rhs) of + Just subst -> + -- For each pair, running the unification over the univar should + -- result in the skolem + conjoin $ zip univars skolems <&> \(uni, skolem) -> + let substd = substTy subst uni + in counterexample (show substd) $ + counterexample (show skolem) $ + CType substd === CType skolem + Nothing -> True === False + + +randomSwap :: (a, a) -> Gen (a, a) +randomSwap ab = do + which <- arbitrary + pure $ bool swap id which ab + + diff --git a/test/functional/Tactic.hs b/test/functional/Tactic.hs index cf6e564493..b722646336 100644 --- a/test/functional/Tactic.hs +++ b/test/functional/Tactic.hs @@ -102,6 +102,11 @@ tests = testGroup , goldenTest "GoldenGADTAuto.hs" 7 13 Auto "" , goldenTest "GoldenSwapMany.hs" 2 12 Auto "" , goldenTest "GoldenBigTuple.hs" 4 12 Auto "" + , goldenTest "GoldenShow.hs" 2 10 Auto "" + , goldenTest "GoldenShowCompose.hs" 2 15 Auto "" + , goldenTest "GoldenShowMapChar.hs" 2 8 Auto "" + , goldenTest "GoldenSuperclass.hs" 7 8 Auto "" + , goldenTest "GoldenApplicativeThen.hs" 2 11 Auto "" ] diff --git a/test/testdata/tactic/GoldenApplicativeThen.hs b/test/testdata/tactic/GoldenApplicativeThen.hs new file mode 100644 index 0000000000..29ce9f5132 --- /dev/null +++ b/test/testdata/tactic/GoldenApplicativeThen.hs @@ -0,0 +1,2 @@ +useThen :: Applicative f => f Int -> f a -> f a +useThen = _ diff --git a/test/testdata/tactic/GoldenApplicativeThen.hs.expected b/test/testdata/tactic/GoldenApplicativeThen.hs.expected new file mode 100644 index 0000000000..fc7816581b --- /dev/null +++ b/test/testdata/tactic/GoldenApplicativeThen.hs.expected @@ -0,0 +1,2 @@ +useThen :: Applicative f => f Int -> f a -> f a +useThen = (\ x x8 -> (*>) x x8) diff --git a/test/testdata/tactic/GoldenShow.hs b/test/testdata/tactic/GoldenShow.hs new file mode 100644 index 0000000000..9ec5e27bcf --- /dev/null +++ b/test/testdata/tactic/GoldenShow.hs @@ -0,0 +1,2 @@ +showMe :: Show a => a -> String +showMe = _ diff --git a/test/testdata/tactic/GoldenShow.hs.expected b/test/testdata/tactic/GoldenShow.hs.expected new file mode 100644 index 0000000000..05ba83e9fe --- /dev/null +++ b/test/testdata/tactic/GoldenShow.hs.expected @@ -0,0 +1,2 @@ +showMe :: Show a => a -> String +showMe = show diff --git a/test/testdata/tactic/GoldenShowCompose.hs b/test/testdata/tactic/GoldenShowCompose.hs new file mode 100644 index 0000000000..c99768e4e5 --- /dev/null +++ b/test/testdata/tactic/GoldenShowCompose.hs @@ -0,0 +1,2 @@ +showCompose :: Show a => (b -> a) -> b -> String +showCompose = _ diff --git a/test/testdata/tactic/GoldenShowCompose.hs.expected b/test/testdata/tactic/GoldenShowCompose.hs.expected new file mode 100644 index 0000000000..373ea6af91 --- /dev/null +++ b/test/testdata/tactic/GoldenShowCompose.hs.expected @@ -0,0 +1,2 @@ +showCompose :: Show a => (b -> a) -> b -> String +showCompose = (\ fba b -> show (fba b)) diff --git a/test/testdata/tactic/GoldenShowMapChar.hs b/test/testdata/tactic/GoldenShowMapChar.hs new file mode 100644 index 0000000000..8e6e5eae6b --- /dev/null +++ b/test/testdata/tactic/GoldenShowMapChar.hs @@ -0,0 +1,2 @@ +test :: Show a => a -> (String -> b) -> b +test = _ diff --git a/test/testdata/tactic/GoldenShowMapChar.hs.expected b/test/testdata/tactic/GoldenShowMapChar.hs.expected new file mode 100644 index 0000000000..8750e4e1f4 --- /dev/null +++ b/test/testdata/tactic/GoldenShowMapChar.hs.expected @@ -0,0 +1,2 @@ +test :: Show a => a -> (String -> b) -> b +test = (\ a fl_cb -> fl_cb (show a)) diff --git a/test/testdata/tactic/GoldenSuperclass.hs b/test/testdata/tactic/GoldenSuperclass.hs new file mode 100644 index 0000000000..86a9fed7bc --- /dev/null +++ b/test/testdata/tactic/GoldenSuperclass.hs @@ -0,0 +1,8 @@ +class Super a where + super :: a + +class Super a => Sub a + +blah :: Sub a => a +blah = _ + diff --git a/test/testdata/tactic/GoldenSuperclass.hs.expected b/test/testdata/tactic/GoldenSuperclass.hs.expected new file mode 100644 index 0000000000..e0a5dbb565 --- /dev/null +++ b/test/testdata/tactic/GoldenSuperclass.hs.expected @@ -0,0 +1,8 @@ +class Super a where + super :: a + +class Super a => Sub a + +blah :: Sub a => a +blah = super +