module Vectorise ( vectorise )
where

import Vectorise.Type.Env
import Vectorise.Type.Type
import Vectorise.Convert
import Vectorise.Utils.Hoisting
import Vectorise.Exp
import Vectorise.Vect
import Vectorise.Env
import Vectorise.Monad

import HscTypes hiding      ( MonadThings(..) )
import CoreUnfold           ( mkInlineUnfolding )
import CoreFVs
import PprCore
import CoreSyn
import CoreMonad            ( CoreM, getHscEnv )
import Type
import Id
import OccName
import DynFlags
import BasicTypes           ( isStrongLoopBreaker )
import Outputable
import Util                 ( zipLazy )
import MonadUtils

import Control.Monad


-- | Vectorise a single module.
--
vectorise :: ModGuts -> CoreM ModGuts
vectorise guts
 = do { hsc_env <- getHscEnv
      ; liftIO $ vectoriseIO hsc_env guts
      }

-- | Vectorise a single monad, given the dynamic compiler flags and HscEnv.
--
vectoriseIO :: HscEnv -> ModGuts -> IO ModGuts
vectoriseIO hsc_env guts
 = do {   -- Get information about currently loaded external packages.
      ; eps <- hscEPS hsc_env

          -- Combine vectorisation info from the current module, and external ones.
      ; let info = hptVectInfo hsc_env `plusVectInfo` eps_vect_info eps

          -- Run the main VM computation.
      ; Just (info', guts') <- initV hsc_env guts info (vectModule guts)
      ; return (guts' { mg_vect_info = info' })
      }

-- | Vectorise a single module, in the VM monad.
--
vectModule :: ModGuts -> VM ModGuts
vectModule guts@(ModGuts { mg_types     = types
                         , mg_binds     = binds
                         , mg_fam_insts = fam_insts
                         })
 = do { dumpOptVt Opt_D_dump_vt_trace "Before vectorisation" $ 
          pprCoreBindings binds
 
          -- Vectorise the type environment.
          -- This may add new TyCons and DataCons.
      ; (types', new_fam_insts, tc_binds) <- vectTypeEnv types

      ; (_, fam_inst_env) <- readGEnv global_fam_inst_env

      -- dicts   <- mapM buildPADict pa_insts
      -- workers <- mapM vectDataConWorkers pa_insts

          -- Vectorise all the top level bindings.
      ; binds'  <- mapM vectTopBind binds

      ; return $ guts { mg_types        = types'
                      , mg_binds        = Rec tc_binds : binds'
                      , mg_fam_inst_env = fam_inst_env
                      , mg_fam_insts    = fam_insts ++ new_fam_insts
                      }
      }

-- |Try to vectorise a top-level binding.  If it doesn't vectorise then return it unharmed.
--
-- For example, for the binding 
--
-- @  
--    foo :: Int -> Int
--    foo = \x -> x + x
-- @
--
-- we get
-- @
--    foo  :: Int -> Int
--    foo  = \x -> vfoo $: x                  
--
--    v_foo :: Closure void vfoo lfoo
--    v_foo = closure vfoo lfoo void        
--
--    vfoo :: Void -> Int -> Int
--    vfoo = ...
--
--    lfoo :: PData Void -> PData Int -> PData Int
--    lfoo = ...
-- @ 
--
-- @vfoo@ is the "vectorised", or scalar, version that does the same as the original
-- function foo, but takes an explicit environment.
--
-- @lfoo@ is the "lifted" version that works on arrays.
--
-- @v_foo@ combines both of these into a `Closure` that also contains the
-- environment.
--
-- The original binding @foo@ is rewritten to call the vectorised version
-- present in the closure.
--
-- Vectorisation may be surpressed by annotating a binding with a 'NOVECTORISE' pragma.  If this
-- pragma is used in a group of mutually recursive bindings, either all or no binding must have
-- the pragma.  If only some bindings are annotated, a fatal error is being raised.
-- FIXME: Once we support partial vectorisation, we may be able to vectorise parts of a group, or
--   we may emit a warning and refrain from vectorising the entire group.
--
vectTopBind :: CoreBind -> VM CoreBind
vectTopBind b@(NonRec var expr)
  = unlessNoVectDecl $
      do {   -- Vectorise the right-hand side, create an appropriate top-level binding and add it
             -- to the vectorisation map.
         ; (inline, isScalar, expr') <- vectTopRhs [] var expr
         ; var' <- vectTopBinder var inline expr'
         ; when isScalar $ 
             addGlobalScalar var
 
             -- We replace the original top-level binding by a value projected from the vectorised
             -- closure and add any newly created hoisted top-level bindings.
         ; cexpr <- tryConvert var var' expr
         ; hs <- takeHoisted
         ; return . Rec $ (var, cexpr) : (var', expr') : hs
         }
     `orElseV`
       return b
  where
    unlessNoVectDecl vectorise
      = do { hasNoVectDecl <- noVectDecl var
           ; when hasNoVectDecl $
               traceVt "NOVECTORISE" $ ppr var
           ; if hasNoVectDecl then return b else vectorise
           }
vectTopBind b@(Rec bs)
  = unlessSomeNoVectDecl $
      do { (vars', _, exprs', hs) <- fixV $ 
             \ ~(_, inlines, rhss, _) ->
               do {   -- Vectorise the right-hand sides, create an appropriate top-level bindings
                      -- and add them to the vectorisation map.
                  ; vars' <- sequence [vectTopBinder var inline rhs
                                      | (var, ~(inline, rhs)) <- zipLazy vars (zip inlines rhss)]
                  ; (inlines, areScalars, exprs') <- mapAndUnzip3M (uncurry $ vectTopRhs vars) bs
                  ; hs <- takeHoisted
                  ; if and areScalars
                    then      -- (1) Entire recursive group is scalar
                              --      => add all variables to the global set of scalars
                         do { mapM_ addGlobalScalar vars
                            ; return (vars', inlines, exprs', hs)
                            }
                    else      -- (2) At least one binding is not scalar
                              --     => vectorise again with empty set of local scalars
                         do { (inlines, _, exprs') <- mapAndUnzip3M (uncurry $ vectTopRhs []) bs
                            ; hs <- takeHoisted
                            ; return (vars', inlines, exprs', hs)
                            }
                  }
                       
             -- Replace the original top-level bindings by a values projected from the vectorised
             -- closures and add any newly created hoisted top-level bindings to the group.
         ; cexprs <- sequence $ zipWith3 tryConvert vars vars' exprs
         ; return . Rec $ zip vars cexprs ++ zip vars' exprs' ++ hs
         }
     `orElseV`
       return b    
  where
    (vars, exprs) = unzip bs

    unlessSomeNoVectDecl vectorise
      = do { hasNoVectDecls <- mapM noVectDecl vars
           ; when (and hasNoVectDecls) $
               traceVt "NOVECTORISE" $ ppr vars
           ; if and hasNoVectDecls 
             then return b                              -- all bindings have 'NOVECTORISE'
             else if or hasNoVectDecls 
             then cantVectorise noVectoriseErr (ppr b)  -- some (but not all) have 'NOVECTORISE'
             else vectorise                             -- no binding has a 'NOVECTORISE' decl
           }
    noVectoriseErr = "NOVECTORISE must be used on all or no bindings of a recursive group"
     
-- | Make the vectorised version of this top level binder, and add the mapping
--   between it and the original to the state. For some binder @foo@ the vectorised
--   version is @$v_foo@
--
--   NOTE: 'vectTopBinder' *MUST* be lazy in inline and expr because of how it is
--   used inside of 'fixV' in 'vectTopBind'.
--
vectTopBinder :: Var      -- ^ Name of the binding.
              -> Inline   -- ^ Whether it should be inlined, used to annotate it.
              -> CoreExpr -- ^ RHS of binding, used to set the 'Unfolding' of the returned 'Var'.
              -> VM Var   -- ^ Name of the vectorised binding.
vectTopBinder var inline expr
 = do {   -- Vectorise the type attached to the var.
      ; vty  <- vectType (idType var)
      
          -- If there is a vectorisation declartion for this binding, make sure that its type
          --  matches
      ; vectDecl <- lookupVectDecl var
      ; case vectDecl of
          Nothing                 -> return ()
          Just (vdty, _) 
            | eqType vty vdty -> return ()
            | otherwise           -> 
              cantVectorise ("Type mismatch in vectorisation pragma for " ++ show var) $
                (text "Expected type" <+> ppr vty)
                $$
                (text "Inferred type" <+> ppr vdty)

          -- Make the vectorised version of binding's name, and set the unfolding used for inlining
      ; var' <- liftM (`setIdUnfoldingLazily` unfolding) 
                $  cloneId mkVectOcc var vty

          -- Add the mapping between the plain and vectorised name to the state.
      ; defGlobalVar var var'

      ; return var'
    }
  where
    unfolding = case inline of
                  Inline arity -> mkInlineUnfolding (Just arity) expr
                  DontInline   -> noUnfolding

-- | Vectorise the RHS of a top-level binding, in an empty local environment.
--
-- We need to distinguish three cases:
--
-- (1) We have a (non-scalar) vectorisation declaration for the variable (which explicitly provides
--     vectorised code implemented by the user)
--     => no automatic vectorisation & instead use the user-supplied code
-- 
-- (2) We have a scalar vectorisation declaration for the variable
--     => generate vectorised code that uses a scalar 'map'/'zipWith' to lift the computation
-- 
-- (3) There is no vectorisation declaration for the variable
--     => perform automatic vectorisation of the RHS
--
vectTopRhs :: [Var]           -- ^ Names of all functions in the rec block
           -> Var             -- ^ Name of the binding.
           -> CoreExpr        -- ^ Body of the binding.
           -> VM ( Inline     -- (1) inline specification for the binding
                 , Bool       -- (2) whether the right-hand side is a scalar computation
                 , CoreExpr)  -- (3) the vectorised right-hand side
vectTopRhs recFs var expr
  = closedV
  $ do { traceVt ("vectTopRhs of " ++ show var) $ ppr expr
  
       ; globalScalar <- isGlobalScalar var
       ; vectDecl     <- lookupVectDecl var
       ; rhs globalScalar vectDecl
       }
  where
    rhs _globalScalar (Just (_, expr'))               -- Case (1)
      = return (inlineMe, False, expr')
    rhs True          Nothing                         -- Case (2)
      = do { expr' <- vectScalarFun True recFs expr
           ; return (inlineMe, True, vectorised expr')
           }
    rhs False         Nothing                         -- Case (3)
      = do { let fvs = freeVars expr
           ; (inline, isScalar, vexpr) <- inBind var $
                                          vectPolyExpr (isStrongLoopBreaker $ idOccInfo var) recFs fvs
           ; return (inline, isScalar, vectorised vexpr)
           }

-- | Project out the vectorised version of a binding from some closure,
--   or return the original body if that doesn't work or the binding is scalar. 
--
tryConvert :: Var       -- ^ Name of the original binding (eg @foo@)
           -> Var       -- ^ Name of vectorised version of binding (eg @$vfoo@)
           -> CoreExpr  -- ^ The original body of the binding.
           -> VM CoreExpr
tryConvert var vect_var rhs
  = do { globalScalar <- isGlobalScalar var
       ; if globalScalar
         then
           return rhs
         else
           fromVect (idType var) (Var vect_var) `orElseV` return rhs
       }