module GHC.Stg.Unarise (unarise) where
#include "HsVersions.h"
import GHC.Prelude
import GHC.Types.Basic
import GHC.Core
import GHC.Core.DataCon
import GHC.Data.FastString (FastString, mkFastString)
import GHC.Types.Id
import GHC.Types.Literal
import GHC.Core.Make (aBSENT_SUM_FIELD_ERROR_ID)
import GHC.Types.Id.Make (voidPrimId, voidArgId)
import GHC.Utils.Monad (mapAccumLM)
import GHC.Utils.Outputable
import GHC.Types.RepType
import GHC.Stg.Syntax
import GHC.Core.Type
import GHC.Builtin.Types.Prim (intPrimTy)
import GHC.Builtin.Types
import GHC.Types.Unique.Supply
import GHC.Utils.Misc
import GHC.Types.Var.Env
import Data.Bifunctor (second)
import Data.Maybe (mapMaybe)
import qualified Data.IntMap as IM
type UnariseEnv = VarEnv UnariseVal
data UnariseVal
= MultiVal [OutStgArg]
| UnaryVal OutStgArg
instance Outputable UnariseVal where
ppr (MultiVal args) = text "MultiVal" <+> ppr args
ppr (UnaryVal arg) = text "UnaryVal" <+> ppr arg
extendRho :: UnariseEnv -> Id -> UnariseVal -> UnariseEnv
extendRho rho x (MultiVal args)
= ASSERT(all (isNvUnaryType . stgArgType) args)
extendVarEnv rho x (MultiVal args)
extendRho rho x (UnaryVal val)
= ASSERT(isNvUnaryType (stgArgType val))
extendVarEnv rho x (UnaryVal val)
unarise :: UniqSupply -> [StgTopBinding] -> [StgTopBinding]
unarise us binds = initUs_ us (mapM (unariseTopBinding emptyVarEnv) binds)
unariseTopBinding :: UnariseEnv -> StgTopBinding -> UniqSM StgTopBinding
unariseTopBinding rho (StgTopLifted bind)
= StgTopLifted <$> unariseBinding rho bind
unariseTopBinding _ bind@StgTopStringLit{} = return bind
unariseBinding :: UnariseEnv -> StgBinding -> UniqSM StgBinding
unariseBinding rho (StgNonRec x rhs)
= StgNonRec x <$> unariseRhs rho rhs
unariseBinding rho (StgRec xrhss)
= StgRec <$> mapM (\(x, rhs) -> (x,) <$> unariseRhs rho rhs) xrhss
unariseRhs :: UnariseEnv -> StgRhs -> UniqSM StgRhs
unariseRhs rho (StgRhsClosure ext ccs update_flag args expr)
= do (rho', args1) <- unariseFunArgBinders rho args
expr' <- unariseExpr rho' expr
return (StgRhsClosure ext ccs update_flag args1 expr')
unariseRhs rho (StgRhsCon ccs con args)
= ASSERT(not (isUnboxedTupleCon con || isUnboxedSumCon con))
return (StgRhsCon ccs con (unariseConArgs rho args))
unariseExpr :: UnariseEnv -> StgExpr -> UniqSM StgExpr
unariseExpr rho e@(StgApp f [])
= case lookupVarEnv rho f of
Just (MultiVal args)
-> return (mkTuple args)
Just (UnaryVal (StgVarArg f'))
-> return (StgApp f' [])
Just (UnaryVal (StgLitArg f'))
-> return (StgLit f')
Nothing
-> return e
unariseExpr rho e@(StgApp f args)
= return (StgApp f' (unariseFunArgs rho args))
where
f' = case lookupVarEnv rho f of
Just (UnaryVal (StgVarArg f')) -> f'
Nothing -> f
err -> pprPanic "unariseExpr - app2" (pprStgExpr panicStgPprOpts e $$ ppr err)
unariseExpr _ (StgLit l)
= return (StgLit l)
unariseExpr rho (StgConApp dc args ty_args)
| Just args' <- unariseMulti_maybe rho dc args ty_args
= return (mkTuple args')
| otherwise
, let args' = unariseConArgs rho args
= return (StgConApp dc args' (map stgArgType args'))
unariseExpr rho (StgOpApp op args ty)
= return (StgOpApp op (unariseFunArgs rho args) ty)
unariseExpr _ e@StgLam{}
= pprPanic "unariseExpr: found lambda" (pprStgExpr panicStgPprOpts e)
unariseExpr rho (StgCase scrut bndr alt_ty alts)
| StgApp v [] <- scrut
, Just (MultiVal xs) <- lookupVarEnv rho v
= elimCase rho xs bndr alt_ty alts
| StgConApp dc args ty_args <- scrut
, Just args' <- unariseMulti_maybe rho dc args ty_args
= elimCase rho args' bndr alt_ty alts
| otherwise
= do scrut' <- unariseExpr rho scrut
alts' <- unariseAlts rho alt_ty bndr alts
return (StgCase scrut' bndr alt_ty alts')
unariseExpr rho (StgLet ext bind e)
= StgLet ext <$> unariseBinding rho bind <*> unariseExpr rho e
unariseExpr rho (StgLetNoEscape ext bind e)
= StgLetNoEscape ext <$> unariseBinding rho bind <*> unariseExpr rho e
unariseExpr rho (StgTick tick e)
= StgTick tick <$> unariseExpr rho e
unariseMulti_maybe :: UnariseEnv -> DataCon -> [InStgArg] -> [Type] -> Maybe [OutStgArg]
unariseMulti_maybe rho dc args ty_args
| isUnboxedTupleCon dc
= Just (unariseConArgs rho args)
| isUnboxedSumCon dc
, let args1 = ASSERT(isSingleton args) (unariseConArgs rho args)
= Just (mkUbxSum dc ty_args args1)
| otherwise
= Nothing
elimCase :: UnariseEnv
-> [OutStgArg]
-> InId -> AltType -> [InStgAlt] -> UniqSM OutStgExpr
elimCase rho args bndr (MultiValAlt _) [(_, bndrs, rhs)]
= do let rho1 = extendRho rho bndr (MultiVal args)
rho2
| isUnboxedTupleBndr bndr
= mapTupleIdBinders bndrs args rho1
| otherwise
= ASSERT(isUnboxedSumBndr bndr)
if null bndrs then rho1
else mapSumIdBinders bndrs args rho1
unariseExpr rho2 rhs
elimCase rho args bndr (MultiValAlt _) alts
| isUnboxedSumBndr bndr
= do let (tag_arg : real_args) = args
tag_bndr <- mkId (mkFastString "tag") tagTy
let rho1 = extendRho rho bndr (MultiVal args)
scrut' = case tag_arg of
StgVarArg v -> StgApp v []
StgLitArg l -> StgLit l
alts' <- unariseSumAlts rho1 real_args alts
return (StgCase scrut' tag_bndr tagAltTy alts')
elimCase _ args bndr alt_ty alts
= pprPanic "elimCase - unhandled case"
(ppr args <+> ppr bndr <+> ppr alt_ty $$ pprPanicAlts alts)
unariseAlts :: UnariseEnv -> AltType -> InId -> [StgAlt] -> UniqSM [StgAlt]
unariseAlts rho (MultiValAlt n) bndr [(DEFAULT, [], e)]
| isUnboxedTupleBndr bndr
= do (rho', ys) <- unariseConArgBinder rho bndr
e' <- unariseExpr rho' e
return [(DataAlt (tupleDataCon Unboxed n), ys, e')]
unariseAlts rho (MultiValAlt n) bndr [(DataAlt _, ys, e)]
| isUnboxedTupleBndr bndr
= do (rho', ys1) <- unariseConArgBinders rho ys
MASSERT(ys1 `lengthIs` n)
let rho'' = extendRho rho' bndr (MultiVal (map StgVarArg ys1))
e' <- unariseExpr rho'' e
return [(DataAlt (tupleDataCon Unboxed n), ys1, e')]
unariseAlts _ (MultiValAlt _) bndr alts
| isUnboxedTupleBndr bndr
= pprPanic "unariseExpr: strange multi val alts" (pprPanicAlts alts)
unariseAlts rho (MultiValAlt _) bndr [(DEFAULT, _, rhs)]
| isUnboxedSumBndr bndr
= do (rho_sum_bndrs, sum_bndrs) <- unariseConArgBinder rho bndr
rhs' <- unariseExpr rho_sum_bndrs rhs
return [(DataAlt (tupleDataCon Unboxed (length sum_bndrs)), sum_bndrs, rhs')]
unariseAlts rho (MultiValAlt _) bndr alts
| isUnboxedSumBndr bndr
= do (rho_sum_bndrs, scrt_bndrs@(tag_bndr : real_bndrs)) <- unariseConArgBinder rho bndr
alts' <- unariseSumAlts rho_sum_bndrs (map StgVarArg real_bndrs) alts
let inner_case = StgCase (StgApp tag_bndr []) tag_bndr tagAltTy alts'
return [ (DataAlt (tupleDataCon Unboxed (length scrt_bndrs)),
scrt_bndrs,
inner_case) ]
unariseAlts rho _ _ alts
= mapM (\alt -> unariseAlt rho alt) alts
unariseAlt :: UnariseEnv -> StgAlt -> UniqSM StgAlt
unariseAlt rho (con, xs, e)
= do (rho', xs') <- unariseConArgBinders rho xs
(con, xs',) <$> unariseExpr rho' e
unariseSumAlts :: UnariseEnv
-> [StgArg]
-> [StgAlt]
-> UniqSM [StgAlt]
unariseSumAlts env args alts
= do alts' <- mapM (unariseSumAlt env args) alts
return (mkDefaultLitAlt alts')
unariseSumAlt :: UnariseEnv
-> [StgArg]
-> StgAlt
-> UniqSM StgAlt
unariseSumAlt rho _ (DEFAULT, _, e)
= ( DEFAULT, [], ) <$> unariseExpr rho e
unariseSumAlt rho args (DataAlt sumCon, bs, e)
= do let rho' = mapSumIdBinders bs args rho
e' <- unariseExpr rho' e
return ( LitAlt (LitNumber LitNumInt (fromIntegral (dataConTag sumCon))), [], e' )
unariseSumAlt _ scrt alt
= pprPanic "unariseSumAlt" (ppr scrt $$ pprPanicAlt alt)
mapTupleIdBinders
:: [InId]
-> [OutStgArg]
-> UnariseEnv
-> UnariseEnv
mapTupleIdBinders ids args0 rho0
= ASSERT(not (any (isVoidTy . stgArgType) args0))
let
ids_unarised :: [(Id, [PrimRep])]
ids_unarised = map (\id -> (id, typePrimRep (idType id))) ids
map_ids :: UnariseEnv -> [(Id, [PrimRep])] -> [StgArg] -> UnariseEnv
map_ids rho [] _ = rho
map_ids rho ((x, x_reps) : xs) args =
let
x_arity = length x_reps
(x_args, args') =
ASSERT(args `lengthAtLeast` x_arity)
splitAt x_arity args
rho'
| x_arity == 1
= ASSERT(x_args `lengthIs` 1)
extendRho rho x (UnaryVal (head x_args))
| otherwise
= extendRho rho x (MultiVal x_args)
in
map_ids rho' xs args'
in
map_ids rho0 ids_unarised args0
mapSumIdBinders
:: [InId]
-> [OutStgArg]
-> UnariseEnv
-> UnariseEnv
mapSumIdBinders [id] args rho0
= ASSERT(not (any (isVoidTy . stgArgType) args))
let
arg_slots = map primRepSlot $ concatMap (typePrimRep . stgArgType) args
id_slots = map primRepSlot $ typePrimRep (idType id)
layout1 = layoutUbxSum arg_slots id_slots
in
if isMultiValBndr id
then extendRho rho0 id (MultiVal [ args !! i | i <- layout1 ])
else ASSERT(layout1 `lengthIs` 1)
extendRho rho0 id (UnaryVal (args !! head layout1))
mapSumIdBinders ids sum_args _
= pprPanic "mapSumIdBinders" (ppr ids $$ ppr sum_args)
mkUbxSum
:: DataCon
-> [Type]
-> [OutStgArg]
-> [OutStgArg]
mkUbxSum dc ty_args args0
= let
(_ : sum_slots) = ubxSumRepType (map typePrimRep ty_args)
tag = dataConTag dc
layout' = layoutUbxSum sum_slots (mapMaybe (typeSlotTy . stgArgType) args0)
tag_arg = StgLitArg (LitNumber LitNumInt (fromIntegral tag))
arg_idxs = IM.fromList (zipEqual "mkUbxSum" layout' args0)
mkTupArgs :: Int -> [SlotTy] -> IM.IntMap StgArg -> [StgArg]
mkTupArgs _ [] _
= []
mkTupArgs arg_idx (slot : slots_left) arg_map
| Just stg_arg <- IM.lookup arg_idx arg_map
= stg_arg : mkTupArgs (arg_idx + 1) slots_left arg_map
| otherwise
= ubxSumRubbishArg slot : mkTupArgs (arg_idx + 1) slots_left arg_map
in
tag_arg : mkTupArgs 0 sum_slots arg_idxs
ubxSumRubbishArg :: SlotTy -> StgArg
ubxSumRubbishArg PtrSlot = StgVarArg aBSENT_SUM_FIELD_ERROR_ID
ubxSumRubbishArg WordSlot = StgLitArg (LitNumber LitNumWord 0)
ubxSumRubbishArg Word64Slot = StgLitArg (LitNumber LitNumWord64 0)
ubxSumRubbishArg FloatSlot = StgLitArg (LitFloat 0)
ubxSumRubbishArg DoubleSlot = StgLitArg (LitDouble 0)
unariseArgBinder
:: Bool
-> UnariseEnv -> Id -> UniqSM (UnariseEnv, [Id])
unariseArgBinder is_con_arg rho x =
case typePrimRep (idType x) of
[]
| is_con_arg
-> return (extendRho rho x (MultiVal []), [])
| otherwise
-> return (extendRho rho x (MultiVal []), [voidArgId])
[rep]
| isUnboxedSumType (idType x) || isUnboxedTupleType (idType x)
-> do x' <- mkId (mkFastString "us") (primRepToType rep)
return (extendRho rho x (MultiVal [StgVarArg x']), [x'])
| otherwise
-> return (rho, [x])
reps -> do
xs <- mkIds (mkFastString "us") (map primRepToType reps)
return (extendRho rho x (MultiVal (map StgVarArg xs)), xs)
unariseFunArg :: UnariseEnv -> StgArg -> [StgArg]
unariseFunArg rho (StgVarArg x) =
case lookupVarEnv rho x of
Just (MultiVal []) -> [voidArg]
Just (MultiVal as) -> as
Just (UnaryVal arg) -> [arg]
Nothing -> [StgVarArg x]
unariseFunArg _ arg = [arg]
unariseFunArgs :: UnariseEnv -> [StgArg] -> [StgArg]
unariseFunArgs = concatMap . unariseFunArg
unariseFunArgBinders :: UnariseEnv -> [Id] -> UniqSM (UnariseEnv, [Id])
unariseFunArgBinders rho xs = second concat <$> mapAccumLM unariseFunArgBinder rho xs
unariseFunArgBinder :: UnariseEnv -> Id -> UniqSM (UnariseEnv, [Id])
unariseFunArgBinder = unariseArgBinder False
unariseConArg :: UnariseEnv -> InStgArg -> [OutStgArg]
unariseConArg rho (StgVarArg x) =
case lookupVarEnv rho x of
Just (UnaryVal arg) -> [arg]
Just (MultiVal as) -> as
Nothing
| isVoidTy (idType x) -> []
| otherwise -> [StgVarArg x]
unariseConArg _ arg@(StgLitArg lit) =
ASSERT(not (isVoidTy (literalType lit)))
[arg]
unariseConArgs :: UnariseEnv -> [InStgArg] -> [OutStgArg]
unariseConArgs = concatMap . unariseConArg
unariseConArgBinders :: UnariseEnv -> [Id] -> UniqSM (UnariseEnv, [Id])
unariseConArgBinders rho xs = second concat <$> mapAccumLM unariseConArgBinder rho xs
unariseConArgBinder :: UnariseEnv -> Id -> UniqSM (UnariseEnv, [Id])
unariseConArgBinder = unariseArgBinder True
mkIds :: FastString -> [UnaryType] -> UniqSM [Id]
mkIds fs tys = mapM (mkId fs) tys
mkId :: FastString -> UnaryType -> UniqSM Id
mkId s t = mkSysLocalM s Many t
isMultiValBndr :: Id -> Bool
isMultiValBndr id
| [_] <- typePrimRep (idType id)
= False
| otherwise
= True
isUnboxedSumBndr :: Id -> Bool
isUnboxedSumBndr = isUnboxedSumType . idType
isUnboxedTupleBndr :: Id -> Bool
isUnboxedTupleBndr = isUnboxedTupleType . idType
mkTuple :: [StgArg] -> StgExpr
mkTuple args = StgConApp (tupleDataCon Unboxed (length args)) args (map stgArgType args)
tagAltTy :: AltType
tagAltTy = PrimAlt IntRep
tagTy :: Type
tagTy = intPrimTy
voidArg :: StgArg
voidArg = StgVarArg voidPrimId
mkDefaultLitAlt :: [StgAlt] -> [StgAlt]
mkDefaultLitAlt [] = pprPanic "elimUbxSumExpr.mkDefaultAlt" (text "Empty alts")
mkDefaultLitAlt alts@((DEFAULT, _, _) : _) = alts
mkDefaultLitAlt ((LitAlt{}, [], rhs) : alts) = (DEFAULT, [], rhs) : alts
mkDefaultLitAlt alts = pprPanic "mkDefaultLitAlt" (text "Not a lit alt:" <+> pprPanicAlts alts)
pprPanicAlts :: (Outputable a, Outputable b, OutputablePass pass) => [(a,b,GenStgExpr pass)] -> SDoc
pprPanicAlts alts = ppr (map pprPanicAlt alts)
pprPanicAlt :: (Outputable a, Outputable b, OutputablePass pass) => (a,b,GenStgExpr pass) -> SDoc
pprPanicAlt (c,b,e) = ppr (c,b,pprStgExpr panicStgPprOpts e)