Skip to content

STM instances #32

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Change Log

# Circa 2022.10.03 (pre release)

- Added `Semigroup` and `Monoid` instances for `STM` and `WrappedSTM` monads
- Added `MArray` instance for `WrappedSTM` monad
- Added `MonadFix` instance for `STM`

# Circa 2022.09.27 (pre release)

- Module structure of `MonadSTM` changed to follow `stm` package structure.
Expand Down
12 changes: 12 additions & 0 deletions io-classes/src/Control/Monad/Class/MonadSTM/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1319,6 +1319,18 @@ deriving instance MonadSTM m => Monad (WrappedSTM t r m)
deriving instance MonadSTM m => Alternative (WrappedSTM t r m)
deriving instance MonadSTM m => MonadPlus (WrappedSTM t r m)

instance ( Semigroup a, MonadSTM m ) => Semigroup (WrappedSTM t r m a) where
a <> b = (<>) <$> a <*> b
instance ( Monoid a, MonadSTM m ) => Monoid (WrappedSTM t r m a) where
mempty = pure mempty

instance ( MonadSTM m, MArray e a (STM m) ) => MArray e a (WrappedSTM t r m) where
getBounds = WrappedSTM . getBounds
getNumElements = WrappedSTM . getNumElements
unsafeRead arr = WrappedSTM . unsafeRead arr
unsafeWrite arr i = WrappedSTM . unsafeWrite arr i


-- note: this (and the following) instance requires 'UndecidableInstances'
-- extension because it violates 3rd Paterson condition, however `STM m` will
-- resolve to a concrete type of kind (Type -> Type), and thus no larger than
Expand Down
13 changes: 13 additions & 0 deletions io-sim/src/Control/Monad/IOSim/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1008,6 +1008,19 @@ execAtomically !time !tid !tlbl !nextVid0 action0 k0 =
trace <- go ctl read written writtenSeq createdSeq nextVid k
return $ SimTrace time tid tlbl (EventLog x) trace

LiftSTStm st k ->
{-# SCC "schedule.LiftSTStm" #-} do
x <- strictToLazyST st
go ctl read written writtenSeq createdSeq nextVid (k x)

FixStm f k ->
{-# SCC "execAtomically.go.FixStm" #-} do
r <- newSTRef (throw NonTermination)
x <- unsafeInterleaveST $ readSTRef r
let k' = unSTM (f x) $ \x' ->
LiftSTStm (lazyToStrictST (writeSTRef r x')) (\() -> k x')
go ctl read written writtenSeq createdSeq nextVid k'

where
localInvariant =
Map.keysSet written
Expand Down
12 changes: 12 additions & 0 deletions io-sim/src/Control/Monad/IOSim/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,12 @@ data SimA s a where

newtype STM s a = STM { unSTM :: forall r. (a -> StmA s r) -> StmA s r }

instance Semigroup a => Semigroup (STM s a) where
a <> b = (<>) <$> a <*> b

instance Monoid a => Monoid (STM s a) where
mempty = pure mempty

runSTM :: STM s a -> StmA s a
runSTM (STM k) = k ReturnStm

Expand All @@ -199,6 +205,9 @@ data StmA s a where
-> (Maybe a -> a -> ST s TraceValue)
-> StmA s b -> StmA s b

LiftSTStm :: StrictST.ST s a -> (a -> StmA s b) -> StmA s b
FixStm :: (x -> STM s x) -> (x -> StmA s r) -> StmA s r

-- Exported type
type STMSim = STM

Expand Down Expand Up @@ -291,6 +300,9 @@ instance Alternative (STM s) where

instance MonadPlus (STM s) where

instance MonadFix (STM s) where
mfix f = STM $ oneShot $ \k -> FixStm f k

instance MonadSay (IOSim s) where
say msg = IOSim $ oneShot $ \k -> Say msg (k ())

Expand Down
13 changes: 13 additions & 0 deletions io-sim/src/Control/Monad/IOSimPOR/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1261,6 +1261,19 @@ execAtomically time tid tlbl nextVid0 action0 k0 =
-- TODO: step
return $ SimPORTrace time tid (-1) tlbl (EventLog x) trace

LiftSTStm st k ->
{-# SCC "schedule.LiftSTStm" #-} do
x <- strictToLazyST st
go ctl read written writtenSeq createdSeq nextVid (k x)

FixStm f k ->
{-# SCC "execAtomically.go.FixStm" #-} do
r <- newSTRef (throw NonTermination)
x <- unsafeInterleaveST $ readSTRef r
let k' = unSTM (f x) $ \x' ->
LiftSTStm (lazyToStrictST (writeSTRef r x')) (\() -> k x')
go ctl read written writtenSeq createdSeq nextVid k'

where
localInvariant =
Map.keysSet written
Expand Down
56 changes: 42 additions & 14 deletions io-sim/test/Test/IOSim.hs
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,18 @@ tests =
[ testProperty "Reference vs IO" prop_stm_referenceIO
, testProperty "Reference vs Sim" prop_stm_referenceSim
]
, testGroup "MonadFix instance"
[ testProperty "purity" prop_mfix_purity
, testProperty "purity2" prop_mfix_purity_2
, testProperty "tightening" prop_mfix_left_shrinking
, testProperty "lazy" prop_mfix_lazy
, testProperty "recdata" prop_mfix_recdata
, testGroup "MonadFix instances"
[ testGroup "IOSim"
[ testProperty "purity" prop_mfix_purity_IOSim
, testProperty "purity2" prop_mfix_purity_2
, testProperty "tightening" prop_mfix_left_shrinking_IOSim
, testProperty "lazy" prop_mfix_lazy
, testProperty "recdata" prop_mfix_recdata
]
, testGroup "STM"
[ testProperty "purity" prop_mfix_purity_STM
, testProperty "tightening" prop_mfix_left_shrinking_STM
]
]
-- NOTE: Most of the tests below only work because the io-sim
-- scheduler works the way it does.
Expand Down Expand Up @@ -592,15 +598,18 @@ test_wakeup_order = do

-- | Purity demands that @mfix (return . f) = return (fix f)@.
--
prop_mfix_purity :: Positive Int -> Bool
prop_mfix_purity (Positive n) =
runSimOrThrow
(mfix (return . factorial)) n
== fix factorial n
prop_mfix_purity_m :: forall m. MonadFix m => Positive Int -> m Bool
prop_mfix_purity_m (Positive n) =
(== fix factorial n) . ($ n) <$> mfix (return . factorial)
where
factorial :: (Int -> Int) -> Int -> Int
factorial = \rec_ k -> if k <= 1 then 1 else k * rec_ (k - 1)

prop_mfix_purity_IOSim :: Positive Int -> Bool
prop_mfix_purity_IOSim a = runSimOrThrow $ prop_mfix_purity_m a

prop_mfix_purity_STM:: Positive Int -> Bool
prop_mfix_purity_STM a = runSimOrThrow $ atomically $ prop_mfix_purity_m a

prop_mfix_purity_2 :: [Positive Int] -> Bool
prop_mfix_purity_2 as =
Expand Down Expand Up @@ -634,12 +643,12 @@ prop_mfix_purity_2 as =
(realToFrac `map` as')


prop_mfix_left_shrinking
prop_mfix_left_shrinking_IOSim
:: Int
-> NonNegative Int
-> Positive Int
-> Bool
prop_mfix_left_shrinking n (NonNegative d) (Positive i) =
prop_mfix_left_shrinking_IOSim n (NonNegative d) (Positive i) =
let mn :: IOSim s Int
mn = do say ""
threadDelay (realToFrac d)
Expand All @@ -657,6 +666,25 @@ prop_mfix_left_shrinking n (NonNegative d) (Positive i) =
threadDelay (realToFrac d) $> a : rec_)))


prop_mfix_left_shrinking_STM
:: Int
-> Positive Int
-> Bool
prop_mfix_left_shrinking_STM n (Positive i) =
let mn :: STMSim s Int
mn = do say ""
return n
in
take i
(runSimOrThrow $ atomically $
mfix (\rec_ -> mn >>= \a -> return $ a : rec_))
==
take i
(runSimOrThrow $ atomically $
mn >>= \a ->
(mfix (\rec_ -> return $ a : rec_)))



-- | 'Example 8.2.1' in 'Value Recursion in Monadic Computations'
-- <https://leventerkok.github.io/papers/erkok-thesis.pdf>
Expand Down Expand Up @@ -756,7 +784,7 @@ probeOutput probe x = atomically (modifyTVar probe (x:))


--
-- Syncronous exceptions
-- Synchronous exceptions
--

unit_catch_0, unit_catch_1, unit_catch_2, unit_catch_3, unit_catch_4,
Expand Down