Skip to content

Commit e8871ab

Browse files
authored
Better scoring metric for deriving safeHead (#545)
This PR tweaks the scoring metric to heavily penalize not using top-level function arguments when defining functions. Presumably if they were added to the type sig, someone had intention behind it. Note that this doesn't prevent us from deriving const, since we have no better alternatives in that case. Furthermore, this fixes a bug where recursive calls were added to the jLocalHypothesis rather than jAmbientHypothesis. The former is for locally introduced variables, for which usage is rewarded. The result was that we were accidentally rewarding recursive calls! Instead we'd like to penalize them, so this PR adds a field which counts recursive calls and penalizes them. Fixes #539
1 parent 9223599 commit e8871ab

File tree

8 files changed

+69
-4
lines changed

8 files changed

+69
-4
lines changed

plugins/tactics/src/Ide/Plugin/Tactic/CodeGen.hs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
1+
{-# LANGUAGE DataKinds #-}
2+
{-# LANGUAGE TypeApplications #-}
13
{-# LANGUAGE FlexibleContexts #-}
24
{-# LANGUAGE TupleSections #-}
35
{-# LANGUAGE ViewPatterns #-}
46

57
module Ide.Plugin.Tactic.CodeGen where
68

9+
import Control.Lens ((+~), (%~), (<>~))
710
import Control.Monad.Except
811
import Control.Monad.State (MonadState)
912
import Control.Monad.State.Class (modify)
13+
import Data.Generics.Product (field)
1014
import Data.List
1115
import qualified Data.Map as M
1216
import qualified Data.Set as S
@@ -31,10 +35,25 @@ useOccName :: MonadState TacticState m => Judgement -> OccName -> m ()
3135
useOccName jdg name =
3236
-- Only score points if this is in the local hypothesis
3337
case M.lookup name $ jLocalHypothesis jdg of
34-
Just{} -> modify $ withUsedVals $ S.insert name
38+
Just{} -> modify
39+
$ (withUsedVals $ S.insert name)
40+
. (field @"ts_unused_top_vals" %~ S.delete name)
3541
Nothing -> pure ()
3642

3743

44+
------------------------------------------------------------------------------
45+
-- | Doing recursion incurs a small penalty in the score.
46+
penalizeRecursion :: MonadState TacticState m => m ()
47+
penalizeRecursion = modify $ field @"ts_recursion_penality" +~ 1
48+
49+
50+
------------------------------------------------------------------------------
51+
-- | Insert some values into the unused top values field. These are
52+
-- subsequently removed via 'useOccName'.
53+
addUnusedTopVals :: MonadState TacticState m => S.Set OccName -> m ()
54+
addUnusedTopVals vals = modify $ field @"ts_unused_top_vals" <>~ vals
55+
56+
3857
destructMatches
3958
:: (DataCon -> Judgement -> Rule)
4059
-- ^ How to construct each match

plugins/tactics/src/Ide/Plugin/Tactic/Judgements.hs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,13 @@ introducing ns =
7575
field @"_jHypothesis" <>~ M.fromList ns
7676

7777

78+
------------------------------------------------------------------------------
79+
-- | Add some terms to the ambient hypothesis
80+
introducingAmbient :: [(OccName, a)] -> Judgement' a -> Judgement' a
81+
introducingAmbient ns =
82+
field @"_jAmbientHypothesis" <>~ M.fromList ns
83+
84+
7885
filterPosition :: OccName -> Int -> Judgement -> Judgement
7986
filterPosition defn pos jdg =
8087
withHypothesis (M.filterWithKey go) jdg

plugins/tactics/src/Ide/Plugin/Tactic/Machinery.hs

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,10 @@ import Control.Monad.State.Class (gets, modify)
2727
import Control.Monad.State.Strict (StateT (..))
2828
import Data.Coerce
2929
import Data.Either
30+
import Data.Foldable
3031
import Data.Functor ((<&>))
3132
import Data.Generics (mkQ, everything, gcount)
32-
import Data.List (sortBy)
33+
import Data.List (nub, sortBy)
3334
import Data.Ord (comparing, Down(..))
3435
import qualified Data.Set as S
3536
import Development.IDE.GHC.Compat
@@ -71,7 +72,12 @@ runTactic
7172
-> Either [TacticError] RunTacticResults
7273
runTactic ctx jdg t =
7374
let skolems = tyCoVarsOfTypeWellScoped $ unCType $ jGoal jdg
74-
tacticState = defaultTacticState { ts_skolems = skolems }
75+
unused_topvals = nub $ join $ join $ toList $ _jPositionMaps jdg
76+
tacticState =
77+
defaultTacticState
78+
{ ts_skolems = skolems
79+
, ts_unused_top_vals = S.fromList unused_topvals
80+
}
7581
in case partitionEithers
7682
. flip runReader ctx
7783
. unExtractM
@@ -126,21 +132,31 @@ setRecursionFrameData b = do
126132
[] -> []
127133

128134

135+
------------------------------------------------------------------------------
136+
-- | Given the results of running a tactic, score the solutions by
137+
-- desirability.
138+
--
139+
-- TODO(sandy): This function is completely unprincipled and was just hacked
140+
-- together to produce the right test results.
129141
scoreSolution
130142
:: LHsExpr GhcPs
131143
-> TacticState
132144
-> [Judgement]
133145
-> ( Penalize Int -- number of holes
134146
, Reward Bool -- all bindings used
147+
, Penalize Int -- unused top-level bindings
135148
, Penalize Int -- number of introduced bindings
136149
, Reward Int -- number used bindings
150+
, Penalize Int -- number of recursive calls
137151
, Penalize Int -- size of extract
138152
)
139153
scoreSolution ext TacticState{..} holes
140154
= ( Penalize $ length holes
141155
, Reward $ S.null $ ts_intro_vals S.\\ ts_used_vals
156+
, Penalize $ S.size ts_unused_top_vals
142157
, Penalize $ S.size ts_intro_vals
143158
, Reward $ S.size ts_used_vals
159+
, Penalize $ ts_recursion_penality
144160
, Penalize $ solutionSize ext
145161
)
146162

plugins/tactics/src/Ide/Plugin/Tactic/Tactics.hs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,9 @@ recursion = requireConcreteHole $ tracing "recursion" $ do
6969
defs <- getCurrentDefinitions
7070
attemptOn (const $ fmap fst defs) $ \name -> do
7171
modify $ withRecursionStack (False :)
72+
penalizeRecursion
7273
ensure recursiveCleanup (withRecursionStack tail) $ do
73-
(localTactic (apply name) $ introducing defs)
74+
(localTactic (apply name) $ introducingAmbient defs)
7475
<@> fmap (localTactic assumption . filterPosition name) [0..]
7576

7677

@@ -88,6 +89,7 @@ intros = rule $ \jdg -> do
8889
let jdg' = introducing (zip vs $ coerce as)
8990
$ withNewGoal (CType b) jdg
9091
modify $ withIntroducedVals $ mappend $ S.fromList vs
92+
when (isTopHole jdg) $ addUnusedTopVals $ S.fromList vs
9193
(tr, sg)
9294
<- newSubgoal
9395
$ bool

plugins/tactics/src/Ide/Plugin/Tactic/Types.hs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,21 @@ instance Show DataCon where
7474
------------------------------------------------------------------------------
7575
data TacticState = TacticState
7676
{ ts_skolems :: !([TyVar])
77+
-- ^ The known skolems.
7778
, ts_unifier :: !(TCvSubst)
79+
-- ^ The current substitution of univars.
7880
, ts_used_vals :: !(Set OccName)
81+
-- ^ Set of values used by tactics.
7982
, ts_intro_vals :: !(Set OccName)
83+
-- ^ Set of values introduced by tactics.
84+
, ts_unused_top_vals :: !(Set OccName)
85+
-- ^ Set of currently unused arguments to the function being defined.
8086
, ts_recursion_stack :: ![Bool]
87+
-- ^ Stack for tracking whether or not the current recursive call has
88+
-- used at least one smaller pat val. Recursive calls for which this
89+
-- value is 'False' are guaranteed to loop, and must be pruned.
90+
, ts_recursion_penality :: !Int
91+
-- ^ Number of calls to recursion. We penalize each.
8192
, ts_unique_gen :: !UniqSupply
8293
} deriving stock (Show, Generic)
8394

@@ -100,7 +111,9 @@ defaultTacticState =
100111
, ts_unifier = emptyTCvSubst
101112
, ts_used_vals = mempty
102113
, ts_intro_vals = mempty
114+
, ts_unused_top_vals = mempty
103115
, ts_recursion_stack = mempty
116+
, ts_recursion_penality = 0
104117
, ts_unique_gen = unsafeDefaultUniqueSupply
105118
}
106119

test/functional/Tactic.hs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ tests = testGroup
109109
, goldenTest "GoldenSuperclass.hs" 7 8 Auto ""
110110
, ignoreTestBecause "It is unreliable in circleci builds"
111111
$ goldenTest "GoldenApplicativeThen.hs" 2 11 Auto ""
112+
, goldenTest "GoldenSafeHead.hs" 2 12 Auto ""
112113
]
113114

114115

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
safeHead :: [x] -> Maybe x
2+
safeHead = _
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
safeHead :: [x] -> Maybe x
2+
safeHead = (\ l_x
3+
-> case l_x of
4+
[] -> Nothing
5+
(x : l_x2) -> Just x)

0 commit comments

Comments
 (0)