@@ -67,6 +67,7 @@ data Term (t :: Type) where
6767
6868 Return :: Expr t -> Term t
6969 Throw :: Expr a -> Term t
70+ Catch :: Term t -> Term t -> Term t
7071 Retry :: Term t
7172
7273 ReadTVar :: Name (TyVar t ) -> Term t
@@ -296,7 +297,7 @@ deriving instance Show (NfTerm t)
296297-- | The STM transition rules. They reduce a 'Term' to a normal-form 'NfTerm'.
297298--
298299-- Compare the implementation of this against the operational semantics in
299- -- Figure 4 in the paper. Note that @catch@ is not included .
300+ -- Figure 4 in the paper including the `Catch` semantics from the Appendix A .
300301--
301302evalTerm :: Env -> Heap -> Allocs -> Term t -> (NfTerm t , Heap , Allocs )
302303evalTerm ! env ! heap ! allocs term = case term of
@@ -309,6 +310,30 @@ evalTerm !env !heap !allocs term = case term of
309310 where
310311 e' = evalExpr env e
311312
313+ -- Exception semantics are detailed in "Appendix A Exception semantics" p 12-13 of
314+ -- <https://research.microsoft.com/en-us/um/people/simonpj/papers/stm/stm.pdf>
315+ Catch t1 t2 ->
316+ let (nf1, heap', allocs') = evalTerm env heap mempty t1 in case nf1 of
317+
318+ -- Rule XSTM1
319+ -- M; heap, {} => return P; heap', allocs'
320+ -- --------------------------------------------------------
321+ -- S[catch M N]; heap, allocs => S[return P]; heap', allocs U allocs'
322+ NfReturn v -> (NfReturn v, heap', allocs <> allocs')
323+
324+ -- Rule XSTM2
325+ -- M; heap, {} => throw P; heap', allocs'
326+ -- --------------------------------------------------------
327+ -- S[catch M N]; heap, allocs => S[N P]; heap U allocs', allocs U allocs'
328+ NfThrow _ -> evalTerm env (heap <> allocs') (allocs <> allocs') t2
329+
330+ -- Rule XSTM3
331+ -- M; heap, {} => retry; heap', allocs'
332+ -- --------------------------------------------------------
333+ -- S[catch M N]; heap, allocs => S[retry]; heap, allocs
334+ NfRetry -> (NfRetry , heap, allocs)
335+
336+
312337 Retry -> (NfRetry , heap, allocs)
313338
314339 -- Rule READ
@@ -437,7 +462,7 @@ extendExecEnv (Name name _tyrep) v (ExecEnv env) =
437462
438463-- | Execute an STM 'Term' in the 'STM' monad.
439464--
440- execTerm :: (MonadSTM m , MonadThrow (STM m ))
465+ execTerm :: (MonadSTM m , MonadCatch (STM m ))
441466 => ExecEnv m
442467 -> Term t
443468 -> STM m (ExecValue m t )
@@ -451,6 +476,8 @@ execTerm env t =
451476 let e' = execExpr env e
452477 throwSTM =<< snapshotExecValue e'
453478
479+ Catch t1 t2 -> execTerm env t1 `catch` \ (_ :: ImmValue ) -> execTerm env t2
480+
454481 Retry -> retry
455482
456483 ReadTVar n -> do
@@ -491,7 +518,7 @@ snapshotExecValue (ExecValInt x) = return (ImmValInt x)
491518snapshotExecValue (ExecValVar v _) = fmap ImmValVar
492519 (snapshotExecValue =<< readTVar v)
493520
494- execAtomically :: forall m t . (MonadSTM m , MonadThrow (STM m ), MonadCatch m )
521+ execAtomically :: forall m t . (MonadSTM m , MonadCatch (STM m ), MonadCatch m )
495522 => Term t -> m TxResult
496523execAtomically t =
497524 toTxResult <$> try (atomically action')
@@ -657,7 +684,7 @@ genTerm env tyrep =
657684 Nothing )
658685 ]
659686
660- binTerm = frequency [ (2 , bindTerm), (1 , orElseTerm)]
687+ binTerm = frequency [ (2 , bindTerm), (1 , orElseTerm), ( 1 , catchTerm) ]
661688
662689 bindTerm =
663690 sized $ \ sz -> do
@@ -671,10 +698,15 @@ genTerm env tyrep =
671698 return (Bind t1 name t2)
672699
673700 orElseTerm =
674- sized $ \ sz -> resize (sz `div` 2 ) $
701+ scale ( `div` 2 ) $
675702 OrElse <$> genTerm env tyrep
676703 <*> genTerm env tyrep
677704
705+ catchTerm =
706+ scale (`div` 2 ) $
707+ Catch <$> genTerm env tyrep
708+ <*> genTerm env tyrep
709+
678710genSomeExpr :: GenEnv -> Gen SomeExpr
679711genSomeExpr env =
680712 oneof'
@@ -713,6 +745,8 @@ shrinkTerm t =
713745 case t of
714746 Return e -> [Return e' | e' <- shrinkExpr e]
715747 Throw e -> [Throw e' | e' <- shrinkExpr e]
748+ Catch t1 t2 -> [t1, t2]
749+ ++ [Catch t1' t2' | (t1', t2') <- liftShrink2 shrinkTerm shrinkTerm (t1, t2)]
716750 Retry -> []
717751 ReadTVar _ -> []
718752
@@ -721,12 +755,10 @@ shrinkTerm t =
721755 NewTVar e -> [NewTVar e' | e' <- shrinkExpr e]
722756
723757 Bind t1 n t2 -> [ t2 | nameId n `Set.notMember` freeNamesTerm t2 ]
724- ++ [ Bind t1' n t2 | t1' <- shrinkTerm t1 ]
725- ++ [ Bind t1 n t2' | t2' <- shrinkTerm t2 ]
758+ ++ [ Bind t1' n t2' | (t1', t2') <- liftShrink2 shrinkTerm shrinkTerm (t1, t2) ]
726759
727760 OrElse t1 t2 -> [t1, t2]
728- ++ [ OrElse t1' t2 | t1' <- shrinkTerm t1 ]
729- ++ [ OrElse t1 t2' | t2' <- shrinkTerm t2 ]
761+ ++ [ OrElse t1' t2' | (t1', t2') <- liftShrink2 shrinkTerm shrinkTerm (t1, t2) ]
730762
731763shrinkExpr :: Expr t -> [Expr t ]
732764shrinkExpr ExprUnit = []
@@ -738,6 +770,10 @@ shrinkExpr (ExprName (Name _ (TyRepVar _))) = []
738770freeNamesTerm :: Term t -> Set NameId
739771freeNamesTerm (Return e) = freeNamesExpr e
740772freeNamesTerm (Throw e) = freeNamesExpr e
773+ -- The current generator of catch term ignores the argument passed to the
774+ -- handler.
775+ -- TODO: Correctly handle free names when the handler also binds a variable.
776+ freeNamesTerm (Catch t1 t2) = freeNamesTerm t1 <> freeNamesTerm t2
741777freeNamesTerm Retry = Set. empty
742778freeNamesTerm (ReadTVar n) = Set. singleton (nameId n)
743779freeNamesTerm (WriteTVar n e) = Set. singleton (nameId n) <> freeNamesExpr e
@@ -768,6 +804,7 @@ prop_genSomeTerm (SomeTerm tyrep term) =
768804termSize :: Term a -> Int
769805termSize Return {} = 1
770806termSize Throw {} = 1
807+ termSize (Catch a b) = 1 + termSize a + termSize b
771808termSize Retry {} = 1
772809termSize ReadTVar {} = 1
773810termSize WriteTVar {} = 1
@@ -778,6 +815,7 @@ termSize (OrElse a b) = 1 + termSize a + termSize b
778815termDepth :: Term a -> Int
779816termDepth Return {} = 1
780817termDepth Throw {} = 1
818+ termDepth (Catch a b) = 1 + max (termDepth a) (termDepth b)
781819termDepth Retry {} = 1
782820termDepth ReadTVar {} = 1
783821termDepth WriteTVar {} = 1
@@ -790,6 +828,9 @@ showTerm p (Return e) = showParen (p > 10) $
790828 showString " return " . showExpr 11 e
791829showTerm p (Throw e) = showParen (p > 10 ) $
792830 showString " throwSTM " . showExpr 11 e
831+ showTerm p (Catch t1 t2) = showParen (p > 9 ) $
832+ showTerm 10 t1 . showString " `catch` "
833+ . showTerm 10 t2
793834showTerm _ Retry = showString " retry"
794835showTerm p (ReadTVar n) = showParen (p > 10 ) $
795836 showString " readTVar " . showName n
0 commit comments