Skip to content

Commit a68ed16

Browse files
Add catch semantics to STM
- Add support for Catch in IOSim and IOSimPOR - Add support for Catch in Test/STM.hs
1 parent e243439 commit a68ed16

File tree

6 files changed

+172
-46
lines changed

6 files changed

+172
-46
lines changed

io-sim/src/Control/Monad/IOSim/Internal.hs

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -926,19 +926,45 @@ execAtomically !time !tid !tlbl !nextVid0 action0 k0 =
926926

927927
ThrowStm e ->
928928
{-# SCC "execAtomically.go.ThrowStm" #-} do
929-
-- Revert all the TVar writes
929+
-- Rollback `TVar`s written since catch handler was installed
930930
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
931-
k0 $ StmTxAborted [] (toException e)
931+
case ctl of
932+
AtomicallyFrame -> do
933+
k0 $ StmTxAborted (Map.elems read) (toException e)
934+
935+
BranchFrame (CatchStmA h) k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
936+
{-# SCC "execAtomically.go.BranchFrame" #-} do
937+
-- Execute the left side in a new frame with an empty written set.
938+
-- but preserve ones that were set prior to it, as specified in the
939+
-- [stm](https://hackage.haskell.org/package/stm/docs/Control-Monad-STM.html#v:catchSTM) package.
940+
let ctl'' = BranchFrame NoOpStmA k writtenOuter writtenOuterSeq createdOuterSeq ctl'
941+
go ctl'' read Map.empty [] [] nextVid (h e)
942+
943+
BranchFrame (OrElseStmA _r) _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
944+
{-# SCC "execAtomically.go.BranchFrame" #-} do
945+
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid (ThrowStm e)
946+
947+
BranchFrame NoOpStmA _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
948+
{-# SCC "execAtomically.go.BranchFrame" #-} do
949+
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid (ThrowStm e)
950+
951+
CatchStm a h k ->
952+
{-# SCC "execAtomically.go.ThrowStm" #-} do
953+
-- Execute the catch handler with an empty written set.
954+
-- but preserve ones that were set prior to it, as specified in the
955+
-- [stm](https://hackage.haskell.org/package/stm/docs/Control-Monad-STM.html#v:catchSTM) package.
956+
let ctl' = BranchFrame (CatchStmA h) k written writtenSeq createdSeq ctl
957+
go ctl' read Map.empty [] [] nextVid a
958+
932959

933960
Retry ->
934-
{-# SCC "execAtomically.go.Retry" #-}
935-
do
961+
{-# SCC "execAtomically.go.Retry" #-} do
936962
-- Always revert all the TVar writes for the retry
937963
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
938964
case ctl of
939965
AtomicallyFrame -> do
940966
-- Return vars read, so the thread can block on them
941-
k0 $! StmTxBlocked $! (Map.elems read)
967+
k0 $! StmTxBlocked $! Map.elems read
942968

943969
BranchFrame (OrElseStmA b) k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
944970
{-# SCC "execAtomically.go.BranchFrame.OrElseStmA" #-} do

io-sim/src/Control/Monad/IOSim/STM.hs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,12 +171,10 @@ writeTBQueueDefault (TBQueue queue _size) a = do
171171

172172
isEmptyTBQueueDefault :: MonadSTM m => TBQueueDefault m a -> STM m Bool
173173
isEmptyTBQueueDefault (TBQueue queue _size) = do
174-
(xs, _, ys, _) <- readTVar queue
174+
(xs, _, _, _) <- readTVar queue
175175
case xs of
176176
_:_ -> return False
177-
[] -> case ys of
178-
[] -> return True
179-
_ -> return False
177+
[] -> return True
180178

181179
isFullTBQueueDefault :: MonadSTM m => TBQueueDefault m a -> STM m Bool
182180
isFullTBQueueDefault (TBQueue queue _size) = do

io-sim/src/Control/Monad/IOSim/Types.hs

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ runSTM (STM k) = k ReturnStm
191191
data StmA s a where
192192
ReturnStm :: a -> StmA s a
193193
ThrowStm :: SomeException -> StmA s a
194+
CatchStm :: StmA s a -> (SomeException -> StmA s a) -> (a -> StmA s b) -> StmA s b
194195

195196
NewTVar :: Maybe String -> x -> (TVar s x -> StmA s b) -> StmA s b
196197
LabelTVar :: String -> TVar s a -> StmA s b -> StmA s b
@@ -335,6 +336,31 @@ instance MonadThrow (STM s) where
335336
instance Exceptions.MonadThrow (STM s) where
336337
throwM = MonadThrow.throwIO
337338

339+
340+
instance MonadCatch (STM s) where
341+
342+
catch action handler = STM $ oneShot $ \k -> CatchStm (runSTM action) (runSTM . fromHandler handler) k
343+
where
344+
-- Get a total handler from the given handler
345+
fromHandler :: Exception e => (e -> STM s a) -> SomeException -> STM s a
346+
fromHandler h e = case fromException e of
347+
Nothing -> throwIO e -- Rethrow the exception if handler does not handle it.
348+
Just e' -> h e'
349+
350+
-- Masking is not required as STM actions are always run inside
351+
-- `execAtomically` and behave as if masked. Also note that the default
352+
-- implementation of `generalBracket` needs mask, and is part of `MonadThrow`.
353+
generalBracket acquire release use = do
354+
resource <- acquire
355+
b <- use resource `catch` \e -> do
356+
_ <- release resource (ExitCaseException e)
357+
throwIO e
358+
c <- release resource (ExitCaseSuccess b)
359+
return (b, c)
360+
361+
instance Exceptions.MonadCatch (STM s) where
362+
catch = MonadThrow.catch
363+
338364
instance MonadCatch (IOSim s) where
339365
catch action handler =
340366
IOSim $ oneShot $ \k -> Catch (runIOSim action) (runIOSim . handler) k
@@ -870,9 +896,22 @@ data StmTxResult s a =
870896
| StmTxAborted [SomeTVar s] SomeException
871897

872898

873-
-- | OrElse/Catch give rise to an alternate right hand side branch. A right branch
874-
-- can be a NoOp
875-
data BranchStmA s a = OrElseStmA (StmA s a) | NoOpStmA
899+
-- | A branch indicates that an alternative statement is available in the current
900+
-- context. For example, `OrElse` has two alternative statements, say "left"
901+
-- and "right". While executing the left statement, `OrElseStmA` branch indicates
902+
-- that the right branch is still available, in case the left statement fails.
903+
data BranchStmA s a =
904+
-- | `OrElse` statement with its 'right' alternative.
905+
OrElseStmA (StmA s a)
906+
-- | `CatchStm` statement with the 'catch' handler.
907+
| CatchStmA (SomeException -> StmA s a)
908+
-- | Unlike the other two branches, the no-op branch is not an explicit
909+
-- part of the STM syntax. It simply indicates that there are no
910+
-- alternative statements left to be executed. For example, when running
911+
-- right alternative of the `OrElse` statement or when running the catch
912+
-- handler of a `CatchStm` statement, there are no alternative statements
913+
-- available. This case is represented by the no-op branch.
914+
| NoOpStmA
876915

877916
data StmStack s b a where
878917
-- | Executing in the context of a top level 'atomically'.

io-sim/src/Control/Monad/IOSimPOR/Internal.hs

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1174,32 +1174,54 @@ execAtomically time tid tlbl nextVid0 action0 k0 =
11741174
{-# SCC "execAtomically.go.ThrowStm" #-} do
11751175
-- Revert all the TVar writes
11761176
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
1177-
k0 $ StmTxAborted (Map.elems read) (toException e)
1177+
case ctl of
1178+
AtomicallyFrame -> do
1179+
k0 $ StmTxAborted (Map.elems read) (toException e)
1180+
1181+
BranchFrame (CatchStmA h) k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
1182+
{-# SCC "execAtomically.go.BranchFrame" #-} do
1183+
-- Execute the left side in a new frame with an empty written set.
1184+
-- but preserve ones that were set prior to it, as specified in the
1185+
-- [stm](https://hackage.haskell.org/package/stm/docs/Control-Monad-STM.html#v:catchSTM) package.
1186+
let ctl'' = BranchFrame NoOpStmA k writtenOuter writtenOuterSeq createdOuterSeq ctl'
1187+
go ctl'' read Map.empty [] [] nextVid (h e)
1188+
1189+
BranchFrame (OrElseStmA _r) _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
1190+
{-# SCC "execAtomically.go.BranchFrame" #-} do
1191+
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid (ThrowStm e)
1192+
1193+
BranchFrame NoOpStmA _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
1194+
{-# SCC "execAtomically.go.BranchFrame" #-} do
1195+
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid (ThrowStm e)
1196+
1197+
CatchStm a h k ->
1198+
{-# SCC "execAtomically.go.ThrowStm" #-} do
1199+
-- Execute the left side in a new frame with an empty written set
1200+
let ctl' = BranchFrame (CatchStmA h) k written writtenSeq createdSeq ctl
1201+
go ctl' read Map.empty [] [] nextVid a
11781202

11791203
Retry ->
1180-
{-# SCC "execAtomically.go.Retry" #-}
1181-
do
1182-
-- Always revert all the TVar writes for the retry
1183-
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
1184-
case ctl of
1185-
AtomicallyFrame -> do
1186-
-- Return vars read, so the thread can block on them
1187-
k0 $! StmTxBlocked $! Map.elems read
1188-
1189-
BranchFrame (OrElseStmA b) k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
1190-
{-# SCC "execAtomically.go.BranchFrame.OrElseStmA" #-} do
1191-
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
1192-
-- Execute the orElse right hand with an empty written set
1193-
let ctl'' = BranchFrame NoOpStmA k writtenOuter writtenOuterSeq createdOuterSeq ctl'
1194-
go ctl'' read Map.empty [] [] nextVid b
1195-
1196-
BranchFrame _ _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
1197-
{-# SCC "execAtomically.go.BranchFrame" #-} do
1198-
-- Retry makes sense only within a OrElse context. If it is a branch other than
1199-
-- OrElse left side, then bubble up the `retry` to the frame above.
1200-
-- Skip the continuation and propagate the retry into the outer frame
1201-
-- using the written set for the outer frame
1202-
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid Retry
1204+
{-# SCC "execAtomically.go.Retry" #-} do
1205+
-- Always revert all the TVar writes for the retry
1206+
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
1207+
case ctl of
1208+
AtomicallyFrame -> do
1209+
-- Return vars read, so the thread can block on them
1210+
k0 $! StmTxBlocked $! Map.elems read
1211+
1212+
BranchFrame (OrElseStmA b) k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
1213+
{-# SCC "execAtomically.go.BranchFrame.OrElseStmA" #-} do
1214+
-- Execute the orElse right hand with an empty written set
1215+
let ctl'' = BranchFrame NoOpStmA k writtenOuter writtenOuterSeq createdOuterSeq ctl'
1216+
go ctl'' read Map.empty [] [] nextVid b
1217+
1218+
BranchFrame _ _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
1219+
{-# SCC "execAtomically.go.BranchFrame" #-} do
1220+
-- Retry makes sense only within a OrElse context. If it is a branch other than
1221+
-- OrElse left side, then bubble up the `retry` to the frame above.
1222+
-- Skip the continuation and propagate the retry into the outer frame
1223+
-- using the written set for the outer frame
1224+
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid Retry
12031225

12041226
OrElse a b k ->
12051227
{-# SCC "execAtomically.go.OrElse" #-} do

io-sim/test/Test/IOSim.hs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1249,7 +1249,7 @@ prop_stm_referenceSim t =
12491249
-- | Compare the behaviour of the STM reference operational semantics with
12501250
-- the behaviour of any 'MonadSTM' STM implementation.
12511251
--
1252-
prop_stm_referenceM :: (MonadSTM m, MonadThrow (STM m), MonadCatch m)
1252+
prop_stm_referenceM :: (MonadSTM m, MonadCatch (STM m), MonadCatch m)
12531253
=> SomeTerm -> m Property
12541254
prop_stm_referenceM (SomeTerm _tyrep t) = do
12551255
let (r1, _heap) = evalAtomically t

io-sim/test/Test/STM.hs

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ data Term (t :: Type) where
6868

6969
Return :: Expr t -> Term t
7070
Throw :: Expr a -> Term t
71+
Catch :: Term t -> Term t -> Term t
7172
Retry :: Term t
7273

7374
ReadTVar :: Name (TyVar t) -> Term t
@@ -297,7 +298,7 @@ deriving instance Show (NfTerm t)
297298
-- | The STM transition rules. They reduce a 'Term' to a normal-form 'NfTerm'.
298299
--
299300
-- Compare the implementation of this against the operational semantics in
300-
-- Figure 4 in the paper. Note that @catch@ is not included.
301+
-- Figure 4 in the paper including the `Catch` semantics from the Appendix A.
301302
--
302303
evalTerm :: Env -> Heap -> Allocs -> Term t -> (NfTerm t, Heap, Allocs)
303304
evalTerm !env !heap !allocs term = case term of
@@ -310,6 +311,30 @@ evalTerm !env !heap !allocs term = case term of
310311
where
311312
e' = evalExpr env e
312313

314+
-- Exception semantics are detailed in "Appendix A Exception semantics" p 12-13 of
315+
-- <https://research.microsoft.com/en-us/um/people/simonpj/papers/stm/stm.pdf>
316+
Catch t1 t2 ->
317+
let (nf1, heap', allocs') = evalTerm env heap mempty t1 in case nf1 of
318+
319+
-- Rule XSTM1
320+
-- M; heap, {} => return P; heap', allocs'
321+
-- --------------------------------------------------------
322+
-- S[catch M N]; heap, allocs => S[return P]; heap', allocs U allocs'
323+
NfReturn v -> (NfReturn v, heap', allocs <> allocs')
324+
325+
-- Rule XSTM2
326+
-- M; heap, {} => throw P; heap', allocs'
327+
-- --------------------------------------------------------
328+
-- S[catch M N]; heap, allocs => S[N P]; heap U allocs', allocs U allocs'
329+
NfThrow _ -> evalTerm env (heap <> allocs') (allocs <> allocs') t2
330+
331+
-- Rule XSTM3
332+
-- M; heap, {} => retry; heap', allocs'
333+
-- --------------------------------------------------------
334+
-- S[catch M N]; heap, allocs => S[retry]; heap, allocs
335+
NfRetry -> (NfRetry, heap, allocs)
336+
337+
313338
Retry -> (NfRetry, heap, allocs)
314339

315340
-- Rule READ
@@ -438,7 +463,7 @@ extendExecEnv (Name name _tyrep) v (ExecEnv env) =
438463

439464
-- | Execute an STM 'Term' in the 'STM' monad.
440465
--
441-
execTerm :: (MonadSTM m, MonadThrow (STM m))
466+
execTerm :: (MonadSTM m, MonadCatch (STM m))
442467
=> ExecEnv m
443468
-> Term t
444469
-> STM m (ExecValue m t)
@@ -452,6 +477,8 @@ execTerm env t =
452477
let e' = execExpr env e
453478
throwSTM =<< snapshotExecValue e'
454479

480+
Catch t1 t2 -> execTerm env t1 `catch` \(_ :: ImmValue) -> execTerm env t2
481+
455482
Retry -> retry
456483

457484
ReadTVar n -> do
@@ -492,7 +519,7 @@ snapshotExecValue (ExecValInt x) = return (ImmValInt x)
492519
snapshotExecValue (ExecValVar v _) = fmap ImmValVar
493520
(snapshotExecValue =<< readTVar v)
494521

495-
execAtomically :: forall m t. (MonadSTM m, MonadThrow (STM m), MonadCatch m)
522+
execAtomically :: forall m t. (MonadSTM m, MonadCatch (STM m), MonadCatch m)
496523
=> Term t -> m TxResult
497524
execAtomically t =
498525
toTxResult <$> try (atomically action')
@@ -658,7 +685,7 @@ genTerm env tyrep =
658685
Nothing)
659686
]
660687

661-
binTerm = frequency [ (2, bindTerm), (1, orElseTerm)]
688+
binTerm = frequency [ (2, bindTerm), (1, orElseTerm), (1, catchTerm)]
662689

663690
bindTerm =
664691
sized $ \sz -> do
@@ -672,10 +699,15 @@ genTerm env tyrep =
672699
return (Bind t1 name t2)
673700

674701
orElseTerm =
675-
sized $ \sz -> resize (sz `div` 2) $
702+
scale (`div` 2) $
676703
OrElse <$> genTerm env tyrep
677704
<*> genTerm env tyrep
678705

706+
catchTerm =
707+
scale (`div` 2) $
708+
Catch <$> genTerm env tyrep
709+
<*> genTerm env tyrep
710+
679711
genSomeExpr :: GenEnv -> Gen SomeExpr
680712
genSomeExpr env =
681713
oneof'
@@ -714,6 +746,8 @@ shrinkTerm t =
714746
case t of
715747
Return e -> [Return e' | e' <- shrinkExpr e]
716748
Throw e -> [Throw e' | e' <- shrinkExpr e]
749+
Catch t1 t2 -> [t1, t2]
750+
++ [Catch t1' t2' | (t1', t2') <- liftShrink2 shrinkTerm shrinkTerm (t1, t2)]
717751
Retry -> []
718752
ReadTVar _ -> []
719753

@@ -722,12 +756,10 @@ shrinkTerm t =
722756
NewTVar e -> [NewTVar e' | e' <- shrinkExpr e]
723757

724758
Bind t1 n t2 -> [ t2 | nameId n `Set.notMember` freeNamesTerm t2 ]
725-
++ [ Bind t1' n t2 | t1' <- shrinkTerm t1 ]
726-
++ [ Bind t1 n t2' | t2' <- shrinkTerm t2 ]
759+
++ [ Bind t1' n t2' | (t1', t2') <- liftShrink2 shrinkTerm shrinkTerm (t1, t2) ]
727760

728761
OrElse t1 t2 -> [t1, t2]
729-
++ [ OrElse t1' t2 | t1' <- shrinkTerm t1 ]
730-
++ [ OrElse t1 t2' | t2' <- shrinkTerm t2 ]
762+
++ [ OrElse t1' t2' | (t1', t2') <- liftShrink2 shrinkTerm shrinkTerm (t1, t2) ]
731763

732764
shrinkExpr :: Expr t -> [Expr t]
733765
shrinkExpr ExprUnit = []
@@ -739,6 +771,10 @@ shrinkExpr (ExprName (Name _ (TyRepVar _))) = []
739771
freeNamesTerm :: Term t -> Set NameId
740772
freeNamesTerm (Return e) = freeNamesExpr e
741773
freeNamesTerm (Throw e) = freeNamesExpr e
774+
-- The current generator of catch term ignores the argument passed to the
775+
-- handler.
776+
-- TODO: Correctly handle free names when the handler also binds a variable.
777+
freeNamesTerm (Catch t1 t2) = freeNamesTerm t1 <> freeNamesTerm t2
742778
freeNamesTerm Retry = Set.empty
743779
freeNamesTerm (ReadTVar n) = Set.singleton (nameId n)
744780
freeNamesTerm (WriteTVar n e) = Set.singleton (nameId n) <> freeNamesExpr e
@@ -769,6 +805,7 @@ prop_genSomeTerm (SomeTerm tyrep term) =
769805
termSize :: Term a -> Int
770806
termSize Return{} = 1
771807
termSize Throw{} = 1
808+
termSize (Catch a b) = 1 + termSize a + termSize b
772809
termSize Retry{} = 1
773810
termSize ReadTVar{} = 1
774811
termSize WriteTVar{} = 1
@@ -779,6 +816,7 @@ termSize (OrElse a b) = 1 + termSize a + termSize b
779816
termDepth :: Term a -> Int
780817
termDepth Return{} = 1
781818
termDepth Throw{} = 1
819+
termDepth (Catch a b) = 1 + max (termDepth a) (termDepth b)
782820
termDepth Retry{} = 1
783821
termDepth ReadTVar{} = 1
784822
termDepth WriteTVar{} = 1
@@ -791,6 +829,9 @@ showTerm p (Return e) = showParen (p > 10) $
791829
showString "return " . showExpr 11 e
792830
showTerm p (Throw e) = showParen (p > 10) $
793831
showString "throwSTM " . showExpr 11 e
832+
showTerm p (Catch t1 t2) = showParen (p > 9) $
833+
showTerm 10 t1 . showString " `catch` "
834+
. showTerm 10 t2
794835
showTerm _ Retry = showString "retry"
795836
showTerm p (ReadTVar n) = showParen (p > 10) $
796837
showString "readTVar " . showName n

0 commit comments

Comments
 (0)