Skip to content

Commit 1cf44fb

Browse files
authored
Execute functions in ST where possible (#543)
Idea is to save work for GHC. Instead of having >>= as type class method for some polymorphic `m` we work in ST which GHC knows how to compile and optimize very well. No changes to exposed API Reduces allocation during compilation of Data.Vector.Generic.Mutable by ~20%. No measurable changes in other modules
1 parent 7b39212 commit 1cf44fb

File tree

1 file changed

+37
-39
lines changed

1 file changed

+37
-39
lines changed

vector/src/Data/Vector/Generic/Mutable.hs

Lines changed: 37 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ module Data.Vector.Generic.Mutable (
7979
) where
8080

8181
import Control.Monad ((<=<))
82+
import Control.Monad.ST
8283
import Data.Vector.Generic.Mutable.Base
8384
import qualified Data.Vector.Generic.Base as V
8485

@@ -91,7 +92,7 @@ import Data.Vector.Fusion.Bundle.Size
9192
import Data.Vector.Fusion.Util ( delay_inline )
9293
import Data.Vector.Internal.Check
9394

94-
import Control.Monad.Primitive ( PrimMonad(..), RealWorld, stToPrim )
95+
import Control.Monad.Primitive ( PrimMonad(..), stToPrim )
9596

9697
import Prelude
9798
( Ord, Monad, Bool(..), Int, Maybe(..), Either(..), Ordering(..)
@@ -106,8 +107,7 @@ import Data.Bits ( Bits(shiftR) )
106107
-- Internal functions
107108
-- ------------------
108109

109-
unsafeAppend1 :: (PrimMonad m, MVector v a)
110-
=> v (PrimState m) a -> Int -> a -> m (v (PrimState m) a)
110+
unsafeAppend1 :: (MVector v a) => v s a -> Int -> a -> ST s (v s a)
111111
{-# INLINE_INNER unsafeAppend1 #-}
112112
-- NOTE: The case distinction has to be on the outside because
113113
-- GHC creates a join point for the unsafeWrite even when everything
@@ -122,8 +122,7 @@ unsafeAppend1 v i x
122122
checkIndex Internal i (length v') $ unsafeWrite v' i x
123123
return v'
124124

125-
unsafePrepend1 :: (PrimMonad m, MVector v a)
126-
=> v (PrimState m) a -> Int -> a -> m (v (PrimState m) a, Int)
125+
unsafePrepend1 :: (MVector v a) => v s a -> Int -> a -> ST s (v s a, Int)
127126
{-# INLINE_INNER unsafePrepend1 #-}
128127
unsafePrepend1 v i x
129128
| i /= 0 = do
@@ -207,7 +206,7 @@ unstream :: (PrimMonad m, MVector v a)
207206
=> Bundle u a -> m (v (PrimState m) a)
208207
-- NOTE: replace INLINE_FUSED by INLINE? (also in unstreamR)
209208
{-# INLINE_FUSED unstream #-}
210-
unstream s = munstream (Bundle.lift s)
209+
unstream s = stToPrim $ munstream (Bundle.lift s)
211210

212211
-- | Create a new mutable vector and fill it with elements from the monadic
213212
-- stream. The vector will grow exponentially if the maximum size of the stream
@@ -243,9 +242,8 @@ munstreamUnknown s
243242
$ unsafeSlice 0 n v'
244243
where
245244
{-# INLINE_INNER put #-}
246-
put (v,i) x = do
247-
v' <- unsafeAppend1 v i x
248-
return (v',i+1)
245+
put (v,i) x = stToPrim $ do v' <- unsafeAppend1 v i x
246+
return (v',i+1)
249247

250248

251249
-- | Create a new mutable vector and fill it with elements from the 'Bundle'.
@@ -255,7 +253,7 @@ vunstream :: (PrimMonad m, V.Vector v a)
255253
=> Bundle v a -> m (V.Mutable v (PrimState m) a)
256254
-- NOTE: replace INLINE_FUSED by INLINE? (also in unstreamR)
257255
{-# INLINE_FUSED vunstream #-}
258-
vunstream s = vmunstream (Bundle.lift s)
256+
vunstream s = stToPrim $ vmunstream (Bundle.lift s)
259257

260258
-- | Create a new mutable vector and fill it with elements from the monadic
261259
-- stream. The vector will grow exponentially if the maximum size of the stream
@@ -311,7 +309,7 @@ unstreamR :: (PrimMonad m, MVector v a)
311309
=> Bundle u a -> m (v (PrimState m) a)
312310
-- NOTE: replace INLINE_FUSED by INLINE? (also in unstream)
313311
{-# INLINE_FUSED unstreamR #-}
314-
unstreamR s = munstreamR (Bundle.lift s)
312+
unstreamR s = stToPrim $ munstreamR (Bundle.lift s)
315313

316314
-- | Create a new mutable vector and fill it with elements from the monadic
317315
-- stream from right to left. The vector will grow exponentially if the maximum
@@ -350,7 +348,7 @@ munstreamRUnknown s
350348
$ unsafeSlice i (n-i) v'
351349
where
352350
{-# INLINE_INNER put #-}
353-
put (v,i) x = unsafePrepend1 v i x
351+
put (v,i) x = stToPrim $ unsafePrepend1 v i x
354352

355353
-- Length
356354
-- ------
@@ -563,10 +561,9 @@ enlarge_delta :: MVector v a => v s a -> Int
563561
enlarge_delta v = max (length v) 1
564562

565563
-- | Grow a vector logarithmically.
566-
enlarge :: (PrimMonad m, MVector v a)
567-
=> v (PrimState m) a -> m (v (PrimState m) a)
564+
enlarge :: (MVector v a) => v s a -> ST s (v s a)
568565
{-# INLINE enlarge #-}
569-
enlarge v = stToPrim $ do
566+
enlarge v = do
570567
vnew <- unsafeGrow v by
571568
basicInitialize $ basicUnsafeSlice (length v) by vnew
572569
return vnew
@@ -996,10 +993,10 @@ unsafeMove dst src = check Unsafe "length mismatch" (length dst == length src)
996993
accum :: forall m v a b u. (HasCallStack, PrimMonad m, MVector v a)
997994
=> (a -> b -> a) -> v (PrimState m) a -> Bundle u (Int, b) -> m ()
998995
{-# INLINE accum #-}
999-
accum f !v s = Bundle.mapM_ upd s
996+
accum f !v s = stToPrim $ Bundle.mapM_ upd s
1000997
where
1001998
{-# INLINE_INNER upd #-}
1002-
upd :: HasCallStack => (Int, b) -> m ()
999+
upd :: HasCallStack => (Int, b) -> ST (PrimState m) ()
10031000
upd (i,b) = do
10041001
a <- checkIndex Bounds i n $ unsafeRead v i
10051002
unsafeWrite v i (f a b)
@@ -1008,18 +1005,18 @@ accum f !v s = Bundle.mapM_ upd s
10081005
update :: forall m v a u. (HasCallStack, PrimMonad m, MVector v a)
10091006
=> v (PrimState m) a -> Bundle u (Int, a) -> m ()
10101007
{-# INLINE update #-}
1011-
update !v s = Bundle.mapM_ upd s
1008+
update !v s = stToPrim $ Bundle.mapM_ upd s
10121009
where
10131010
{-# INLINE_INNER upd #-}
1014-
upd :: HasCallStack => (Int, a) -> m ()
1011+
upd :: HasCallStack => (Int, a) -> ST (PrimState m) ()
10151012
upd (i,b) = checkIndex Bounds i n $ unsafeWrite v i b
10161013

10171014
!n = length v
10181015

10191016
unsafeAccum :: (PrimMonad m, MVector v a)
10201017
=> (a -> b -> a) -> v (PrimState m) a -> Bundle u (Int, b) -> m ()
10211018
{-# INLINE unsafeAccum #-}
1022-
unsafeAccum f !v s = Bundle.mapM_ upd s
1019+
unsafeAccum f !v s = stToPrim $ Bundle.mapM_ upd s
10231020
where
10241021
{-# INLINE_INNER upd #-}
10251022
upd (i,b) = do
@@ -1028,17 +1025,17 @@ unsafeAccum f !v s = Bundle.mapM_ upd s
10281025
!n = length v
10291026

10301027
unsafeUpdate :: (PrimMonad m, MVector v a)
1031-
=> v (PrimState m) a -> Bundle u (Int, a) -> m ()
1028+
=> v (PrimState m) a -> Bundle u (Int, a) -> m ()
10321029
{-# INLINE unsafeUpdate #-}
1033-
unsafeUpdate !v s = Bundle.mapM_ upd s
1030+
unsafeUpdate !v s = stToPrim $ Bundle.mapM_ upd s
10341031
where
10351032
{-# INLINE_INNER upd #-}
10361033
upd (i,b) = checkIndex Unsafe i n $ unsafeWrite v i b
10371034
!n = length v
10381035

10391036
reverse :: (PrimMonad m, MVector v a) => v (PrimState m) a -> m ()
10401037
{-# INLINE reverse #-}
1041-
reverse !v = reverse_loop 0 (length v - 1)
1038+
reverse !v = stToPrim $ reverse_loop 0 (length v - 1)
10421039
where
10431040
reverse_loop i j | i < j = do
10441041
unsafeSwap v i j
@@ -1048,11 +1045,11 @@ reverse !v = reverse_loop 0 (length v - 1)
10481045
unstablePartition :: forall m v a. (PrimMonad m, MVector v a)
10491046
=> (a -> Bool) -> v (PrimState m) a -> m Int
10501047
{-# INLINE unstablePartition #-}
1051-
unstablePartition f !v = from_left 0 (length v)
1048+
unstablePartition f !v = stToPrim $ from_left 0 (length v)
10521049
where
10531050
-- NOTE: GHC 6.10.4 panics without the signatures on from_left and
10541051
-- from_right
1055-
from_left :: Int -> Int -> m Int
1052+
from_left :: Int -> Int -> ST (PrimState m) Int
10561053
from_left i j
10571054
| i == j = return i
10581055
| otherwise = do
@@ -1061,7 +1058,7 @@ unstablePartition f !v = from_left 0 (length v)
10611058
then from_left (i+1) j
10621059
else from_right i (j-1)
10631060

1064-
from_right :: Int -> Int -> m Int
1061+
from_right :: Int -> Int -> ST (PrimState m) Int
10651062
from_right i j
10661063
| i == j = return i
10671064
| otherwise = do
@@ -1078,7 +1075,8 @@ unstablePartitionBundle :: (PrimMonad m, MVector v a)
10781075
=> (a -> Bool) -> Bundle u a -> m (v (PrimState m) a, v (PrimState m) a)
10791076
{-# INLINE unstablePartitionBundle #-}
10801077
unstablePartitionBundle f s
1081-
= case upperBound (Bundle.size s) of
1078+
= stToPrim
1079+
$ case upperBound (Bundle.size s) of
10821080
Just n -> unstablePartitionMax f s n
10831081
Nothing -> partitionUnknown f s
10841082

@@ -1087,7 +1085,7 @@ unstablePartitionMax :: (PrimMonad m, MVector v a)
10871085
-> m (v (PrimState m) a, v (PrimState m) a)
10881086
{-# INLINE unstablePartitionMax #-}
10891087
unstablePartitionMax f s n
1090-
= do
1088+
= stToPrim $ do
10911089
v <- checkLength Internal n $ unsafeNew n
10921090
let {-# INLINE_INNER put #-}
10931091
put (i, j) x
@@ -1105,15 +1103,15 @@ partitionBundle :: (PrimMonad m, MVector v a)
11051103
=> (a -> Bool) -> Bundle u a -> m (v (PrimState m) a, v (PrimState m) a)
11061104
{-# INLINE partitionBundle #-}
11071105
partitionBundle f s
1108-
= case upperBound (Bundle.size s) of
1106+
= stToPrim
1107+
$ case upperBound (Bundle.size s) of
11091108
Just n -> partitionMax f s n
11101109
Nothing -> partitionUnknown f s
11111110

11121111
partitionMax :: (PrimMonad m, MVector v a)
11131112
=> (a -> Bool) -> Bundle u a -> Int -> m (v (PrimState m) a, v (PrimState m) a)
11141113
{-# INLINE partitionMax #-}
1115-
partitionMax f s n
1116-
= do
1114+
partitionMax f s n = stToPrim $ do
11171115
v <- checkLength Internal n $ unsafeNew n
11181116

11191117
let {-# INLINE_INNER put #-}
@@ -1138,8 +1136,7 @@ partitionMax f s n
11381136
partitionUnknown :: (PrimMonad m, MVector v a)
11391137
=> (a -> Bool) -> Bundle u a -> m (v (PrimState m) a, v (PrimState m) a)
11401138
{-# INLINE partitionUnknown #-}
1141-
partitionUnknown f s
1142-
= do
1139+
partitionUnknown f s = stToPrim $ do
11431140
v1 <- unsafeNew 0
11441141
v2 <- unsafeNew 0
11451142
(v1', n1, v2', n2) <- Bundle.foldM' put (v1, 0, v2, 0) s
@@ -1165,15 +1162,16 @@ partitionWithBundle :: (PrimMonad m, MVector v a, MVector v b, MVector v c)
11651162
=> (a -> Either b c) -> Bundle u a -> m (v (PrimState m) b, v (PrimState m) c)
11661163
{-# INLINE partitionWithBundle #-}
11671164
partitionWithBundle f s
1168-
= case upperBound (Bundle.size s) of
1165+
= stToPrim
1166+
$ case upperBound (Bundle.size s) of
11691167
Just n -> partitionWithMax f s n
11701168
Nothing -> partitionWithUnknown f s
11711169

11721170
partitionWithMax :: (PrimMonad m, MVector v a, MVector v b, MVector v c)
11731171
=> (a -> Either b c) -> Bundle u a -> Int -> m (v (PrimState m) b, v (PrimState m) c)
11741172
{-# INLINE partitionWithMax #-}
11751173
partitionWithMax f s n
1176-
= do
1174+
= stToPrim $ do
11771175
v1 <- unsafeNew n
11781176
v2 <- unsafeNew n
11791177
let {-# INLINE_INNER put #-}
@@ -1194,7 +1192,7 @@ partitionWithUnknown :: forall m v u a b c.
11941192
=> (a -> Either b c) -> Bundle u a -> m (v (PrimState m) b, v (PrimState m) c)
11951193
{-# INLINE partitionWithUnknown #-}
11961194
partitionWithUnknown f s
1197-
= do
1195+
= stToPrim $ do
11981196
v1 <- unsafeNew 0
11991197
v2 <- unsafeNew 0
12001198
(v1', n1, v2', n2) <- Bundle.foldM' put (v1, 0, v2, 0) s
@@ -1204,14 +1202,14 @@ partitionWithUnknown f s
12041202
where
12051203
put :: (v (PrimState m) b, Int, v (PrimState m) c, Int)
12061204
-> a
1207-
-> m (v (PrimState m) b, Int, v (PrimState m) c, Int)
1205+
-> ST (PrimState m) (v (PrimState m) b, Int, v (PrimState m) c, Int)
12081206
{-# INLINE_INNER put #-}
12091207
put (v1, i1, v2, i2) x = case f x of
12101208
Left b -> do
1211-
v1' <- unsafeAppend1 v1 i1 b
1209+
v1' <- stToPrim $ unsafeAppend1 v1 i1 b
12121210
return (v1', i1+1, v2, i2)
12131211
Right c -> do
1214-
v2' <- unsafeAppend1 v2 i2 c
1212+
v2' <- stToPrim $ unsafeAppend1 v2 i2 c
12151213
return (v1, i1, v2', i2+1)
12161214

12171215
-- Modifying vectors

0 commit comments

Comments
 (0)