module Vectorise.Utils.PADict (
  paDictArgType,
  paDictOfType,
  paMethod,
  prDictOfReprType,
  prDictOfPReprInstTyCon
) where

import Vectorise.Monad
import Vectorise.Builtins
import Vectorise.Utils.Base

import CoreSyn
import CoreUtils
import Coercion
import Type
import TypeRep
import TyCon
import Var
import Outputable
import FastString
import Control.Monad


-- |Construct the PA argument type for the tyvar. For the tyvar (v :: *) it's
-- just PA v. For (v :: (* -> *) -> *) it's
--
-- > forall (a :: * -> *). (forall (b :: *). PA b -> PA (a b)) -> PA (v a)
--
paDictArgType :: TyVar -> VM (Maybe Type)
paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
  where
    go ty (FunTy k1 k2)
      = do
          tv   <- newTyVar (fsLit "a") k1
          mty1 <- go (TyVarTy tv) k1
          case mty1 of
            Just ty1 -> do
                          mty2 <- go (AppTy ty (TyVarTy tv)) k2
                          return $ fmap (ForAllTy tv . FunTy ty1) mty2
            Nothing  -> go ty k2

    go ty k
      | isLiftedTypeKind k
      = do
          pa_cls <- builtin paClass
          return $ Just $ mkClassPred pa_cls [ty]

    go _ _ = return Nothing


-- |Get the PA dictionary for some type
--
paDictOfType :: Type -> VM CoreExpr
paDictOfType ty 
  = paDictOfTyApp ty_fn ty_args
  where
    (ty_fn, ty_args) = splitAppTys ty

    paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
    paDictOfTyApp ty_fn ty_args
        | Just ty_fn' <- coreView ty_fn 
        = paDictOfTyApp ty_fn' ty_args

    -- for type variables, look up the dfun and apply to the PA dictionaries
    -- of the type arguments
    paDictOfTyApp (TyVarTy tv) ty_args
     = do dfun <- maybeCantVectoriseM "No PA dictionary for type variable"
                                      (ppr tv <+> text "in" <+> ppr ty)
                $ lookupTyVarPA tv
          dicts <- mapM paDictOfType ty_args
          return $ dfun `mkTyApps` ty_args `mkApps` dicts

    -- for tycons, we also need to apply the dfun to the PR dictionary of
    -- the representation type if the tycon is polymorphic
    paDictOfTyApp (TyConApp tc []) ty_args
     = do
         dfun <- maybeCantVectoriseM noPADictErr (ppr tc <+> text "in" <+> ppr ty)
                $ lookupTyConPA tc
         dicts <- mapM paDictOfType ty_args
         return $ Var dfun `mkTyApps` ty_args `mkApps` dicts
     where
       noPADictErr = "No PA dictionary for type constructor (did you import 'Data.Array.Parallel'?)"

    paDictOfTyApp _ _ = failure

    failure = cantVectorise "Can't construct PA dictionary for type" (ppr ty)

-- |Produce code that refers to a method of the 'PA' class.
--
paMethod :: (Builtins -> Var) -> (TyCon -> Builtins -> Var) -> Type -> VM CoreExpr
paMethod _ query ty
  | Just tycon <- splitPrimTyCon ty             -- Is 'ty' from 'GHC.Prim' (e.g., 'Int#')?
  = liftM Var $ builtin (query tycon)
paMethod method _ ty
  = do
      fn   <- builtin method
      dict <- paDictOfType ty
      return $ mkApps (Var fn) [Type ty, dict]

-- | Given a type @ty@, its PRepr synonym tycon and its type arguments,
-- return the PR @PRepr ty@. Suppose we have:
--
-- > type instance PRepr (T a1 ... an) = t
--
-- which is internally translated into
--
-- > type :R:PRepr a1 ... an = t
--
-- and the corresponding coercion. Then,
--
-- > prDictOfPReprInstTyCon (T a1 ... an) :R:PRepr u1 ... un = PR (T u1 ... un)
--
-- Note that @ty@ is only used for error messages
--
prDictOfPReprInstTyCon :: Type -> TyCon -> [Type] -> VM CoreExpr
prDictOfPReprInstTyCon ty prepr_tc prepr_args
  | Just rhs <- coreView (mkTyConApp prepr_tc prepr_args)
  = do
      dict <- prDictOfReprType' rhs
      pr_co <- mkBuiltinCo prTyCon
      let Just arg_co = tyConFamilyCoercion_maybe prepr_tc
      let co = mkAppCo pr_co
             $ mkSymCo
             $ mkAxInstCo arg_co prepr_args
      return $ mkCast dict co

  | otherwise = cantVectorise "Invalid PRepr type instance" (ppr ty)

-- |Get the PR dictionary for a type. The argument must be a representation
-- type.
--
prDictOfReprType :: Type -> VM CoreExpr
prDictOfReprType ty
  | Just (tycon, tyargs) <- splitTyConApp_maybe ty
    = do
        prepr <- builtin preprTyCon
        if tycon == prepr
          then do
                 let [ty'] = tyargs
                 pa <- paDictOfType ty'
                 sel <- builtin paPRSel
                 return $ Var sel `App` Type ty' `App` pa
          else do 
                 -- a representation tycon must have a PR instance
                 dfun <- maybeV (text "look up PR dictionary for" <+> ppr tycon) $ 
                           lookupTyConPR tycon
                 prDFunApply dfun tyargs

  | otherwise
    = do
        -- it is a tyvar or an application of a tyvar
        -- determine the PR dictionary from its PA dictionary
        --
        -- NOTE: This assumes that PRepr t ~ t is for all representation types
        -- t
        --
        -- FIXME: This doesn't work for kinds other than * at the moment. We'd
        -- have to simply abstract the term over the missing type arguments.
        pa    <- paDictOfType ty
        prsel <- builtin paPRSel
        return $ Var prsel `mkApps` [Type ty, pa]

prDictOfReprType' :: Type -> VM CoreExpr
prDictOfReprType' ty = prDictOfReprType ty `orElseV`
                       cantVectorise "No PR dictionary for representation type"
                                     (ppr ty)

-- | Apply a tycon's PR dfun to dictionary arguments (PR or PA) corresponding
-- to the argument types.
prDFunApply :: Var -> [Type] -> VM CoreExpr
prDFunApply dfun tys
  | Just [] <- ctxs    -- PR (a :-> b) doesn't have a context
  = return $ Var dfun `mkTyApps` tys

  | Just tycons <- ctxs
  , length tycons == length tys
  = do
      pa <- builtin paTyCon
      pr <- builtin prTyCon 
      args <- zipWithM (dictionary pa pr) tys tycons
      return $ Var dfun `mkTyApps` tys `mkApps` args

  | otherwise = invalid
  where
    -- the dfun's contexts - if its type is (PA a, PR b) => PR (C a b) then
    -- ctxs is Just [PA, PR]
    ctxs = fmap (map fst)
         $ sequence
         $ map splitTyConApp_maybe
         $ fst
         $ splitFunTys
         $ snd
         $ splitForAllTys
         $ varType dfun

    dictionary pa pr ty tycon
      | tycon == pa = paDictOfType ty
      | tycon == pr = prDictOfReprType ty
      | otherwise   = invalid

    invalid = cantVectorise "Invalid PR dfun type" (ppr (varType dfun) <+> ppr tys)