{-# LANGUAGE PatternGuards #-}
{-|
  Description : GDSA abstract syntax.
  Copyright   : (c) Paul Govereau and Jean-Baptiste Tristan 2010
  License     : All Rights Reserved
  Maintainer  : Paul Govereau <govereau@cs.harvard.edu>

  Abstract syntax for GDSA programs and symbolic values.
-}
module MD.Syntax.GDSA
  ( module MD.Syntax.Enum
  , module MD.Syntax.GDSA
  , Ident(..), Atom(..), Type(..)
  --, fvs, fivs
  )
where
import Control.Monad
import Data.List
import Data.Maybe
import MD.Syntax.Enum
import MD.Syntax.LLVM (Ident(..), Atom(..), Type(..))

-------------------------------------------------------------------------------
-- | GDSA blocks

data GBlock
   = GBlock { label   :: Label
            , phis    :: [(Ident, Type, [(Label,Term)])]
            , rules   :: [RR]
            , pre'    :: [Label]
            , suc'    :: [Label]
            , mem_in  :: Int
            , mem_out :: Int
            , control :: Control
            }

type Label = Int
data RR    = RR Ident Term deriving Eq

data Control
   = Ret Term
   | Seq Label
   | MBr Term Type Label [(Term,Label)]
     deriving Show

-- GDSA/Symbolic Term language

data Term
   = Atom Atom
   | Proj PI Term

   | GetElemPtr Type Term [Term]
   | BinOp Opr String Type Term Term
   | Conv ConvOp Type Term Type
   | Select Term Term Term

   | Alloc Type Integer Term
   | Load  Type Term Term
   | Store Type Term Term Term
   | Call Term [Term] Term

   | Phi Type Term Term Term
   | Mu Type Ident Ident
   | Omega { variable  :: Ident
           , vartype   :: Type
           , condition :: Term
           , sequences :: [Sequence]
           }
     deriving (Eq,Show)

data PI = Val | Mem deriving (Eq,Show)
data Sequence = SeqRules String Term Term deriving Eq

-------------------------------------------------------------------------------

-- | free variables

fvs :: Term -> [Ident]
fvs (Atom (Var i))          = [i]
fvs (Atom _)                = []
fvs (Proj _ x)              = fvs x
fvs (GetElemPtr _ty t l)    = fvs t ++ concatMap fvs l
fvs (BinOp _o _s _ty v1 v2) = fvs v1 ++ fvs v2
fvs (Conv _o _vty v _ty)    = fvs v
fvs (Select a b c)          = fvs a ++ fvs b ++ fvs c
fvs (Alloc _ty _n t)        = fvs t
fvs (Load _ty t p)          = fvs t ++ fvs p
fvs (Store _ty x y z)       = fvs x ++ fvs y ++ fvs z
fvs (Phi _ty a b c)         = fvs a ++ fvs b ++ fvs c
fvs (Mu {})                 = []
fvs (Omega {})              = []
fvs (Call _s ax i)          = fvs i ++ concatMap fvs ax

-- | free inductive variables

fivs :: Term -> [Ident]
fivs t = mapMaybe isInd (fvs t)
 where
   isInd :: Ident -> Maybe Ident
   isInd (Ident _) = Nothing
   isInd (Z i)     = Just i
   isInd (N i)     = Just i
   isInd (P i)     = Just i

-------------------------------------------------------------------------------

checkAndPrint :: Term -> Term -> IO ()
checkAndPrint t1 t2 = case chkEq t1 t2 of
  Nothing    -> putStrLn "OK"
  Just (a@(Omega {}),b@(Omega {})) -> q a b
  Just (a,b) ->
      do { putStrLn "ALARM"
         ; putStrLn (show a)
         ; putStrLn (show b)
         }
 where
   q (Omega {sequences = l1}) (Omega {sequences = l2}) =
       do { let l1' = sortBy f l1
          ; let l2' = sortBy f l2
          ; print "###"
          ; print (map n l1')
          ; print (map n l2')
          ; mapM g $ zip l1' l2'
          ; print "###"
          }
   q a b = pr a >> pr b

   n (SeqRules n _ _) = n
   f (SeqRules x _ _) (SeqRules y _ _) = compare x y
   g (SeqRules n a b, SeqRules _ a' b') =
       do { putStrLn n
          ; unless (a == a') $ q a a'
          ; unless (b == b') $ q b b'
          }
   pr x = putStrLn (take 50 $ show x)

chkEq :: Term -> Term -> Maybe (Term,Term)
chkEq t1 t2 = case (t1,t2) of
 (Atom a, Atom b) | a == b -> Nothing
                  | otherwise -> Just (Atom a, Atom b)
 (Proj i a, Proj i' b) | i == i' -> chkEq a b
 (GetElemPtr ty a b, GetElemPtr ty' c d) | ty == ty' -> if a == c
                                                        then if length b == length d then chkL b d
                                                             else Just (t1,t2)
                                                        else chkEq a c
                                         | otherwise -> Just (t1,t2)
 (BinOp o s t a b, BinOp o' s' t' x y)   | o == o', s == s', t == t' -> if a == x then chkEq b y else chkEq a x
                                         | otherwise -> Just (t1,t2)
 (Conv o q t r, Conv o' s t' u) | o == o', q == s, r == u -> chkEq t t'
                                | otherwise -> Just (t1,t2)
 (Select a b c, Select d e f) ->
     if a == d
     then if b == e
          then chkEq c f
          else chkEq b e
     else chkEq a d
 (Alloc t n a, Alloc t' n' b) | t == t', n == n' -> chkEq a b
                              | otherwise -> Just (t1,t2)
 (Load ty a b, Load ty' c d) | ty == ty' -> if a == c then chkEq b d else chkEq a b
                             | otherwise -> Just (t1,t2)
 (Store ty a b c, Store ty' d e f) | ty /= ty' -> Just (t1,t2)
                                   | otherwise -> if a == d
                                                  then if b == e
                                                       then chkEq c f
                                                       else chkEq b e
                                                  else chkEq a d
 (Phi ty a b c, Phi ty' d e f) | ty /= ty' -> Just (t1,t2)
                               | otherwise -> if a == d
                                              then if b == e
                                                   then chkEq c f
                                                   else chkEq b e
                                              else chkEq a d
 (Mu n x y, Mu n' x' y') | n == n', x == x', y == y' -> Nothing
                         | otherwise -> Just (t1,t2)
 (Omega x t c l, Omega x' t' c' l') --   | x == x', t == t', c == c' -> chkS l l'
                                    | x == x', t == t, l == l' -> chkEq c c'
                                    | otherwise                -> Just(t1,t2)
 (Call t l m, Call t' l' m') ->
     if t == t'
     then if l == l'
          then chkEq m m'
          else chkL l l'
     else chkEq t t'
 _ -> Just (t1,t2)

chkS :: [Sequence] -> [Sequence] -> Maybe (Term,Term)
chkS l1 l2 = chkL (map fst' l1') (map snd' l2')
 where
   l1' = sortBy f l1
   l2' = sortBy f l2
   f (SeqRules x _ _) (SeqRules y _ _) = compare x y
   fst' (SeqRules _ t _) = t
   snd' (SeqRules _ _ t) = t

chkL :: [Term] -> [Term] -> Maybe (Term,Term)
chkL [] [] = Nothing
chkL (t1:l1) (t2:l2) = case chkEq t1 t2 of
                         Nothing -> chkL l1 l2
                         Just p  -> Just p
chkL _ _ = error "chkL"

-------------------------------------------------------------------------------

instance Show GBlock where
    show (GBlock l p rr pr sc i o rv) =
        show l ++ ":"++ show pr++show sc ++
        " ("++ show i ++ " -> "++ show o ++")\n"++
        unlines (map shphi p) ++ show rr ++ show rv

    showList bs _ = concat $ intersperse "\n\n" $ map show bs

shphi :: (Ident, Type, [(Label,Term)]) -> String
shphi (i,t,l) = " PHI "++ show i ++"::"++ show t ++" <- "++ show l

instance Show RR where
    show (RR i t) = show i ++ " -> " ++ show t
    showList l _  = unlines (map ((' ':) . show) l)

instance Show Sequence where
    show (SeqRules s z n) = "{"++s++":"++ show z ++", "++show n ++"}"