{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TemplateHaskell #-}
module Clash.Normalize.Transformations.Cast
( argCastSpec
, caseCast
, elimCastCast
, letCast
, splitCastWork
) where
import Control.Exception (throw)
import qualified Control.Lens as Lens
import Control.Monad.Writer (listen)
import qualified Data.Monoid as Monoid (Any(..))
import GHC.Stack (HasCallStack)
import Clash.Core.Name (nameOcc)
import Clash.Core.Pretty (showPpr)
import Clash.Core.Term (LetBinding, Term(..), collectArgs, stripTicks)
import Clash.Core.TermInfo (isCast)
import Clash.Core.Type (normalizeType)
import Clash.Core.Var (isGlobalId, varName)
import Clash.Core.VarEnv (InScopeSet)
import Clash.Debug (trace)
import Clash.Normalize.Transformations.Specialize (specialize)
import Clash.Normalize.Types (NormRewrite, NormalizeSession)
import Clash.Rewrite.Types
(TransformContext(..), bindings, curFun, tcCache, workFreeBinders)
import Clash.Rewrite.Util (changed, mkDerivedName, mkTmBinderFor)
import Clash.Rewrite.WorkFree (isWorkFree)
import Clash.Util (ClashException(..), curLoc)
argCastSpec :: HasCallStack => NormRewrite
argCastSpec :: HasCallStack => NormRewrite
argCastSpec TransformContext
ctx e :: Term
e@(App Term
f (Term -> Term
stripTicks -> Cast Term
e' Type
_ Type
_))
| Bool -> Bool
not (Term -> Bool
isCast Term
e')
, (Var Id
g, [Either Term Type]
_) <- Term -> (Term, [Either Term Type])
collectArgs Term
f
, Id -> Bool
forall a. Var a -> Bool
isGlobalId Id
g = do
bndrs <- Getting BindingMap (RewriteState NormalizeState) BindingMap
-> RewriteMonad NormalizeState BindingMap
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting BindingMap (RewriteState NormalizeState) BindingMap
forall extra (f :: Type -> Type).
Functor f =>
(BindingMap -> f BindingMap)
-> RewriteState extra -> f (RewriteState extra)
bindings
isWorkFree workFreeBinders bndrs e' >>= \case
Bool
True -> RewriteMonad NormalizeState Term
go
Bool
False -> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall {a}. a -> a
warn RewriteMonad NormalizeState Term
go
where
go :: RewriteMonad NormalizeState Term
go = NormRewrite
specialize TransformContext
ctx Term
e
warn :: a -> a
warn = String -> a -> a
forall a. String -> a -> a
trace ([String] -> String
unwords
[ String
"WARNING:", $(curLoc), String
"specializing a function on a non work-free"
, String
"cast. Generated HDL implementation might contain duplicate work."
, String
"Please report this as a bug.", String
"\n\nExpression where this occured:"
, String
"\n\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Term -> String
forall p. PrettyPrec p => p -> String
showPpr Term
e
])
argCastSpec TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC argCastSpec #-}
caseCast :: HasCallStack => NormRewrite
caseCast :: HasCallStack => NormRewrite
caseCast TransformContext
_ (Cast (Term -> Term
stripTicks -> Case Term
subj Type
ty [Alt]
alts) Type
ty1 Type
ty2) = do
let alts' :: [Alt]
alts' = (Alt -> Alt) -> [Alt] -> [Alt]
forall a b. (a -> b) -> [a] -> [b]
map (\(Pat
p,Term
e) -> (Pat
p, Term -> Type -> Type -> Term
Cast Term
e Type
ty1 Type
ty2)) [Alt]
alts
Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> Type -> [Alt] -> Term
Case Term
subj Type
ty [Alt]
alts')
caseCast TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC caseCast #-}
elimCastCast :: HasCallStack => NormRewrite
elimCastCast :: HasCallStack => NormRewrite
elimCastCast TransformContext
_ c :: Term
c@(Cast (Term -> Term
stripTicks -> Cast Term
e Type
tyA Type
tyB) Type
tyB' Type
tyC) = do
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
let ntyA = TyConMap -> Type -> Type
normalizeType TyConMap
tcm Type
tyA
ntyB = TyConMap -> Type -> Type
normalizeType TyConMap
tcm Type
tyB
ntyB' = TyConMap -> Type -> Type
normalizeType TyConMap
tcm Type
tyB'
ntyC = TyConMap -> Type -> Type
normalizeType TyConMap
tcm Type
tyC
if ntyB == ntyB' && ntyA == ntyC then changed e
else throwError
where throwError :: RewriteMonad NormalizeState b
throwError = do
(nm,sp) <- Getting (Id, SrcSpan) (RewriteState NormalizeState) (Id, SrcSpan)
-> RewriteMonad NormalizeState (Id, SrcSpan)
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting (Id, SrcSpan) (RewriteState NormalizeState) (Id, SrcSpan)
forall extra (f :: Type -> Type).
Functor f =>
((Id, SrcSpan) -> f (Id, SrcSpan))
-> RewriteState extra -> f (RewriteState extra)
curFun
throw (ClashException sp ($(curLoc) ++ showPpr nm
++ ": Found 2 nested casts whose types don't line up:\n"
++ showPpr c)
Nothing)
elimCastCast TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC elimCastCast #-}
letCast :: HasCallStack => NormRewrite
letCast :: HasCallStack => NormRewrite
letCast TransformContext
_ (Cast (Term -> Term
stripTicks -> Let Bind Term
binds Term
body) Type
ty1 Type
ty2) =
Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> RewriteMonad NormalizeState Term)
-> Term -> RewriteMonad NormalizeState Term
forall a b. (a -> b) -> a -> b
$ Bind Term -> Term -> Term
Let Bind Term
binds (Term -> Type -> Type -> Term
Cast Term
body Type
ty1 Type
ty2)
letCast TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC letCast #-}
splitCastWork :: HasCallStack => NormRewrite
splitCastWork :: HasCallStack => NormRewrite
splitCastWork ctx :: TransformContext
ctx@(TransformContext InScopeSet
is0 Context
_) unchanged :: Term
unchanged@(Letrec [LetBinding]
vs Term
e') = do
(vss', Monoid.getAny -> hasChanged) <- RewriteMonad NormalizeState [[LetBinding]]
-> RewriteMonad NormalizeState ([[LetBinding]], Any)
forall a.
RewriteMonad NormalizeState a
-> RewriteMonad NormalizeState (a, Any)
forall w (m :: Type -> Type) a. MonadWriter w m => m a -> m (a, w)
listen ((LetBinding -> RewriteMonad NormalizeState [LetBinding])
-> [LetBinding] -> RewriteMonad NormalizeState [[LetBinding]]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: Type -> Type) a b.
Monad m =>
(a -> m b) -> [a] -> m [b]
mapM (InScopeSet
-> LetBinding -> RewriteMonad NormalizeState [LetBinding]
splitCastLetBinding InScopeSet
is0) [LetBinding]
vs)
let vs' = [[LetBinding]] -> [LetBinding]
forall (t :: Type -> Type) a. Foldable t => t [a] -> [a]
concat [[LetBinding]]
vss'
if hasChanged then changed (Letrec vs' e')
else return unchanged
where
splitCastLetBinding
:: InScopeSet
-> LetBinding
-> NormalizeSession [LetBinding]
splitCastLetBinding :: InScopeSet
-> LetBinding -> RewriteMonad NormalizeState [LetBinding]
splitCastLetBinding InScopeSet
isN x :: LetBinding
x@(Id
nm, Term
e) = case Term -> Term
stripTicks Term
e of
Cast (Var {}) Type
_ Type
_ -> [LetBinding] -> RewriteMonad NormalizeState [LetBinding]
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return [LetBinding
x]
Cast (Cast {}) Type
_ Type
_ -> [LetBinding] -> RewriteMonad NormalizeState [LetBinding]
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return [LetBinding
x]
Cast Term
e0 Type
ty1 Type
ty2 -> do
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
nm' <- mkTmBinderFor isN tcm (mkDerivedName ctx (nameOcc $ varName nm)) e0
changed [(nm',e0)
,(nm, Cast (Var nm') ty1 ty2)
]
Term
_ -> [LetBinding] -> RewriteMonad NormalizeState [LetBinding]
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return [LetBinding
x]
splitCastWork TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC splitCastWork #-}