{-|
  Copyright   :  (C) 2012-2016, University of Twente,
                     2016     , Myrtle Software Ltd,
                     2017     , Google Inc.,
                     2021-2023, QBayLogic B.V.
  License     :  BSD2 (see the file LICENSE)
  Maintainer  :  QBayLogic B.V. <devops@qbaylogic.com>

  Turn CoreHW terms into normalized CoreHW Terms
-}

{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TemplateHaskell #-}

module Clash.Normalize where

import           Control.Exception                (throw)
import qualified Control.Lens                     as Lens
import           Control.Monad                    (when)
import           Control.Monad.State.Strict       (State)
import           Data.Default                     (def)
import           Data.Either                      (lefts,partitionEithers)
import qualified Data.IntMap                      as IntMap
import           Data.List
  (intersect, mapAccumL)
import qualified Data.Map                         as Map
import qualified Data.Maybe                       as Maybe
import qualified Data.Set                         as Set
import qualified Data.Set.Lens                    as Lens

#if MIN_VERSION_prettyprinter(1,7,0)
import           Prettyprinter                    (vcat)
#else
import           Data.Text.Prettyprint.Doc        (vcat)
#endif

import           GHC.BasicTypes.Extra             (isNoInline)

import           Clash.Annotations.BitRepresentation.Internal
  (CustomReprs)
import           Clash.Core.Evaluator.Types as WHNF (Evaluator)
import           Clash.Core.FreeVars
  (freeLocalIds, globalIds)
import           Clash.Core.HasFreeVars           (notElemFreeVars)
import           Clash.Core.HasType
import           Clash.Core.PartialEval as PE     (Evaluator)
import           Clash.Core.Pretty                (PrettyOptions(..), showPpr, showPpr', ppr)
import           Clash.Core.Subst
  (extendGblSubstList, mkSubst, substTm)
import           Clash.Core.Term                  (Term (..), collectArgsTicks
                                                  ,mkApps, mkTicks)
import           Clash.Core.Type                  (Type, splitCoreFunForallTy)
import           Clash.Core.TyCon (TyConMap)
import           Clash.Core.Type                  (isPolyTy)
import           Clash.Core.Var                   (Id, varName, varType)
import           Clash.Core.VarEnv
  (VarEnv, elemVarSet, eltsVarEnv, emptyInScopeSet, emptyVarEnv,
   extendVarEnv, lookupVarEnv, mapVarEnv, mapMaybeVarEnv,
   mkVarEnv, mkVarSet, notElemVarEnv, notElemVarSet, nullVarEnv, unionVarEnv)
import           Clash.Debug                      (traceIf)
import           Clash.Driver.Types
  (BindingMap, Binding(..), DebugOpts(..), ClashEnv(..))
import           Clash.Netlist.Types
  (HWMap, FilteredHWType(..))
import           Clash.Netlist.Util
  (splitNormalized)
import           Clash.Normalize.Strategy
import           Clash.Normalize.Transformations
import           Clash.Normalize.Types
import           Clash.Normalize.Util
import           Clash.Rewrite.Combinators
  ((>->), (!->), bottomupR, repeatR, topdownR)
import           Clash.Rewrite.Types
  (RewriteEnv (..), RewriteState (..), bindings, debugOpts, extra,
   tcCache, topEntities, newInlineStrategy)
import           Clash.Rewrite.Util
  (apply, isUntranslatableType, runRewriteSession)
import           Clash.Util
import           Clash.Util.Interpolate           (i)
import           Clash.Util.Supply                (Supply)

import           Data.Binary                      (encode)
import qualified Data.ByteString                  as BS
import qualified Data.ByteString.Lazy             as BL

import           System.IO.Unsafe                 (unsafePerformIO)
import           Clash.Rewrite.Types (RewriteStep(..))


-- | Run a NormalizeSession in a given environment
runNormalization
  :: ClashEnv
  -> Supply
  -- ^ UniqueSupply
  -> BindingMap
  -- ^ Global Binders
  -> (CustomReprs -> TyConMap -> Type ->
      State HWMap (Maybe (Either String FilteredHWType)))
  -- ^ Hardcoded Type -> HWType translator
  -> PE.Evaluator
  -- ^ Hardcoded evaluator for partial evaluation
  -> WHNF.Evaluator
  -- ^ Hardcoded evaluator for WHNF (old evaluator)
  -> VarEnv Bool
  -- ^ Map telling whether a components is part of a recursive group
  -> [Id]
  -- ^ topEntities
  -> NormalizeSession a
  -- ^ NormalizeSession to run
  -> IO a
runNormalization :: forall a.
ClashEnv
-> Supply
-> VarEnv (Binding Term)
-> (CustomReprs
    -> TyConMap
    -> Type
    -> State HWMap (Maybe (Either [Char] FilteredHWType)))
-> Evaluator
-> Evaluator
-> VarEnv Bool
-> [Id]
-> NormalizeSession a
-> IO a
runNormalization ClashEnv
env Supply
supply VarEnv (Binding Term)
globals CustomReprs
-> TyConMap
-> Type
-> State HWMap (Maybe (Either [Char] FilteredHWType))
typeTrans Evaluator
peEval Evaluator
eval VarEnv Bool
rcsMap [Id]
topEnts =
  RewriteEnv
-> RewriteState NormalizeState
-> RewriteMonad NormalizeState a
-> IO a
forall extra a.
RewriteEnv -> RewriteState extra -> RewriteMonad extra a -> IO a
runRewriteSession RewriteEnv
rwEnv RewriteState NormalizeState
rwState
  where
    -- TODO The RewriteEnv should just take ClashOpts.
    rwEnv :: RewriteEnv
rwEnv     = ClashEnv
-> (CustomReprs
    -> TyConMap
    -> Type
    -> State HWMap (Maybe (Either [Char] FilteredHWType)))
-> Evaluator
-> Evaluator
-> UniqMap (Var Any)
-> RewriteEnv
RewriteEnv
                  ClashEnv
env
                  CustomReprs
-> TyConMap
-> Type
-> State HWMap (Maybe (Either [Char] FilteredHWType))
typeTrans
                  Evaluator
peEval
                  Evaluator
eval
                  ([Id] -> UniqMap (Var Any)
forall a. [Var a] -> UniqMap (Var Any)
mkVarSet [Id]
topEnts)

    rwState :: RewriteState NormalizeState
rwState   = Word
-> HashMap Text Word
-> VarEnv (Binding Term)
-> Supply
-> (Id, SrcSpan)
-> Int
-> PrimHeap
-> VarEnv Bool
-> NormalizeState
-> RewriteState NormalizeState
forall extra.
Word
-> HashMap Text Word
-> VarEnv (Binding Term)
-> Supply
-> (Id, SrcSpan)
-> Int
-> PrimHeap
-> VarEnv Bool
-> extra
-> RewriteState extra
RewriteState
                  Word
0
                  HashMap Text Word
forall a. Monoid a => a
mempty       -- transformCounters Map
                  VarEnv (Binding Term)
globals
                  Supply
supply
                  ([Char] -> Id
forall a. HasCallStack => [Char] -> a
error ([Char] -> Id) -> [Char] -> Id
forall a b. (a -> b) -> a -> b
$ $(curLoc) [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"Report as bug: no curFun",SrcSpan
noSrcSpan)
                  Int
0
                  (IntMap Term
forall a. IntMap a
IntMap.empty, Int
0)
                  VarEnv Bool
forall a. VarEnv a
emptyVarEnv
                  NormalizeState
normState

    normState :: NormalizeState
normState = VarEnv (Binding Term)
-> Map (Id, Int, Either Term Type) Id
-> VarEnv Int
-> VarEnv (VarEnv Int)
-> Map Text (Set Int)
-> VarEnv Bool
-> NormalizeState
NormalizeState
                  VarEnv (Binding Term)
forall a. VarEnv a
emptyVarEnv
                  Map (Id, Int, Either Term Type) Id
forall k a. Map k a
Map.empty
                  VarEnv Int
forall a. VarEnv a
emptyVarEnv
                  VarEnv (VarEnv Int)
forall a. VarEnv a
emptyVarEnv
                  Map Text (Set Int)
forall k a. Map k a
Map.empty
                  VarEnv Bool
rcsMap

normalize
  :: [Id]
  -> NormalizeSession BindingMap
normalize :: [Id] -> NormalizeSession (VarEnv (Binding Term))
normalize []  = VarEnv (Binding Term) -> NormalizeSession (VarEnv (Binding Term))
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return VarEnv (Binding Term)
forall a. VarEnv a
emptyVarEnv
normalize [Id]
top = do
  (new,topNormalized) <- [([Id], (Id, Binding Term))] -> ([[Id]], [(Id, Binding Term)])
forall a b. [(a, b)] -> ([a], [b])
unzip ([([Id], (Id, Binding Term))] -> ([[Id]], [(Id, Binding Term)]))
-> RewriteMonad NormalizeState [([Id], (Id, Binding Term))]
-> RewriteMonad NormalizeState ([[Id]], [(Id, Binding Term)])
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> (Id -> RewriteMonad NormalizeState ([Id], (Id, Binding Term)))
-> [Id] -> RewriteMonad NormalizeState [([Id], (Id, Binding Term))]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: Type -> Type) a b.
Monad m =>
(a -> m b) -> [a] -> m [b]
mapM Id -> RewriteMonad NormalizeState ([Id], (Id, Binding Term))
normalize' [Id]
top
  newNormalized <- normalize (concat new)
  return (unionVarEnv (mkVarEnv topNormalized) newNormalized)

normalize' :: Id -> NormalizeSession ([Id], (Id, Binding Term))
normalize' :: Id -> RewriteMonad NormalizeState ([Id], (Id, Binding Term))
normalize' Id
nm = do
  exprM <- Id -> VarEnv (Binding Term) -> Maybe (Binding Term)
forall b a. Var b -> VarEnv a -> Maybe a
lookupVarEnv Id
nm (VarEnv (Binding Term) -> Maybe (Binding Term))
-> NormalizeSession (VarEnv (Binding Term))
-> RewriteMonad NormalizeState (Maybe (Binding Term))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Getting
  (VarEnv (Binding Term))
  (RewriteState NormalizeState)
  (VarEnv (Binding Term))
-> NormalizeSession (VarEnv (Binding Term))
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting
  (VarEnv (Binding Term))
  (RewriteState NormalizeState)
  (VarEnv (Binding Term))
forall extra (f :: Type -> Type).
Functor f =>
(VarEnv (Binding Term) -> f (VarEnv (Binding Term)))
-> RewriteState extra -> f (RewriteState extra)
bindings
  let nmS = Name Term -> [Char]
forall p. PrettyPrec p => p -> [Char]
showPpr (Id -> Name Term
forall a. Var a -> Name a
varName Id
nm)
  case exprM of
    Just (Binding Id
nm' SrcSpan
sp InlineSpec
inl IsPrim
pr Term
tm Bool
r) -> do
      tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
      topEnts <- Lens.view topEntities
      let isTop = Id
nm Id -> UniqMap (Var Any) -> Bool
forall a. Var a -> UniqMap (Var Any) -> Bool
`elemVarSet` UniqMap (Var Any)
topEnts
          ty0 = Id -> Type
forall a. HasType a => a -> Type
coreTypeOf Id
nm'
          ty1 = if Bool
isTop then Type -> Type
tvSubstWithTyEq Type
ty0 else Type
ty0

      -- check for polymorphic types
      when (isPolyTy ty1) $
        let msg = $[Char]
curLoc [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [i|
              Clash can only normalize monomorphic functions, but this is polymorphic:
              #{showPpr' def{displayUniques=False\} nm'}
              |]
            msgExtra | Type
ty0 Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
ty1 = Maybe [Char]
forall a. Maybe a
Nothing
                     | Bool
otherwise = [Char] -> Maybe [Char]
forall a. a -> Maybe a
Just ([Char] -> Maybe [Char]) -> [Char] -> Maybe [Char]
forall a b. (a -> b) -> a -> b
$ [i|
              Even after applying type equality constraints it remained polymorphic:
              #{showPpr' def{displayUniques=False\} nm'{varType=ty1\}}
                         |]
        in throw (ClashException sp msg msgExtra)

      -- check for unrepresentable result type
      let (args,resTy) = splitCoreFunForallTy tcm ty1
          isTopEnt = Id
nm Id -> UniqMap (Var Any) -> Bool
forall a. Var a -> UniqMap (Var Any) -> Bool
`elemVarSet` UniqMap (Var Any)
topEnts
          isFunction = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [TyVar] -> Bool
forall a. [a] -> Bool
forall (t :: Type -> Type) a. Foldable t => t a -> Bool
null ([TyVar] -> Bool) -> [TyVar] -> Bool
forall a b. (a -> b) -> a -> b
$ [Either TyVar Type] -> [TyVar]
forall a b. [Either a b] -> [a]
lefts [Either TyVar Type]
args
      resTyRep <- not <$> isUntranslatableType False resTy
      if resTyRep
         then do
            tmNorm <- normalizeTopLvlBndr isTopEnt nm (Binding nm' sp inl pr tm r)
            let usedBndrs = Getting (Endo [Id]) Term Id -> Term -> [Id]
forall a s. Getting (Endo [a]) s a -> s -> [a]
Lens.toListOf Getting (Endo [Id]) Term Id
Fold Term Id
globalIds (Binding Term -> Term
forall a. Binding a -> a
bindingTerm Binding Term
tmNorm)
            traceIf (bindingRecursive tmNorm)
                    (concat [ $(curLoc),"Expr belonging to bndr: ",nmS ," (:: "
                            , showPpr (coreTypeOf (bindingId tmNorm))
                            , ") remains recursive after normalization:\n"
                            , showPpr (bindingTerm tmNorm) ])
                    (return ())
            prevNorm <- mapVarEnv bindingId <$> Lens.use (extra.normalized)
            let toNormalize = (Id -> Bool) -> [Id] -> [Id]
forall a. (a -> Bool) -> [a] -> [a]
filter (Id -> UniqMap (Var Any) -> Bool
forall a. Var a -> UniqMap (Var Any) -> Bool
`notElemVarSet` UniqMap (Var Any)
topEnts)
                            ([Id] -> [Id]) -> [Id] -> [Id]
forall a b. (a -> b) -> a -> b
$ (Id -> Bool) -> [Id] -> [Id]
forall a. (a -> Bool) -> [a] -> [a]
filter (Id -> VarEnv Id -> Bool
forall a b. Var a -> VarEnv b -> Bool
`notElemVarEnv` (Id -> Id -> VarEnv Id -> VarEnv Id
forall b a. Var b -> a -> VarEnv a -> VarEnv a
extendVarEnv Id
nm Id
nm VarEnv Id
prevNorm)) [Id]
usedBndrs
            return (toNormalize,(nm,tmNorm))
         else
           do
            -- Throw an error for unrepresentable topEntities and functions
            when (isTopEnt || isFunction) $
              let msg = $(curLoc) [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [i|
                    This bndr has a non-representable return type and can't be normalized:
                    #{showPpr' def{displayUniques=False\} nm'}
                    |]
              in throw (ClashException sp msg Nothing)

            -- But allow the compilation to proceed for nonrepresentable values.
            -- This can happen for example when GHC decides to create a toplevel binder
            -- for the ByteArray# inside of a Natural constant.
            -- (GHC-8.4 does this with tests/shouldwork/Numbers/Exp.hs)
            -- It will later be inlined by flattenCallTree.
            opts <- Lens.view debugOpts
            traceIf (dbg_invariants opts)
                    (concat [$(curLoc), "Expr belonging to bndr: ", nmS, " (:: "
                            , showPpr (coreTypeOf nm')
                            , ") has a non-representable return type."
                            , " Not normalising:\n", showPpr tm] )
                    (return ([],(nm,(Binding nm' sp inl pr tm r))))


    Maybe (Binding Term)
Nothing -> [Char] -> RewriteMonad NormalizeState ([Id], (Id, Binding Term))
forall a. HasCallStack => [Char] -> a
error ([Char] -> RewriteMonad NormalizeState ([Id], (Id, Binding Term)))
-> [Char] -> RewriteMonad NormalizeState ([Id], (Id, Binding Term))
forall a b. (a -> b) -> a -> b
$ $(curLoc) [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"Expr belonging to bndr: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
nmS [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
" not found"

-- | Check whether the normalized bindings are non-recursive. Errors when one
-- of the components is recursive.
checkNonRecursive
  :: BindingMap
  -- ^ List of normalized binders
  -> BindingMap
checkNonRecursive :: VarEnv (Binding Term) -> VarEnv (Binding Term)
checkNonRecursive VarEnv (Binding Term)
norm = case (Binding Term -> Maybe (Id, Term))
-> VarEnv (Binding Term) -> VarEnv (Id, Term)
forall a b. (a -> Maybe b) -> VarEnv a -> VarEnv b
mapMaybeVarEnv Binding Term -> Maybe (Id, Term)
forall {b}. Binding b -> Maybe (Id, b)
go VarEnv (Binding Term)
norm of
  VarEnv (Id, Term)
rcs | VarEnv (Id, Term) -> Bool
forall a. VarEnv a -> Bool
nullVarEnv VarEnv (Id, Term)
rcs  -> VarEnv (Binding Term)
norm
  VarEnv (Id, Term)
rcs -> [Char] -> VarEnv (Binding Term)
forall a. HasCallStack => [Char] -> a
error ([Char] -> VarEnv (Binding Term))
-> [Char] -> VarEnv (Binding Term)
forall a b. (a -> b) -> a -> b
$ $(curLoc) [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"Callgraph after normalization contains following recursive components: "
                   [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Doc ClashAnnotation -> [Char]
forall a. Show a => a -> [Char]
show ([Doc ClashAnnotation] -> Doc ClashAnnotation
forall ann. [Doc ann] -> Doc ann
vcat [ Id -> Doc ClashAnnotation
forall p. PrettyPrec p => p -> Doc ClashAnnotation
ppr Id
a Doc ClashAnnotation -> Doc ClashAnnotation -> Doc ClashAnnotation
forall a. Semigroup a => a -> a -> a
<> Term -> Doc ClashAnnotation
forall p. PrettyPrec p => p -> Doc ClashAnnotation
ppr Term
b
                                 | (Id
a,Term
b) <- VarEnv (Id, Term) -> [(Id, Term)]
forall a. VarEnv a -> [a]
eltsVarEnv VarEnv (Id, Term)
rcs
                                 ])
 where
  go :: Binding b -> Maybe (Id, b)
go (Binding Id
nm SrcSpan
_ InlineSpec
_ IsPrim
_ b
tm Bool
r) =
    if Bool
r then (Id, b) -> Maybe (Id, b)
forall a. a -> Maybe a
Just (Id
nm,b
tm) else Maybe (Id, b)
forall a. Maybe a
Nothing

-- | Perform general \"clean up\" of the normalized (non-recursive) function
-- hierarchy. This includes:
--
--   * Inlining functions that simply \"wrap\" another function
cleanupGraph
  :: Id
  -> BindingMap
  -> NormalizeSession BindingMap
cleanupGraph :: Id
-> VarEnv (Binding Term)
-> NormalizeSession (VarEnv (Binding Term))
cleanupGraph Id
topEntity VarEnv (Binding Term)
norm
  | Just CallTree
ct <- [Id] -> VarEnv (Binding Term) -> Id -> Maybe CallTree
mkCallTree [] VarEnv (Binding Term)
norm Id
topEntity
  = do ctFlat <- CallTree -> NormalizeSession CallTree
flattenCallTree CallTree
ct
       return (mkVarEnv $ snd $ callTreeToList [] ctFlat)
cleanupGraph Id
_ VarEnv (Binding Term)
norm = VarEnv (Binding Term) -> NormalizeSession (VarEnv (Binding Term))
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return VarEnv (Binding Term)
norm

-- | A tree of identifiers and their bindings, with branches containing
-- additional bindings which are used. See "Clash.Driver.Types.Binding".
--
data CallTree
  = CLeaf   (Id, Binding Term)
  | CBranch (Id, Binding Term) [CallTree]

mkCallTree
  :: [Id]
  -- ^ Visited
  -> BindingMap
  -- ^ Global binders
  -> Id
  -- ^ Root of the call graph
  -> Maybe CallTree
mkCallTree :: [Id] -> VarEnv (Binding Term) -> Id -> Maybe CallTree
mkCallTree [Id]
visited VarEnv (Binding Term)
bindingMap Id
root
  | Just Binding Term
rootTm <- Id -> VarEnv (Binding Term) -> Maybe (Binding Term)
forall b a. Var b -> VarEnv a -> Maybe a
lookupVarEnv Id
root VarEnv (Binding Term)
bindingMap
  = let used :: [Id]
used   = Set Id -> [Id]
forall a. Set a -> [a]
Set.toList (Set Id -> [Id]) -> Set Id -> [Id]
forall a b. (a -> b) -> a -> b
$ Getting (Set Id) Term Id -> Term -> Set Id
forall a s. Getting (Set a) s a -> s -> Set a
Lens.setOf Getting (Set Id) Term Id
Fold Term Id
globalIds (Term -> Set Id) -> Term -> Set Id
forall a b. (a -> b) -> a -> b
$ (Binding Term -> Term
forall a. Binding a -> a
bindingTerm Binding Term
rootTm)
        other :: [CallTree]
other  = (Id -> Maybe CallTree) -> [Id] -> [CallTree]
forall a b. (a -> Maybe b) -> [a] -> [b]
Maybe.mapMaybe ([Id] -> VarEnv (Binding Term) -> Id -> Maybe CallTree
mkCallTree (Id
rootId -> [Id] -> [Id]
forall a. a -> [a] -> [a]
:[Id]
visited) VarEnv (Binding Term)
bindingMap) ((Id -> Bool) -> [Id] -> [Id]
forall a. (a -> Bool) -> [a] -> [a]
filter (Id -> [Id] -> Bool
forall (t :: Type -> Type) a.
(Foldable t, Eq a) =>
a -> t a -> Bool
`notElem` [Id]
visited) [Id]
used)
    in  case [Id]
used of
          [] -> CallTree -> Maybe CallTree
forall a. a -> Maybe a
Just ((Id, Binding Term) -> CallTree
CLeaf   (Id
root,Binding Term
rootTm))
          [Id]
_  -> CallTree -> Maybe CallTree
forall a. a -> Maybe a
Just ((Id, Binding Term) -> [CallTree] -> CallTree
CBranch (Id
root,Binding Term
rootTm) [CallTree]
other)
mkCallTree [Id]
_ VarEnv (Binding Term)
_ Id
_ = Maybe CallTree
forall a. Maybe a
Nothing

stripArgs
  :: [Id]
  -> [Id]
  -> [Either Term Type]
  -> Maybe [Either Term Type]
stripArgs :: [Id] -> [Id] -> [Either Term Type] -> Maybe [Either Term Type]
stripArgs [Id]
_      (Id
_:[Id]
_) []   = Maybe [Either Term Type]
forall a. Maybe a
Nothing
stripArgs [Id]
allIds []    [Either Term Type]
args = if (Either Term Type -> Bool) -> [Either Term Type] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
any Either Term Type -> Bool
forall {b}. Either Term b -> Bool
mentionsId [Either Term Type]
args
                                then Maybe [Either Term Type]
forall a. Maybe a
Nothing
                                else [Either Term Type] -> Maybe [Either Term Type]
forall a. a -> Maybe a
Just [Either Term Type]
args
  where
    mentionsId :: Either Term b -> Bool
mentionsId Either Term b
t = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Id] -> Bool
forall a. [a] -> Bool
forall (t :: Type -> Type) a. Foldable t => t a -> Bool
null ((Term -> [Id]) -> (b -> [Id]) -> Either Term b -> [Id]
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Getting (Endo [Id]) Term Id -> Term -> [Id]
forall a s. Getting (Endo [a]) s a -> s -> [a]
Lens.toListOf Getting (Endo [Id]) Term Id
Fold Term Id
freeLocalIds) ([Id] -> b -> [Id]
forall a b. a -> b -> a
const []) Either Term b
t
                              [Id] -> [Id] -> [Id]
forall a. Eq a => [a] -> [a] -> [a]
`intersect`
                              [Id]
allIds)

stripArgs [Id]
allIds (Id
id_:[Id]
ids) (Left (Var Id
nm):[Either Term Type]
args)
      | Id
id_ Id -> Id -> Bool
forall a. Eq a => a -> a -> Bool
== Id
nm = [Id] -> [Id] -> [Either Term Type] -> Maybe [Either Term Type]
stripArgs [Id]
allIds [Id]
ids [Either Term Type]
args
      | Bool
otherwise = Maybe [Either Term Type]
forall a. Maybe a
Nothing
stripArgs [Id]
_ [Id]
_ [Either Term Type]
_ = Maybe [Either Term Type]
forall a. Maybe a
Nothing

flattenNode
  :: CallTree
  -> NormalizeSession (Either CallTree ((Id,Term),[CallTree]))
flattenNode :: CallTree
-> NormalizeSession (Either CallTree ((Id, Term), [CallTree]))
flattenNode c :: CallTree
c@(CLeaf (Id
_,(Binding Id
_ SrcSpan
_ InlineSpec
spec IsPrim
_ Term
_ Bool
_))) | InlineSpec -> Bool
isNoInline InlineSpec
spec = Either CallTree ((Id, Term), [CallTree])
-> NormalizeSession (Either CallTree ((Id, Term), [CallTree]))
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (CallTree -> Either CallTree ((Id, Term), [CallTree])
forall a b. a -> Either a b
Left CallTree
c)
flattenNode c :: CallTree
c@(CLeaf (Id
nm,(Binding Id
_ SrcSpan
_ InlineSpec
_ IsPrim
_ Term
e Bool
_))) = do
  isTopEntity <- Id -> UniqMap (Var Any) -> Bool
forall a. Var a -> UniqMap (Var Any) -> Bool
elemVarSet Id
nm (UniqMap (Var Any) -> Bool)
-> RewriteMonad NormalizeState (UniqMap (Var Any))
-> RewriteMonad NormalizeState Bool
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Getting (UniqMap (Var Any)) RewriteEnv (UniqMap (Var Any))
-> RewriteMonad NormalizeState (UniqMap (Var Any))
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting (UniqMap (Var Any)) RewriteEnv (UniqMap (Var Any))
Lens' RewriteEnv (UniqMap (Var Any))
topEntities
  if isTopEntity then return (Left c) else do
    tcm  <- Lens.view tcCache
    let norm = TyConMap -> Term -> Either [Char] ([Id], [(Id, Term)], Id)
splitNormalized TyConMap
tcm Term
e
    case norm of
      Right ([Id]
ids,[(Id
bId,Term
bExpr)],Id
_) -> do
        let (Term
fun,[Either Term Type]
args,[TickInfo]
ticks) = Term -> (Term, [Either Term Type], [TickInfo])
collectArgsTicks Term
bExpr
        case [Id] -> [Id] -> [Either Term Type] -> Maybe [Either Term Type]
stripArgs [Id]
ids ([Id] -> [Id]
forall a. [a] -> [a]
reverse [Id]
ids) ([Either Term Type] -> [Either Term Type]
forall a. [a] -> [a]
reverse [Either Term Type]
args) of
          Just [Either Term Type]
remainder | Id
bId Id -> Term -> Bool
forall a. HasFreeVars a => Var a -> a -> Bool
`notElemFreeVars` Term
bExpr ->
               Either CallTree ((Id, Term), [CallTree])
-> NormalizeSession (Either CallTree ((Id, Term), [CallTree]))
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (((Id, Term), [CallTree])
-> Either CallTree ((Id, Term), [CallTree])
forall a b. b -> Either a b
Right ((Id
nm,Term -> [Either Term Type] -> Term
mkApps (Term -> [TickInfo] -> Term
mkTicks Term
fun [TickInfo]
ticks) ([Either Term Type] -> [Either Term Type]
forall a. [a] -> [a]
reverse [Either Term Type]
remainder)),[]))
          Maybe [Either Term Type]
_ -> Either CallTree ((Id, Term), [CallTree])
-> NormalizeSession (Either CallTree ((Id, Term), [CallTree]))
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (((Id, Term), [CallTree])
-> Either CallTree ((Id, Term), [CallTree])
forall a b. b -> Either a b
Right ((Id
nm,Term
e),[]))
      Either [Char] ([Id], [(Id, Term)], Id)
_ -> Either CallTree ((Id, Term), [CallTree])
-> NormalizeSession (Either CallTree ((Id, Term), [CallTree]))
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (((Id, Term), [CallTree])
-> Either CallTree ((Id, Term), [CallTree])
forall a b. b -> Either a b
Right ((Id
nm,Term
e),[]))
flattenNode b :: CallTree
b@(CBranch (Id
_,(Binding Id
_ SrcSpan
_ InlineSpec
spec IsPrim
_ Term
_ Bool
_)) [CallTree]
_) | InlineSpec -> Bool
isNoInline InlineSpec
spec =
  Either CallTree ((Id, Term), [CallTree])
-> NormalizeSession (Either CallTree ((Id, Term), [CallTree]))
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (CallTree -> Either CallTree ((Id, Term), [CallTree])
forall a b. a -> Either a b
Left CallTree
b)
flattenNode b :: CallTree
b@(CBranch (Id
nm,(Binding Id
_ SrcSpan
_ InlineSpec
_ IsPrim
_ Term
e Bool
_)) [CallTree]
us) = do
  isTopEntity <- Id -> UniqMap (Var Any) -> Bool
forall a. Var a -> UniqMap (Var Any) -> Bool
elemVarSet Id
nm (UniqMap (Var Any) -> Bool)
-> RewriteMonad NormalizeState (UniqMap (Var Any))
-> RewriteMonad NormalizeState Bool
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Getting (UniqMap (Var Any)) RewriteEnv (UniqMap (Var Any))
-> RewriteMonad NormalizeState (UniqMap (Var Any))
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting (UniqMap (Var Any)) RewriteEnv (UniqMap (Var Any))
Lens' RewriteEnv (UniqMap (Var Any))
topEntities
  if isTopEntity then return (Left b) else do
    tcm  <- Lens.view tcCache
    let norm = TyConMap -> Term -> Either [Char] ([Id], [(Id, Term)], Id)
splitNormalized TyConMap
tcm Term
e
    case norm of
      Right ([Id]
ids,[(Id
bId,Term
bExpr)],Id
_) -> do
        let (Term
fun,[Either Term Type]
args,[TickInfo]
ticks) = Term -> (Term, [Either Term Type], [TickInfo])
collectArgsTicks Term
bExpr
        case [Id] -> [Id] -> [Either Term Type] -> Maybe [Either Term Type]
stripArgs [Id]
ids ([Id] -> [Id]
forall a. [a] -> [a]
reverse [Id]
ids) ([Either Term Type] -> [Either Term Type]
forall a. [a] -> [a]
reverse [Either Term Type]
args) of
          Just [Either Term Type]
remainder | Id
bId Id -> Term -> Bool
forall a. HasFreeVars a => Var a -> a -> Bool
`notElemFreeVars` Term
bExpr ->
               Either CallTree ((Id, Term), [CallTree])
-> NormalizeSession (Either CallTree ((Id, Term), [CallTree]))
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (((Id, Term), [CallTree])
-> Either CallTree ((Id, Term), [CallTree])
forall a b. b -> Either a b
Right ((Id
nm,Term -> [Either Term Type] -> Term
mkApps (Term -> [TickInfo] -> Term
mkTicks Term
fun [TickInfo]
ticks) ([Either Term Type] -> [Either Term Type]
forall a. [a] -> [a]
reverse [Either Term Type]
remainder)),[CallTree]
us))
          Maybe [Either Term Type]
_ -> Either CallTree ((Id, Term), [CallTree])
-> NormalizeSession (Either CallTree ((Id, Term), [CallTree]))
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (((Id, Term), [CallTree])
-> Either CallTree ((Id, Term), [CallTree])
forall a b. b -> Either a b
Right ((Id
nm,Term
e),[CallTree]
us))
      Either [Char] ([Id], [(Id, Term)], Id)
_ -> do
        newInlineStrat <- Getting Bool RewriteEnv Bool -> RewriteMonad NormalizeState Bool
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting Bool RewriteEnv Bool
Getter RewriteEnv Bool
newInlineStrategy
        if newInlineStrat || isCheapFunction e
           then return (Right ((nm,e),us))
           else return (Left b)

flattenCallTree
  :: CallTree
  -> NormalizeSession CallTree
flattenCallTree :: CallTree -> NormalizeSession CallTree
flattenCallTree c :: CallTree
c@(CLeaf (Id, Binding Term)
_) = CallTree -> NormalizeSession CallTree
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return CallTree
c
flattenCallTree (CBranch (Id
nm,(Binding Id
nm' SrcSpan
sp InlineSpec
inl IsPrim
pr Term
tm Bool
r)) [CallTree]
used) = do
  flattenedUsed   <- (CallTree -> NormalizeSession CallTree)
-> [CallTree] -> RewriteMonad NormalizeState [CallTree]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: Type -> Type) a b.
Monad m =>
(a -> m b) -> [a] -> m [b]
mapM CallTree -> NormalizeSession CallTree
flattenCallTree [CallTree]
used
  (newUsed,il_ct) <- partitionEithers <$> mapM flattenNode flattenedUsed
  let (toInline,il_used) = unzip il_ct
      subst = Subst -> [(Id, Term)] -> Subst
extendGblSubstList (InScopeSet -> Subst
mkSubst InScopeSet
emptyInScopeSet) [(Id, Term)]
toInline
  newExpr <- case toInline of
    [] -> Term -> RewriteMonad NormalizeState Term
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
tm
    [(Id, Term)]
_  -> do
      let tm1 :: Term
tm1 = HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm Doc ()
"flattenCallTree.flattenExpr" Subst
subst Term
tm

      -- NB: When -fclash-debug-history is on, emit binary data holding the recorded rewrite steps
      opts <- Getting DebugOpts RewriteEnv DebugOpts
-> RewriteMonad NormalizeState DebugOpts
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting DebugOpts RewriteEnv DebugOpts
Getter RewriteEnv DebugOpts
debugOpts
      let rewriteHistFile = DebugOpts -> Maybe [Char]
dbg_historyFile DebugOpts
opts
      when (Maybe.isJust rewriteHistFile) $
        let !_ = unsafePerformIO
             $ BS.appendFile (Maybe.fromJust rewriteHistFile)
             $ BL.toStrict
             $ encode RewriteStep
                 { t_ctx    = []
                 , t_name   = "INLINE"
                 , t_bndrS  = showPpr (varName nm')
                 , t_before = tm
                 , t_after  = tm1
                 }
        in pure ()
      rewriteExpr ("flattenExpr",flatten) (showPpr nm, tm1) (nm', sp)
  let allUsed = [CallTree]
newUsed [CallTree] -> [CallTree] -> [CallTree]
forall a. [a] -> [a] -> [a]
++ [[CallTree]] -> [CallTree]
forall (t :: Type -> Type) a. Foldable t => t [a] -> [a]
concat [[CallTree]]
il_used
  -- inline all components when the resulting expression after flattening
  -- is still considered "cheap". This happens often at the topEntity which
  -- wraps another functions and has some selectors and data-constructors.
  if not (isNoInline inl) && isCheapFunction newExpr
     then do
        let (toInline',allUsed') = unzip (map goCheap allUsed)
            subst' = Subst -> [(Id, Term)] -> Subst
extendGblSubstList (InScopeSet -> Subst
mkSubst InScopeSet
emptyInScopeSet)
                                        ([Maybe (Id, Term)] -> [(Id, Term)]
forall a. [Maybe a] -> [a]
Maybe.catMaybes [Maybe (Id, Term)]
toInline')
        let tm1 = HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm Doc ()
"flattenCallTree.flattenCheap" Subst
subst' Term
newExpr
        newExpr' <- rewriteExpr ("flattenCheap",flatten) (showPpr nm, tm1) (nm', sp)
        return (CBranch (nm,(Binding nm' sp inl pr newExpr' r)) (concat allUsed'))
     else return (CBranch (nm,(Binding nm' sp inl pr newExpr r)) allUsed)
  where
    flatten :: NormRewrite
flatten =
      NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m
repeatR (NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m
topdownR ([Char] -> NormRewrite -> NormRewrite
forall extra. [Char] -> Rewrite extra -> Rewrite extra
apply [Char]
"appProp" HasCallStack => NormRewrite
NormRewrite
appProp NormRewrite -> NormRewrite -> NormRewrite
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>->
                 [Char] -> NormRewrite -> NormRewrite
forall extra. [Char] -> Rewrite extra -> Rewrite extra
apply [Char]
"bindConstantVar" HasCallStack => NormRewrite
NormRewrite
bindConstantVar NormRewrite -> NormRewrite -> NormRewrite
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>->
                 [Char] -> NormRewrite -> NormRewrite
forall extra. [Char] -> Rewrite extra -> Rewrite extra
apply [Char]
"caseCon" HasCallStack => NormRewrite
NormRewrite
caseCon NormRewrite -> NormRewrite -> NormRewrite
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>->
                 ([Char] -> NormRewrite -> NormRewrite
forall extra. [Char] -> Rewrite extra -> Rewrite extra
apply [Char]
"reduceConst" HasCallStack => NormRewrite
NormRewrite
reduceConst NormRewrite -> NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m -> Rewrite m
!-> [Char] -> NormRewrite -> NormRewrite
forall extra. [Char] -> Rewrite extra -> Rewrite extra
apply [Char]
"deadcode" HasCallStack => NormRewrite
NormRewrite
deadCode) NormRewrite -> NormRewrite -> NormRewrite
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>->
                 [Char] -> NormRewrite -> NormRewrite
forall extra. [Char] -> Rewrite extra -> Rewrite extra
apply [Char]
"reduceNonRepPrim" HasCallStack => NormRewrite
NormRewrite
reduceNonRepPrim NormRewrite -> NormRewrite -> NormRewrite
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>->
                 [Char] -> NormRewrite -> NormRewrite
forall extra. [Char] -> Rewrite extra -> Rewrite extra
apply [Char]
"removeUnusedExpr" HasCallStack => NormRewrite
NormRewrite
removeUnusedExpr) NormRewrite -> NormRewrite -> NormRewrite
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>->
               NormRewrite -> NormRewrite
forall (m :: Type -> Type). Monad m => Transform m -> Transform m
bottomupR ([Char] -> NormRewrite -> NormRewrite
forall extra. [Char] -> Rewrite extra -> Rewrite extra
apply [Char]
"flattenLet" HasCallStack => NormRewrite
NormRewrite
flattenLet)) NormRewrite -> NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m -> Rewrite m
!->
      NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m
topdownSucR ([Char] -> NormRewrite -> NormRewrite
forall extra. [Char] -> Rewrite extra -> Rewrite extra
apply [Char]
"topLet" HasCallStack => NormRewrite
NormRewrite
topLet) NormRewrite -> NormRewrite -> NormRewrite
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>->
      -- See [Note] relation `collapseRHSNoops` and `inlineCleanup`
      -- Note that we do this as the very last step, after all constant propagation
      -- has been done to avoid #3036.
      NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m
topdownSucR ([Char] -> NormRewrite -> NormRewrite
forall extra. [Char] -> Rewrite extra -> Rewrite extra
apply [Char]
"collapseRHSNoops" HasCallStack => NormRewrite
NormRewrite
collapseRHSNoops) NormRewrite -> NormRewrite -> NormRewrite
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>->
      NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m
topdownSucR ([Char] -> NormRewrite -> NormRewrite
forall extra. [Char] -> Rewrite extra -> Rewrite extra
apply [Char]
"inlineCleanup" HasCallStack => NormRewrite
NormRewrite
inlineCleanup)

    goCheap :: CallTree -> (Maybe (Id, Term), [CallTree])
goCheap c :: CallTree
c@(CLeaf   (Id
nm2,(Binding Id
_ SrcSpan
_ InlineSpec
inl2 IsPrim
_ Term
e Bool
_)))
      | InlineSpec -> Bool
isNoInline InlineSpec
inl2  = (Maybe (Id, Term)
forall a. Maybe a
Nothing     ,[CallTree
c])
      | Bool
otherwise        = ((Id, Term) -> Maybe (Id, Term)
forall a. a -> Maybe a
Just (Id
nm2,Term
e),[])
    goCheap c :: CallTree
c@(CBranch (Id
nm2,(Binding Id
_ SrcSpan
_ InlineSpec
inl2 IsPrim
_ Term
e Bool
_)) [CallTree]
us)
      | InlineSpec -> Bool
isNoInline InlineSpec
inl2  = (Maybe (Id, Term)
forall a. Maybe a
Nothing, [CallTree
c])
      | Bool
otherwise        = ((Id, Term) -> Maybe (Id, Term)
forall a. a -> Maybe a
Just (Id
nm2,Term
e),[CallTree]
us)

callTreeToList :: [Id] -> CallTree -> ([Id], [(Id, Binding Term)])
callTreeToList :: [Id] -> CallTree -> ([Id], [(Id, Binding Term)])
callTreeToList [Id]
visited (CLeaf (Id
nm,Binding Term
bndr))
  | Id
nm Id -> [Id] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: Type -> Type) a.
(Foldable t, Eq a) =>
a -> t a -> Bool
`elem` [Id]
visited = ([Id]
visited,[])
  | Bool
otherwise         = (Id
nmId -> [Id] -> [Id]
forall a. a -> [a] -> [a]
:[Id]
visited,[(Id
nm,Binding Term
bndr)])
callTreeToList [Id]
visited (CBranch (Id
nm,Binding Term
bndr) [CallTree]
used)
  | Id
nm Id -> [Id] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: Type -> Type) a.
(Foldable t, Eq a) =>
a -> t a -> Bool
`elem` [Id]
visited = ([Id]
visited,[])
  | Bool
otherwise         = ([Id]
visited',(Id
nm,Binding Term
bndr)(Id, Binding Term) -> [(Id, Binding Term)] -> [(Id, Binding Term)]
forall a. a -> [a] -> [a]
:([[(Id, Binding Term)]] -> [(Id, Binding Term)]
forall (t :: Type -> Type) a. Foldable t => t [a] -> [a]
concat [[(Id, Binding Term)]]
others))
  where
    ([Id]
visited',[[(Id, Binding Term)]]
others) = ([Id] -> CallTree -> ([Id], [(Id, Binding Term)]))
-> [Id] -> [CallTree] -> ([Id], [[(Id, Binding Term)]])
forall (t :: Type -> Type) s a b.
Traversable t =>
(s -> a -> (s, b)) -> s -> t a -> (s, t b)
mapAccumL [Id] -> CallTree -> ([Id], [(Id, Binding Term)])
callTreeToList (Id
nmId -> [Id] -> [Id]
forall a. a -> [a] -> [a]
:[Id]
visited) [CallTree]
used