{-# LANGUAGE PartialTypeSignatures #-}
{-# OPTIONS_HADDOCK hide #-}


module Algorithm.SCC where

import qualified Data.IntMap as IM
import qualified Data.IntSet as IS
import qualified Data.Set as S
import Control.Monad.State.Strict
import Control.Monad
import Data.List
import Debug.Trace


{--
 - Generating strongly connected components.
 - https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm

 - The only modification is that this algorithm considers SCCs wrt to some \emph{frontier}.
 - A frontier is a set of nodes that is not passed, i.e., we consider a subgraph bounded by 
 - (up to not including) the frontier.
 -
 - We assume the existence of a function post :: G -> Int -> IS.IntSet that returns, 
 - given the graph and the current node, a set of next nodes.
 - We also assume the existence of a function V :: G -> IS.IntSet that returns all vertices.
 -
 -
 - I could get none of the Data.Graph functions to work properly, hence this reimplementation.
--}


class IntGraph g where
  intgraph_post :: g -> Int -> IS.IntSet
  intgraph_V    :: g -> IS.IntSet


data SCC_state = SCC_State {
  SCC_state -> IntMap Int
scc_indices  :: IM.IntMap Int,
  SCC_state -> IntMap Int
scc_lowlinks :: IM.IntMap Int,
  SCC_state -> Int
scc_index :: Int,
  SCC_state -> [Int]
scc_stack :: [Int],
  SCC_state -> [IntSet]
scc_return :: [IS.IntSet]
 }

set_index_of :: Int -> (SCC_state -> Int) -> SCC_state -> SCC_state
set_index_of Int
v SCC_state -> Int
x SCC_state
s = SCC_state
s { scc_indices :: IntMap Int
scc_indices = Int -> Int -> IntMap Int -> IntMap Int
forall a. Int -> a -> IntMap a -> IntMap a
IM.insert Int
v (SCC_state -> Int
x SCC_state
s) (SCC_state -> IntMap Int
scc_indices SCC_state
s) }

set_lowlink_of :: Int -> (SCC_state -> Int) -> SCC_state -> SCC_state
set_lowlink_of Int
v SCC_state -> Int
x SCC_state
s = SCC_state
s { scc_lowlinks :: IntMap Int
scc_lowlinks = Int -> Int -> IntMap Int -> IntMap Int
forall a. Int -> a -> IntMap a -> IntMap a
IM.insert Int
v (SCC_state -> Int
x SCC_state
s) (SCC_state -> IntMap Int
scc_lowlinks SCC_state
s) }

set_index :: (SCC_state -> Int) -> SCC_state -> SCC_state
set_index SCC_state -> Int
x SCC_state
s = SCC_state
s { scc_index :: Int
scc_index = SCC_state -> Int
x SCC_state
s }

push :: Int -> SCC_state -> SCC_state
push Int
v SCC_state
s = SCC_state
s { scc_stack :: [Int]
scc_stack = Int
vInt -> [Int] -> [Int]
forall a. a -> [a] -> [a]
:SCC_state -> [Int]
scc_stack SCC_state
s }

pop_and_return :: Int -> SCC_state -> SCC_state
pop_and_return Int
v SCC_state
s = do
  let stack :: [Int]
stack        = SCC_state -> [Int]
scc_stack SCC_state
s
      ([Int]
scc,[Int]
stack') = (Int -> Bool) -> [Int] -> ([Int], [Int])
forall a. (a -> Bool) -> [a] -> ([a], [a])
break (Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
(==) Int
v) [Int]
stack in
    SCC_state
s { scc_stack :: [Int]
scc_stack = [Int] -> [Int]
forall a. [a] -> [a]
tail [Int]
stack', scc_return :: [IntSet]
scc_return = ([Int] -> IntSet
IS.fromList (Int
vInt -> [Int] -> [Int]
forall a. a -> [a] -> [a]
:[Int]
scc) IntSet -> [IntSet] -> [IntSet]
forall a. a -> [a] -> [a]
: SCC_state -> [IntSet]
scc_return SCC_state
s) }


strongconnect :: IntGraph g => g -> Int -> IS.IntSet -> State SCC_state () 
strongconnect :: g -> Int -> IntSet -> State SCC_state ()
strongconnect g
g Int
v IntSet
frontier = do
   (SCC_state -> SCC_state) -> State SCC_state ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((SCC_state -> SCC_state) -> State SCC_state ())
-> (SCC_state -> SCC_state) -> State SCC_state ()
forall a b. (a -> b) -> a -> b
$ Int -> (SCC_state -> Int) -> SCC_state -> SCC_state
set_index_of   Int
v SCC_state -> Int
scc_index
   (SCC_state -> SCC_state) -> State SCC_state ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((SCC_state -> SCC_state) -> State SCC_state ())
-> (SCC_state -> SCC_state) -> State SCC_state ()
forall a b. (a -> b) -> a -> b
$ Int -> (SCC_state -> Int) -> SCC_state -> SCC_state
set_lowlink_of Int
v SCC_state -> Int
scc_index
   (SCC_state -> SCC_state) -> State SCC_state ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((SCC_state -> SCC_state) -> State SCC_state ())
-> (SCC_state -> SCC_state) -> State SCC_state ()
forall a b. (a -> b) -> a -> b
$ (SCC_state -> Int) -> SCC_state -> SCC_state
set_index      ((Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) Int
1) (Int -> Int) -> (SCC_state -> Int) -> SCC_state -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SCC_state -> Int
scc_index)
   (SCC_state -> SCC_state) -> State SCC_state ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((SCC_state -> SCC_state) -> State SCC_state ())
-> (SCC_state -> SCC_state) -> State SCC_state ()
forall a b. (a -> b) -> a -> b
$ Int -> SCC_state -> SCC_state
push Int
v

   [Int] -> (Int -> State SCC_state ()) -> State SCC_state ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (IntSet -> [Int]
IS.toList (IntSet -> [Int]) -> IntSet -> [Int]
forall a b. (a -> b) -> a -> b
$ g -> Int -> IntSet
forall g. IntGraph g => g -> Int -> IntSet
intgraph_post g
g Int
v) (\Int
w -> do 
     Bool -> State SCC_state () -> State SCC_state ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Int
w Int -> IntSet -> Bool
`IS.member` IntSet
frontier) (do 
       Maybe Int
lookup_w_index <- (SCC_state -> Maybe Int) -> StateT SCC_state Identity (Maybe Int)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (Int -> IntMap Int -> Maybe Int
forall a. Int -> IntMap a -> Maybe a
IM.lookup Int
w (IntMap Int -> Maybe Int)
-> (SCC_state -> IntMap Int) -> SCC_state -> Maybe Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SCC_state -> IntMap Int
scc_indices)
       case Maybe Int
lookup_w_index of
         Maybe Int
Nothing -> do
           g -> Int -> IntSet -> State SCC_state ()
forall g. IntGraph g => g -> Int -> IntSet -> State SCC_state ()
strongconnect g
g Int
w IntSet
frontier
           (SCC_state -> SCC_state) -> State SCC_state ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((SCC_state -> SCC_state) -> State SCC_state ())
-> (SCC_state -> SCC_state) -> State SCC_state ()
forall a b. (a -> b) -> a -> b
$ Int -> (SCC_state -> Int) -> SCC_state -> SCC_state
set_lowlink_of Int
v (\SCC_state
s -> Int -> Int -> Int
forall a. Ord a => a -> a -> a
min (SCC_state -> IntMap Int
scc_lowlinks SCC_state
s IntMap Int -> Int -> Int
forall a. IntMap a -> Int -> a
IM.! Int
v) (SCC_state -> IntMap Int
scc_lowlinks SCC_state
s IntMap Int -> Int -> Int
forall a. IntMap a -> Int -> a
IM.! Int
w))
         Just Int
w_index -> do
           [Int]
stack <- (SCC_state -> [Int]) -> StateT SCC_state Identity [Int]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets SCC_state -> [Int]
scc_stack
           Bool -> State SCC_state () -> State SCC_state ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
w Int -> [Int] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Int]
stack) (State SCC_state () -> State SCC_state ())
-> State SCC_state () -> State SCC_state ()
forall a b. (a -> b) -> a -> b
$
             (SCC_state -> SCC_state) -> State SCC_state ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((SCC_state -> SCC_state) -> State SCC_state ())
-> (SCC_state -> SCC_state) -> State SCC_state ()
forall a b. (a -> b) -> a -> b
$ Int -> (SCC_state -> Int) -> SCC_state -> SCC_state
set_lowlink_of Int
v (\SCC_state
s -> Int -> Int -> Int
forall a. Ord a => a -> a -> a
min (SCC_state -> IntMap Int
scc_lowlinks SCC_state
s IntMap Int -> Int -> Int
forall a. IntMap a -> Int -> a
IM.! Int
v) (SCC_state -> IntMap Int
scc_indices SCC_state
s IntMap Int -> Int -> Int
forall a. IntMap a -> Int -> a
IM.! Int
w))
      )
    )

   SCC_state
s <- StateT SCC_state Identity SCC_state
forall s (m :: * -> *). MonadState s m => m s
get
   Bool -> State SCC_state () -> State SCC_state ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (SCC_state -> IntMap Int
scc_lowlinks SCC_state
s IntMap Int -> Int -> Int
forall a. IntMap a -> Int -> a
IM.! Int
v Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== SCC_state -> IntMap Int
scc_indices SCC_state
s IntMap Int -> Int -> Int
forall a. IntMap a -> Int -> a
IM.! Int
v) (State SCC_state () -> State SCC_state ())
-> State SCC_state () -> State SCC_state ()
forall a b. (a -> b) -> a -> b
$
     (SCC_state -> SCC_state) -> State SCC_state ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((SCC_state -> SCC_state) -> State SCC_state ())
-> (SCC_state -> SCC_state) -> State SCC_state ()
forall a b. (a -> b) -> a -> b
$ Int -> SCC_state -> SCC_state
pop_and_return Int
v


compute_all_sccs :: IntGraph g => g -> IS.IntSet -> State SCC_state ()
compute_all_sccs :: g -> IntSet -> State SCC_state ()
compute_all_sccs g
g IntSet
frontier = do
  [Int] -> (Int -> State SCC_state ()) -> State SCC_state ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (IntSet -> [Int]
IS.toList (IntSet -> [Int]) -> IntSet -> [Int]
forall a b. (a -> b) -> a -> b
$ g -> IntSet
forall g. IntGraph g => g -> IntSet
intgraph_V g
g) (\Int
v -> do
    Maybe Int
lookup_v_index <- (SCC_state -> Maybe Int) -> StateT SCC_state Identity (Maybe Int)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (Int -> IntMap Int -> Maybe Int
forall a. Int -> IntMap a -> Maybe a
IM.lookup Int
v (IntMap Int -> Maybe Int)
-> (SCC_state -> IntMap Int) -> SCC_state -> Maybe Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SCC_state -> IntMap Int
scc_indices)
    Bool -> State SCC_state () -> State SCC_state ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Maybe Int
lookup_v_index Maybe Int -> Maybe Int -> Bool
forall a. Eq a => a -> a -> Bool
== Maybe Int
forall a. Maybe a
Nothing) (State SCC_state () -> State SCC_state ())
-> State SCC_state () -> State SCC_state ()
forall a b. (a -> b) -> a -> b
$
      g -> Int -> IntSet -> State SCC_state ()
forall g. IntGraph g => g -> Int -> IntSet -> State SCC_state ()
strongconnect g
g Int
v IntSet
frontier
   )

init_scc_state :: SCC_state
init_scc_state = IntMap Int -> IntMap Int -> Int -> [Int] -> [IntSet] -> SCC_state
SCC_State IntMap Int
forall a. IntMap a
IM.empty IntMap Int
forall a. IntMap a
IM.empty Int
0 [] []

-- Start SCC generation at a given root vertex.
-- Reaches only those vertices reachable from the root.
scc_of :: IntGraph g => g -> Int -> IS.IntSet -> [IS.IntSet]
scc_of :: g -> Int -> IntSet -> [IntSet]
scc_of g
g Int
v IntSet
frontier = SCC_state -> [IntSet]
scc_return (SCC_state -> [IntSet]) -> SCC_state -> [IntSet]
forall a b. (a -> b) -> a -> b
$ State SCC_state () -> SCC_state -> SCC_state
forall s a. State s a -> s -> s
execState (g -> Int -> IntSet -> State SCC_state ()
forall g. IntGraph g => g -> Int -> IntSet -> State SCC_state ()
strongconnect g
g Int
v IntSet
frontier) SCC_state
init_scc_state

-- SCC generation over all vertices.
all_sccs :: IntGraph g => g -> IS.IntSet -> [IS.IntSet]
all_sccs :: g -> IntSet -> [IntSet]
all_sccs g
g IntSet
frontier = SCC_state -> [IntSet]
scc_return (SCC_state -> [IntSet]) -> SCC_state -> [IntSet]
forall a b. (a -> b) -> a -> b
$ State SCC_state () -> SCC_state -> SCC_state
forall s a. State s a -> s -> s
execState (g -> IntSet -> State SCC_state ()
forall g. IntGraph g => g -> IntSet -> State SCC_state ()
compute_all_sccs g
g IntSet
frontier) SCC_state
init_scc_state