Skip to content

Commit 0f07efc

Browse files
authored
Add a known tactic for writing arbitrary instances (#695)
Christmas comes early for QuickCheck users! This PR adds support for generating arbitrary --- including the tricky business of ensuring termination. It can be run by calling Attempt to fill hole on anything of the form arbitrary :: Gen A for some type A.
1 parent cc23521 commit 0f07efc

File tree

7 files changed

+220
-2
lines changed

7 files changed

+220
-2
lines changed

plugins/tactics/hls-tactics-plugin.cabal

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ library
3030
Ide.Plugin.Tactic.GHC
3131
Ide.Plugin.Tactic.Judgements
3232
Ide.Plugin.Tactic.KnownStrategies
33+
Ide.Plugin.Tactic.KnownStrategies.QuickCheck
3334
Ide.Plugin.Tactic.Machinery
3435
Ide.Plugin.Tactic.Naming
3536
Ide.Plugin.Tactic.Range

plugins/tactics/src/Ide/Plugin/Tactic/CodeGen.hs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,3 +236,27 @@ var' = var . fromString . occNameString
236236
bvar' :: BVar a => OccName -> a
237237
bvar' = bvar . fromString . occNameString
238238

239+
240+
------------------------------------------------------------------------------
241+
-- | Get an HsExpr corresponding to a function name.
242+
mkFunc :: String -> HsExpr GhcPs
243+
mkFunc = var' . mkVarOcc
244+
245+
246+
------------------------------------------------------------------------------
247+
-- | Get an HsExpr corresponding to a value name.
248+
mkVal :: String -> HsExpr GhcPs
249+
mkVal = var' . mkVarOcc
250+
251+
252+
------------------------------------------------------------------------------
253+
-- | Like 'op', but easier to call.
254+
infixCall :: String -> HsExpr GhcPs -> HsExpr GhcPs -> HsExpr GhcPs
255+
infixCall s = flip op (fromString s)
256+
257+
258+
------------------------------------------------------------------------------
259+
-- | Like '(@@)', but uses a dollar instead of parentheses.
260+
appDollar :: HsExpr GhcPs -> HsExpr GhcPs -> HsExpr GhcPs
261+
appDollar = infixCall "$"
262+

plugins/tactics/src/Ide/Plugin/Tactic/KnownStrategies.hs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@ import Ide.Plugin.Tactic.Types
99
import OccName (mkVarOcc)
1010
import Refinery.Tactic
1111
import Ide.Plugin.Tactic.Machinery (tracing)
12+
import Ide.Plugin.Tactic.KnownStrategies.QuickCheck (deriveArbitrary)
1213

1314

1415
knownStrategies :: TacticsM ()
1516
knownStrategies = choice
16-
[ deriveFmap
17+
[ known "fmap" deriveFmap
18+
, known "arbitrary" deriveArbitrary
1719
]
1820

1921

@@ -26,7 +28,7 @@ known name t = do
2628

2729

2830
deriveFmap :: TacticsM ()
29-
deriveFmap = known "fmap" $ do
31+
deriveFmap = do
3032
try intros
3133
overAlgebraicTerms homo
3234
choice
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
{-# LANGUAGE ViewPatterns #-}
2+
{-# LANGUAGE LambdaCase #-}
3+
4+
module Ide.Plugin.Tactic.KnownStrategies.QuickCheck where
5+
6+
import Control.Monad.Except (MonadError(throwError))
7+
import Data.Bool (bool)
8+
import Data.List (partition)
9+
import DataCon ( DataCon, dataConName )
10+
import Development.IDE.GHC.Compat (HsExpr, GhcPs, noLoc)
11+
import GHC.Exts ( IsString(fromString) )
12+
import GHC.List ( foldl' )
13+
import GHC.SourceGen (int)
14+
import GHC.SourceGen.Binds ( match, valBind )
15+
import GHC.SourceGen.Expr ( case', lambda, let' )
16+
import GHC.SourceGen.Overloaded ( App((@@)), HasList(list) )
17+
import GHC.SourceGen.Pat ( conP )
18+
import Ide.Plugin.Tactic.CodeGen
19+
import Ide.Plugin.Tactic.Judgements (jGoal)
20+
import Ide.Plugin.Tactic.Machinery (tracePrim)
21+
import Ide.Plugin.Tactic.Types
22+
import OccName (occNameString, mkVarOcc, HasOccName(occName) )
23+
import Refinery.Tactic (goal, rule )
24+
import TyCon (tyConName, TyCon, tyConDataCons )
25+
import Type ( splitTyConApp_maybe )
26+
import Data.Generics (mkQ, everything)
27+
28+
29+
------------------------------------------------------------------------------
30+
-- | Known tactic for deriving @arbitrary :: Gen a@. This tactic splits the
31+
-- type's data cons into terminal and inductive cases, and generates code that
32+
-- produces a terminal if the QuickCheck size parameter is <=1, or any data con
33+
-- otherwise. It correctly scales recursive parameters, ensuring termination.
34+
deriveArbitrary :: TacticsM ()
35+
deriveArbitrary = do
36+
ty <- jGoal <$> goal
37+
case splitTyConApp_maybe $ unCType ty of
38+
Just (gen_tc, [splitTyConApp_maybe -> Just (tc, apps)])
39+
| occNameString (occName $ tyConName gen_tc) == "Gen" -> do
40+
rule $ \_ -> do
41+
let dcs = tyConDataCons tc
42+
(terminal, big) = partition ((== 0) . genRecursiveCount)
43+
$ fmap (mkGenerator tc apps) dcs
44+
terminal_expr = mkVal "terminal"
45+
oneof_expr = mkVal "oneof"
46+
pure
47+
( tracePrim "deriveArbitrary"
48+
, noLoc $
49+
let' [valBind (fromString "terminal") $ list $ fmap genExpr terminal] $
50+
appDollar (mkFunc "sized") $ lambda [bvar' (mkVarOcc "n")] $
51+
case' (infixCall "<=" (mkVal "n") (int 1))
52+
[ match [conP (fromString "True") []] $
53+
oneof_expr @@ terminal_expr
54+
, match [conP (fromString "False") []] $
55+
appDollar oneof_expr $
56+
infixCall "<>"
57+
(list $ fmap genExpr big)
58+
terminal_expr
59+
]
60+
)
61+
_ -> throwError $ GoalMismatch "deriveArbitrary" ty
62+
63+
64+
------------------------------------------------------------------------------
65+
-- | Helper data type for the generator of a specific data con.
66+
data Generator = Generator
67+
{ genRecursiveCount :: Integer
68+
, genExpr :: HsExpr GhcPs
69+
}
70+
71+
72+
------------------------------------------------------------------------------
73+
-- | Make a 'Generator' for a given tycon instantiated with the given @[Type]@.
74+
mkGenerator :: TyCon -> [Type] -> DataCon -> Generator
75+
mkGenerator tc apps dc = do
76+
let dc_expr = var' $ occName $ dataConName dc
77+
args = dataConInstOrigArgTys' dc apps
78+
num_recursive_calls = sum $ fmap (bool 0 1 . doesTypeContain tc) args
79+
mkArbitrary = mkArbitraryCall tc num_recursive_calls
80+
Generator num_recursive_calls $ case args of
81+
[] -> mkFunc "pure" @@ dc_expr
82+
(a : as) ->
83+
foldl'
84+
(infixCall "<*>")
85+
(infixCall "<$>" dc_expr $ mkArbitrary a)
86+
(fmap mkArbitrary as)
87+
88+
89+
------------------------------------------------------------------------------
90+
-- | Check if the given 'TyCon' exists anywhere in the 'Type'.
91+
doesTypeContain :: TyCon -> Type -> Bool
92+
doesTypeContain recursive_tc =
93+
everything (||) $ mkQ False (== recursive_tc)
94+
95+
96+
------------------------------------------------------------------------------
97+
-- | Generate the correct sort of call to @arbitrary@. For recursive calls, we
98+
-- need to scale down the size parameter, either by a constant factor of 1 if
99+
-- it's the only recursive parameter, or by @`div` n@ where n is the number of
100+
-- recursive parameters. For all other types, just call @arbitrary@ directly.
101+
mkArbitraryCall :: TyCon -> Integer -> Type -> HsExpr GhcPs
102+
mkArbitraryCall recursive_tc n ty =
103+
let arbitrary = mkFunc "arbitrary"
104+
in case doesTypeContain recursive_tc ty of
105+
True ->
106+
mkFunc "scale"
107+
@@ bool (mkFunc "flip" @@ mkFunc "div" @@ int n)
108+
(mkFunc "subtract" @@ int 1)
109+
(n == 1)
110+
@@ arbitrary
111+
False -> arbitrary
112+

test/functional/Tactic.hs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ tests = testGroup
116116
$ goldenTest "GoldenApplicativeThen.hs" 2 11 Auto ""
117117
, goldenTest "GoldenSafeHead.hs" 2 12 Auto ""
118118
, expectFail "GoldenFish.hs" 5 18 Auto ""
119+
, goldenTest "GoldenArbitrary.hs" 25 13 Auto ""
119120
]
120121

121122

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
-- Emulate a quickcheck import; deriveArbitrary works on any type with the
2+
-- right name and kind
3+
data Gen a
4+
5+
data Obj
6+
= Square Int Int
7+
| Circle Int
8+
| Polygon [(Int, Int)]
9+
| Rotate2 Double Obj
10+
| Empty
11+
| Full
12+
| Complement Obj
13+
| UnionR Double [Obj]
14+
| DifferenceR Double Obj [Obj]
15+
| IntersectR Double [Obj]
16+
| Translate Double Double Obj
17+
| Scale Double Double Obj
18+
| Mirror Double Double Obj
19+
| Outset Double Obj
20+
| Shell Double Obj
21+
| WithRounding Double Obj
22+
23+
24+
arbitrary :: Gen Obj
25+
arbitrary = _
26+
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
-- Emulate a quickcheck import; deriveArbitrary works on any type with the
2+
-- right name and kind
3+
data Gen a
4+
5+
data Obj
6+
= Square Int Int
7+
| Circle Int
8+
| Polygon [(Int, Int)]
9+
| Rotate2 Double Obj
10+
| Empty
11+
| Full
12+
| Complement Obj
13+
| UnionR Double [Obj]
14+
| DifferenceR Double Obj [Obj]
15+
| IntersectR Double [Obj]
16+
| Translate Double Double Obj
17+
| Scale Double Double Obj
18+
| Mirror Double Double Obj
19+
| Outset Double Obj
20+
| Shell Double Obj
21+
| WithRounding Double Obj
22+
23+
24+
arbitrary :: Gen Obj
25+
arbitrary = (let
26+
terminal
27+
= [(Square <$> arbitrary) <*> arbitrary, Circle <$> arbitrary,
28+
Polygon <$> arbitrary, pure Empty, pure Full]
29+
in
30+
sized
31+
$ (\ n
32+
-> case n <= 1 of
33+
True -> oneof terminal
34+
False
35+
-> oneof
36+
$ ([(Rotate2 <$> arbitrary) <*> scale (subtract 1) arbitrary,
37+
Complement <$> scale (subtract 1) arbitrary,
38+
(UnionR <$> arbitrary) <*> scale (subtract 1) arbitrary,
39+
((DifferenceR <$> arbitrary) <*> scale (flip div 2) arbitrary)
40+
<*> scale (flip div 2) arbitrary,
41+
(IntersectR <$> arbitrary) <*> scale (subtract 1) arbitrary,
42+
((Translate <$> arbitrary) <*> arbitrary)
43+
<*> scale (subtract 1) arbitrary,
44+
((Scale <$> arbitrary) <*> arbitrary)
45+
<*> scale (subtract 1) arbitrary,
46+
((Mirror <$> arbitrary) <*> arbitrary)
47+
<*> scale (subtract 1) arbitrary,
48+
(Outset <$> arbitrary) <*> scale (subtract 1) arbitrary,
49+
(Shell <$> arbitrary) <*> scale (subtract 1) arbitrary,
50+
(WithRounding <$> arbitrary) <*> scale (subtract 1) arbitrary]
51+
<> terminal)))
52+

0 commit comments

Comments
 (0)