Skip to content

Commit 6c14163

Browse files
authored
Refactor tactics to track hypothesis provenance (#557)
This PR greatly simplifies the Judgement type, removing several pieces of denormalized state that were a headache to keep synchronized (and, as it happens, weren't!). The core change here is a new type Provenance, which tracks where a term came from, and thus what we're allowed to do with it. The result is a significantly more maintainable implementation.
1 parent b7a390f commit 6c14163

File tree

10 files changed

+475
-230
lines changed

10 files changed

+475
-230
lines changed

plugins/tactics/src/Ide/Plugin/Tactic.hs

+32-20
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,11 @@ import Control.Monad.Trans
2222
import Control.Monad.Trans.Maybe
2323
import Data.Aeson
2424
import Data.Coerce
25+
import Data.Functor ((<&>))
2526
import Data.Generics.Aliases (mkQ)
2627
import Data.Generics.Schemes (everything)
2728
import Data.List
29+
import Data.Map (Map)
2830
import qualified Data.Map as M
2931
import Data.Maybe
3032
import Data.Monoid
@@ -214,7 +216,7 @@ filterBindingType
214216
filterBindingType p tp dflags plId uri range jdg =
215217
let hy = jHypothesis jdg
216218
g = jGoal jdg
217-
in fmap join $ for (M.toList hy) $ \(occ, CType ty) ->
219+
in fmap join $ for (M.toList hy) $ \(occ, hi_type -> CType ty) ->
218220
case p (unCType g) ty of
219221
True -> tp occ ty dflags plId uri range jdg
220222
False -> pure []
@@ -264,23 +266,28 @@ judgementForHole state nfp range = do
264266
(mapMaybe (sequenceA . (occName *** coerce))
265267
$ getDefiningBindings binds rss)
266268
tcg
267-
hyps = hypothesisFromBindings rss binds
268-
ambient = M.fromList $ contextMethodHypothesis ctx
269+
top_provs = getRhsPosVals rss tcs
270+
local_hy = spliceProvenance top_provs
271+
$ hypothesisFromBindings rss binds
272+
cls_hy = contextMethodHypothesis ctx
269273
pure ( resulting_range
270274
, mkFirstJudgement
271-
hyps
272-
ambient
275+
(local_hy <> cls_hy)
273276
(isRhsHole rss tcs)
274-
(maybe
275-
mempty
276-
(uncurry M.singleton . fmap pure)
277-
$ getRhsPosVals rss tcs)
278277
goal
279278
, ctx
280279
, dflags
281280
)
282281

283282

283+
spliceProvenance
284+
:: Map OccName Provenance
285+
-> Map OccName (HyInfo a)
286+
-> Map OccName (HyInfo a)
287+
spliceProvenance provs =
288+
M.mapWithKey $ \name hi ->
289+
overProvenance (maybe id const $ M.lookup name provs) hi
290+
284291

285292
tacticCmd :: (OccName -> TacticsM ()) -> CommandFunction TacticParams
286293
tacticCmd tac lf state (TacticParams uri range var_name)
@@ -334,17 +341,22 @@ isRhsHole rss tcs = everything (||) (mkQ False $ \case
334341

335342
------------------------------------------------------------------------------
336343
-- | Compute top-level position vals of a function
337-
getRhsPosVals :: RealSrcSpan -> TypecheckedSource -> Maybe (OccName, [OccName])
338-
getRhsPosVals rss tcs = getFirst $ everything (<>) (mkQ mempty $ \case
339-
TopLevelRHS name ps
340-
(L (RealSrcSpan span) -- body with no guards and a single defn
341-
(HsVar _ (L _ hole)))
342-
| containsSpan rss span -- which contains our span
343-
, isHole $ occName hole -- and the span is a hole
344-
-> First $ do
345-
patnames <- traverse getPatName ps
346-
pure (occName name, patnames)
347-
_ -> mempty
344+
getRhsPosVals :: RealSrcSpan -> TypecheckedSource -> Map OccName Provenance
345+
getRhsPosVals rss tcs
346+
= M.fromList
347+
$ join
348+
$ maybeToList
349+
$ getFirst
350+
$ everything (<>) (mkQ mempty $ \case
351+
TopLevelRHS name ps
352+
(L (RealSrcSpan span) -- body with no guards and a single defn
353+
(HsVar _ (L _ hole)))
354+
| containsSpan rss span -- which contains our span
355+
, isHole $ occName hole -- and the span is a hole
356+
-> First $ do
357+
patnames <- traverse getPatName ps
358+
pure $ zip patnames $ [0..] <&> TopLevelArgPrv name
359+
_ -> mempty
348360
) tcs
349361

350362

plugins/tactics/src/Ide/Plugin/Tactic/Auto.hs

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,6 @@ auto = do
2323
commit knownStrategies
2424
. tracing "auto"
2525
. localTactic (auto' 4)
26-
. disallowing
26+
. disallowing RecursiveCall
2727
$ fmap fst current
2828

plugins/tactics/src/Ide/Plugin/Tactic/CodeGen.hs

+12-17
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ useOccName jdg name =
4343

4444
------------------------------------------------------------------------------
4545
-- | Doing recursion incurs a small penalty in the score.
46-
penalizeRecursion :: MonadState TacticState m => m ()
47-
penalizeRecursion = modify $ field @"ts_recursion_penality" +~ 1
46+
countRecursiveCall :: TacticState -> TacticState
47+
countRecursiveCall = field @"ts_recursion_count" +~ 1
4848

4949

5050
------------------------------------------------------------------------------
@@ -57,14 +57,14 @@ addUnusedTopVals vals = modify $ field @"ts_unused_top_vals" <>~ vals
5757
destructMatches
5858
:: (DataCon -> Judgement -> Rule)
5959
-- ^ How to construct each match
60-
-> ([(OccName, CType)] -> Judgement -> Judgement)
61-
-- ^ How to derive each match judgement
60+
-> Maybe OccName
61+
-- ^ Scrutinee
6262
-> CType
6363
-- ^ Type being destructed
6464
-> Judgement
6565
-> RuleM (Trace, [RawMatch])
66-
destructMatches f f2 t jdg = do
67-
let hy = jHypothesis jdg
66+
destructMatches f scrut t jdg = do
67+
let hy = jEntireHypothesis jdg
6868
g = jGoal jdg
6969
case splitTyConApp_maybe $ unCType t of
7070
Nothing -> throwError $ GoalMismatch "destruct" g
@@ -76,11 +76,7 @@ destructMatches f f2 t jdg = do
7676
let args = dataConInstOrigArgTys' dc apps
7777
names <- mkManyGoodNames hy args
7878
let hy' = zip names $ coerce args
79-
dcon_name = nameOccName $ dataConName dc
80-
81-
let j = f2 hy'
82-
$ withPositionMapping dcon_name names
83-
$ introducingPat hy'
79+
j = introducingPat scrut dc hy'
8480
$ withNewGoal g jdg
8581
(tr, sg) <- f dc j
8682
modify $ withIntroducedVals $ mappend $ S.fromList names
@@ -142,14 +138,14 @@ destruct' f term jdg = do
142138
let hy = jHypothesis jdg
143139
case find ((== term) . fst) $ toList hy of
144140
Nothing -> throwError $ UndefinedHypothesis term
145-
Just (_, t) -> do
141+
Just (_, hi_type -> t) -> do
146142
useOccName jdg term
147143
(tr, ms)
148144
<- destructMatches
149145
f
150-
(\cs -> setParents term (fmap fst cs) . destructing term)
146+
(Just term)
151147
t
152-
jdg
148+
$ disallowing AlreadyDestructed [term] jdg
153149
pure ( rose ("destruct " <> show term) $ pure tr
154150
, noLoc $ case' (var' term) ms
155151
)
@@ -165,7 +161,7 @@ destructLambdaCase' f jdg = do
165161
case splitFunTy_maybe (unCType g) of
166162
Just (arg, _) | isAlgType arg ->
167163
fmap (fmap noLoc $ lambdaCase) <$>
168-
destructMatches f (const id) (CType arg) jdg
164+
destructMatches f Nothing (CType arg) jdg
169165
_ -> throwError $ GoalMismatch "destructLambdaCase'" g
170166

171167

@@ -178,12 +174,11 @@ buildDataCon
178174
-> RuleM (Trace, LHsExpr GhcPs)
179175
buildDataCon jdg dc apps = do
180176
let args = dataConInstOrigArgTys' dc apps
181-
dcon_name = nameOccName $ dataConName dc
182177
(tr, sgs)
183178
<- fmap unzipTrace
184179
$ traverse ( \(arg, n) ->
185180
newSubgoal
186-
. filterSameTypeFromOtherPositions dcon_name n
181+
. filterSameTypeFromOtherPositions dc n
187182
. blacklistingDestruct
188183
. flip withNewGoal jdg
189184
$ CType arg

plugins/tactics/src/Ide/Plugin/Tactic/Context.hs

+6-3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ import Bag
77
import Control.Arrow
88
import Control.Monad.Reader
99
import Data.List
10+
import Data.Map (Map)
11+
import qualified Data.Map as M
1012
import Data.Maybe (mapMaybe)
1113
import Data.Set (Set)
1214
import qualified Data.Set as S
@@ -33,9 +35,10 @@ mkContext locals tcg = Context
3335

3436
------------------------------------------------------------------------------
3537
-- | Find all of the class methods that exist from the givens in the context.
36-
contextMethodHypothesis :: Context -> [(OccName, CType)]
38+
contextMethodHypothesis :: Context -> Map OccName (HyInfo CType)
3739
contextMethodHypothesis ctx
38-
= excludeForbiddenMethods
40+
= M.fromList
41+
. excludeForbiddenMethods
3942
. join
4043
. concatMap
4144
( mapMaybe methodHypothesis
@@ -51,7 +54,7 @@ contextMethodHypothesis ctx
5154
-- | Many operations are defined in typeclasses for performance reasons, rather
5255
-- than being a true part of the class. This function filters out those, in
5356
-- order to keep our hypothesis space small.
54-
excludeForbiddenMethods :: [(OccName, CType)] -> [(OccName, CType)]
57+
excludeForbiddenMethods :: [(OccName, a)] -> [(OccName, a)]
5558
excludeForbiddenMethods = filter (not . flip S.member forbiddenMethods . fst)
5659
where
5760
forbiddenMethods :: Set OccName

0 commit comments

Comments
 (0)