@@ -5,18 +5,17 @@ module Streamly.Coreutils.Tsort
55 ) where
66
77import qualified Streamly.Prelude as S
8- -- import qualified Streamly.Internal.Data.Fold as FL
98
10- -- import Data.Char (isSpace, toLower)
9+ import Streamly
1110import System.IO.Unsafe (unsafePerformIO )
11+ import Streamly.Internal.Data.Stream.StreamK (adapt )
1212
13- import Streamly
1413
1514vertices
16- :: Eq a
17- => SerialT IO (a , a )
15+ :: ( IsStream t , Eq a )
16+ => t IO (a , a )
1817 -- ^ Edges
19- -> SerialT IO (Int , a )
18+ -> t IO (Int , a )
2019 -- ^ Map each vertex to a unique integer
2120vertices strm =
2221 S. indexed
@@ -25,86 +24,95 @@ vertices strm =
2524
2625
2726buildAdjList
28- :: Eq a
29- => SerialT IO (a , a )
27+ :: ( IsStream t , Eq a )
28+ => t IO (a , a )
3029 -- ^ stream of edges
31- -> SerialT IO (Int , a )
32- -- ^ Map from @a@ to @Int@
33- -> SerialT IO (SerialT IO Int )
30+ -> t IO (Int , a )
31+ -- ^ map from @a@ to @Int@
32+ -> t IO (t IO Int )
3433 -- ^ initial adj list
35- -> SerialT IO (SerialT IO Int )
34+ -> t IO (t IO Int )
3635 -- ^ adj list now
3736buildAdjList edges vtx adj = do
38- let maybeStrm = unsafePerformIO $ S. last $ S. scanl' (insertPair vtx) adj edges
37+ let maybeStrm =
38+ ( unsafePerformIO
39+ $ S. last
40+ $ adapt
41+ $ S. scanl' (insertPair vtx) adj edges
42+ )
3943 case maybeStrm of
40- Just strm -> strm
41- Nothing -> S. nil
44+ Just strm -> strm
45+ Nothing -> S. nil
4246
4347 where
4448
4549 insertPair
46- :: Eq a
47- => SerialT IO (Int , a )
48- -> SerialT IO (SerialT IO Int )
50+ :: ( IsStream t , Eq a )
51+ => t IO (Int , a )
52+ -> t IO (t IO Int )
4953 -> (a , a )
50- -> SerialT IO (SerialT IO Int )
54+ -> t IO (t IO Int )
5155 insertPair vtxMap adjl (xa, xb) = do
5256 let indexA = unsafePerformIO $ getInt vtxMap xa
5357 let indexB = unsafePerformIO $ getInt vtxMap xb
54- let maybeStrm = unsafePerformIO $ (S. !!) adjl indexA
58+ let maybeStrm = unsafePerformIO $ (S. !!) (adapt adjl) indexA
5559 case maybeStrm of
56- Just strm -> modifyNeighbours indexA (S. cons indexB $ strm) adjl
60+ Just strm -> modifyNeighbours indexA (S. cons indexB strm) adjl
5761 Nothing -> modifyNeighbours indexA (S. yield indexB) adjl
5862
5963
64+
6065 getInt
61- :: Eq a
62- => SerialT IO (Int , a )
66+ :: ( IsStream t , Eq a )
67+ => t IO (Int , a )
6368 -> a
6469 -> IO Int
6570 getInt vtxMap ele = do
66- maybeIndex <- S. findIndex (\ (_, v) -> v == ele) vtxMap
71+ maybeIndex <- S. findIndex (\ (_, v) -> v == ele) $ adapt vtxMap
6772 case maybeIndex of
6873 Just idx -> return idx
6974 Nothing -> return (- 1 ) -- won't ever equal any other index in a stream
7075
7176 modifyNeighbours
72- :: Int
73- -> SerialT IO Int
74- -> SerialT IO (SerialT IO Int )
75- -> SerialT IO (SerialT IO Int )
77+ :: IsStream t
78+ => Int
79+ -> t IO Int
80+ -> t IO (t IO Int )
81+ -> t IO (t IO Int )
7682 modifyNeighbours idx newNbd adjStrm =
7783 S. map (\ (_, v) -> v)
7884 $ S. map (\ (i, v) -> do
79- if i == idx
80- then (i, newNbd)
81- else (i, v))
85+ if i == idx
86+ then (i, newNbd)
87+ else (i, v))
8288 $ S. indexed adjStrm
8389
8490
8591-- | dfs
8692dfs
87- :: IsStream t
93+ :: ( IsStream t , Monad m )
8894 => Int
8995 -- ^ root node to start dfs
90- -> SerialT IO ( SerialT IO Int )
96+ -> t m ( t m Int )
9197 -- ^ adj list
92- -> SerialT IO Bool
98+ -> t m Bool
9399 -- ^ visited
94- -> t IO Int
100+ -> t m Int
95101 -- ^ parent
96- -> t IO Int
102+ -> t m Int
97103 -- ^ the stack
98- -> IO ( SerialT IO Bool , t IO Int , t IO Int )
104+ -> m ( t m Bool , t m Int , t m Int )
99105 -- ^ (visited, parent, stack)
100106dfs root adj vis par stck = do
101- strm <- (S. !!) adj root
107+ strm <- (S. !!) (adapt adj) root
102108 case strm of
103109 Just nbd -> do
104110 maybeTuple <- S. last
111+ $ adapt
105112 $ S. scanlM'
106- (\ (visi, parent, stack) v -> dfs v adj
107- (markVisited visi v) (setParent parent v root) (S. cons v stack))
113+ (\ (visi, parent, stack) v ->
114+ dfs v adj (markVisited visi v)
115+ (setParent parent v root) (S. cons v stack))
108116 (vis, par, stck)
109117 $ S. filterM (unVisited vis) nbd
110118 case maybeTuple of
@@ -115,12 +123,12 @@ dfs root adj vis par stck = do
115123 where
116124
117125 unVisited
118- :: Monad m
119- => SerialT m Bool
126+ :: ( IsStream t , Monad m )
127+ => t m Bool
120128 -> Int
121129 -> m Bool
122130 unVisited visStream n = do
123- ele <- (S. !!) visStream n
131+ ele <- (S. !!) (adapt visStream) n
124132 case ele of
125133 Just v -> return $ not v
126134 Nothing -> return False
0 commit comments