{-|
  Description : Convert basic blocks to GDSA form.
  Copyright   : (c) Paul Govereau and Jean-Baptiste Tristan 2010
  License     : All Rights Reserved
  Maintainer  : Paul Govereau <govereau@cs.harvard.edu>

  This module builds a CFG of GDSA rules from an SSA program, by converting
  each basic block. This module also generates the abstract state variables
  for side-effects.
-}
module MD.Convert ( convertToGDSA ) where
import Data.List
import Data.STRef
import Control.Monad.ST

import qualified MD.Syntax.LLVM as LL
import MD.Syntax.LLVM hiding (Label(..), MemoryInst(..), Call, Expr(..), Value(..), Phi)
import MD.Syntax.GDSA

-- | Convert a set of LLVM Blocks to GDSA GBlocks.

convertToGDSA :: [Block] -> [GBlock]
convertToGDSA bs =
    addPrePhi $ runST $
    do { ref1 <- newSTRef 0
       ; ref2 <- newSTRef 0
       ; mapM (gblock trans ref1 ref2) bs
       }
 where
   al = map (\b -> (blockName b, blockIndex b)) bs
   trans (LL.LI i) = i
   trans (LL.LS s) = case lookup s al of
                       Just n  -> n
                       Nothing -> error "invalid label name"

-- After we have a set of gblocks, this function adds the meta-data including
-- the memory variables assignments, memory phi data, and the block
-- predecessors.

addPrePhi :: [GBlock] -> [GBlock]
addPrePhi bs = map proc $ map addpre bs
 where
   addpre b = b { pre' = pre (label b) }
   pl    = [ (n,ns)  | b <- bs, let n = label b, let ns = prl n ]
   prl n = [ label b | b <- bs, n `elem` (suc' b) ]
   pre n = case lookup n pl of
             Just ps -> ps
             Nothing -> error "invalid block index"

   proc b = case pre' b of
              []  -> b
              [n] -> b { rules = sgl b n : rules b }
              l   -> b { phis  = phi b l : phis  b }
   mid n   = Ident ('$':show n)
   sgl b n = RR (mid $ mem_in b) (snd $ node n)
   phi b l = (mid $ mem_in b, ST, map node l)
   node n  = (n, Atom $ Var $ mid $ mem_out $ findB n)
   findB n = case find (\x -> n == label x) bs of
               Just b -> b
               Nothing -> error "no block with label"

-- This is the main conversion function. We run under the ST monad so we can
-- generate fresh variable names.

gblock :: (LL.Label -> Label) -> STRef s Int -> STRef s Int -> Block -> ST s GBlock
gblock trans ref fref b =
    do { x <- readSTRef ref; writeSTRef ref (x+1)
       ; let m_in = x+1
       ; let (px,ix) = span isPhi (blockMiddle b)
       ; rx <- mapM rr ix
       ; m_out <- readSTRef ref
       ; return $
         GBlock { label   = blockIndex b
                , phis    = map phi px
                , rules   = concat rx
                , pre'    = [] -- filled in later
                , suc'    = map trans $ blockTargets b
                , mem_in  = m_in
                , mem_out = m_out
                , control = cont (blockEnd b)
                }
       }
 where
   isPhi (Instruction _ (LL.Phi _ _)) = True
   isPhi _ = False
   phi ~(Instruction (Just x) (LL.Phi t l)) = (x,t,map tr l)
   tr (x,l) = (trans l, val x)

   it x = Ident ('$': show x)
   curm = do { old <- readSTRef ref ; return (it old) }
   newm = do { old <- readSTRef ref ; writeSTRef ref (old+1)
             ; return (it old, it (old+1))
             }
   fresh = do { n <- readSTRef fref; writeSTRef fref (n+1)
              ; return (Ident $ '#':show n)
              }
   rr (Instruction Nothing  i) = newm >>= \x -> return (mem x i)
   rr (Instruction (Just x) i) = rhs x i

   var i = Atom (Var i)

   mem :: (Ident,Ident) -> RHS -> [RR]
   mem (old,new) r = case r of
     LL.MemOp (LL.Store _ (TV t v) p) -> [RR new (Store t (val v) (val p) (var old))]
     LL.Call s ax                     -> [RR new (Proj Mem $ Call (val s) (map tval ax) (var old))]
     _ -> error ("Bad mem instruction:"++ show r)

   --rhs :: Ident -> RHS -> ST s [RR]
   rhs x r = case r of
     LL.MemOp (LL.Alloca t sz _al) -> do { fv <- fresh
                                         ; (old,new) <- newm
                                         ; return [ RR fv  (Alloc t sz (var old))
                                                  , RR new (Proj Mem (var fv))
                                                  , RR x   (Proj Val (var fv)) ]
                                         }
     LL.MemOp (LL.Load _ (TV t v)) -> do { old <- curm
                                         ; return [RR x (Load t (val v) (var old))]
                                         }
     LL.Expr e     -> return [ RR x (expr e) ]
     LL.Phi {}     -> error "phi in rhs node"
     LL.Call s ax  ->  do { fv <- fresh
                          ; (old,new) <- newm
                          ; let ax' = map tval ax
                          ; return [ RR fv  (Call (val s) ax' (var old))
                                   , RR new (Proj Mem (var fv))
                                   , RR x   (Proj Val (var fv)) ]
                          }
     _ -> fail ("Bad rhs instruction:"++ show r)

   tval (TV _ v)        = val v
   val (LL.Atom a)      = Atom a
   val (LL.ConstExpr e) = expr e

   expr :: LL.Expr -> Term
   expr (LL.GetElemPtr (TV ty p) ndxs) = GetElemPtr ty (val p) (map tval ndxs)
   expr (LL.BinOp o s ty v1 v2)        = BinOp o s ty (val v1) (val v2)
   expr (LL.Conv o (TV ty v) ty')      = Conv o ty (val v) ty'
   expr (LL.Select c x y)              = Select (tval c) (tval x) (tval y)

   cont :: ControlInst -> Control
   cont ci = case ci of
     Return (Just (TV _ v)) -> Ret $ val v
     Return Nothing         -> Ret $ Atom Undef
     Br l                   -> Seq $ trans l
     CBr v l1 l2            -> MBr (val v) (I 1) (trans l2) [(Atom $ Int 1, trans l1)]
     Switch (TV ty v) l ls  -> let f (TV _ x,l') = (val x, trans l')
                               in MBr (val v) ty (trans l) (map f ls)