{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TemplateHaskell #-}

module Clash.Normalize.Transformations.Cast
  ( argCastSpec
  , caseCast
  , elimCastCast
  , letCast
  , splitCastWork
  ) where

import Control.Exception (throw)
import qualified Control.Lens as Lens
import Control.Monad.Writer (listen)
import qualified Data.Monoid as Monoid (Any(..))
import GHC.Stack (HasCallStack)

import Clash.Core.Name (nameOcc)
import Clash.Core.Pretty (showPpr)
import Clash.Core.Term (LetBinding, Term(..), collectArgs, stripTicks)
import Clash.Core.TermInfo (isCast)
import Clash.Core.Type (normalizeType)
import Clash.Core.Var (isGlobalId, varName)
import Clash.Core.VarEnv (InScopeSet)
import Clash.Debug (trace)
import Clash.Normalize.Transformations.Specialize (specialize)
import Clash.Normalize.Types (NormRewrite, NormalizeSession)
import Clash.Rewrite.Types
  (TransformContext(..), bindings, curFun, tcCache, workFreeBinders)
import Clash.Rewrite.Util (changed, mkDerivedName, mkTmBinderFor)
import Clash.Rewrite.WorkFree (isWorkFree)
import Clash.Util (ClashException(..), curLoc)

-- | Push cast over an argument to a function into that function
--
-- This is done by specializing on the casted argument.
-- Example:
-- @
--   y = f (cast a)
--     where f x = g x
-- @
-- transforms to:
-- @
--   y = f' a
--     where f' x' = (\\x -> g x) (cast x')
-- @
--
-- The reason d'etre for this transformation is that we hope to end up with
-- and expression where two casts are "back-to-back" after which we can
-- eliminate them in 'eliminateCastCast'.
argCastSpec :: HasCallStack => NormRewrite
argCastSpec :: HasCallStack => NormRewrite
argCastSpec TransformContext
ctx e :: Term
e@(App Term
f (Term -> Term
stripTicks -> Cast Term
e' Type
_ Type
_))
 -- Don't specialise when the arguments are casts-of-casts, these casts-of-casts
 -- will be eliminated by 'eliminateCastCast' during the normalization of the
 -- "current" function. We thus prevent the unnecessary introduction of a
 -- specialized version of 'f'.
 | Bool -> Bool
not (Term -> Bool
isCast Term
e')
 -- We can only push casts into global binders
 , (Var Id
g, [Either Term Type]
_) <- Term -> (Term, [Either Term Type])
collectArgs Term
f
 , Id -> Bool
forall a. Var a -> Bool
isGlobalId Id
g = do
  bndrs <- Getting BindingMap (RewriteState NormalizeState) BindingMap
-> RewriteMonad NormalizeState BindingMap
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting BindingMap (RewriteState NormalizeState) BindingMap
forall extra (f :: Type -> Type).
Functor f =>
(BindingMap -> f BindingMap)
-> RewriteState extra -> f (RewriteState extra)
bindings
  isWorkFree workFreeBinders bndrs e' >>= \case
    Bool
True -> RewriteMonad NormalizeState Term
go
    Bool
False -> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall {a}. a -> a
warn RewriteMonad NormalizeState Term
go
 where
  go :: RewriteMonad NormalizeState Term
go = NormRewrite
specialize TransformContext
ctx Term
e
  warn :: a -> a
warn = String -> a -> a
forall a. String -> a -> a
trace ([String] -> String
unwords
    [ String
"WARNING:", $(curLoc), String
"specializing a function on a non work-free"
    , String
"cast. Generated HDL implementation might contain duplicate work."
    , String
"Please report this as a bug.", String
"\n\nExpression where this occured:"
    , String
"\n\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Term -> String
forall p. PrettyPrec p => p -> String
showPpr Term
e
    ])
argCastSpec TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC argCastSpec #-}

-- | Push a cast over a case into it's alternatives.
caseCast :: HasCallStack => NormRewrite
caseCast :: HasCallStack => NormRewrite
caseCast TransformContext
_ (Cast (Term -> Term
stripTicks -> Case Term
subj Type
ty [Alt]
alts) Type
ty1 Type
ty2) = do
  let alts' :: [Alt]
alts' = (Alt -> Alt) -> [Alt] -> [Alt]
forall a b. (a -> b) -> [a] -> [b]
map (\(Pat
p,Term
e) -> (Pat
p, Term -> Type -> Type -> Term
Cast Term
e Type
ty1 Type
ty2)) [Alt]
alts
  Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> Type -> [Alt] -> Term
Case Term
subj Type
ty [Alt]
alts')
caseCast TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC caseCast #-}

-- | Eliminate two back to back casts where the type going in and coming out are the same
--
-- @
--   (cast :: b -> a) $ (cast :: a -> b) x   ==> x
-- @
elimCastCast :: HasCallStack => NormRewrite
elimCastCast :: HasCallStack => NormRewrite
elimCastCast TransformContext
_ c :: Term
c@(Cast (Term -> Term
stripTicks -> Cast Term
e Type
tyA Type
tyB) Type
tyB' Type
tyC) = do
  tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
  let ntyA  = TyConMap -> Type -> Type
normalizeType TyConMap
tcm Type
tyA
      ntyB  = TyConMap -> Type -> Type
normalizeType TyConMap
tcm Type
tyB
      ntyB' = TyConMap -> Type -> Type
normalizeType TyConMap
tcm Type
tyB'
      ntyC  = TyConMap -> Type -> Type
normalizeType TyConMap
tcm Type
tyC
  if ntyB == ntyB' && ntyA == ntyC then changed e
                                   else throwError
  where throwError :: RewriteMonad NormalizeState b
throwError = do
          (nm,sp) <- Getting (Id, SrcSpan) (RewriteState NormalizeState) (Id, SrcSpan)
-> RewriteMonad NormalizeState (Id, SrcSpan)
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting (Id, SrcSpan) (RewriteState NormalizeState) (Id, SrcSpan)
forall extra (f :: Type -> Type).
Functor f =>
((Id, SrcSpan) -> f (Id, SrcSpan))
-> RewriteState extra -> f (RewriteState extra)
curFun
          throw (ClashException sp ($(curLoc) ++ showPpr nm
                  ++ ": Found 2 nested casts whose types don't line up:\n"
                  ++ showPpr c)
                Nothing)

elimCastCast TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC elimCastCast #-}

-- | Push a cast over a Let into it's body
letCast :: HasCallStack => NormRewrite
letCast :: HasCallStack => NormRewrite
letCast TransformContext
_ (Cast (Term -> Term
stripTicks -> Let Bind Term
binds Term
body) Type
ty1 Type
ty2) =
  Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> RewriteMonad NormalizeState Term)
-> Term -> RewriteMonad NormalizeState Term
forall a b. (a -> b) -> a -> b
$ Bind Term -> Term -> Term
Let Bind Term
binds (Term -> Type -> Type -> Term
Cast Term
body Type
ty1 Type
ty2)
letCast TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC letCast #-}

-- | Make a cast work-free by splitting the work of to a separate binding
--
-- @
-- let x = cast (f a b)
-- ==>
-- let x  = cast x'
--     x' = f a b
-- @
splitCastWork :: HasCallStack => NormRewrite
splitCastWork :: HasCallStack => NormRewrite
splitCastWork ctx :: TransformContext
ctx@(TransformContext InScopeSet
is0 Context
_) unchanged :: Term
unchanged@(Letrec [LetBinding]
vs Term
e') = do
  (vss', Monoid.getAny -> hasChanged) <- RewriteMonad NormalizeState [[LetBinding]]
-> RewriteMonad NormalizeState ([[LetBinding]], Any)
forall a.
RewriteMonad NormalizeState a
-> RewriteMonad NormalizeState (a, Any)
forall w (m :: Type -> Type) a. MonadWriter w m => m a -> m (a, w)
listen ((LetBinding -> RewriteMonad NormalizeState [LetBinding])
-> [LetBinding] -> RewriteMonad NormalizeState [[LetBinding]]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: Type -> Type) a b.
Monad m =>
(a -> m b) -> [a] -> m [b]
mapM (InScopeSet
-> LetBinding -> RewriteMonad NormalizeState [LetBinding]
splitCastLetBinding InScopeSet
is0) [LetBinding]
vs)
  let vs' = [[LetBinding]] -> [LetBinding]
forall (t :: Type -> Type) a. Foldable t => t [a] -> [a]
concat [[LetBinding]]
vss'
  if hasChanged then changed (Letrec vs' e')
                else return unchanged
  where
    splitCastLetBinding
      :: InScopeSet
      -> LetBinding
      -> NormalizeSession [LetBinding]
    splitCastLetBinding :: InScopeSet
-> LetBinding -> RewriteMonad NormalizeState [LetBinding]
splitCastLetBinding InScopeSet
isN x :: LetBinding
x@(Id
nm, Term
e) = case Term -> Term
stripTicks Term
e of
      Cast (Var {}) Type
_ Type
_  -> [LetBinding] -> RewriteMonad NormalizeState [LetBinding]
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return [LetBinding
x]  -- already work-free
      Cast (Cast {}) Type
_ Type
_ -> [LetBinding] -> RewriteMonad NormalizeState [LetBinding]
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return [LetBinding
x]  -- casts will be eliminated
      Cast Term
e0 Type
ty1 Type
ty2 -> do
        tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
        nm' <- mkTmBinderFor isN tcm (mkDerivedName ctx (nameOcc $ varName nm)) e0
        changed [(nm',e0)
                ,(nm, Cast (Var nm') ty1 ty2)
                ]
      Term
_ -> [LetBinding] -> RewriteMonad NormalizeState [LetBinding]
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return [LetBinding
x]

splitCastWork TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC splitCastWork #-}