%
% (c) The University of Glasgow 2006
% (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
%
Typecheck arrow notation

\begin{code}
module TcArrows ( tcProc ) where

import {-# SOURCE #-}	TcExpr( tcMonoExpr, tcInferRho, tcSyntaxOp )

import HsSyn
import TcMatches
import TcType
import TcMType
import TcBinds
import TcPat
import TcUnify
import TcRnMonad
import Coercion
import Inst
import Name
import TysWiredIn
import VarSet 
import TysPrim

import SrcLoc
import Outputable
import FastString
import Util

import Control.Monad
\end{code}

%************************************************************************
%*									*
		Proc	
%*									*
%************************************************************************

\begin{code}
tcProc :: InPat Name -> LHsCmdTop Name		-- proc pat -> expr
       -> TcRhoType				-- Expected type of whole proc expression
       -> TcM (OutPat TcId, LHsCmdTop TcId, CoercionI)

tcProc pat cmd exp_ty
  = newArrowScope $
    do	{ (coi, (exp_ty1, res_ty)) <- matchExpectedAppTy exp_ty 
	; (coi1, (arr_ty, arg_ty)) <- matchExpectedAppTy exp_ty1
	; let cmd_env = CmdEnv { cmd_arr = arr_ty }
	; (pat', cmd') <- tcPat ProcExpr pat arg_ty $
			  tcCmdTop cmd_env cmd [] res_ty
        ; let res_coi = mkTransCoI coi (mkAppTyCoI coi1 (IdCo res_ty))
	; return (pat', cmd', res_coi) }
\end{code}


%************************************************************************
%*									*
		Commands
%*									*
%************************************************************************

\begin{code}
type CmdStack = [TcTauType]
data CmdEnv
  = CmdEnv {
	cmd_arr		:: TcType -- arrow type constructor, of kind *->*->*
    }

mkCmdArrTy :: CmdEnv -> TcTauType -> TcTauType -> TcTauType
mkCmdArrTy env t1 t2 = mkAppTys (cmd_arr env) [t1, t2]

---------------------------------------
tcCmdTop :: CmdEnv 
         -> LHsCmdTop Name
         -> CmdStack
	 -> TcTauType	-- Expected result type; always a monotype
                             -- We know exactly how many cmd args are expected,
			     -- albeit perhaps not their types; so we can pass 
			     -- in a CmdStack
        -> TcM (LHsCmdTop TcId)

tcCmdTop env (L loc (HsCmdTop cmd _ _ names)) cmd_stk res_ty
  = setSrcSpan loc $
    do	{ cmd'   <- tcGuardedCmd env cmd cmd_stk res_ty
	; names' <- mapM (tcSyntaxName ProcOrigin (cmd_arr env)) names
	; return (L loc $ HsCmdTop cmd' cmd_stk res_ty names') }


----------------------------------------
tcGuardedCmd :: CmdEnv -> LHsExpr Name -> CmdStack
	     -> TcTauType -> TcM (LHsExpr TcId)
-- A wrapper that deals with the refinement (if any)
tcGuardedCmd env expr stk res_ty
  = do	{ body <- tcCmd env expr (stk, res_ty)
	; return body 
        }

tcCmd :: CmdEnv -> LHsExpr Name -> (CmdStack, TcTauType) -> TcM (LHsExpr TcId)
	-- The main recursive function
tcCmd env (L loc expr) res_ty
  = setSrcSpan loc $ do
	{ expr' <- tc_cmd env expr res_ty
	; return (L loc expr') }

tc_cmd :: CmdEnv -> HsExpr Name -> (CmdStack, TcTauType) -> TcM (HsExpr TcId)
tc_cmd env (HsPar cmd) res_ty
  = do	{ cmd' <- tcCmd env cmd res_ty
	; return (HsPar cmd') }

tc_cmd env (HsLet binds (L body_loc body)) res_ty
  = do	{ (binds', body') <- tcLocalBinds binds		$
			     setSrcSpan body_loc 	$
			     tc_cmd env body res_ty
	; return (HsLet binds' (L body_loc body')) }

tc_cmd env in_cmd@(HsCase scrut matches) (stk, res_ty)
  = addErrCtxt (cmdCtxt in_cmd) $ do
      (scrut', scrut_ty) <- tcInferRho scrut 
      matches' <- tcMatchesCase match_ctxt scrut_ty matches res_ty
      return (HsCase scrut' matches')
  where
    match_ctxt = MC { mc_what = CaseAlt,
                      mc_body = mc_body }
    mc_body body res_ty' = tcGuardedCmd env body stk res_ty'

tc_cmd env (HsIf mb_fun pred b1 b2) (stack_ty,res_ty)
  = do 	{ pred_ty <- newFlexiTyVarTy openTypeKind
	; b_ty <- newFlexiTyVarTy openTypeKind
        ; let if_ty = mkFunTys [pred_ty, b_ty, b_ty] res_ty
	; mb_fun' <- case mb_fun of 
              Nothing  -> return Nothing
              Just fun -> liftM Just (tcSyntaxOp IfOrigin fun if_ty)
  	; pred' <- tcMonoExpr pred pred_ty
	; b1'   <- tcCmd env b1 (stack_ty,b_ty)
	; b2'   <- tcCmd env b2 (stack_ty,b_ty)
	; return (HsIf mb_fun' pred' b1' b2')
    }

-------------------------------------------
-- 		Arrow application
--     	    (f -< a)   or   (f -<< a)

tc_cmd env cmd@(HsArrApp fun arg _ ho_app lr) (cmd_stk, res_ty)
  = addErrCtxt (cmdCtxt cmd)	$
    do  { arg_ty <- newFlexiTyVarTy openTypeKind
	; let fun_ty = mkCmdArrTy env (foldl mkPairTy arg_ty cmd_stk) res_ty

	; fun' <- select_arrow_scope (tcMonoExpr fun fun_ty)

	; arg' <- tcMonoExpr arg arg_ty

	; return (HsArrApp fun' arg' fun_ty ho_app lr) }
  where
	-- Before type-checking f, use the environment of the enclosing
	-- proc for the (-<) case.  
	-- Local bindings, inside the enclosing proc, are not in scope 
	-- inside f.  In the higher-order case (-<<), they are.
    select_arrow_scope tc = case ho_app of
	HsHigherOrderApp -> tc
	HsFirstOrderApp  -> escapeArrowScope tc

-------------------------------------------
-- 		Command application

tc_cmd env cmd@(HsApp fun arg) (cmd_stk, res_ty)
  = addErrCtxt (cmdCtxt cmd)	$
    do  { arg_ty <- newFlexiTyVarTy openTypeKind

	; fun' <- tcCmd env fun (arg_ty:cmd_stk, res_ty)

	; arg' <- tcMonoExpr arg arg_ty

	; return (HsApp fun' arg') }

-------------------------------------------
-- 		Lambda

tc_cmd env cmd@(HsLam (MatchGroup [L mtch_loc (match@(Match pats _maybe_rhs_sig grhss))] _))
       (cmd_stk, res_ty)
  = addErrCtxt (pprMatchInCtxt match_ctxt match)	$

    do	{ 	-- Check the cmd stack is big enough
	; checkTc (lengthAtLeast cmd_stk n_pats)
		  (kappaUnderflow cmd)

		-- Check the patterns, and the GRHSs inside
	; (pats', grhss') <- setSrcSpan mtch_loc		$
			     tcPats LambdaExpr pats cmd_stk	$
			     tc_grhss grhss res_ty

	; let match' = L mtch_loc (Match pats' Nothing grhss')
	; return (HsLam (MatchGroup [match'] res_ty))
	}

  where
    n_pats     = length pats
    stk'       = drop n_pats cmd_stk
    match_ctxt = (LambdaExpr :: HsMatchContext Name)	-- Maybe KappaExpr?
    pg_ctxt    = PatGuard match_ctxt

    tc_grhss (GRHSs grhss binds) res_ty
	= do { (binds', grhss') <- tcLocalBinds binds $
				   mapM (wrapLocM (tc_grhs res_ty)) grhss
	     ; return (GRHSs grhss' binds') }

    tc_grhs res_ty (GRHS guards body)
	= do { (guards', rhs') <- tcStmts pg_ctxt tcGuardStmt guards res_ty $
				  tcGuardedCmd env body stk'
	     ; return (GRHS guards' rhs') }

-------------------------------------------
-- 		Do notation

tc_cmd env cmd@(HsDo do_or_lc stmts body _ty) (cmd_stk, res_ty)
  = do 	{ checkTc (null cmd_stk) (nonEmptyCmdStkErr cmd)
	; (stmts', body') <- tcStmts do_or_lc (tcMDoStmt tc_rhs) stmts res_ty $
			     tcGuardedCmd env body []
	; return (HsDo do_or_lc stmts' body' res_ty) }
  where
    tc_rhs rhs = do { ty <- newFlexiTyVarTy liftedTypeKind
		    ; rhs' <- tcCmd env rhs ([], ty)
		    ; return (rhs', ty) }


-----------------------------------------------------------------
--	Arrow ``forms''	      (| e c1 .. cn |)
--
--	G      |-b  c : [s1 .. sm] s
--	pop(G) |-   e : forall w. b ((w,s1) .. sm) s
--			        -> a ((w,t1) .. tn) t
--	e \not\in (s, s1..sm, t, t1..tn)
--	----------------------------------------------
--	G |-a  (| e c |)  :  [t1 .. tn] t

tc_cmd env cmd@(HsArrForm expr fixity cmd_args) (cmd_stk, res_ty)	
  = addErrCtxt (cmdCtxt cmd)	$
    do	{ cmds_w_tys <- zipWithM new_cmd_ty cmd_args [1..]
	; [w_tv]     <- tcInstSkolTyVars ArrowSkol [alphaTyVar]
	; let w_ty = mkTyVarTy w_tv 	-- Just a convenient starting point

		--  a ((w,t1) .. tn) t
	; let e_res_ty = mkCmdArrTy env (foldl mkPairTy w_ty cmd_stk) res_ty

	 	--   b ((w,s1) .. sm) s
		--   -> a ((w,t1) .. tn) t
	; let e_ty = mkFunTys [mkAppTys b [tup,s] | (_,_,b,tup,s) <- cmds_w_tys] 
			      e_res_ty

		-- Check expr
	; (inst_binds, expr') <- checkConstraints ArrowSkol [w_tv] [] $
                                 escapeArrowScope (tcMonoExpr expr e_ty)

		-- OK, now we are in a position to unscramble 
		-- the s1..sm and check each cmd
	; cmds' <- mapM (tc_cmd w_tv) cmds_w_tys

        ; let wrap = WpTyLam w_tv <.> mkWpLet inst_binds
	; return (HsArrForm (mkLHsWrap wrap expr') fixity cmds') }
  where
 	-- Make the types	
	--	b, ((e,s1) .. sm), s
    new_cmd_ty :: LHsCmdTop Name -> Int
	       -> TcM (LHsCmdTop Name, Int, TcType, TcType, TcType)
    new_cmd_ty cmd i
	  = do	{ b_ty   <- newFlexiTyVarTy arrowTyConKind
		; tup_ty <- newFlexiTyVarTy liftedTypeKind
			-- We actually make a type variable for the tuple
			-- because we don't know how deeply nested it is yet    
		; s_ty   <- newFlexiTyVarTy liftedTypeKind
		; return (cmd, i, b_ty, tup_ty, s_ty)
		}

    tc_cmd w_tv (cmd, i, b, tup_ty, s)
      = do { tup_ty' <- zonkTcType tup_ty
	   ; let (corner_ty, arg_tys) = unscramble tup_ty'

		-- Check that it has the right shape:
		-- 	((w,s1) .. sn)
		-- where the si do not mention w
	   ; checkTc (corner_ty `tcEqType` mkTyVarTy w_tv && 
		      not (w_tv `elemVarSet` tyVarsOfTypes arg_tys))
		     (badFormFun i tup_ty')

	   ; tcCmdTop (env { cmd_arr = b }) cmd arg_tys s }

    unscramble :: TcType -> (TcType, [TcType])
    -- unscramble ((w,s1) .. sn)	=  (w, [s1..sn])
    unscramble ty = unscramble' ty []

    unscramble' ty ss
       = case tcSplitTyConApp_maybe ty of
	    Just (tc, [t,s]) | tc == pairTyCon 
	       ->  unscramble' t (s:ss)
	    _ -> (ty, ss)

-----------------------------------------------------------------
--		Base case for illegal commands
-- This is where expressions that aren't commands get rejected

tc_cmd _ cmd _
  = failWithTc (vcat [ptext (sLit "The expression"), nest 2 (ppr cmd), 
		      ptext (sLit "was found where an arrow command was expected")])
\end{code}


%************************************************************************
%*									*
		Helpers
%*									*
%************************************************************************


\begin{code}
mkPairTy :: Type -> Type -> Type
mkPairTy t1 t2 = mkTyConApp pairTyCon [t1,t2]

arrowTyConKind :: Kind		--  *->*->*
arrowTyConKind = mkArrowKinds [liftedTypeKind, liftedTypeKind] liftedTypeKind
\end{code}


%************************************************************************
%*									*
		Errors
%*									*
%************************************************************************

\begin{code}
cmdCtxt :: HsExpr Name -> SDoc
cmdCtxt cmd = ptext (sLit "In the command:") <+> ppr cmd

nonEmptyCmdStkErr :: HsExpr Name -> SDoc
nonEmptyCmdStkErr cmd
  = hang (ptext (sLit "Non-empty command stack at command:"))
       2 (ppr cmd)

kappaUnderflow :: HsExpr Name -> SDoc
kappaUnderflow cmd
  = hang (ptext (sLit "Command stack underflow at command:"))
       2 (ppr cmd)

badFormFun :: Int -> TcType -> SDoc
badFormFun i tup_ty'
 = hang (ptext (sLit "The type of the") <+> speakNth i <+> ptext (sLit "argument of a command form has the wrong shape"))
      2 (ptext (sLit "Argument type:") <+> ppr tup_ty')
\end{code}