@@ -68,6 +68,7 @@ data Term (t :: Type) where
6868
6969 Return :: Expr t -> Term t
7070 Throw :: Expr a -> Term t
71+ Catch :: Term t -> Term t -> Term t
7172 Retry :: Term t
7273
7374 ReadTVar :: Name (TyVar t ) -> Term t
@@ -297,7 +298,7 @@ deriving instance Show (NfTerm t)
297298-- | The STM transition rules. They reduce a 'Term' to a normal-form 'NfTerm'.
298299--
299300-- Compare the implementation of this against the operational semantics in
300- -- Figure 4 in the paper. Note that @catch@ is not included .
301+ -- Figure 4 in the paper including the `Catch` semantics from the Appendix A .
301302--
302303evalTerm :: Env -> Heap -> Allocs -> Term t -> (NfTerm t , Heap , Allocs )
303304evalTerm ! env ! heap ! allocs term = case term of
@@ -310,6 +311,30 @@ evalTerm !env !heap !allocs term = case term of
310311 where
311312 e' = evalExpr env e
312313
314+ -- Exception semantics are detailed in "Appendix A Exception semantics" p 12-13 of
315+ -- <https://research.microsoft.com/en-us/um/people/simonpj/papers/stm/stm.pdf>
316+ Catch t1 t2 ->
317+ let (nf1, heap', allocs') = evalTerm env heap mempty t1 in case nf1 of
318+
319+ -- Rule XSTM1
320+ -- M; heap, {} => return P; heap', allocs'
321+ -- --------------------------------------------------------
322+ -- S[catch M N]; heap, allocs => S[return P]; heap', allocs'
323+ NfReturn v -> (NfReturn v, heap', allocs <> allocs')
324+
325+ -- Rule XSTM2
326+ -- M; heap, {} => throw P; heap', allocs'
327+ -- --------------------------------------------------------
328+ -- S[catch M N]; heap, allocs => S[N P]; heap U allocs', allocs U allocs'
329+ NfThrow _ -> evalTerm env (heap <> allocs') (allocs <> allocs') t2
330+
331+ -- Rule XSTM3
332+ -- M; heap, {} => retry; heap', allocs'
333+ -- --------------------------------------------------------
334+ -- S[catch M N]; heap, allocs => S[retry]; heap, allocs
335+ NfRetry -> (NfRetry , heap, allocs)
336+
337+
313338 Retry -> (NfRetry , heap, allocs)
314339
315340 -- Rule READ
@@ -438,7 +463,7 @@ extendExecEnv (Name name _tyrep) v (ExecEnv env) =
438463
439464-- | Execute an STM 'Term' in the 'STM' monad.
440465--
441- execTerm :: (MonadSTM m , MonadThrow (STM m ))
466+ execTerm :: (MonadSTM m , MonadCatch (STM m ))
442467 => ExecEnv m
443468 -> Term t
444469 -> STM m (ExecValue m t )
@@ -452,6 +477,8 @@ execTerm env t =
452477 let e' = execExpr env e
453478 throwSTM =<< snapshotExecValue e'
454479
480+ Catch t1 t2 -> execTerm env t1 `catch` \ (_ :: ImmValue ) -> execTerm env t2
481+
455482 Retry -> retry
456483
457484 ReadTVar n -> do
@@ -492,7 +519,7 @@ snapshotExecValue (ExecValInt x) = return (ImmValInt x)
492519snapshotExecValue (ExecValVar v _) = fmap ImmValVar
493520 (snapshotExecValue =<< readTVar v)
494521
495- execAtomically :: forall m t . (MonadSTM m , MonadThrow (STM m ), MonadCatch m )
522+ execAtomically :: forall m t . (MonadSTM m , MonadCatch (STM m ), MonadCatch m )
496523 => Term t -> m TxResult
497524execAtomically t =
498525 toTxResult <$> try (atomically action')
@@ -658,7 +685,7 @@ genTerm env tyrep =
658685 Nothing )
659686 ]
660687
661- binTerm = frequency [ (2 , bindTerm), (1 , orElseTerm)]
688+ binTerm = frequency [ (2 , bindTerm), (1 , orElseTerm), ( 1 , catchTerm) ]
662689
663690 bindTerm =
664691 sized $ \ sz -> do
@@ -672,10 +699,15 @@ genTerm env tyrep =
672699 return (Bind t1 name t2)
673700
674701 orElseTerm =
675- sized $ \ sz -> resize (sz `div` 2 ) $
702+ scale ( `div` 2 ) $
676703 OrElse <$> genTerm env tyrep
677704 <*> genTerm env tyrep
678705
706+ catchTerm =
707+ scale (`div` 2 ) $
708+ Catch <$> genTerm env tyrep
709+ <*> genTerm env tyrep
710+
679711genSomeExpr :: GenEnv -> Gen SomeExpr
680712genSomeExpr env =
681713 oneof'
@@ -714,6 +746,8 @@ shrinkTerm t =
714746 case t of
715747 Return e -> [Return e' | e' <- shrinkExpr e]
716748 Throw e -> [Throw e' | e' <- shrinkExpr e]
749+ Catch t1 t2 -> [t1, t2]
750+ ++ [Catch t1' t2' | (t1', t2') <- liftShrink2 shrinkTerm shrinkTerm (t1, t2)]
717751 Retry -> []
718752 ReadTVar _ -> []
719753
@@ -722,12 +756,10 @@ shrinkTerm t =
722756 NewTVar e -> [NewTVar e' | e' <- shrinkExpr e]
723757
724758 Bind t1 n t2 -> [ t2 | nameId n `Set.notMember` freeNamesTerm t2 ]
725- ++ [ Bind t1' n t2 | t1' <- shrinkTerm t1 ]
726- ++ [ Bind t1 n t2' | t2' <- shrinkTerm t2 ]
759+ ++ [ Bind t1' n t2' | (t1', t2') <- liftShrink2 shrinkTerm shrinkTerm (t1, t2) ]
727760
728761 OrElse t1 t2 -> [t1, t2]
729- ++ [ OrElse t1' t2 | t1' <- shrinkTerm t1 ]
730- ++ [ OrElse t1 t2' | t2' <- shrinkTerm t2 ]
762+ ++ [ OrElse t1' t2' | (t1', t2') <- liftShrink2 shrinkTerm shrinkTerm (t1, t2) ]
731763
732764shrinkExpr :: Expr t -> [Expr t ]
733765shrinkExpr ExprUnit = []
@@ -739,6 +771,11 @@ shrinkExpr (ExprName (Name _ (TyRepVar _))) = []
739771freeNamesTerm :: Term t -> Set NameId
740772freeNamesTerm (Return e) = freeNamesExpr e
741773freeNamesTerm (Throw e) = freeNamesExpr e
774+ -- A catch handler should actually have an argument, and then the implementation should
775+ -- handle it. But since current implementation of catch never binds the variable, the following
776+ -- implementation is correct as of now. It needs to be tackled once nested exceptions are handled
777+ -- TODO: Correctly handle free names when the handler also binds a variable
778+ freeNamesTerm (Catch t1 t2) = freeNamesTerm t1 <> freeNamesTerm t2
742779freeNamesTerm Retry = Set. empty
743780freeNamesTerm (ReadTVar n) = Set. singleton (nameId n)
744781freeNamesTerm (WriteTVar n e) = Set. singleton (nameId n) <> freeNamesExpr e
@@ -769,6 +806,7 @@ prop_genSomeTerm (SomeTerm tyrep term) =
769806termSize :: Term a -> Int
770807termSize Return {} = 1
771808termSize Throw {} = 1
809+ termSize (Catch a b) = 1 + termSize a + termSize b
772810termSize Retry {} = 1
773811termSize ReadTVar {} = 1
774812termSize WriteTVar {} = 1
@@ -779,6 +817,7 @@ termSize (OrElse a b) = 1 + termSize a + termSize b
779817termDepth :: Term a -> Int
780818termDepth Return {} = 1
781819termDepth Throw {} = 1
820+ termDepth (Catch a b) = 1 + max (termDepth a) (termDepth b)
782821termDepth Retry {} = 1
783822termDepth ReadTVar {} = 1
784823termDepth WriteTVar {} = 1
@@ -791,6 +830,9 @@ showTerm p (Return e) = showParen (p > 10) $
791830 showString " return " . showExpr 11 e
792831showTerm p (Throw e) = showParen (p > 10 ) $
793832 showString " throwSTM " . showExpr 11 e
833+ showTerm p (Catch t1 t2) = showParen (p > 9 ) $
834+ showTerm 10 t1 . showString " `catch` "
835+ . showTerm 10 t2
794836showTerm _ Retry = showString " retry"
795837showTerm p (ReadTVar n) = showParen (p > 10 ) $
796838 showString " readTVar " . showName n
0 commit comments