{-|
  Description : Generate guarded phi-nodes.
  Copyright   : (c) Paul Govereau and Jean-Baptiste Tristan 2010
  License     : All Rights Reserved
  Maintainer  : Paul Govereau <govereau@cs.harvard.edu>

  This module computes the guards for phi-nodes that do not appear in loop
  headers.
-}
module MD.Phi
  ( computePhi )
where
import Data.List
import Data.Tree
import MD.Graph
import MD.Syntax.GDSA
import MD.Structural

-------------------------------------------------------------------------------
-- | Compute guarded phi nodes.

computePhi :: Structure GBlock -> Structure GBlock
computePhi s = head $ scan [] [s]

checkPhis :: Structure GBlock -> Structure GBlock
checkPhis s | ok = s
            | otherwise = error ("OOPS WE MISSED SOME PHI NODES:\n"++show s++ "\n" ++ info)
 where
   pvs b = map (\(v,_,_) -> v)  (phis b)
   dvs b = map (\(RR x _) -> x) (rules b)
   ok    = all chk (sflatten s)
   chk b = let pvars = pvs b in
           let dvars = dvs b in
           all (`elem` dvars) pvars
   info  = unlines $ map show $
           filter (\(n,_) -> n >= 0) $ map nfo (sflatten s)
   nfo b = let depvs = map (\(v,_,t) -> (v,concatMap (fvs . snd) t)) (phis b) in
           let pvars = pvs b in
           let dvars = dvs b in
           if all (`elem` dvars) pvars
           then (-1,[])
           else (name b,map fnd' depvs)
   fnd' (v,vs) = (v,map fnd vs)
   fnd v = case find (\b -> v `elem` dvs b) (sflatten s) of
             Nothing -> (v,-1)
             Just b  -> (v,name b)

type Info = [(Term,[Label])]

scan :: Info -> [Structure GBlock] -> [Structure GBlock]
scan _ []            = []
scan i (If cs x y:l) = let (s',info) = convertIf cs x y
                       in  phi i s' : scan info l
scan i (IfN c xs:l)  = let s' = IfN c xs
                       in  phi i s' : scan (ifnfo c xs) l
scan i (s:l)         = phi i s : scan [] l

phi :: Info -> Structure GBlock -> Structure GBlock
phi info struct = case struct of
  Empty           -> Empty
  Single x        -> Single (convert x)
  Sequence ss     -> Sequence (scan info ss)
  If (c:cs) x y   -> If (convert c : cs) x y        -- see convertIf
  IfN c xs        -> IfN (convert c) (map computePhi xs)
  WhileLoop c b   -> WhileLoop c (computePhi b)
 where
   convert b   = b { --phis  = []
                   --,
                     rules = map cvt (phis b) ++ rules b
                   }
   cvt (i,_,[(_,t)]) = RR i t
   cvt (i,ty,ts) =
       let m = [ (findCond l, v) | (l,v) <- ts ]
       in RR i $ foldr (\(c,v) v' -> mkPhi ty c v v')
                       (snd $ head m) (tail m)

   findCond l = case find ((l `elem`) . snd) info of
                  Nothing    -> error $ show ("no info",l,info)
                  Just (t,_) -> t

convertIf :: [GBlock] -> Structure GBlock -> Structure GBlock
          -> (Structure GBlock, Info)
convertIf bs t f =
    case bs of
      []  -> error "empty bat"
      [b] -> (If [b] t' f', einfo)
      _   -> (If bs  t' f', einfo)
 where
   t' = head $ scan info [t]
   f' = head $ scan info [f]
   -- find the arm pit
   tlbl   = name t
   flbl   = name f
   labels = concatMap targets bs
   isOr   = length bs == 1 || length (filter (==tlbl) labels) > 1
   -- external conditions
   einfo = [ (tcond, children t), (fcond, children f) ] ++ info
   tcond | isOr      = foldl1 lor  (map cnd bs)
         | otherwise = foldl1 land (map cnd bs)
   fcond = eq (I 1) tcond (Atom $ Int 0)

   -- compute internal conditions
   tgt  | isOr      = tlbl
        | otherwise = flbl
   to   | isOr      = eq
        | otherwise = ne
   away | isOr      = ne
        | otherwise = eq
   info = zip conds lbls
   lbls  = map children bs
   conds = map cnd bs
   cnd b = case control b of
             MBr t ty l [(t',l')]
                 | l  == tgt -> away ty t t'
                 | l' == tgt -> to   ty t t'
             _ -> error $ show ("bat cond",tlbl,flbl,control b)


ifnfo :: GBlock -> [Structure GBlock] -> Info
ifnfo c ss =
    case control c of
      MBr t1 ty dl ls -> dcnd t1 ty dl (map fst ls) :
                         map (cnd t1 ty) ls
      _ -> error "ifnfo: bad condition"
 where
   dcnd t ty dl ts = (foldl1 land $ map (ne ty t) ts, [dl])
   cnd t ty (t',l) = (eq ty t t', gets l)

   gets l = case find (\s -> name s == l) ss of
              Nothing -> error "ifnfo: no struct"
              Just s  -> l:children s


mkPhi :: Type -> Term -> Term -> Term -> Term
mkPhi ty c t1 t2 =
    case c of
      BinOp (Cmp MD.Syntax.GDSA.EQ) s (I 1) c' (Atom (Int 0)) -> mkPhi ty c' t2 t1
      BinOp (Cmp MD.Syntax.GDSA.EQ) s (I 1) c' (Atom (Int 1)) -> mkPhi ty c' t1 t2
      BinOp (Cmp MD.Syntax.GDSA.NE) s (I 1) c' (Atom (Int 0)) -> mkPhi ty c' t1 t2
      BinOp (Cmp MD.Syntax.GDSA.NE) s (I 1) c' (Atom (Int 1)) -> mkPhi ty c' t2 t1
      _ -> Phi ty c t1 t2

eq,ne :: Type -> Term -> Term -> Term
eq = BinOp (Cmp MD.Syntax.GDSA.EQ) []
ne = BinOp (Cmp MD.Syntax.GDSA.NE) []

lor,land :: Term -> Term -> Term
lor  = BinOp (Bop MD.Syntax.GDSA.Or)  [] ST
land = BinOp (Bop MD.Syntax.GDSA.And) [] ST

svars :: Structure GBlock -> [Ident]
svars = concatMap bvars . sflatten

bvars :: GBlock -> [Ident]
bvars b = map (\(RR x _) -> x) (rules b) ++
          map (\(x,_,_)  -> x) (phis  b)