{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell #-}
module Clash.Normalize.Transformations.EtaExpand
( etaExpandSyn
, etaExpansionTL
) where
import qualified Control.Lens as Lens
import qualified Data.Maybe as Maybe
import GHC.Stack (HasCallStack)
import Clash.Core.HasType
import Clash.Core.Term (Bind(..), CoreContext(..), Term(..), collectArgs, mkLams)
import Clash.Core.TermInfo (isFun)
import Clash.Core.Type (splitFunTy)
import Clash.Core.Util (mkInternalVar)
import Clash.Core.Var (Id)
import Clash.Core.VarEnv (elemVarSet, extendInScopeSet, extendInScopeSetList)
import Clash.Normalize.Types (NormRewrite)
import Clash.Rewrite.Types (TransformContext(..), tcCache, topEntities)
import Clash.Rewrite.Util (changed)
import Clash.Util (curLoc)
etaExpandSyn :: HasCallStack => NormRewrite
etaExpandSyn :: HasCallStack => NormRewrite
etaExpandSyn (TransformContext InScopeSet
is0 Context
ctx) e :: Term
e@(Term -> (Term, [Either Term Type])
collectArgs -> (Var Id
f, [Either Term Type]
_)) = do
topEnts <- Getting (UniqMap (Var Any)) RewriteEnv (UniqMap (Var Any))
-> RewriteMonad NormalizeState (UniqMap (Var Any))
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting (UniqMap (Var Any)) RewriteEnv (UniqMap (Var Any))
Lens' RewriteEnv (UniqMap (Var Any))
topEntities
tcm <- Lens.view tcCache
let isTopEnt = Id
f Id -> UniqMap (Var Any) -> Bool
forall a. Var a -> UniqMap (Var Any) -> Bool
`elemVarSet` UniqMap (Var Any)
topEnts
isAppFunCtx =
\case
CoreContext
AppFun:Context
_ -> Bool
True
TickC TickInfo
_:Context
c -> Context -> Bool
isAppFunCtx Context
c
Context
_ -> Bool
False
argTyM = ((Type, Type) -> Type) -> Maybe (Type, Type) -> Maybe Type
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (Type, Type) -> Type
forall a b. (a, b) -> a
fst (TyConMap -> Type -> Maybe (Type, Type)
splitFunTy TyConMap
tcm (TyConMap -> Term -> Type
forall a. InferType a => TyConMap -> a -> Type
inferCoreTypeOf TyConMap
tcm Term
e))
case argTyM of
Just Type
argTy | Bool
isTopEnt Bool -> Bool -> Bool
&& Bool -> Bool
not (Context -> Bool
isAppFunCtx Context
ctx) -> do
newId <- InScopeSet -> OccName -> Type -> RewriteMonad NormalizeState Id
forall (m :: Type -> Type).
MonadUnique m =>
InScopeSet -> OccName -> Type -> m Id
mkInternalVar InScopeSet
is0 OccName
"arg" Type
argTy
changed (Lam newId (App e (Var newId)))
Maybe Type
_ -> Term -> RewriteMonad NormalizeState Term
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
etaExpandSyn TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC etaExpandSyn #-}
stripLambda :: Term -> ([Id], Term)
stripLambda :: Term -> ([Id], Term)
stripLambda (Lam Id
bndr Term
e) =
let ([Id]
bndrs, Term
e') = Term -> ([Id], Term)
stripLambda Term
e
in (Id
bndr Id -> [Id] -> [Id]
forall a. a -> [a] -> [a]
: [Id]
bndrs, Term
e')
stripLambda Term
e = ([], Term
e)
etaExpansionTL :: HasCallStack => NormRewrite
etaExpansionTL :: HasCallStack => NormRewrite
etaExpansionTL (TransformContext InScopeSet
is0 Context
ctx) (Lam Id
bndr Term
e) = do
let ctx' :: TransformContext
ctx' = InScopeSet -> Context -> TransformContext
TransformContext (InScopeSet -> Id -> InScopeSet
forall a. InScopeSet -> Var a -> InScopeSet
extendInScopeSet InScopeSet
is0 Id
bndr) (Id -> CoreContext
LamBody Id
bndr CoreContext -> Context -> Context
forall a. a -> [a] -> [a]
: Context
ctx)
e' <- HasCallStack => NormRewrite
NormRewrite
etaExpansionTL TransformContext
ctx' Term
e
return $ Lam bndr e'
etaExpansionTL (TransformContext InScopeSet
is0 Context
ctx) (Let (NonRec Id
i Term
x) Term
e) = do
let ctx' :: TransformContext
ctx' = InScopeSet -> Context -> TransformContext
TransformContext (InScopeSet -> Id -> InScopeSet
forall a. InScopeSet -> Var a -> InScopeSet
extendInScopeSet InScopeSet
is0 Id
i) ([LetBinding] -> CoreContext
LetBody [(Id
i,Term
x)] CoreContext -> Context -> Context
forall a. a -> [a] -> [a]
: Context
ctx)
e' <- HasCallStack => NormRewrite
NormRewrite
etaExpansionTL TransformContext
ctx' Term
e
case stripLambda e' of
(bs :: [Id]
bs@(Id
_:[Id]
_),Term
e2) -> do
let e3 :: Term
e3 = Bind Term -> Term -> Term
Let (Id -> Term -> Bind Term
forall a. Id -> a -> Bind a
NonRec Id
i Term
x) Term
e2
Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> [Id] -> Term
mkLams Term
e3 [Id]
bs)
([Id], Term)
_ -> Term -> RewriteMonad NormalizeState Term
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Bind Term -> Term -> Term
Let (Id -> Term -> Bind Term
forall a. Id -> a -> Bind a
NonRec Id
i Term
x) Term
e')
etaExpansionTL (TransformContext InScopeSet
is0 Context
ctx) (Let (Rec [LetBinding]
xes) Term
e) = do
let bndrs :: [Id]
bndrs = (LetBinding -> Id) -> [LetBinding] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map LetBinding -> Id
forall a b. (a, b) -> a
fst [LetBinding]
xes
ctx' :: TransformContext
ctx' = InScopeSet -> Context -> TransformContext
TransformContext (InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
is0 [Id]
bndrs) ([LetBinding] -> CoreContext
LetBody [LetBinding]
xes CoreContext -> Context -> Context
forall a. a -> [a] -> [a]
: Context
ctx)
e' <- HasCallStack => NormRewrite
NormRewrite
etaExpansionTL TransformContext
ctx' Term
e
case stripLambda e' of
(bs :: [Id]
bs@(Id
_:[Id]
_),Term
e2) -> do
let e3 :: Term
e3 = Bind Term -> Term -> Term
Let ([LetBinding] -> Bind Term
forall a. [(Id, a)] -> Bind a
Rec [LetBinding]
xes) Term
e2
Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> [Id] -> Term
mkLams Term
e3 [Id]
bs)
([Id], Term)
_ -> Term -> RewriteMonad NormalizeState Term
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Bind Term -> Term -> Term
Let ([LetBinding] -> Bind Term
forall a. [(Id, a)] -> Bind a
Rec [LetBinding]
xes) Term
e')
etaExpansionTL (TransformContext InScopeSet
is0 Context
ctx) Term
e
= do
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
if isFun tcm e
then do
let argTy = ( (Type, Type) -> Type
forall a b. (a, b) -> a
fst
((Type, Type) -> Type) -> (Term -> (Type, Type)) -> Term -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type, Type) -> Maybe (Type, Type) -> (Type, Type)
forall a. a -> Maybe a -> a
Maybe.fromMaybe (String -> (Type, Type)
forall a. HasCallStack => String -> a
error (String -> (Type, Type)) -> String -> (Type, Type)
forall a b. (a -> b) -> a -> b
$ $(curLoc) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"etaExpansion splitFunTy")
(Maybe (Type, Type) -> (Type, Type))
-> (Term -> Maybe (Type, Type)) -> Term -> (Type, Type)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyConMap -> Type -> Maybe (Type, Type)
splitFunTy TyConMap
tcm
(Type -> Maybe (Type, Type))
-> (Term -> Type) -> Term -> Maybe (Type, Type)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyConMap -> Term -> Type
forall a. InferType a => TyConMap -> a -> Type
inferCoreTypeOf TyConMap
tcm
) Term
e
newId <- mkInternalVar is0 "arg" argTy
let ctx' = InScopeSet -> Context -> TransformContext
TransformContext (InScopeSet -> Id -> InScopeSet
forall a. InScopeSet -> Var a -> InScopeSet
extendInScopeSet InScopeSet
is0 Id
newId) (Id -> CoreContext
LamBody Id
newId CoreContext -> Context -> Context
forall a. a -> [a] -> [a]
: Context
ctx)
e' <- etaExpansionTL ctx' (App e (Var newId))
changed (Lam newId e')
else return e
{-# SCC etaExpansionTL #-}