{-# LANGUAGE CPP #-}
module GHC.Core.Opt.FloatOut ( floatOutwards ) where
import GHC.Prelude
import GHC.Core
import GHC.Core.Utils
import GHC.Core.Make
import GHC.Core.Opt.Arity ( exprArity, etaExpand )
import GHC.Core.Opt.Monad ( FloatOutSwitches(..) )
import GHC.Driver.Session
import GHC.Utils.Error ( dumpIfSet_dyn, DumpFormat (..) )
import GHC.Types.Id ( Id, idArity, idType, isDeadEndId,
isJoinId, isJoinId_maybe )
import GHC.Core.Opt.SetLevels
import GHC.Types.Unique.Supply ( UniqSupply )
import GHC.Data.Bag
import GHC.Utils.Misc
import GHC.Data.Maybe
import GHC.Utils.Outputable
import GHC.Core.Type
import qualified Data.IntMap as M
import Data.List ( partition )
#include "HsVersions.h"
floatOutwards :: FloatOutSwitches
-> DynFlags
-> UniqSupply
-> CoreProgram -> IO CoreProgram
floatOutwards :: FloatOutSwitches
-> DynFlags -> UniqSupply -> [CoreBind] -> IO [CoreBind]
floatOutwards FloatOutSwitches
float_sws DynFlags
dflags UniqSupply
us [CoreBind]
pgm
= do {
let { annotated_w_levels :: [LevelledBind]
annotated_w_levels = FloatOutSwitches -> [CoreBind] -> UniqSupply -> [LevelledBind]
setLevels FloatOutSwitches
float_sws [CoreBind]
pgm UniqSupply
us ;
([FloatStats]
fss, [Bag CoreBind]
binds_s') = [(FloatStats, Bag CoreBind)] -> ([FloatStats], [Bag CoreBind])
forall a b. [(a, b)] -> ([a], [b])
unzip ((LevelledBind -> (FloatStats, Bag CoreBind))
-> [LevelledBind] -> [(FloatStats, Bag CoreBind)]
forall a b. (a -> b) -> [a] -> [b]
map LevelledBind -> (FloatStats, Bag CoreBind)
floatTopBind [LevelledBind]
annotated_w_levels)
} ;
DynFlags -> DumpFlag -> String -> DumpFormat -> SDoc -> IO ()
dumpIfSet_dyn DynFlags
dflags DumpFlag
Opt_D_verbose_core2core String
"Levels added:"
DumpFormat
FormatCore
([SDoc] -> SDoc
vcat ((LevelledBind -> SDoc) -> [LevelledBind] -> [SDoc]
forall a b. (a -> b) -> [a] -> [b]
map LevelledBind -> SDoc
forall a. Outputable a => a -> SDoc
ppr [LevelledBind]
annotated_w_levels));
let { (Int
tlets, Int
ntlets, Int
lams) = FloatStats -> (Int, Int, Int)
get_stats ([FloatStats] -> FloatStats
sum_stats [FloatStats]
fss) };
DynFlags -> DumpFlag -> String -> DumpFormat -> SDoc -> IO ()
dumpIfSet_dyn DynFlags
dflags DumpFlag
Opt_D_dump_simpl_stats String
"FloatOut stats:"
DumpFormat
FormatText
([SDoc] -> SDoc
hcat [ Int -> SDoc
int Int
tlets, String -> SDoc
text String
" Lets floated to top level; ",
Int -> SDoc
int Int
ntlets, String -> SDoc
text String
" Lets floated elsewhere; from ",
Int -> SDoc
int Int
lams, String -> SDoc
text String
" Lambda groups"]);
[CoreBind] -> IO [CoreBind]
forall (m :: * -> *) a. Monad m => a -> m a
return (Bag CoreBind -> [CoreBind]
forall a. Bag a -> [a]
bagToList ([Bag CoreBind] -> Bag CoreBind
forall a. [Bag a] -> Bag a
unionManyBags [Bag CoreBind]
binds_s'))
}
floatTopBind :: LevelledBind -> (FloatStats, Bag CoreBind)
floatTopBind :: LevelledBind -> (FloatStats, Bag CoreBind)
floatTopBind LevelledBind
bind
= case (LevelledBind -> (FloatStats, FloatBinds, [CoreBind])
floatBind LevelledBind
bind) of { (FloatStats
fs, FloatBinds
floats, [CoreBind]
bind') ->
let float_bag :: Bag CoreBind
float_bag = FloatBinds -> Bag CoreBind
flattenTopFloats FloatBinds
floats
in case [CoreBind]
bind' of
[Rec [(Id, Expr Id)]
prs] -> (FloatStats
fs, CoreBind -> Bag CoreBind
forall a. a -> Bag a
unitBag ([(Id, Expr Id)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec (Bag CoreBind -> [(Id, Expr Id)] -> [(Id, Expr Id)]
addTopFloatPairs Bag CoreBind
float_bag [(Id, Expr Id)]
prs)))
[NonRec Id
b Expr Id
e] -> (FloatStats
fs, Bag CoreBind
float_bag Bag CoreBind -> CoreBind -> Bag CoreBind
forall a. Bag a -> a -> Bag a
`snocBag` Id -> Expr Id -> CoreBind
forall b. b -> Expr b -> Bind b
NonRec Id
b Expr Id
e)
[CoreBind]
_ -> String -> SDoc -> (FloatStats, Bag CoreBind)
forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"floatTopBind" ([CoreBind] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [CoreBind]
bind') }
floatBind :: LevelledBind -> (FloatStats, FloatBinds, [CoreBind])
floatBind :: LevelledBind -> (FloatStats, FloatBinds, [CoreBind])
floatBind (NonRec (TB Id
var FloatSpec
_) Expr (TaggedBndr FloatSpec)
rhs)
= case (Id
-> Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatRhs Id
var Expr (TaggedBndr FloatSpec)
rhs) of { (FloatStats
fs, FloatBinds
rhs_floats, Expr Id
rhs') ->
let rhs'' :: Expr Id
rhs'' | Id -> Bool
isDeadEndId Id
var
, Expr Id -> Int
exprArity Expr Id
rhs' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Id -> Int
idArity Id
var = Int -> Expr Id -> Expr Id
etaExpand (Id -> Int
idArity Id
var) Expr Id
rhs'
| Bool
otherwise = Expr Id
rhs'
in (FloatStats
fs, FloatBinds
rhs_floats, [Id -> Expr Id -> CoreBind
forall b. b -> Expr b -> Bind b
NonRec Id
var Expr Id
rhs'']) }
floatBind (Rec [(TaggedBndr FloatSpec, Expr (TaggedBndr FloatSpec))]
pairs)
= case ((TaggedBndr FloatSpec, Expr (TaggedBndr FloatSpec))
-> (FloatStats, FloatBinds, ([(Id, Expr Id)], [(Id, Expr Id)])))
-> [(TaggedBndr FloatSpec, Expr (TaggedBndr FloatSpec))]
-> (FloatStats, FloatBinds, [([(Id, Expr Id)], [(Id, Expr Id)])])
forall a b.
(a -> (FloatStats, FloatBinds, b))
-> [a] -> (FloatStats, FloatBinds, [b])
floatList (TaggedBndr FloatSpec, Expr (TaggedBndr FloatSpec))
-> (FloatStats, FloatBinds, ([(Id, Expr Id)], [(Id, Expr Id)]))
do_pair [(TaggedBndr FloatSpec, Expr (TaggedBndr FloatSpec))]
pairs of { (FloatStats
fs, FloatBinds
rhs_floats, [([(Id, Expr Id)], [(Id, Expr Id)])]
new_pairs) ->
let ([[(Id, Expr Id)]]
new_ul_pairss, [[(Id, Expr Id)]]
new_other_pairss) = [([(Id, Expr Id)], [(Id, Expr Id)])]
-> ([[(Id, Expr Id)]], [[(Id, Expr Id)]])
forall a b. [(a, b)] -> ([a], [b])
unzip [([(Id, Expr Id)], [(Id, Expr Id)])]
new_pairs
([(Id, Expr Id)]
new_join_pairs, [(Id, Expr Id)]
new_l_pairs) = ((Id, Expr Id) -> Bool)
-> [(Id, Expr Id)] -> ([(Id, Expr Id)], [(Id, Expr Id)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (Id -> Bool
isJoinId (Id -> Bool) -> ((Id, Expr Id) -> Id) -> (Id, Expr Id) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Id, Expr Id) -> Id
forall a b. (a, b) -> a
fst)
([[(Id, Expr Id)]] -> [(Id, Expr Id)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[(Id, Expr Id)]]
new_other_pairss)
new_rec_binds :: [CoreBind]
new_rec_binds | [(Id, Expr Id)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Id, Expr Id)]
new_join_pairs = [ [(Id, Expr Id)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec [(Id, Expr Id)]
new_l_pairs ]
| [(Id, Expr Id)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Id, Expr Id)]
new_l_pairs = [ [(Id, Expr Id)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec [(Id, Expr Id)]
new_join_pairs ]
| Bool
otherwise = [ [(Id, Expr Id)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec [(Id, Expr Id)]
new_l_pairs
, [(Id, Expr Id)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec [(Id, Expr Id)]
new_join_pairs ]
new_non_rec_binds :: [CoreBind]
new_non_rec_binds = [ Id -> Expr Id -> CoreBind
forall b. b -> Expr b -> Bind b
NonRec Id
b Expr Id
e | (Id
b, Expr Id
e) <- [[(Id, Expr Id)]] -> [(Id, Expr Id)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[(Id, Expr Id)]]
new_ul_pairss ]
in
(FloatStats
fs, FloatBinds
rhs_floats, [CoreBind]
new_non_rec_binds [CoreBind] -> [CoreBind] -> [CoreBind]
forall a. [a] -> [a] -> [a]
++ [CoreBind]
new_rec_binds) }
where
do_pair :: (LevelledBndr, LevelledExpr)
-> (FloatStats, FloatBinds,
([(Id,CoreExpr)],
[(Id,CoreExpr)]))
do_pair :: (TaggedBndr FloatSpec, Expr (TaggedBndr FloatSpec))
-> (FloatStats, FloatBinds, ([(Id, Expr Id)], [(Id, Expr Id)]))
do_pair (TB Id
name FloatSpec
spec, Expr (TaggedBndr FloatSpec)
rhs)
| Level -> Bool
isTopLvl Level
dest_lvl
= case (Id
-> Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatRhs Id
name Expr (TaggedBndr FloatSpec)
rhs) of { (FloatStats
fs, FloatBinds
rhs_floats, Expr Id
rhs') ->
(FloatStats
fs, FloatBinds
emptyFloats, ([], Bag CoreBind -> [(Id, Expr Id)] -> [(Id, Expr Id)]
addTopFloatPairs (FloatBinds -> Bag CoreBind
flattenTopFloats FloatBinds
rhs_floats)
[(Id
name, Expr Id
rhs')]))}
| Bool
otherwise
= case (Id
-> Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatRhs Id
name Expr (TaggedBndr FloatSpec)
rhs) of { (FloatStats
fs, FloatBinds
rhs_floats, Expr Id
rhs') ->
case (Level -> FloatBinds -> (FloatBinds, Bag FloatBind)
partitionByLevel Level
dest_lvl FloatBinds
rhs_floats) of { (FloatBinds
rhs_floats', Bag FloatBind
heres) ->
case (Bag FloatBind -> ([(Id, Expr Id)], [(Id, Expr Id)], Bag FloatBind)
splitRecFloats Bag FloatBind
heres) of { ([(Id, Expr Id)]
ul_pairs, [(Id, Expr Id)]
pairs, Bag FloatBind
case_heres) ->
let pairs' :: [(Id, Expr Id)]
pairs' = (Id
name, Bag FloatBind -> Expr Id -> Expr Id
installUnderLambdas Bag FloatBind
case_heres Expr Id
rhs') (Id, Expr Id) -> [(Id, Expr Id)] -> [(Id, Expr Id)]
forall a. a -> [a] -> [a]
: [(Id, Expr Id)]
pairs in
(FloatStats
fs, FloatBinds
rhs_floats', ([(Id, Expr Id)]
ul_pairs, [(Id, Expr Id)]
pairs')) }}}
where
dest_lvl :: Level
dest_lvl = FloatSpec -> Level
floatSpecLevel FloatSpec
spec
splitRecFloats :: Bag FloatBind
-> ([(Id,CoreExpr)],
[(Id,CoreExpr)],
Bag FloatBind)
splitRecFloats :: Bag FloatBind -> ([(Id, Expr Id)], [(Id, Expr Id)], Bag FloatBind)
splitRecFloats Bag FloatBind
fs
= [(Id, Expr Id)]
-> [(Id, Expr Id)]
-> [FloatBind]
-> ([(Id, Expr Id)], [(Id, Expr Id)], Bag FloatBind)
go [] [] (Bag FloatBind -> [FloatBind]
forall a. Bag a -> [a]
bagToList Bag FloatBind
fs)
where
go :: [(Id, Expr Id)]
-> [(Id, Expr Id)]
-> [FloatBind]
-> ([(Id, Expr Id)], [(Id, Expr Id)], Bag FloatBind)
go [(Id, Expr Id)]
ul_prs [(Id, Expr Id)]
prs (FloatLet (NonRec Id
b Expr Id
r) : [FloatBind]
fs) | HasDebugCallStack => Type -> Bool
Type -> Bool
isUnliftedType (Id -> Type
idType Id
b)
, Bool -> Bool
not (Id -> Bool
isJoinId Id
b)
= [(Id, Expr Id)]
-> [(Id, Expr Id)]
-> [FloatBind]
-> ([(Id, Expr Id)], [(Id, Expr Id)], Bag FloatBind)
go ((Id
b,Expr Id
r)(Id, Expr Id) -> [(Id, Expr Id)] -> [(Id, Expr Id)]
forall a. a -> [a] -> [a]
:[(Id, Expr Id)]
ul_prs) [(Id, Expr Id)]
prs [FloatBind]
fs
| Bool
otherwise
= [(Id, Expr Id)]
-> [(Id, Expr Id)]
-> [FloatBind]
-> ([(Id, Expr Id)], [(Id, Expr Id)], Bag FloatBind)
go [(Id, Expr Id)]
ul_prs ((Id
b,Expr Id
r)(Id, Expr Id) -> [(Id, Expr Id)] -> [(Id, Expr Id)]
forall a. a -> [a] -> [a]
:[(Id, Expr Id)]
prs) [FloatBind]
fs
go [(Id, Expr Id)]
ul_prs [(Id, Expr Id)]
prs (FloatLet (Rec [(Id, Expr Id)]
prs') : [FloatBind]
fs) = [(Id, Expr Id)]
-> [(Id, Expr Id)]
-> [FloatBind]
-> ([(Id, Expr Id)], [(Id, Expr Id)], Bag FloatBind)
go [(Id, Expr Id)]
ul_prs ([(Id, Expr Id)]
prs' [(Id, Expr Id)] -> [(Id, Expr Id)] -> [(Id, Expr Id)]
forall a. [a] -> [a] -> [a]
++ [(Id, Expr Id)]
prs) [FloatBind]
fs
go [(Id, Expr Id)]
ul_prs [(Id, Expr Id)]
prs [FloatBind]
fs = ([(Id, Expr Id)] -> [(Id, Expr Id)]
forall a. [a] -> [a]
reverse [(Id, Expr Id)]
ul_prs, [(Id, Expr Id)]
prs,
[FloatBind] -> Bag FloatBind
forall a. [a] -> Bag a
listToBag [FloatBind]
fs)
installUnderLambdas :: Bag FloatBind -> CoreExpr -> CoreExpr
installUnderLambdas :: Bag FloatBind -> Expr Id -> Expr Id
installUnderLambdas Bag FloatBind
floats Expr Id
e
| Bag FloatBind -> Bool
forall a. Bag a -> Bool
isEmptyBag Bag FloatBind
floats = Expr Id
e
| Bool
otherwise = Expr Id -> Expr Id
go Expr Id
e
where
go :: Expr Id -> Expr Id
go (Lam Id
b Expr Id
e) = Id -> Expr Id -> Expr Id
forall b. b -> Expr b -> Expr b
Lam Id
b (Expr Id -> Expr Id
go Expr Id
e)
go Expr Id
e = Bag FloatBind -> Expr Id -> Expr Id
install Bag FloatBind
floats Expr Id
e
floatList :: (a -> (FloatStats, FloatBinds, b)) -> [a] -> (FloatStats, FloatBinds, [b])
floatList :: forall a b.
(a -> (FloatStats, FloatBinds, b))
-> [a] -> (FloatStats, FloatBinds, [b])
floatList a -> (FloatStats, FloatBinds, b)
_ [] = (FloatStats
zeroStats, FloatBinds
emptyFloats, [])
floatList a -> (FloatStats, FloatBinds, b)
f (a
a:[a]
as) = case a -> (FloatStats, FloatBinds, b)
f a
a of { (FloatStats
fs_a, FloatBinds
binds_a, b
b) ->
case (a -> (FloatStats, FloatBinds, b))
-> [a] -> (FloatStats, FloatBinds, [b])
forall a b.
(a -> (FloatStats, FloatBinds, b))
-> [a] -> (FloatStats, FloatBinds, [b])
floatList a -> (FloatStats, FloatBinds, b)
f [a]
as of { (FloatStats
fs_as, FloatBinds
binds_as, [b]
bs) ->
(FloatStats
fs_a FloatStats -> FloatStats -> FloatStats
`add_stats` FloatStats
fs_as, FloatBinds
binds_a FloatBinds -> FloatBinds -> FloatBinds
`plusFloats` FloatBinds
binds_as, b
bb -> [b] -> [b]
forall a. a -> [a] -> [a]
:[b]
bs) }}
floatBody :: Level
-> LevelledExpr
-> (FloatStats, FloatBinds, CoreExpr)
floatBody :: Level
-> Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatBody Level
lvl Expr (TaggedBndr FloatSpec)
arg
= case (Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatExpr Expr (TaggedBndr FloatSpec)
arg) of { (FloatStats
fsa, FloatBinds
floats, Expr Id
arg') ->
case (Level -> FloatBinds -> (FloatBinds, Bag FloatBind)
partitionByLevel Level
lvl FloatBinds
floats) of { (FloatBinds
floats', Bag FloatBind
heres) ->
(FloatStats
fsa, FloatBinds
floats', Bag FloatBind -> Expr Id -> Expr Id
install Bag FloatBind
heres Expr Id
arg') }}
floatExpr :: LevelledExpr
-> (FloatStats, FloatBinds, CoreExpr)
floatExpr :: Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatExpr (Var Id
v) = (FloatStats
zeroStats, FloatBinds
emptyFloats, Id -> Expr Id
forall b. Id -> Expr b
Var Id
v)
floatExpr (Type Type
ty) = (FloatStats
zeroStats, FloatBinds
emptyFloats, Type -> Expr Id
forall b. Type -> Expr b
Type Type
ty)
floatExpr (Coercion Coercion
co) = (FloatStats
zeroStats, FloatBinds
emptyFloats, Coercion -> Expr Id
forall b. Coercion -> Expr b
Coercion Coercion
co)
floatExpr (Lit Literal
lit) = (FloatStats
zeroStats, FloatBinds
emptyFloats, Literal -> Expr Id
forall b. Literal -> Expr b
Lit Literal
lit)
floatExpr (App Expr (TaggedBndr FloatSpec)
e Expr (TaggedBndr FloatSpec)
a)
= case ((FloatStats, FloatBinds, Expr Id)
-> (FloatStats, FloatBinds, Expr Id)
atJoinCeiling ((FloatStats, FloatBinds, Expr Id)
-> (FloatStats, FloatBinds, Expr Id))
-> (FloatStats, FloatBinds, Expr Id)
-> (FloatStats, FloatBinds, Expr Id)
forall a b. (a -> b) -> a -> b
$ Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatExpr Expr (TaggedBndr FloatSpec)
e) of { (FloatStats
fse, FloatBinds
floats_e, Expr Id
e') ->
case ((FloatStats, FloatBinds, Expr Id)
-> (FloatStats, FloatBinds, Expr Id)
atJoinCeiling ((FloatStats, FloatBinds, Expr Id)
-> (FloatStats, FloatBinds, Expr Id))
-> (FloatStats, FloatBinds, Expr Id)
-> (FloatStats, FloatBinds, Expr Id)
forall a b. (a -> b) -> a -> b
$ Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatExpr Expr (TaggedBndr FloatSpec)
a) of { (FloatStats
fsa, FloatBinds
floats_a, Expr Id
a') ->
(FloatStats
fse FloatStats -> FloatStats -> FloatStats
`add_stats` FloatStats
fsa, FloatBinds
floats_e FloatBinds -> FloatBinds -> FloatBinds
`plusFloats` FloatBinds
floats_a, Expr Id -> Expr Id -> Expr Id
forall b. Expr b -> Expr b -> Expr b
App Expr Id
e' Expr Id
a') }}
floatExpr lam :: Expr (TaggedBndr FloatSpec)
lam@(Lam (TB Id
_ FloatSpec
lam_spec) Expr (TaggedBndr FloatSpec)
_)
= let ([TaggedBndr FloatSpec]
bndrs_w_lvls, Expr (TaggedBndr FloatSpec)
body) = Expr (TaggedBndr FloatSpec)
-> ([TaggedBndr FloatSpec], Expr (TaggedBndr FloatSpec))
forall b. Expr b -> ([b], Expr b)
collectBinders Expr (TaggedBndr FloatSpec)
lam
bndrs :: [Id]
bndrs = [Id
b | TB Id
b FloatSpec
_ <- [TaggedBndr FloatSpec]
bndrs_w_lvls]
bndr_lvl :: Level
bndr_lvl = Level -> Level
asJoinCeilLvl (FloatSpec -> Level
floatSpecLevel FloatSpec
lam_spec)
in
case (Level
-> Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatBody Level
bndr_lvl Expr (TaggedBndr FloatSpec)
body) of { (FloatStats
fs, FloatBinds
floats, Expr Id
body') ->
(FloatStats -> FloatBinds -> FloatStats
add_to_stats FloatStats
fs FloatBinds
floats, FloatBinds
floats, [Id] -> Expr Id -> Expr Id
forall b. [b] -> Expr b -> Expr b
mkLams [Id]
bndrs Expr Id
body') }
floatExpr (Tick Tickish Id
tickish Expr (TaggedBndr FloatSpec)
expr)
| Tickish Id
tickish Tickish Id -> TickishScoping -> Bool
forall id. Tickish id -> TickishScoping -> Bool
`tickishScopesLike` TickishScoping
SoftScope
= case ((FloatStats, FloatBinds, Expr Id)
-> (FloatStats, FloatBinds, Expr Id)
atJoinCeiling ((FloatStats, FloatBinds, Expr Id)
-> (FloatStats, FloatBinds, Expr Id))
-> (FloatStats, FloatBinds, Expr Id)
-> (FloatStats, FloatBinds, Expr Id)
forall a b. (a -> b) -> a -> b
$ Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatExpr Expr (TaggedBndr FloatSpec)
expr) of { (FloatStats
fs, FloatBinds
floating_defns, Expr Id
expr') ->
(FloatStats
fs, FloatBinds
floating_defns, Tickish Id -> Expr Id -> Expr Id
forall b. Tickish Id -> Expr b -> Expr b
Tick Tickish Id
tickish Expr Id
expr') }
| Bool -> Bool
not (Tickish Id -> Bool
forall id. Tickish id -> Bool
tickishCounts Tickish Id
tickish) Bool -> Bool -> Bool
|| Tickish Id -> Bool
forall id. Tickish id -> Bool
tickishCanSplit Tickish Id
tickish
= case ((FloatStats, FloatBinds, Expr Id)
-> (FloatStats, FloatBinds, Expr Id)
atJoinCeiling ((FloatStats, FloatBinds, Expr Id)
-> (FloatStats, FloatBinds, Expr Id))
-> (FloatStats, FloatBinds, Expr Id)
-> (FloatStats, FloatBinds, Expr Id)
forall a b. (a -> b) -> a -> b
$ Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatExpr Expr (TaggedBndr FloatSpec)
expr) of { (FloatStats
fs, FloatBinds
floating_defns, Expr Id
expr') ->
let
annotated_defns :: FloatBinds
annotated_defns = Tickish Id -> FloatBinds -> FloatBinds
wrapTick (Tickish Id -> Tickish Id
forall id. Tickish id -> Tickish id
mkNoCount Tickish Id
tickish) FloatBinds
floating_defns
in
(FloatStats
fs, FloatBinds
annotated_defns, Tickish Id -> Expr Id -> Expr Id
forall b. Tickish Id -> Expr b -> Expr b
Tick Tickish Id
tickish Expr Id
expr') }
| Breakpoint{} <- Tickish Id
tickish
= case (Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatExpr Expr (TaggedBndr FloatSpec)
expr) of { (FloatStats
fs, FloatBinds
floating_defns, Expr Id
expr') ->
(FloatStats
fs, FloatBinds
floating_defns, Tickish Id -> Expr Id -> Expr Id
forall b. Tickish Id -> Expr b -> Expr b
Tick Tickish Id
tickish Expr Id
expr') }
| Bool
otherwise
= String -> SDoc -> (FloatStats, FloatBinds, Expr Id)
forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"floatExpr tick" (Tickish Id -> SDoc
forall a. Outputable a => a -> SDoc
ppr Tickish Id
tickish)
floatExpr (Cast Expr (TaggedBndr FloatSpec)
expr Coercion
co)
= case ((FloatStats, FloatBinds, Expr Id)
-> (FloatStats, FloatBinds, Expr Id)
atJoinCeiling ((FloatStats, FloatBinds, Expr Id)
-> (FloatStats, FloatBinds, Expr Id))
-> (FloatStats, FloatBinds, Expr Id)
-> (FloatStats, FloatBinds, Expr Id)
forall a b. (a -> b) -> a -> b
$ Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatExpr Expr (TaggedBndr FloatSpec)
expr) of { (FloatStats
fs, FloatBinds
floating_defns, Expr Id
expr') ->
(FloatStats
fs, FloatBinds
floating_defns, Expr Id -> Coercion -> Expr Id
forall b. Expr b -> Coercion -> Expr b
Cast Expr Id
expr' Coercion
co) }
floatExpr (Let LevelledBind
bind Expr (TaggedBndr FloatSpec)
body)
= case FloatSpec
bind_spec of
FloatMe Level
dest_lvl
-> case (LevelledBind -> (FloatStats, FloatBinds, [CoreBind])
floatBind LevelledBind
bind) of { (FloatStats
fsb, FloatBinds
bind_floats, [CoreBind]
binds') ->
case (Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatExpr Expr (TaggedBndr FloatSpec)
body) of { (FloatStats
fse, FloatBinds
body_floats, Expr Id
body') ->
let new_bind_floats :: FloatBinds
new_bind_floats = (FloatBinds -> FloatBinds -> FloatBinds)
-> FloatBinds -> [FloatBinds] -> FloatBinds
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr FloatBinds -> FloatBinds -> FloatBinds
plusFloats FloatBinds
emptyFloats
((CoreBind -> FloatBinds) -> [CoreBind] -> [FloatBinds]
forall a b. (a -> b) -> [a] -> [b]
map (Level -> CoreBind -> FloatBinds
unitLetFloat Level
dest_lvl) [CoreBind]
binds') in
( FloatStats -> FloatStats -> FloatStats
add_stats FloatStats
fsb FloatStats
fse
, FloatBinds
bind_floats FloatBinds -> FloatBinds -> FloatBinds
`plusFloats` FloatBinds
new_bind_floats
FloatBinds -> FloatBinds -> FloatBinds
`plusFloats` FloatBinds
body_floats
, Expr Id
body') }}
StayPut Level
bind_lvl
-> case (LevelledBind -> (FloatStats, FloatBinds, [CoreBind])
floatBind LevelledBind
bind) of { (FloatStats
fsb, FloatBinds
bind_floats, [CoreBind]
binds') ->
case (Level
-> Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatBody Level
bind_lvl Expr (TaggedBndr FloatSpec)
body) of { (FloatStats
fse, FloatBinds
body_floats, Expr Id
body') ->
( FloatStats -> FloatStats -> FloatStats
add_stats FloatStats
fsb FloatStats
fse
, FloatBinds
bind_floats FloatBinds -> FloatBinds -> FloatBinds
`plusFloats` FloatBinds
body_floats
, (CoreBind -> Expr Id -> Expr Id)
-> Expr Id -> [CoreBind] -> Expr Id
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr CoreBind -> Expr Id -> Expr Id
forall b. Bind b -> Expr b -> Expr b
Let Expr Id
body' [CoreBind]
binds' ) }}
where
bind_spec :: FloatSpec
bind_spec = case LevelledBind
bind of
NonRec (TB Id
_ FloatSpec
s) Expr (TaggedBndr FloatSpec)
_ -> FloatSpec
s
Rec ((TB Id
_ FloatSpec
s, Expr (TaggedBndr FloatSpec)
_) : [(TaggedBndr FloatSpec, Expr (TaggedBndr FloatSpec))]
_) -> FloatSpec
s
Rec [] -> String -> FloatSpec
forall a. String -> a
panic String
"floatExpr:rec"
floatExpr (Case Expr (TaggedBndr FloatSpec)
scrut (TB Id
case_bndr FloatSpec
case_spec) Type
ty [Alt (TaggedBndr FloatSpec)]
alts)
= case FloatSpec
case_spec of
FloatMe Level
dest_lvl
| [(con :: AltCon
con@(DataAlt {}), [TaggedBndr FloatSpec]
bndrs, Expr (TaggedBndr FloatSpec)
rhs)] <- [Alt (TaggedBndr FloatSpec)]
alts
-> case (FloatStats, FloatBinds, Expr Id)
-> (FloatStats, FloatBinds, Expr Id)
atJoinCeiling ((FloatStats, FloatBinds, Expr Id)
-> (FloatStats, FloatBinds, Expr Id))
-> (FloatStats, FloatBinds, Expr Id)
-> (FloatStats, FloatBinds, Expr Id)
forall a b. (a -> b) -> a -> b
$ Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatExpr Expr (TaggedBndr FloatSpec)
scrut of { (FloatStats
fse, FloatBinds
fde, Expr Id
scrut') ->
case Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatExpr Expr (TaggedBndr FloatSpec)
rhs of { (FloatStats
fsb, FloatBinds
fdb, Expr Id
rhs') ->
let
float :: FloatBinds
float = Level -> Expr Id -> Id -> AltCon -> [Id] -> FloatBinds
unitCaseFloat Level
dest_lvl Expr Id
scrut'
Id
case_bndr AltCon
con [Id
b | TB Id
b FloatSpec
_ <- [TaggedBndr FloatSpec]
bndrs]
in
(FloatStats -> FloatStats -> FloatStats
add_stats FloatStats
fse FloatStats
fsb, FloatBinds
fde FloatBinds -> FloatBinds -> FloatBinds
`plusFloats` FloatBinds
float FloatBinds -> FloatBinds -> FloatBinds
`plusFloats` FloatBinds
fdb, Expr Id
rhs') }}
| Bool
otherwise
-> String -> SDoc -> (FloatStats, FloatBinds, Expr Id)
forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"Floating multi-case" ([Alt (TaggedBndr FloatSpec)] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [Alt (TaggedBndr FloatSpec)]
alts)
StayPut Level
bind_lvl
-> case (FloatStats, FloatBinds, Expr Id)
-> (FloatStats, FloatBinds, Expr Id)
atJoinCeiling ((FloatStats, FloatBinds, Expr Id)
-> (FloatStats, FloatBinds, Expr Id))
-> (FloatStats, FloatBinds, Expr Id)
-> (FloatStats, FloatBinds, Expr Id)
forall a b. (a -> b) -> a -> b
$ Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatExpr Expr (TaggedBndr FloatSpec)
scrut of { (FloatStats
fse, FloatBinds
fde, Expr Id
scrut') ->
case (Alt (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, (AltCon, [Id], Expr Id)))
-> [Alt (TaggedBndr FloatSpec)]
-> (FloatStats, FloatBinds, [(AltCon, [Id], Expr Id)])
forall a b.
(a -> (FloatStats, FloatBinds, b))
-> [a] -> (FloatStats, FloatBinds, [b])
floatList (Level
-> Alt (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, (AltCon, [Id], Expr Id))
forall {a} {t}.
Level
-> (a, [TaggedBndr t], Expr (TaggedBndr FloatSpec))
-> (FloatStats, FloatBinds, (a, [Id], Expr Id))
float_alt Level
bind_lvl) [Alt (TaggedBndr FloatSpec)]
alts of { (FloatStats
fsa, FloatBinds
fda, [(AltCon, [Id], Expr Id)]
alts') ->
(FloatStats -> FloatStats -> FloatStats
add_stats FloatStats
fse FloatStats
fsa, FloatBinds
fda FloatBinds -> FloatBinds -> FloatBinds
`plusFloats` FloatBinds
fde, Expr Id -> Id -> Type -> [(AltCon, [Id], Expr Id)] -> Expr Id
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case Expr Id
scrut' Id
case_bndr Type
ty [(AltCon, [Id], Expr Id)]
alts')
}}
where
float_alt :: Level
-> (a, [TaggedBndr t], Expr (TaggedBndr FloatSpec))
-> (FloatStats, FloatBinds, (a, [Id], Expr Id))
float_alt Level
bind_lvl (a
con, [TaggedBndr t]
bs, Expr (TaggedBndr FloatSpec)
rhs)
= case (Level
-> Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatBody Level
bind_lvl Expr (TaggedBndr FloatSpec)
rhs) of { (FloatStats
fs, FloatBinds
rhs_floats, Expr Id
rhs') ->
(FloatStats
fs, FloatBinds
rhs_floats, (a
con, [Id
b | TB Id
b t
_ <- [TaggedBndr t]
bs], Expr Id
rhs')) }
floatRhs :: CoreBndr
-> LevelledExpr
-> (FloatStats, FloatBinds, CoreExpr)
floatRhs :: Id
-> Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatRhs Id
bndr Expr (TaggedBndr FloatSpec)
rhs
| Just Int
join_arity <- Id -> Maybe Int
isJoinId_maybe Id
bndr
, Just ([TaggedBndr FloatSpec]
bndrs, Expr (TaggedBndr FloatSpec)
body) <- Int
-> Expr (TaggedBndr FloatSpec)
-> [TaggedBndr FloatSpec]
-> Maybe ([TaggedBndr FloatSpec], Expr (TaggedBndr FloatSpec))
forall {t} {a}.
(Eq t, Num t) =>
t -> Expr a -> [a] -> Maybe ([a], Expr a)
try_collect Int
join_arity Expr (TaggedBndr FloatSpec)
rhs []
= case [TaggedBndr FloatSpec]
bndrs of
[] -> Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatExpr Expr (TaggedBndr FloatSpec)
rhs
(TB Id
_ FloatSpec
lam_spec):[TaggedBndr FloatSpec]
_ ->
let lvl :: Level
lvl = FloatSpec -> Level
floatSpecLevel FloatSpec
lam_spec in
case Level
-> Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatBody Level
lvl Expr (TaggedBndr FloatSpec)
body of { (FloatStats
fs, FloatBinds
floats, Expr Id
body') ->
(FloatStats
fs, FloatBinds
floats, [Id] -> Expr Id -> Expr Id
forall b. [b] -> Expr b -> Expr b
mkLams [Id
b | TB Id
b FloatSpec
_ <- [TaggedBndr FloatSpec]
bndrs] Expr Id
body') }
| Bool
otherwise
= (FloatStats, FloatBinds, Expr Id)
-> (FloatStats, FloatBinds, Expr Id)
atJoinCeiling ((FloatStats, FloatBinds, Expr Id)
-> (FloatStats, FloatBinds, Expr Id))
-> (FloatStats, FloatBinds, Expr Id)
-> (FloatStats, FloatBinds, Expr Id)
forall a b. (a -> b) -> a -> b
$ Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatExpr Expr (TaggedBndr FloatSpec)
rhs
where
try_collect :: t -> Expr a -> [a] -> Maybe ([a], Expr a)
try_collect t
0 Expr a
expr [a]
acc = ([a], Expr a) -> Maybe ([a], Expr a)
forall a. a -> Maybe a
Just ([a] -> [a]
forall a. [a] -> [a]
reverse [a]
acc, Expr a
expr)
try_collect t
n (Lam a
b Expr a
e) [a]
acc = t -> Expr a -> [a] -> Maybe ([a], Expr a)
try_collect (t
nt -> t -> t
forall a. Num a => a -> a -> a
-t
1) Expr a
e (a
ba -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
acc)
try_collect t
_ Expr a
_ [a]
_ = Maybe ([a], Expr a)
forall a. Maybe a
Nothing
data FloatStats
= FlS Int
Int
Int
get_stats :: FloatStats -> (Int, Int, Int)
get_stats :: FloatStats -> (Int, Int, Int)
get_stats (FlS Int
a Int
b Int
c) = (Int
a, Int
b, Int
c)
zeroStats :: FloatStats
zeroStats :: FloatStats
zeroStats = Int -> Int -> Int -> FloatStats
FlS Int
0 Int
0 Int
0
sum_stats :: [FloatStats] -> FloatStats
sum_stats :: [FloatStats] -> FloatStats
sum_stats [FloatStats]
xs = (FloatStats -> FloatStats -> FloatStats)
-> FloatStats -> [FloatStats] -> FloatStats
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr FloatStats -> FloatStats -> FloatStats
add_stats FloatStats
zeroStats [FloatStats]
xs
add_stats :: FloatStats -> FloatStats -> FloatStats
add_stats :: FloatStats -> FloatStats -> FloatStats
add_stats (FlS Int
a1 Int
b1 Int
c1) (FlS Int
a2 Int
b2 Int
c2)
= Int -> Int -> Int -> FloatStats
FlS (Int
a1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
a2) (Int
b1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
b2) (Int
c1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
c2)
add_to_stats :: FloatStats -> FloatBinds -> FloatStats
add_to_stats :: FloatStats -> FloatBinds -> FloatStats
add_to_stats (FlS Int
a Int
b Int
c) (FB Bag CoreBind
tops Bag FloatBind
ceils MajorEnv
others)
= Int -> Int -> Int -> FloatStats
FlS (Int
a Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Bag CoreBind -> Int
forall a. Bag a -> Int
lengthBag Bag CoreBind
tops)
(Int
b Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Bag FloatBind -> Int
forall a. Bag a -> Int
lengthBag Bag FloatBind
ceils Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Bag FloatBind -> Int
forall a. Bag a -> Int
lengthBag (MajorEnv -> Bag FloatBind
flattenMajor MajorEnv
others))
(Int
c Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
type FloatLet = CoreBind
type MajorEnv = M.IntMap MinorEnv
type MinorEnv = M.IntMap (Bag FloatBind)
data FloatBinds = FB !(Bag FloatLet)
!(Bag FloatBind)
!MajorEnv
instance Outputable FloatBinds where
ppr :: FloatBinds -> SDoc
ppr (FB Bag CoreBind
fbs Bag FloatBind
ceils MajorEnv
defs)
= String -> SDoc
text String
"FB" SDoc -> SDoc -> SDoc
<+> (SDoc -> SDoc
braces (SDoc -> SDoc) -> SDoc -> SDoc
forall a b. (a -> b) -> a -> b
$ [SDoc] -> SDoc
vcat
[ String -> SDoc
text String
"tops =" SDoc -> SDoc -> SDoc
<+> Bag CoreBind -> SDoc
forall a. Outputable a => a -> SDoc
ppr Bag CoreBind
fbs
, String -> SDoc
text String
"ceils =" SDoc -> SDoc -> SDoc
<+> Bag FloatBind -> SDoc
forall a. Outputable a => a -> SDoc
ppr Bag FloatBind
ceils
, String -> SDoc
text String
"non-tops =" SDoc -> SDoc -> SDoc
<+> MajorEnv -> SDoc
forall a. Outputable a => a -> SDoc
ppr MajorEnv
defs ])
flattenTopFloats :: FloatBinds -> Bag CoreBind
flattenTopFloats :: FloatBinds -> Bag CoreBind
flattenTopFloats (FB Bag CoreBind
tops Bag FloatBind
ceils MajorEnv
defs)
= ASSERT2( isEmptyBag (flattenMajor defs), ppr defs )
ASSERT2( isEmptyBag ceils, ppr ceils )
Bag CoreBind
tops
addTopFloatPairs :: Bag CoreBind -> [(Id,CoreExpr)] -> [(Id,CoreExpr)]
addTopFloatPairs :: Bag CoreBind -> [(Id, Expr Id)] -> [(Id, Expr Id)]
addTopFloatPairs Bag CoreBind
float_bag [(Id, Expr Id)]
prs
= (CoreBind -> [(Id, Expr Id)] -> [(Id, Expr Id)])
-> [(Id, Expr Id)] -> Bag CoreBind -> [(Id, Expr Id)]
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr CoreBind -> [(Id, Expr Id)] -> [(Id, Expr Id)]
forall {a}. Bind a -> [(a, Expr a)] -> [(a, Expr a)]
add [(Id, Expr Id)]
prs Bag CoreBind
float_bag
where
add :: Bind a -> [(a, Expr a)] -> [(a, Expr a)]
add (NonRec a
b Expr a
r) [(a, Expr a)]
prs = (a
b,Expr a
r)(a, Expr a) -> [(a, Expr a)] -> [(a, Expr a)]
forall a. a -> [a] -> [a]
:[(a, Expr a)]
prs
add (Rec [(a, Expr a)]
prs1) [(a, Expr a)]
prs2 = [(a, Expr a)]
prs1 [(a, Expr a)] -> [(a, Expr a)] -> [(a, Expr a)]
forall a. [a] -> [a] -> [a]
++ [(a, Expr a)]
prs2
flattenMajor :: MajorEnv -> Bag FloatBind
flattenMajor :: MajorEnv -> Bag FloatBind
flattenMajor = (IntMap (Bag FloatBind) -> Bag FloatBind -> Bag FloatBind)
-> Bag FloatBind -> MajorEnv -> Bag FloatBind
forall a b. (a -> b -> b) -> b -> IntMap a -> b
M.foldr (Bag FloatBind -> Bag FloatBind -> Bag FloatBind
forall a. Bag a -> Bag a -> Bag a
unionBags (Bag FloatBind -> Bag FloatBind -> Bag FloatBind)
-> (IntMap (Bag FloatBind) -> Bag FloatBind)
-> IntMap (Bag FloatBind)
-> Bag FloatBind
-> Bag FloatBind
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IntMap (Bag FloatBind) -> Bag FloatBind
flattenMinor) Bag FloatBind
forall a. Bag a
emptyBag
flattenMinor :: MinorEnv -> Bag FloatBind
flattenMinor :: IntMap (Bag FloatBind) -> Bag FloatBind
flattenMinor = (Bag FloatBind -> Bag FloatBind -> Bag FloatBind)
-> Bag FloatBind -> IntMap (Bag FloatBind) -> Bag FloatBind
forall a b. (a -> b -> b) -> b -> IntMap a -> b
M.foldr Bag FloatBind -> Bag FloatBind -> Bag FloatBind
forall a. Bag a -> Bag a -> Bag a
unionBags Bag FloatBind
forall a. Bag a
emptyBag
emptyFloats :: FloatBinds
emptyFloats :: FloatBinds
emptyFloats = Bag CoreBind -> Bag FloatBind -> MajorEnv -> FloatBinds
FB Bag CoreBind
forall a. Bag a
emptyBag Bag FloatBind
forall a. Bag a
emptyBag MajorEnv
forall a. IntMap a
M.empty
unitCaseFloat :: Level -> CoreExpr -> Id -> AltCon -> [Var] -> FloatBinds
unitCaseFloat :: Level -> Expr Id -> Id -> AltCon -> [Id] -> FloatBinds
unitCaseFloat (Level Int
major Int
minor LevelType
t) Expr Id
e Id
b AltCon
con [Id]
bs
| LevelType
t LevelType -> LevelType -> Bool
forall a. Eq a => a -> a -> Bool
== LevelType
JoinCeilLvl
= Bag CoreBind -> Bag FloatBind -> MajorEnv -> FloatBinds
FB Bag CoreBind
forall a. Bag a
emptyBag Bag FloatBind
floats MajorEnv
forall a. IntMap a
M.empty
| Bool
otherwise
= Bag CoreBind -> Bag FloatBind -> MajorEnv -> FloatBinds
FB Bag CoreBind
forall a. Bag a
emptyBag Bag FloatBind
forall a. Bag a
emptyBag (Int -> IntMap (Bag FloatBind) -> MajorEnv
forall a. Int -> a -> IntMap a
M.singleton Int
major (Int -> Bag FloatBind -> IntMap (Bag FloatBind)
forall a. Int -> a -> IntMap a
M.singleton Int
minor Bag FloatBind
floats))
where
floats :: Bag FloatBind
floats = FloatBind -> Bag FloatBind
forall a. a -> Bag a
unitBag (Expr Id -> Id -> AltCon -> [Id] -> FloatBind
FloatCase Expr Id
e Id
b AltCon
con [Id]
bs)
unitLetFloat :: Level -> FloatLet -> FloatBinds
unitLetFloat :: Level -> CoreBind -> FloatBinds
unitLetFloat lvl :: Level
lvl@(Level Int
major Int
minor LevelType
t) CoreBind
b
| Level -> Bool
isTopLvl Level
lvl = Bag CoreBind -> Bag FloatBind -> MajorEnv -> FloatBinds
FB (CoreBind -> Bag CoreBind
forall a. a -> Bag a
unitBag CoreBind
b) Bag FloatBind
forall a. Bag a
emptyBag MajorEnv
forall a. IntMap a
M.empty
| LevelType
t LevelType -> LevelType -> Bool
forall a. Eq a => a -> a -> Bool
== LevelType
JoinCeilLvl = Bag CoreBind -> Bag FloatBind -> MajorEnv -> FloatBinds
FB Bag CoreBind
forall a. Bag a
emptyBag Bag FloatBind
floats MajorEnv
forall a. IntMap a
M.empty
| Bool
otherwise = Bag CoreBind -> Bag FloatBind -> MajorEnv -> FloatBinds
FB Bag CoreBind
forall a. Bag a
emptyBag Bag FloatBind
forall a. Bag a
emptyBag (Int -> IntMap (Bag FloatBind) -> MajorEnv
forall a. Int -> a -> IntMap a
M.singleton Int
major
(Int -> Bag FloatBind -> IntMap (Bag FloatBind)
forall a. Int -> a -> IntMap a
M.singleton Int
minor Bag FloatBind
floats))
where
floats :: Bag FloatBind
floats = FloatBind -> Bag FloatBind
forall a. a -> Bag a
unitBag (CoreBind -> FloatBind
FloatLet CoreBind
b)
plusFloats :: FloatBinds -> FloatBinds -> FloatBinds
plusFloats :: FloatBinds -> FloatBinds -> FloatBinds
plusFloats (FB Bag CoreBind
t1 Bag FloatBind
c1 MajorEnv
l1) (FB Bag CoreBind
t2 Bag FloatBind
c2 MajorEnv
l2)
= Bag CoreBind -> Bag FloatBind -> MajorEnv -> FloatBinds
FB (Bag CoreBind
t1 Bag CoreBind -> Bag CoreBind -> Bag CoreBind
forall a. Bag a -> Bag a -> Bag a
`unionBags` Bag CoreBind
t2) (Bag FloatBind
c1 Bag FloatBind -> Bag FloatBind -> Bag FloatBind
forall a. Bag a -> Bag a -> Bag a
`unionBags` Bag FloatBind
c2) (MajorEnv
l1 MajorEnv -> MajorEnv -> MajorEnv
`plusMajor` MajorEnv
l2)
plusMajor :: MajorEnv -> MajorEnv -> MajorEnv
plusMajor :: MajorEnv -> MajorEnv -> MajorEnv
plusMajor = (IntMap (Bag FloatBind)
-> IntMap (Bag FloatBind) -> IntMap (Bag FloatBind))
-> MajorEnv -> MajorEnv -> MajorEnv
forall a. (a -> a -> a) -> IntMap a -> IntMap a -> IntMap a
M.unionWith IntMap (Bag FloatBind)
-> IntMap (Bag FloatBind) -> IntMap (Bag FloatBind)
plusMinor
plusMinor :: MinorEnv -> MinorEnv -> MinorEnv
plusMinor :: IntMap (Bag FloatBind)
-> IntMap (Bag FloatBind) -> IntMap (Bag FloatBind)
plusMinor = (Bag FloatBind -> Bag FloatBind -> Bag FloatBind)
-> IntMap (Bag FloatBind)
-> IntMap (Bag FloatBind)
-> IntMap (Bag FloatBind)
forall a. (a -> a -> a) -> IntMap a -> IntMap a -> IntMap a
M.unionWith Bag FloatBind -> Bag FloatBind -> Bag FloatBind
forall a. Bag a -> Bag a -> Bag a
unionBags
install :: Bag FloatBind -> CoreExpr -> CoreExpr
install :: Bag FloatBind -> Expr Id -> Expr Id
install Bag FloatBind
defn_groups Expr Id
expr
= (FloatBind -> Expr Id -> Expr Id)
-> Expr Id -> Bag FloatBind -> Expr Id
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr FloatBind -> Expr Id -> Expr Id
wrapFloat Expr Id
expr Bag FloatBind
defn_groups
partitionByLevel
:: Level
-> FloatBinds
-> (FloatBinds,
Bag FloatBind)
partitionByLevel :: Level -> FloatBinds -> (FloatBinds, Bag FloatBind)
partitionByLevel (Level Int
major Int
minor LevelType
typ) (FB Bag CoreBind
tops Bag FloatBind
ceils MajorEnv
defns)
= (Bag CoreBind -> Bag FloatBind -> MajorEnv -> FloatBinds
FB Bag CoreBind
tops Bag FloatBind
ceils' (MajorEnv
outer_maj MajorEnv -> MajorEnv -> MajorEnv
`plusMajor` Int -> IntMap (Bag FloatBind) -> MajorEnv
forall a. Int -> a -> IntMap a
M.singleton Int
major IntMap (Bag FloatBind)
outer_min),
Bag FloatBind
here_min Bag FloatBind -> Bag FloatBind -> Bag FloatBind
forall a. Bag a -> Bag a -> Bag a
`unionBags` Bag FloatBind
here_ceil
Bag FloatBind -> Bag FloatBind -> Bag FloatBind
forall a. Bag a -> Bag a -> Bag a
`unionBags` IntMap (Bag FloatBind) -> Bag FloatBind
flattenMinor IntMap (Bag FloatBind)
inner_min
Bag FloatBind -> Bag FloatBind -> Bag FloatBind
forall a. Bag a -> Bag a -> Bag a
`unionBags` MajorEnv -> Bag FloatBind
flattenMajor MajorEnv
inner_maj)
where
(MajorEnv
outer_maj, Maybe (IntMap (Bag FloatBind))
mb_here_maj, MajorEnv
inner_maj) = Int
-> MajorEnv -> (MajorEnv, Maybe (IntMap (Bag FloatBind)), MajorEnv)
forall a. Int -> IntMap a -> (IntMap a, Maybe a, IntMap a)
M.splitLookup Int
major MajorEnv
defns
(IntMap (Bag FloatBind)
outer_min, Maybe (Bag FloatBind)
mb_here_min, IntMap (Bag FloatBind)
inner_min) = case Maybe (IntMap (Bag FloatBind))
mb_here_maj of
Maybe (IntMap (Bag FloatBind))
Nothing -> (IntMap (Bag FloatBind)
forall a. IntMap a
M.empty, Maybe (Bag FloatBind)
forall a. Maybe a
Nothing, IntMap (Bag FloatBind)
forall a. IntMap a
M.empty)
Just IntMap (Bag FloatBind)
min_defns -> Int
-> IntMap (Bag FloatBind)
-> (IntMap (Bag FloatBind), Maybe (Bag FloatBind),
IntMap (Bag FloatBind))
forall a. Int -> IntMap a -> (IntMap a, Maybe a, IntMap a)
M.splitLookup Int
minor IntMap (Bag FloatBind)
min_defns
here_min :: Bag FloatBind
here_min = Maybe (Bag FloatBind)
mb_here_min Maybe (Bag FloatBind) -> Bag FloatBind -> Bag FloatBind
forall a. Maybe a -> a -> a
`orElse` Bag FloatBind
forall a. Bag a
emptyBag
(Bag FloatBind
here_ceil, Bag FloatBind
ceils') | LevelType
typ LevelType -> LevelType -> Bool
forall a. Eq a => a -> a -> Bool
== LevelType
JoinCeilLvl = (Bag FloatBind
ceils, Bag FloatBind
forall a. Bag a
emptyBag)
| Bool
otherwise = (Bag FloatBind
forall a. Bag a
emptyBag, Bag FloatBind
ceils)
partitionAtJoinCeiling :: FloatBinds -> (FloatBinds, Bag FloatBind)
partitionAtJoinCeiling :: FloatBinds -> (FloatBinds, Bag FloatBind)
partitionAtJoinCeiling (FB Bag CoreBind
tops Bag FloatBind
ceils MajorEnv
defs)
= (Bag CoreBind -> Bag FloatBind -> MajorEnv -> FloatBinds
FB Bag CoreBind
tops Bag FloatBind
forall a. Bag a
emptyBag MajorEnv
defs, Bag FloatBind
ceils)
atJoinCeiling :: (FloatStats, FloatBinds, CoreExpr)
-> (FloatStats, FloatBinds, CoreExpr)
atJoinCeiling :: (FloatStats, FloatBinds, Expr Id)
-> (FloatStats, FloatBinds, Expr Id)
atJoinCeiling (FloatStats
fs, FloatBinds
floats, Expr Id
expr')
= (FloatStats
fs, FloatBinds
floats', Bag FloatBind -> Expr Id -> Expr Id
install Bag FloatBind
ceils Expr Id
expr')
where
(FloatBinds
floats', Bag FloatBind
ceils) = FloatBinds -> (FloatBinds, Bag FloatBind)
partitionAtJoinCeiling FloatBinds
floats
wrapTick :: Tickish Id -> FloatBinds -> FloatBinds
wrapTick :: Tickish Id -> FloatBinds -> FloatBinds
wrapTick Tickish Id
t (FB Bag CoreBind
tops Bag FloatBind
ceils MajorEnv
defns)
= Bag CoreBind -> Bag FloatBind -> MajorEnv -> FloatBinds
FB ((CoreBind -> CoreBind) -> Bag CoreBind -> Bag CoreBind
forall a b. (a -> b) -> Bag a -> Bag b
mapBag CoreBind -> CoreBind
wrap_bind Bag CoreBind
tops) (Bag FloatBind -> Bag FloatBind
wrap_defns Bag FloatBind
ceils)
((IntMap (Bag FloatBind) -> IntMap (Bag FloatBind))
-> MajorEnv -> MajorEnv
forall a b. (a -> b) -> IntMap a -> IntMap b
M.map ((Bag FloatBind -> Bag FloatBind)
-> IntMap (Bag FloatBind) -> IntMap (Bag FloatBind)
forall a b. (a -> b) -> IntMap a -> IntMap b
M.map Bag FloatBind -> Bag FloatBind
wrap_defns) MajorEnv
defns)
where
wrap_defns :: Bag FloatBind -> Bag FloatBind
wrap_defns = (FloatBind -> FloatBind) -> Bag FloatBind -> Bag FloatBind
forall a b. (a -> b) -> Bag a -> Bag b
mapBag FloatBind -> FloatBind
wrap_one
wrap_bind :: CoreBind -> CoreBind
wrap_bind (NonRec Id
binder Expr Id
rhs) = Id -> Expr Id -> CoreBind
forall b. b -> Expr b -> Bind b
NonRec Id
binder (Expr Id -> Expr Id
maybe_tick Expr Id
rhs)
wrap_bind (Rec [(Id, Expr Id)]
pairs) = [(Id, Expr Id)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec ((Expr Id -> Expr Id) -> [(Id, Expr Id)] -> [(Id, Expr Id)]
forall b c a. (b -> c) -> [(a, b)] -> [(a, c)]
mapSnd Expr Id -> Expr Id
maybe_tick [(Id, Expr Id)]
pairs)
wrap_one :: FloatBind -> FloatBind
wrap_one (FloatLet CoreBind
bind) = CoreBind -> FloatBind
FloatLet (CoreBind -> CoreBind
wrap_bind CoreBind
bind)
wrap_one (FloatCase Expr Id
e Id
b AltCon
c [Id]
bs) = Expr Id -> Id -> AltCon -> [Id] -> FloatBind
FloatCase (Expr Id -> Expr Id
maybe_tick Expr Id
e) Id
b AltCon
c [Id]
bs
maybe_tick :: Expr Id -> Expr Id
maybe_tick Expr Id
e | Expr Id -> Bool
exprIsHNF Expr Id
e = Tickish Id -> Expr Id -> Expr Id
tickHNFArgs Tickish Id
t Expr Id
e
| Bool
otherwise = Tickish Id -> Expr Id -> Expr Id
mkTick Tickish Id
t Expr Id
e