{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell #-}
module Clash.Normalize.Transformations.ANF
( makeANF
, nonRepANF
) where
import Control.Arrow ((***))
import Control.Lens (_2)
import qualified Control.Lens as Lens
import qualified Control.Monad as Monad
import Control.Monad.State (StateT, lift, modify, runStateT)
import Control.Monad.Writer (listen)
import Data.Bifunctor (second)
import qualified Data.Monoid as Monoid (Any(..))
import qualified Data.Text.Extra as Text (showt)
import GHC.Stack (HasCallStack)
import Clash.Signal.Internal (Signal(..))
import Clash.Core.DataCon (DataCon(..))
import Clash.Core.HasFreeVars (disjointFreeVars)
import Clash.Core.HasType
import Clash.Core.Name (mkUnsafeSystemName, nameOcc)
import Clash.Core.Subst (deshadowLetExpr, freshenTm)
import Clash.Core.Term
( Alt, CoreContext(..), LetBinding, Pat(..), PrimInfo(..), Term(..)
, collectArgs, collectTicks, mkTicks, partitionTicks, stripTicks)
import Clash.Core.TermInfo (isCon, isLocalVar, isPrim, isVar)
import Clash.Core.TyCon (TyConMap)
import Clash.Core.Type (Type, TypeView(..), coreView, tyView)
import Clash.Core.Util (mkSelectorCase)
import Clash.Core.Var (Id)
import Clash.Core.VarEnv (InScopeSet, extendInScopeSet, extendInScopeSetList, mkVarSet)
import Clash.Netlist.Util (bindsExistentials)
import Clash.Normalize.Transformations.Specialize (specialize)
import Clash.Normalize.Types (NormRewrite, NormalizeSession)
import Clash.Rewrite.Combinators (bottomupR)
import Clash.Rewrite.Types
(Transform, TransformContext(..), tcCache)
import Clash.Rewrite.Util
(changed, isUntranslatable, mkDerivedName, mkTmBinderFor)
import Clash.Rewrite.WorkFree (isConstant, isConstantNotClockReset)
import Clash.Util (curLoc)
makeANF :: HasCallStack => NormRewrite
makeANF :: HasCallStack => NormRewrite
makeANF (TransformContext InScopeSet
is0 Context
ctx) (Lam Id
bndr Term
e) = do
let ctx' :: TransformContext
ctx' = InScopeSet -> Context -> TransformContext
TransformContext (InScopeSet -> Id -> InScopeSet
forall a. InScopeSet -> Var a -> InScopeSet
extendInScopeSet InScopeSet
is0 Id
bndr) (Id -> CoreContext
LamBody Id
bndr CoreContext -> Context -> Context
forall a. a -> [a] -> [a]
: Context
ctx)
e' <- HasCallStack => NormRewrite
NormRewrite
makeANF TransformContext
ctx' Term
e
return (Lam bndr e')
makeANF TransformContext
_ e :: Term
e@(TyLam {}) = Term -> RewriteMonad NormalizeState Term
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
makeANF ctx :: TransformContext
ctx@(TransformContext InScopeSet
is0 Context
_) Term
e0 = do
let (InScopeSet
is2,Term
e1) = InScopeSet -> Term -> (InScopeSet, Term)
freshenTm InScopeSet
is0 Term
e0
((e2,(bndrs,_)),Monoid.getAny -> hasChanged) <-
RewriteMonad NormalizeState (Term, ([LetBinding], InScopeSet))
-> RewriteMonad
NormalizeState ((Term, ([LetBinding], InScopeSet)), Any)
forall a.
RewriteMonad NormalizeState a
-> RewriteMonad NormalizeState (a, Any)
forall w (m :: Type -> Type) a. MonadWriter w m => m a -> m (a, w)
listen (StateT ([LetBinding], InScopeSet) NormalizeSession Term
-> ([LetBinding], InScopeSet)
-> RewriteMonad NormalizeState (Term, ([LetBinding], InScopeSet))
forall s (m :: Type -> Type) a. StateT s m a -> s -> m (a, s)
runStateT (Transform (StateT ([LetBinding], InScopeSet) NormalizeSession)
-> Transform (StateT ([LetBinding], InScopeSet) NormalizeSession)
forall (m :: Type -> Type). Monad m => Transform m -> Transform m
bottomupR HasCallStack =>
Transform (StateT ([LetBinding], InScopeSet) NormalizeSession)
Transform (StateT ([LetBinding], InScopeSet) NormalizeSession)
collectANF TransformContext
ctx Term
e1) ([],InScopeSet
is2))
case bndrs of
[] -> if Bool
hasChanged then Term -> RewriteMonad NormalizeState Term
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e2 else Term -> RewriteMonad NormalizeState Term
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e0
[LetBinding]
_ -> do
let (Term
e3,[TickInfo]
ticks) = Term -> (Term, [TickInfo])
collectTicks Term
e2
([TickInfo]
srcTicks,[TickInfo]
nmTicks) = [TickInfo] -> ([TickInfo], [TickInfo])
partitionTicks [TickInfo]
ticks
Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> [TickInfo] -> Term
mkTicks ([LetBinding] -> Term -> Term
Letrec [LetBinding]
bndrs (Term -> [TickInfo] -> Term
mkTicks Term
e3 [TickInfo]
srcTicks)) [TickInfo]
nmTicks)
{-# SCC makeANF #-}
type NormRewriteW = Transform (StateT ([LetBinding],InScopeSet) NormalizeSession)
tellBinders :: [LetBinding] -> StateT ([LetBinding],InScopeSet) NormalizeSession ()
tellBinders :: [LetBinding]
-> StateT ([LetBinding], InScopeSet) NormalizeSession ()
tellBinders [LetBinding]
bs = (([LetBinding], InScopeSet) -> ([LetBinding], InScopeSet))
-> StateT ([LetBinding], InScopeSet) NormalizeSession ()
forall s (m :: Type -> Type). MonadState s m => (s -> s) -> m ()
modify (([LetBinding]
bs [LetBinding] -> [LetBinding] -> [LetBinding]
forall a. [a] -> [a] -> [a]
++) ([LetBinding] -> [LetBinding])
-> (InScopeSet -> InScopeSet)
-> ([LetBinding], InScopeSet)
-> ([LetBinding], InScopeSet)
forall b c b' c'. (b -> c) -> (b' -> c') -> (b, b') -> (c, c')
forall (a :: Type -> Type -> Type) b c b' c'.
Arrow a =>
a b c -> a b' c' -> a (b, b') (c, c')
*** (InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
`extendInScopeSetList` ((LetBinding -> Id) -> [LetBinding] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map LetBinding -> Id
forall a b. (a, b) -> a
fst [LetBinding]
bs)))
notifyBinders :: Monad m => [LetBinding] -> StateT ([LetBinding],InScopeSet) m ()
notifyBinders :: forall (m :: Type -> Type).
Monad m =>
[LetBinding] -> StateT ([LetBinding], InScopeSet) m ()
notifyBinders [LetBinding]
bs = (([LetBinding], InScopeSet) -> ([LetBinding], InScopeSet))
-> StateT ([LetBinding], InScopeSet) m ()
forall s (m :: Type -> Type). MonadState s m => (s -> s) -> m ()
modify ((InScopeSet -> InScopeSet)
-> ([LetBinding], InScopeSet) -> ([LetBinding], InScopeSet)
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: Type -> Type -> Type) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
`extendInScopeSetList` ((LetBinding -> Id) -> [LetBinding] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map LetBinding -> Id
forall a b. (a, b) -> a
fst [LetBinding]
bs)))
isSimIOTy
:: TyConMap
-> Type
-> Bool
isSimIOTy :: TyConMap -> Type -> Bool
isSimIOTy TyConMap
tcm Type
ty = case Type -> TypeView
tyView (TyConMap -> Type -> Type
coreView TyConMap
tcm Type
ty) of
TyConApp TyConName
tcNm [Type]
args
| TyConName -> OccName
forall a. Name a -> OccName
nameOcc TyConName
tcNm OccName -> OccName -> Bool
forall a. Eq a => a -> a -> Bool
== OccName
"Clash.Explicit.SimIO.SimIO"
-> Bool
True
| TyConName -> OccName
forall a. Name a -> OccName
nameOcc TyConName
tcNm OccName -> OccName -> Bool
forall a. Eq a => a -> a -> Bool
== OccName
"GHC.Prim.(#,#)"
, [Type
_,Type
_,Type
st,Type
_] <- [Type]
args
-> TyConMap -> Type -> Bool
isStateTokenTy TyConMap
tcm Type
st
FunTy Type
_ Type
res -> TyConMap -> Type -> Bool
isSimIOTy TyConMap
tcm Type
res
TypeView
_ -> Bool
False
isStateTokenTy
:: TyConMap
-> Type
-> Bool
isStateTokenTy :: TyConMap -> Type -> Bool
isStateTokenTy TyConMap
tcm Type
ty = case Type -> TypeView
tyView (TyConMap -> Type -> Type
coreView TyConMap
tcm Type
ty) of
TyConApp TyConName
tcNm [Type]
_ -> TyConName -> OccName
forall a. Name a -> OccName
nameOcc TyConName
tcNm OccName -> OccName -> Bool
forall a. Eq a => a -> a -> Bool
== OccName
"GHC.Prim.State#"
TypeView
_ -> Bool
False
collectANF :: HasCallStack => NormRewriteW
collectANF :: HasCallStack =>
Transform (StateT ([LetBinding], InScopeSet) NormalizeSession)
collectANF TransformContext
ctx e :: Term
e@(App Term
appf Term
arg)
| (Term
conVarPrim, [Either Term Type]
_) <- Term -> (Term, [Either Term Type])
collectArgs Term
e
, Term -> Bool
isCon Term
conVarPrim Bool -> Bool -> Bool
|| Term -> Bool
isPrim Term
conVarPrim Bool -> Bool -> Bool
|| Term -> Bool
isVar Term
conVarPrim
= do
tcm <- Getting TyConMap RewriteEnv TyConMap
-> StateT ([LetBinding], InScopeSet) NormalizeSession TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
untranslatable <- lift (isUntranslatable False arg)
let localVar = Term -> Bool
isLocalVar Term
arg
constantNoCR = TyConMap -> Term -> Bool
isConstantNotClockReset TyConMap
tcm Term
arg
case (untranslatable,localVar || constantNoCR, isSimBind conVarPrim,arg) of
(Bool
False,Bool
False,Bool
False,Term
_) -> do
is1 <- Getting InScopeSet ([LetBinding], InScopeSet) InScopeSet
-> StateT ([LetBinding], InScopeSet) NormalizeSession InScopeSet
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting InScopeSet ([LetBinding], InScopeSet) InScopeSet
forall s t a b. Field2 s t a b => Lens s t a b
Lens
([LetBinding], InScopeSet)
([LetBinding], InScopeSet)
InScopeSet
InScopeSet
_2
argId <- lift (mkTmBinderFor is1 tcm (mkDerivedName ctx "app_arg") arg)
tellBinders [(argId,arg)]
return (App appf (Var argId))
(Bool
True,Bool
False,Bool
_,Letrec [LetBinding]
binds Term
body) -> do
[LetBinding]
-> StateT ([LetBinding], InScopeSet) NormalizeSession ()
tellBinders [LetBinding]
binds
Term -> StateT ([LetBinding], InScopeSet) NormalizeSession Term
forall a. a -> StateT ([LetBinding], InScopeSet) NormalizeSession a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Term -> Term -> Term
App Term
appf Term
body)
(Bool, Bool, Bool, Term)
_ -> Term -> StateT ([LetBinding], InScopeSet) NormalizeSession Term
forall a. a -> StateT ([LetBinding], InScopeSet) NormalizeSession a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
where
isSimBind :: Term -> Bool
isSimBind (Prim PrimInfo
p) = PrimInfo -> OccName
primName PrimInfo
p OccName -> OccName -> Bool
forall a. Eq a => a -> a -> Bool
== OccName
"Clash.Explicit.SimIO.bindSimIO#"
isSimBind Term
_ = Bool
False
collectANF TransformContext
_ (Letrec [LetBinding]
binds Term
body) = do
tcm <- Getting TyConMap RewriteEnv TyConMap
-> StateT ([LetBinding], InScopeSet) NormalizeSession TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
let isSimIO = TyConMap -> Type -> Bool
isSimIOTy TyConMap
tcm (TyConMap -> Term -> Type
forall a. InferType a => TyConMap -> a -> Type
inferCoreTypeOf TyConMap
tcm Term
body)
untranslatable <- lift (isUntranslatable False body)
let localVar = Term -> Bool
isLocalVar Term
body
if localVar || untranslatable || isSimIO
then do
tellBinders binds
return body
else do
is1 <- Lens.use _2
argId <- lift (mkTmBinderFor is1 tcm (mkUnsafeSystemName "result" 0) body)
tellBinders [(argId,body)]
tellBinders binds
return (Var argId)
collectANF TransformContext
_ e :: Term
e@(Case Term
_ Type
_ [(DataPat DataCon
dc [TyVar]
_ [Id]
_,Term
_)])
| Name DataCon -> OccName
forall a. Name a -> OccName
nameOcc (DataCon -> Name DataCon
dcName DataCon
dc) OccName -> OccName -> Bool
forall a. Eq a => a -> a -> Bool
== Name -> OccName
forall a. Show a => a -> OccName
Text.showt '(:-) = Term -> StateT ([LetBinding], InScopeSet) NormalizeSession Term
forall a. a -> StateT ([LetBinding], InScopeSet) NormalizeSession a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
collectANF TransformContext
ctx (Case Term
subj Type
ty [Alt]
alts) = do
let localVar :: Bool
localVar = Term -> Bool
isLocalVar Term
subj
let isConstantSubj :: Bool
isConstantSubj = Term -> Bool
isConstant Term
subj
(subj',subjBinders) <- if Bool
localVar Bool -> Bool -> Bool
|| Bool
isConstantSubj
then (Term, [LetBinding])
-> StateT
([LetBinding], InScopeSet) NormalizeSession (Term, [LetBinding])
forall a. a -> StateT ([LetBinding], InScopeSet) NormalizeSession a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Term
subj,[])
else do
tcm <- Getting TyConMap RewriteEnv TyConMap
-> StateT ([LetBinding], InScopeSet) NormalizeSession TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
is1 <- Lens.use _2
argId <- lift (mkTmBinderFor is1 tcm (mkDerivedName ctx "case_scrut") subj)
notifyBinders [(argId,subj)]
return (Var argId,[(argId,subj)])
tcm <- Lens.view tcCache
let isSimIOAlt = TyConMap -> Type -> Bool
isSimIOTy TyConMap
tcm Type
ty
alts' <- mapM (doAlt isSimIOAlt subj') alts
tellBinders subjBinders
case alts' of
[(DataPat DataCon
_ [] [Id]
xs,Term
altExpr)]
| [Id] -> VarSet
forall a. [Var a] -> VarSet
mkVarSet [Id]
xs VarSet -> Term -> Bool
forall a. HasFreeVars a => VarSet -> a -> Bool
`disjointFreeVars` Term
altExpr Bool -> Bool -> Bool
|| Bool
isSimIOAlt
-> Term -> StateT ([LetBinding], InScopeSet) NormalizeSession Term
forall a. a -> StateT ([LetBinding], InScopeSet) NormalizeSession a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
altExpr
[Alt]
_ -> Term -> StateT ([LetBinding], InScopeSet) NormalizeSession Term
forall a. a -> StateT ([LetBinding], InScopeSet) NormalizeSession a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Term -> Type -> [Alt] -> Term
Case Term
subj' Type
ty [Alt]
alts')
where
doAlt :: Bool -> Term -> Alt -> StateT ([LetBinding],InScopeSet) NormalizeSession Alt
doAlt :: Bool
-> Term
-> Alt
-> StateT ([LetBinding], InScopeSet) NormalizeSession Alt
doAlt Bool
isSimIOAlt Term
subj' alt :: Alt
alt@(DataPat DataCon
dc [TyVar]
exts [Id]
xs,Term
altExpr) | Bool -> Bool
not ([TyVar] -> [Id] -> Bool
forall a. [TyVar] -> [Var a] -> Bool
bindsExistentials [TyVar]
exts [Id]
xs) = do
let lv :: Bool
lv = Term -> Bool
isLocalVar Term
altExpr
patSels <- (Id
-> Int
-> StateT ([LetBinding], InScopeSet) NormalizeSession LetBinding)
-> [Id]
-> [Int]
-> StateT ([LetBinding], InScopeSet) NormalizeSession [LetBinding]
forall (m :: Type -> Type) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
Monad.zipWithM (Term
-> DataCon
-> Id
-> Int
-> StateT ([LetBinding], InScopeSet) NormalizeSession LetBinding
doPatBndr Term
subj' DataCon
dc) [Id]
xs [Int
0..]
let altExprIsConstant = Term -> Bool
isConstant Term
altExpr
let usesXs (Var Id
n) = (Id -> Bool) -> [Id] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
any (Id -> Id -> Bool
forall a. Eq a => a -> a -> Bool
== Id
n) [Id]
xs
usesXs Term
_ = Bool
False
if or [isSimIOAlt, lv && (not (usesXs altExpr) || length alts == 1), altExprIsConstant]
then do
tellBinders patSels
return alt
else do
tcm <- Lens.view tcCache
is1 <- Lens.use _2
altId <- lift (mkTmBinderFor is1 tcm (mkDerivedName ctx "case_alt") altExpr)
tellBinders (patSels ++ [(altId,altExpr)])
return (DataPat dc exts xs,Var altId)
doAlt Bool
_ Term
_ alt :: Alt
alt@(DataPat {}, Term
_) = Alt -> StateT ([LetBinding], InScopeSet) NormalizeSession Alt
forall a. a -> StateT ([LetBinding], InScopeSet) NormalizeSession a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Alt
alt
doAlt Bool
isSimIOAlt Term
_ alt :: Alt
alt@(Pat
pat,Term
altExpr) = do
let lv :: Bool
lv = Term -> Bool
isLocalVar Term
altExpr
let altExprIsConstant :: Bool
altExprIsConstant = Term -> Bool
isConstant Term
altExpr
if Bool
isSimIOAlt Bool -> Bool -> Bool
|| Bool
lv Bool -> Bool -> Bool
|| Bool
altExprIsConstant
then Alt -> StateT ([LetBinding], InScopeSet) NormalizeSession Alt
forall a. a -> StateT ([LetBinding], InScopeSet) NormalizeSession a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Alt
alt
else do
tcm <- Getting TyConMap RewriteEnv TyConMap
-> StateT ([LetBinding], InScopeSet) NormalizeSession TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
is1 <- Lens.use _2
altId <- lift (mkTmBinderFor is1 tcm (mkDerivedName ctx "case_alt") altExpr)
tellBinders [(altId,altExpr)]
return (pat,Var altId)
doPatBndr :: Term -> DataCon -> Id -> Int -> StateT ([LetBinding],InScopeSet) NormalizeSession LetBinding
doPatBndr :: Term
-> DataCon
-> Id
-> Int
-> StateT ([LetBinding], InScopeSet) NormalizeSession LetBinding
doPatBndr Term
subj' DataCon
dc Id
pId Int
i = do
tcm <- Getting TyConMap RewriteEnv TyConMap
-> StateT ([LetBinding], InScopeSet) NormalizeSession TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
is1 <- Lens.use _2
patExpr <- lift (mkSelectorCase ($(curLoc) ++ "doPatBndr") is1 tcm subj' (dcTag dc) i)
return (pId,patExpr)
collectANF TransformContext
_ Term
e = Term -> StateT ([LetBinding], InScopeSet) NormalizeSession Term
forall a. a -> StateT ([LetBinding], InScopeSet) NormalizeSession a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC collectANF #-}
nonRepANF :: HasCallStack => NormRewrite
nonRepANF :: HasCallStack => NormRewrite
nonRepANF ctx :: TransformContext
ctx@(TransformContext InScopeSet
is0 Context
_) e :: Term
e@(App Term
appConPrim Term
arg)
| (Term
conPrim, [Either Term Type]
_) <- Term -> (Term, [Either Term Type])
collectArgs Term
e
, Term -> Bool
isCon Term
conPrim Bool -> Bool -> Bool
|| Term -> Bool
isPrim Term
conPrim
= do
untranslatable <- Bool -> Term -> NormalizeSession Bool
forall extra. Bool -> Term -> RewriteMonad extra Bool
isUntranslatable Bool
False Term
arg
case (untranslatable,stripTicks arg) of
(Bool
True,Let Bind Term
binds Term
body) ->
let (Bind Term
binds1,Term
body1) = HasCallStack =>
InScopeSet -> Bind Term -> Term -> (Bind Term, Term)
InScopeSet -> Bind Term -> Term -> (Bind Term, Term)
deshadowLetExpr InScopeSet
is0 Bind Term
binds Term
body
in Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Bind Term -> Term -> Term
Let Bind Term
binds1 (Term -> Term -> Term
App Term
appConPrim Term
body1))
(Bool
True,Case {}) -> NormRewrite
specialize TransformContext
ctx Term
e
(Bool
True,Lam {}) -> NormRewrite
specialize TransformContext
ctx Term
e
(Bool
True,TyLam {}) -> NormRewrite
specialize TransformContext
ctx Term
e
(Bool, Term)
_ -> Term -> RewriteMonad NormalizeState Term
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
nonRepANF TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC nonRepANF #-}