{-# LANGUAGE CPP, TupleSections #-}
module UnariseStg (unarise) where
#include "HsVersions.h"
import GhcPrelude
import BasicTypes
import CoreSyn
import DataCon
import FastString (FastString, mkFastString)
import Id
import Literal (Literal (..))
import MkCore (aBSENT_SUM_FIELD_ERROR_ID)
import MkId (voidPrimId, voidArgId)
import MonadUtils (mapAccumLM)
import Outputable
import RepType
import StgSyn
import Type
import TysPrim (intPrimTy)
import TysWiredIn
import UniqSupply
import Util
import VarEnv
import Data.Bifunctor (second)
import Data.Maybe (mapMaybe)
import qualified Data.IntMap as IM
type UnariseEnv = VarEnv UnariseVal
data UnariseVal
= MultiVal [OutStgArg]
| UnaryVal OutStgArg
instance Outputable UnariseVal where
ppr (MultiVal args) = text "MultiVal" <+> ppr args
ppr (UnaryVal arg) = text "UnaryVal" <+> ppr arg
extendRho :: UnariseEnv -> Id -> UnariseVal -> UnariseEnv
extendRho rho x (MultiVal args)
= ASSERT(all (isNvUnaryType . stgArgType) args)
extendVarEnv rho x (MultiVal args)
extendRho rho x (UnaryVal val)
= ASSERT(isNvUnaryType (stgArgType val))
extendVarEnv rho x (UnaryVal val)
unarise :: UniqSupply -> [StgTopBinding] -> [StgTopBinding]
unarise us binds = initUs_ us (mapM (unariseTopBinding emptyVarEnv) binds)
unariseTopBinding :: UnariseEnv -> StgTopBinding -> UniqSM StgTopBinding
unariseTopBinding rho (StgTopLifted bind)
= StgTopLifted <$> unariseBinding rho bind
unariseTopBinding _ bind@StgTopStringLit{} = return bind
unariseBinding :: UnariseEnv -> StgBinding -> UniqSM StgBinding
unariseBinding rho (StgNonRec x rhs)
= StgNonRec x <$> unariseRhs rho rhs
unariseBinding rho (StgRec xrhss)
= StgRec <$> mapM (\(x, rhs) -> (x,) <$> unariseRhs rho rhs) xrhss
unariseRhs :: UnariseEnv -> StgRhs -> UniqSM StgRhs
unariseRhs rho (StgRhsClosure ccs b_info fvs update_flag args expr)
= do (rho', args1) <- unariseFunArgBinders rho args
expr' <- unariseExpr rho' expr
let fvs' = unariseFreeVars rho fvs
return (StgRhsClosure ccs b_info fvs' update_flag args1 expr')
unariseRhs rho (StgRhsCon ccs con args)
= ASSERT(not (isUnboxedTupleCon con || isUnboxedSumCon con))
return (StgRhsCon ccs con (unariseConArgs rho args))
unariseExpr :: UnariseEnv -> StgExpr -> UniqSM StgExpr
unariseExpr rho e@(StgApp f [])
= case lookupVarEnv rho f of
Just (MultiVal args)
-> return (mkTuple args)
Just (UnaryVal (StgVarArg f'))
-> return (StgApp f' [])
Just (UnaryVal (StgLitArg f'))
-> return (StgLit f')
Nothing
-> return e
unariseExpr rho e@(StgApp f args)
= return (StgApp f' (unariseFunArgs rho args))
where
f' = case lookupVarEnv rho f of
Just (UnaryVal (StgVarArg f')) -> f'
Nothing -> f
err -> pprPanic "unariseExpr - app2" (ppr e $$ ppr err)
unariseExpr _ (StgLit l)
= return (StgLit l)
unariseExpr rho (StgConApp dc args ty_args)
| Just args' <- unariseMulti_maybe rho dc args ty_args
= return (mkTuple args')
| otherwise
, let args' = unariseConArgs rho args
= return (StgConApp dc args' (map stgArgType args'))
unariseExpr rho (StgOpApp op args ty)
= return (StgOpApp op (unariseFunArgs rho args) ty)
unariseExpr _ e@StgLam{}
= pprPanic "unariseExpr: found lambda" (ppr e)
unariseExpr rho (StgCase scrut bndr alt_ty alts)
| StgApp v [] <- scrut
, Just (MultiVal xs) <- lookupVarEnv rho v
= elimCase rho xs bndr alt_ty alts
| StgConApp dc args ty_args <- scrut
, Just args' <- unariseMulti_maybe rho dc args ty_args
= elimCase rho args' bndr alt_ty alts
| otherwise
= do scrut' <- unariseExpr rho scrut
alts' <- unariseAlts rho alt_ty bndr alts
return (StgCase scrut' bndr alt_ty alts')
unariseExpr rho (StgLet bind e)
= StgLet <$> unariseBinding rho bind <*> unariseExpr rho e
unariseExpr rho (StgLetNoEscape bind e)
= StgLetNoEscape <$> unariseBinding rho bind <*> unariseExpr rho e
unariseExpr rho (StgTick tick e)
= StgTick tick <$> unariseExpr rho e
unariseMulti_maybe :: UnariseEnv -> DataCon -> [InStgArg] -> [Type] -> Maybe [OutStgArg]
unariseMulti_maybe rho dc args ty_args
| isUnboxedTupleCon dc
= Just (unariseConArgs rho args)
| isUnboxedSumCon dc
, let args1 = ASSERT(isSingleton args) (unariseConArgs rho args)
= Just (mkUbxSum dc ty_args args1)
| otherwise
= Nothing
elimCase :: UnariseEnv
-> [OutStgArg]
-> InId -> AltType -> [InStgAlt] -> UniqSM OutStgExpr
elimCase rho args bndr (MultiValAlt _) [(_, bndrs, rhs)]
= do let rho1 = extendRho rho bndr (MultiVal args)
rho2
| isUnboxedTupleBndr bndr
= mapTupleIdBinders bndrs args rho1
| otherwise
= ASSERT(isUnboxedSumBndr bndr)
if null bndrs then rho1
else mapSumIdBinders bndrs args rho1
unariseExpr rho2 rhs
elimCase rho args bndr (MultiValAlt _) alts
| isUnboxedSumBndr bndr
= do let (tag_arg : real_args) = args
tag_bndr <- mkId (mkFastString "tag") tagTy
let rho1 = extendRho rho bndr (MultiVal args)
scrut' = case tag_arg of
StgVarArg v -> StgApp v []
StgLitArg l -> StgLit l
alts' <- unariseSumAlts rho1 real_args alts
return (StgCase scrut' tag_bndr tagAltTy alts')
elimCase _ args bndr alt_ty alts
= pprPanic "elimCase - unhandled case"
(ppr args <+> ppr bndr <+> ppr alt_ty $$ ppr alts)
unariseAlts :: UnariseEnv -> AltType -> InId -> [StgAlt] -> UniqSM [StgAlt]
unariseAlts rho (MultiValAlt n) bndr [(DEFAULT, [], e)]
| isUnboxedTupleBndr bndr
= do (rho', ys) <- unariseConArgBinder rho bndr
e' <- unariseExpr rho' e
return [(DataAlt (tupleDataCon Unboxed n), ys, e')]
unariseAlts rho (MultiValAlt n) bndr [(DataAlt _, ys, e)]
| isUnboxedTupleBndr bndr
= do (rho', ys1) <- unariseConArgBinders rho ys
MASSERT(ys1 `lengthIs` n)
let rho'' = extendRho rho' bndr (MultiVal (map StgVarArg ys1))
e' <- unariseExpr rho'' e
return [(DataAlt (tupleDataCon Unboxed n), ys1, e')]
unariseAlts _ (MultiValAlt _) bndr alts
| isUnboxedTupleBndr bndr
= pprPanic "unariseExpr: strange multi val alts" (ppr alts)
unariseAlts rho (MultiValAlt _) bndr [(DEFAULT, _, rhs)]
| isUnboxedSumBndr bndr
= do (rho_sum_bndrs, sum_bndrs) <- unariseConArgBinder rho bndr
rhs' <- unariseExpr rho_sum_bndrs rhs
return [(DataAlt (tupleDataCon Unboxed (length sum_bndrs)), sum_bndrs, rhs')]
unariseAlts rho (MultiValAlt _) bndr alts
| isUnboxedSumBndr bndr
= do (rho_sum_bndrs, scrt_bndrs@(tag_bndr : real_bndrs)) <- unariseConArgBinder rho bndr
alts' <- unariseSumAlts rho_sum_bndrs (map StgVarArg real_bndrs) alts
let inner_case = StgCase (StgApp tag_bndr []) tag_bndr tagAltTy alts'
return [ (DataAlt (tupleDataCon Unboxed (length scrt_bndrs)),
scrt_bndrs,
inner_case) ]
unariseAlts rho _ _ alts
= mapM (\alt -> unariseAlt rho alt) alts
unariseAlt :: UnariseEnv -> StgAlt -> UniqSM StgAlt
unariseAlt rho (con, xs, e)
= do (rho', xs') <- unariseConArgBinders rho xs
(con, xs',) <$> unariseExpr rho' e
unariseSumAlts :: UnariseEnv
-> [StgArg]
-> [StgAlt]
-> UniqSM [StgAlt]
unariseSumAlts env args alts
= do alts' <- mapM (unariseSumAlt env args) alts
return (mkDefaultLitAlt alts')
unariseSumAlt :: UnariseEnv
-> [StgArg]
-> StgAlt
-> UniqSM StgAlt
unariseSumAlt rho _ (DEFAULT, _, e)
= ( DEFAULT, [], ) <$> unariseExpr rho e
unariseSumAlt rho args (DataAlt sumCon, bs, e)
= do let rho' = mapSumIdBinders bs args rho
e' <- unariseExpr rho' e
return ( LitAlt (MachInt (fromIntegral (dataConTag sumCon))), [], e' )
unariseSumAlt _ scrt alt
= pprPanic "unariseSumAlt" (ppr scrt $$ ppr alt)
mapTupleIdBinders
:: [InId]
-> [OutStgArg]
-> UnariseEnv
-> UnariseEnv
mapTupleIdBinders ids args0 rho0
= ASSERT(not (any (isVoidTy . stgArgType) args0))
let
ids_unarised :: [(Id, [PrimRep])]
ids_unarised = map (\id -> (id, typePrimRep (idType id))) ids
map_ids :: UnariseEnv -> [(Id, [PrimRep])] -> [StgArg] -> UnariseEnv
map_ids rho [] _ = rho
map_ids rho ((x, x_reps) : xs) args =
let
x_arity = length x_reps
(x_args, args') =
ASSERT(args `lengthAtLeast` x_arity)
splitAt x_arity args
rho'
| x_arity == 1
= ASSERT(x_args `lengthIs` 1)
extendRho rho x (UnaryVal (head x_args))
| otherwise
= extendRho rho x (MultiVal x_args)
in
map_ids rho' xs args'
in
map_ids rho0 ids_unarised args0
mapSumIdBinders
:: [InId]
-> [OutStgArg]
-> UnariseEnv
-> UnariseEnv
mapSumIdBinders [id] args rho0
= ASSERT(not (any (isVoidTy . stgArgType) args))
let
arg_slots = map primRepSlot $ concatMap (typePrimRep . stgArgType) args
id_slots = map primRepSlot $ typePrimRep (idType id)
layout1 = layoutUbxSum arg_slots id_slots
in
if isMultiValBndr id
then extendRho rho0 id (MultiVal [ args !! i | i <- layout1 ])
else ASSERT(layout1 `lengthIs` 1)
extendRho rho0 id (UnaryVal (args !! head layout1))
mapSumIdBinders ids sum_args _
= pprPanic "mapSumIdBinders" (ppr ids $$ ppr sum_args)
mkUbxSum
:: DataCon
-> [Type]
-> [OutStgArg]
-> [OutStgArg]
mkUbxSum dc ty_args args0
= let
(_ : sum_slots) = ubxSumRepType (map typePrimRep ty_args)
tag = dataConTag dc
layout' = layoutUbxSum sum_slots (mapMaybe (typeSlotTy . stgArgType) args0)
tag_arg = StgLitArg (MachInt (fromIntegral tag))
arg_idxs = IM.fromList (zipEqual "mkUbxSum" layout' args0)
mkTupArgs :: Int -> [SlotTy] -> IM.IntMap StgArg -> [StgArg]
mkTupArgs _ [] _
= []
mkTupArgs arg_idx (slot : slots_left) arg_map
| Just stg_arg <- IM.lookup arg_idx arg_map
= stg_arg : mkTupArgs (arg_idx + 1) slots_left arg_map
| otherwise
= slotRubbishArg slot : mkTupArgs (arg_idx + 1) slots_left arg_map
slotRubbishArg :: SlotTy -> StgArg
slotRubbishArg PtrSlot = StgVarArg aBSENT_SUM_FIELD_ERROR_ID
slotRubbishArg WordSlot = StgLitArg (MachWord 0)
slotRubbishArg Word64Slot = StgLitArg (MachWord64 0)
slotRubbishArg FloatSlot = StgLitArg (MachFloat 0)
slotRubbishArg DoubleSlot = StgLitArg (MachDouble 0)
in
tag_arg : mkTupArgs 0 sum_slots arg_idxs
unariseFunArg :: UnariseEnv -> StgArg -> [StgArg]
unariseFunArg rho (StgVarArg x) =
case lookupVarEnv rho x of
Just (MultiVal []) -> [voidArg]
Just (MultiVal as) -> as
Just (UnaryVal arg) -> [arg]
Nothing -> [StgVarArg x]
unariseFunArg _ arg = [arg]
unariseFunArgs :: UnariseEnv -> [StgArg] -> [StgArg]
unariseFunArgs = concatMap . unariseFunArg
unariseFunArgBinders :: UnariseEnv -> [Id] -> UniqSM (UnariseEnv, [Id])
unariseFunArgBinders rho xs = second concat <$> mapAccumLM unariseFunArgBinder rho xs
unariseFunArgBinder :: UnariseEnv -> Id -> UniqSM (UnariseEnv, [Id])
unariseFunArgBinder rho x =
case typePrimRep (idType x) of
[] -> return (extendRho rho x (MultiVal []), [voidArgId])
[_] -> return (rho, [x])
reps -> do
xs <- mkIds (mkFastString "us") (map primRepToType reps)
return (extendRho rho x (MultiVal (map StgVarArg xs)), xs)
unariseConArg :: UnariseEnv -> InStgArg -> [OutStgArg]
unariseConArg rho (StgVarArg x) =
case lookupVarEnv rho x of
Just (UnaryVal arg) -> [arg]
Just (MultiVal as) -> as
Nothing
| isVoidTy (idType x) -> []
| otherwise -> [StgVarArg x]
unariseConArg _ arg = [arg]
unariseConArgs :: UnariseEnv -> [InStgArg] -> [OutStgArg]
unariseConArgs = concatMap . unariseConArg
unariseConArgBinders :: UnariseEnv -> [Id] -> UniqSM (UnariseEnv, [Id])
unariseConArgBinders rho xs = second concat <$> mapAccumLM unariseConArgBinder rho xs
unariseConArgBinder :: UnariseEnv -> Id -> UniqSM (UnariseEnv, [Id])
unariseConArgBinder rho x =
case typePrimRep (idType x) of
[_] -> return (rho, [x])
reps -> do
xs <- mkIds (mkFastString "us") (map primRepToType reps)
return (extendRho rho x (MultiVal (map StgVarArg xs)), xs)
unariseFreeVars :: UnariseEnv -> [InId] -> [OutId]
unariseFreeVars rho fvs
= [ v | fv <- fvs, StgVarArg v <- unariseFreeVar rho fv ]
unariseFreeVar :: UnariseEnv -> Id -> [StgArg]
unariseFreeVar rho x =
case lookupVarEnv rho x of
Just (MultiVal args) -> args
Just (UnaryVal arg) -> [arg]
Nothing -> [StgVarArg x]
mkIds :: FastString -> [UnaryType] -> UniqSM [Id]
mkIds fs tys = mapM (mkId fs) tys
mkId :: FastString -> UnaryType -> UniqSM Id
mkId = mkSysLocalOrCoVarM
isMultiValBndr :: Id -> Bool
isMultiValBndr id
| [_] <- typePrimRep (idType id)
= False
| otherwise
= True
isUnboxedSumBndr :: Id -> Bool
isUnboxedSumBndr = isUnboxedSumType . idType
isUnboxedTupleBndr :: Id -> Bool
isUnboxedTupleBndr = isUnboxedTupleType . idType
mkTuple :: [StgArg] -> StgExpr
mkTuple args = StgConApp (tupleDataCon Unboxed (length args)) args (map stgArgType args)
tagAltTy :: AltType
tagAltTy = PrimAlt IntRep
tagTy :: Type
tagTy = intPrimTy
voidArg :: StgArg
voidArg = StgVarArg voidPrimId
mkDefaultLitAlt :: [StgAlt] -> [StgAlt]
mkDefaultLitAlt [] = pprPanic "elimUbxSumExpr.mkDefaultAlt" (text "Empty alts")
mkDefaultLitAlt alts@((DEFAULT, _, _) : _) = alts
mkDefaultLitAlt ((LitAlt{}, [], rhs) : alts) = (DEFAULT, [], rhs) : alts
mkDefaultLitAlt alts = pprPanic "mkDefaultLitAlt" (text "Not a lit alt:" <+> ppr alts)