{-# LANGUAGE NondecreasingIndentation #-}

-- | A simple mutable union-find data structure.
--
-- It is used in a unification algorithm for backpack mix-in linking.
--
-- This implementation is based off of the one in \"The Essence of ML Type
-- Inference\". (N.B. the union-find package is also based off of this.)
module Distribution.Utils.UnionFind
  ( Point
  , fresh
  , find
  , union
  , equivalent
  ) where

import Control.Monad
import Control.Monad.ST
import Data.STRef

-- | A variable which can be unified; alternately, this can be thought
-- of as an equivalence class with a distinguished representative.
newtype Point s a = Point (STRef s (Link s a))
  deriving (Point s a -> Point s a -> Bool
(Point s a -> Point s a -> Bool)
-> (Point s a -> Point s a -> Bool) -> Eq (Point s a)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall s a. Point s a -> Point s a -> Bool
$c== :: forall s a. Point s a -> Point s a -> Bool
== :: Point s a -> Point s a -> Bool
$c/= :: forall s a. Point s a -> Point s a -> Bool
/= :: Point s a -> Point s a -> Bool
Eq)

-- | Mutable write to a 'Point'
writePoint :: Point s a -> Link s a -> ST s ()
writePoint :: forall s a. Point s a -> Link s a -> ST s ()
writePoint (Point STRef s (Link s a)
v) = STRef s (Link s a) -> Link s a -> ST s ()
forall s a. STRef s a -> a -> ST s ()
writeSTRef STRef s (Link s a)
v

-- | Read the current value of 'Point'.
readPoint :: Point s a -> ST s (Link s a)
readPoint :: forall s a. Point s a -> ST s (Link s a)
readPoint (Point STRef s (Link s a)
v) = STRef s (Link s a) -> ST s (Link s a)
forall s a. STRef s a -> ST s a
readSTRef STRef s (Link s a)
v

-- | The internal data structure for a 'Point', which either records
-- the representative element of an equivalence class, or a link to
-- the 'Point' that actually stores the representative type.
data Link s a
  = -- NB: it is too bad we can't say STRef Int#; the weights remain boxed
    Info {-# UNPACK #-} !(STRef s Int) {-# UNPACK #-} !(STRef s a)
  | Link {-# UNPACK #-} !(Point s a)

-- | Create a fresh equivalence class with one element.
fresh :: a -> ST s (Point s a)
fresh :: forall a s. a -> ST s (Point s a)
fresh a
desc = do
  STRef s Int
weight <- Int -> ST s (STRef s Int)
forall a s. a -> ST s (STRef s a)
newSTRef Int
1
  STRef s a
descriptor <- a -> ST s (STRef s a)
forall a s. a -> ST s (STRef s a)
newSTRef a
desc
  STRef s (Link s a) -> Point s a
forall s a. STRef s (Link s a) -> Point s a
Point (STRef s (Link s a) -> Point s a)
-> ST s (STRef s (Link s a)) -> ST s (Point s a)
forall a b. (a -> b) -> ST s a -> ST s b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` Link s a -> ST s (STRef s (Link s a))
forall a s. a -> ST s (STRef s a)
newSTRef (STRef s Int -> STRef s a -> Link s a
forall s a. STRef s Int -> STRef s a -> Link s a
Info STRef s Int
weight STRef s a
descriptor)

-- | Flatten any chains of links, returning a 'Point'
-- which points directly to the canonical representation.
repr :: Point s a -> ST s (Point s a)
repr :: forall s a. Point s a -> ST s (Point s a)
repr Point s a
point =
  Point s a -> ST s (Link s a)
forall s a. Point s a -> ST s (Link s a)
readPoint Point s a
point ST s (Link s a)
-> (Link s a -> ST s (Point s a)) -> ST s (Point s a)
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Link s a
r ->
    case Link s a
r of
      Link Point s a
point' -> do
        Point s a
point'' <- Point s a -> ST s (Point s a)
forall s a. Point s a -> ST s (Point s a)
repr Point s a
point'
        Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Point s a
point'' Point s a -> Point s a -> Bool
forall a. Eq a => a -> a -> Bool
/= Point s a
point') (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
          Point s a -> Link s a -> ST s ()
forall s a. Point s a -> Link s a -> ST s ()
writePoint Point s a
point (Link s a -> ST s ()) -> ST s (Link s a) -> ST s ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Point s a -> ST s (Link s a)
forall s a. Point s a -> ST s (Link s a)
readPoint Point s a
point'
        Point s a -> ST s (Point s a)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return Point s a
point''
      Info STRef s Int
_ STRef s a
_ -> Point s a -> ST s (Point s a)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return Point s a
point

-- | Return the canonical element of an equivalence
-- class 'Point'.
find :: Point s a -> ST s a
find :: forall s a. Point s a -> ST s a
find Point s a
point =
  -- Optimize length 0 and 1 case at expense of
  -- general case
  Point s a -> ST s (Link s a)
forall s a. Point s a -> ST s (Link s a)
readPoint Point s a
point ST s (Link s a) -> (Link s a -> ST s a) -> ST s a
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Link s a
r ->
    case Link s a
r of
      Info STRef s Int
_ STRef s a
d_ref -> STRef s a -> ST s a
forall s a. STRef s a -> ST s a
readSTRef STRef s a
d_ref
      Link Point s a
point' ->
        Point s a -> ST s (Link s a)
forall s a. Point s a -> ST s (Link s a)
readPoint Point s a
point' ST s (Link s a) -> (Link s a -> ST s a) -> ST s a
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Link s a
r' ->
          case Link s a
r' of
            Info STRef s Int
_ STRef s a
d_ref -> STRef s a -> ST s a
forall s a. STRef s a -> ST s a
readSTRef STRef s a
d_ref
            Link Point s a
_ -> Point s a -> ST s (Point s a)
forall s a. Point s a -> ST s (Point s a)
repr Point s a
point ST s (Point s a) -> (Point s a -> ST s a) -> ST s a
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Point s a -> ST s a
forall s a. Point s a -> ST s a
find

-- | Unify two equivalence classes, so that they share
-- a canonical element. Keeps the descriptor of point2.
union :: Point s a -> Point s a -> ST s ()
union :: forall s a. Point s a -> Point s a -> ST s ()
union Point s a
refpoint1 Point s a
refpoint2 = do
  Point s a
point1 <- Point s a -> ST s (Point s a)
forall s a. Point s a -> ST s (Point s a)
repr Point s a
refpoint1
  Point s a
point2 <- Point s a -> ST s (Point s a)
forall s a. Point s a -> ST s (Point s a)
repr Point s a
refpoint2
  Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Point s a
point1 Point s a -> Point s a -> Bool
forall a. Eq a => a -> a -> Bool
/= Point s a
point2) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
    Link s a
l1 <- Point s a -> ST s (Link s a)
forall s a. Point s a -> ST s (Link s a)
readPoint Point s a
point1
    Link s a
l2 <- Point s a -> ST s (Link s a)
forall s a. Point s a -> ST s (Link s a)
readPoint Point s a
point2
    case (Link s a
l1, Link s a
l2) of
      (Info STRef s Int
wref1 STRef s a
dref1, Info STRef s Int
wref2 STRef s a
dref2) -> do
        Int
weight1 <- STRef s Int -> ST s Int
forall s a. STRef s a -> ST s a
readSTRef STRef s Int
wref1
        Int
weight2 <- STRef s Int -> ST s Int
forall s a. STRef s a -> ST s a
readSTRef STRef s Int
wref2
        -- Should be able to optimize the == case separately
        if Int
weight1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
weight2
          then do
            Point s a -> Link s a -> ST s ()
forall s a. Point s a -> Link s a -> ST s ()
writePoint Point s a
point2 (Point s a -> Link s a
forall s a. Point s a -> Link s a
Link Point s a
point1)
            -- The weight calculation here seems a bit dodgy
            STRef s Int -> Int -> ST s ()
forall s a. STRef s a -> a -> ST s ()
writeSTRef STRef s Int
wref1 (Int
weight1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
weight2)
            STRef s a -> a -> ST s ()
forall s a. STRef s a -> a -> ST s ()
writeSTRef STRef s a
dref1 (a -> ST s ()) -> ST s a -> ST s ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< STRef s a -> ST s a
forall s a. STRef s a -> ST s a
readSTRef STRef s a
dref2
          else do
            Point s a -> Link s a -> ST s ()
forall s a. Point s a -> Link s a -> ST s ()
writePoint Point s a
point1 (Point s a -> Link s a
forall s a. Point s a -> Link s a
Link Point s a
point2)
            STRef s Int -> Int -> ST s ()
forall s a. STRef s a -> a -> ST s ()
writeSTRef STRef s Int
wref2 (Int
weight1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
weight2)
      (Link s a, Link s a)
_ -> [Char] -> ST s ()
forall a. HasCallStack => [Char] -> a
error [Char]
"UnionFind.union: repr invariant broken"

-- | Test if two points are in the same equivalence class.
equivalent :: Point s a -> Point s a -> ST s Bool
equivalent :: forall s a. Point s a -> Point s a -> ST s Bool
equivalent Point s a
point1 Point s a
point2 = (Point s a -> Point s a -> Bool)
-> ST s (Point s a) -> ST s (Point s a) -> ST s Bool
forall (m :: * -> *) a1 a2 r.
Monad m =>
(a1 -> a2 -> r) -> m a1 -> m a2 -> m r
liftM2 Point s a -> Point s a -> Bool
forall a. Eq a => a -> a -> Bool
(==) (Point s a -> ST s (Point s a)
forall s a. Point s a -> ST s (Point s a)
repr Point s a
point1) (Point s a -> ST s (Point s a)
forall s a. Point s a -> ST s (Point s a)
repr Point s a
point2)