diff --git a/plugins/tactics/hls-tactics-plugin.cabal b/plugins/tactics/hls-tactics-plugin.cabal index bc5a4b03f1..aa1256c02c 100644 --- a/plugins/tactics/hls-tactics-plugin.cabal +++ b/plugins/tactics/hls-tactics-plugin.cabal @@ -30,6 +30,7 @@ library Ide.Plugin.Tactic.GHC Ide.Plugin.Tactic.Judgements Ide.Plugin.Tactic.KnownStrategies + Ide.Plugin.Tactic.KnownStrategies.QuickCheck Ide.Plugin.Tactic.Machinery Ide.Plugin.Tactic.Naming Ide.Plugin.Tactic.Range diff --git a/plugins/tactics/src/Ide/Plugin/Tactic/CodeGen.hs b/plugins/tactics/src/Ide/Plugin/Tactic/CodeGen.hs index 7ee099c20d..3a3785971c 100644 --- a/plugins/tactics/src/Ide/Plugin/Tactic/CodeGen.hs +++ b/plugins/tactics/src/Ide/Plugin/Tactic/CodeGen.hs @@ -236,3 +236,27 @@ var' = var . fromString . occNameString bvar' :: BVar a => OccName -> a bvar' = bvar . fromString . occNameString + +------------------------------------------------------------------------------ +-- | Get an HsExpr corresponding to a function name. +mkFunc :: String -> HsExpr GhcPs +mkFunc = var' . mkVarOcc + + +------------------------------------------------------------------------------ +-- | Get an HsExpr corresponding to a value name. +mkVal :: String -> HsExpr GhcPs +mkVal = var' . mkVarOcc + + +------------------------------------------------------------------------------ +-- | Like 'op', but easier to call. +infixCall :: String -> HsExpr GhcPs -> HsExpr GhcPs -> HsExpr GhcPs +infixCall s = flip op (fromString s) + + +------------------------------------------------------------------------------ +-- | Like '(@@)', but uses a dollar instead of parentheses. +appDollar :: HsExpr GhcPs -> HsExpr GhcPs -> HsExpr GhcPs +appDollar = infixCall "$" + diff --git a/plugins/tactics/src/Ide/Plugin/Tactic/KnownStrategies.hs b/plugins/tactics/src/Ide/Plugin/Tactic/KnownStrategies.hs index 610740aba3..ca42e15ac5 100644 --- a/plugins/tactics/src/Ide/Plugin/Tactic/KnownStrategies.hs +++ b/plugins/tactics/src/Ide/Plugin/Tactic/KnownStrategies.hs @@ -9,11 +9,13 @@ import Ide.Plugin.Tactic.Types import OccName (mkVarOcc) import Refinery.Tactic import Ide.Plugin.Tactic.Machinery (tracing) +import Ide.Plugin.Tactic.KnownStrategies.QuickCheck (deriveArbitrary) knownStrategies :: TacticsM () knownStrategies = choice - [ deriveFmap + [ known "fmap" deriveFmap + , known "arbitrary" deriveArbitrary ] @@ -26,7 +28,7 @@ known name t = do deriveFmap :: TacticsM () -deriveFmap = known "fmap" $ do +deriveFmap = do try intros overAlgebraicTerms homo choice diff --git a/plugins/tactics/src/Ide/Plugin/Tactic/KnownStrategies/QuickCheck.hs b/plugins/tactics/src/Ide/Plugin/Tactic/KnownStrategies/QuickCheck.hs new file mode 100644 index 0000000000..c29c1d58d8 --- /dev/null +++ b/plugins/tactics/src/Ide/Plugin/Tactic/KnownStrategies/QuickCheck.hs @@ -0,0 +1,112 @@ +{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE LambdaCase #-} + +module Ide.Plugin.Tactic.KnownStrategies.QuickCheck where + +import Control.Monad.Except (MonadError(throwError)) +import Data.Bool (bool) +import Data.List (partition) +import DataCon ( DataCon, dataConName ) +import Development.IDE.GHC.Compat (HsExpr, GhcPs, noLoc) +import GHC.Exts ( IsString(fromString) ) +import GHC.List ( foldl' ) +import GHC.SourceGen (int) +import GHC.SourceGen.Binds ( match, valBind ) +import GHC.SourceGen.Expr ( case', lambda, let' ) +import GHC.SourceGen.Overloaded ( App((@@)), HasList(list) ) +import GHC.SourceGen.Pat ( conP ) +import Ide.Plugin.Tactic.CodeGen +import Ide.Plugin.Tactic.Judgements (jGoal) +import Ide.Plugin.Tactic.Machinery (tracePrim) +import Ide.Plugin.Tactic.Types +import OccName (occNameString, mkVarOcc, HasOccName(occName) ) +import Refinery.Tactic (goal, rule ) +import TyCon (tyConName, TyCon, tyConDataCons ) +import Type ( splitTyConApp_maybe ) +import Data.Generics (mkQ, everything) + + +------------------------------------------------------------------------------ +-- | Known tactic for deriving @arbitrary :: Gen a@. This tactic splits the +-- type's data cons into terminal and inductive cases, and generates code that +-- produces a terminal if the QuickCheck size parameter is <=1, or any data con +-- otherwise. It correctly scales recursive parameters, ensuring termination. +deriveArbitrary :: TacticsM () +deriveArbitrary = do + ty <- jGoal <$> goal + case splitTyConApp_maybe $ unCType ty of + Just (gen_tc, [splitTyConApp_maybe -> Just (tc, apps)]) + | occNameString (occName $ tyConName gen_tc) == "Gen" -> do + rule $ \_ -> do + let dcs = tyConDataCons tc + (terminal, big) = partition ((== 0) . genRecursiveCount) + $ fmap (mkGenerator tc apps) dcs + terminal_expr = mkVal "terminal" + oneof_expr = mkVal "oneof" + pure + ( tracePrim "deriveArbitrary" + , noLoc $ + let' [valBind (fromString "terminal") $ list $ fmap genExpr terminal] $ + appDollar (mkFunc "sized") $ lambda [bvar' (mkVarOcc "n")] $ + case' (infixCall "<=" (mkVal "n") (int 1)) + [ match [conP (fromString "True") []] $ + oneof_expr @@ terminal_expr + , match [conP (fromString "False") []] $ + appDollar oneof_expr $ + infixCall "<>" + (list $ fmap genExpr big) + terminal_expr + ] + ) + _ -> throwError $ GoalMismatch "deriveArbitrary" ty + + +------------------------------------------------------------------------------ +-- | Helper data type for the generator of a specific data con. +data Generator = Generator + { genRecursiveCount :: Integer + , genExpr :: HsExpr GhcPs + } + + +------------------------------------------------------------------------------ +-- | Make a 'Generator' for a given tycon instantiated with the given @[Type]@. +mkGenerator :: TyCon -> [Type] -> DataCon -> Generator +mkGenerator tc apps dc = do + let dc_expr = var' $ occName $ dataConName dc + args = dataConInstOrigArgTys' dc apps + num_recursive_calls = sum $ fmap (bool 0 1 . doesTypeContain tc) args + mkArbitrary = mkArbitraryCall tc num_recursive_calls + Generator num_recursive_calls $ case args of + [] -> mkFunc "pure" @@ dc_expr + (a : as) -> + foldl' + (infixCall "<*>") + (infixCall "<$>" dc_expr $ mkArbitrary a) + (fmap mkArbitrary as) + + +------------------------------------------------------------------------------ +-- | Check if the given 'TyCon' exists anywhere in the 'Type'. +doesTypeContain :: TyCon -> Type -> Bool +doesTypeContain recursive_tc = + everything (||) $ mkQ False (== recursive_tc) + + +------------------------------------------------------------------------------ +-- | Generate the correct sort of call to @arbitrary@. For recursive calls, we +-- need to scale down the size parameter, either by a constant factor of 1 if +-- it's the only recursive parameter, or by @`div` n@ where n is the number of +-- recursive parameters. For all other types, just call @arbitrary@ directly. +mkArbitraryCall :: TyCon -> Integer -> Type -> HsExpr GhcPs +mkArbitraryCall recursive_tc n ty = + let arbitrary = mkFunc "arbitrary" + in case doesTypeContain recursive_tc ty of + True -> + mkFunc "scale" + @@ bool (mkFunc "flip" @@ mkFunc "div" @@ int n) + (mkFunc "subtract" @@ int 1) + (n == 1) + @@ arbitrary + False -> arbitrary + diff --git a/test/functional/Tactic.hs b/test/functional/Tactic.hs index 7eeac215d8..0850fc741e 100644 --- a/test/functional/Tactic.hs +++ b/test/functional/Tactic.hs @@ -116,6 +116,7 @@ tests = testGroup $ goldenTest "GoldenApplicativeThen.hs" 2 11 Auto "" , goldenTest "GoldenSafeHead.hs" 2 12 Auto "" , expectFail "GoldenFish.hs" 5 18 Auto "" + , goldenTest "GoldenArbitrary.hs" 25 13 Auto "" ] diff --git a/test/testdata/tactic/GoldenArbitrary.hs b/test/testdata/tactic/GoldenArbitrary.hs new file mode 100644 index 0000000000..f45d2d1fea --- /dev/null +++ b/test/testdata/tactic/GoldenArbitrary.hs @@ -0,0 +1,26 @@ +-- Emulate a quickcheck import; deriveArbitrary works on any type with the +-- right name and kind +data Gen a + +data Obj + = Square Int Int + | Circle Int + | Polygon [(Int, Int)] + | Rotate2 Double Obj + | Empty + | Full + | Complement Obj + | UnionR Double [Obj] + | DifferenceR Double Obj [Obj] + | IntersectR Double [Obj] + | Translate Double Double Obj + | Scale Double Double Obj + | Mirror Double Double Obj + | Outset Double Obj + | Shell Double Obj + | WithRounding Double Obj + + +arbitrary :: Gen Obj +arbitrary = _ + diff --git a/test/testdata/tactic/GoldenArbitrary.hs.expected b/test/testdata/tactic/GoldenArbitrary.hs.expected new file mode 100644 index 0000000000..a3f677d1a1 --- /dev/null +++ b/test/testdata/tactic/GoldenArbitrary.hs.expected @@ -0,0 +1,52 @@ +-- Emulate a quickcheck import; deriveArbitrary works on any type with the +-- right name and kind +data Gen a + +data Obj + = Square Int Int + | Circle Int + | Polygon [(Int, Int)] + | Rotate2 Double Obj + | Empty + | Full + | Complement Obj + | UnionR Double [Obj] + | DifferenceR Double Obj [Obj] + | IntersectR Double [Obj] + | Translate Double Double Obj + | Scale Double Double Obj + | Mirror Double Double Obj + | Outset Double Obj + | Shell Double Obj + | WithRounding Double Obj + + +arbitrary :: Gen Obj +arbitrary = (let + terminal + = [(Square <$> arbitrary) <*> arbitrary, Circle <$> arbitrary, + Polygon <$> arbitrary, pure Empty, pure Full] + in + sized + $ (\ n + -> case n <= 1 of + True -> oneof terminal + False + -> oneof + $ ([(Rotate2 <$> arbitrary) <*> scale (subtract 1) arbitrary, + Complement <$> scale (subtract 1) arbitrary, + (UnionR <$> arbitrary) <*> scale (subtract 1) arbitrary, + ((DifferenceR <$> arbitrary) <*> scale (flip div 2) arbitrary) + <*> scale (flip div 2) arbitrary, + (IntersectR <$> arbitrary) <*> scale (subtract 1) arbitrary, + ((Translate <$> arbitrary) <*> arbitrary) + <*> scale (subtract 1) arbitrary, + ((Scale <$> arbitrary) <*> arbitrary) + <*> scale (subtract 1) arbitrary, + ((Mirror <$> arbitrary) <*> arbitrary) + <*> scale (subtract 1) arbitrary, + (Outset <$> arbitrary) <*> scale (subtract 1) arbitrary, + (Shell <$> arbitrary) <*> scale (subtract 1) arbitrary, + (WithRounding <$> arbitrary) <*> scale (subtract 1) arbitrary] + <> terminal))) +