{-|
  Copyright  :  (C) 2012-2016, University of Twente,
                    2016-2017, Myrtle Software Ltd,
                    2017-2018, Google Inc.,
                    2021-2023, QBayLogic B.V.
  License    :  BSD2 (see the file LICENSE)
  Maintainer :  QBayLogic B.V. <devops@qbaylogic.com>

  Transformations for specialization
-}

{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE NondecreasingIndentation #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskellQuotes #-}

module Clash.Normalize.Transformations.Specialize
  ( appProp
  , constantSpec
  , specialize
  , nonRepSpec
  , typeSpec
  , zeroWidthSpec
  ) where

import Control.Arrow ((***), (&&&))
import Control.DeepSeq (deepseq)
import Control.Exception (throw)
import Control.Lens ((%=))
import qualified Control.Lens as Lens
import qualified Control.Monad as Monad
import Control.Monad.Extra (orM)
import qualified Control.Monad.Writer as Writer (listen)
import Data.Bifunctor (bimap)
import Data.Coerce (coerce)
import qualified Data.Either as Either
import Data.Functor.Const (Const(..))
import qualified Data.Map.Strict as Map
import qualified Data.Monoid as Monoid (getAny)
import qualified Data.Set.Ordered as OSet
import qualified Data.Set.Ordered.Extra as OSet
import qualified Data.Text as Text
import qualified Data.Text.Extra as Text
import GHC.Stack (HasCallStack)

#if MIN_VERSION_ghc(9,0,0)
import GHC.Types.Basic (InlineSpec (..))
#else
import BasicTypes (InlineSpec (..))
#endif

import qualified Clash.Sized.Internal.BitVector as BV (BitVector, fromInteger#)
import qualified Clash.Sized.Internal.Index as I (Index, fromInteger#)
import qualified Clash.Sized.Internal.Signed as S (Signed, fromInteger#)
import qualified Clash.Sized.Internal.Unsigned as U (Unsigned, fromInteger#)

import Clash.Core.DataCon (DataCon(dcArgTys))
import Clash.Core.FreeVars (freeLocalVars, termFreeTyVars, typeFreeVars)
import Clash.Core.HasType
import Clash.Core.Literal (Literal(..))
import Clash.Core.Name
  (NameSort(..), Name(..), appendToName, mkUnsafeInternalName, mkUnsafeSystemName)
import Clash.Core.Pretty (showPpr)
import Clash.Core.Subst
import Clash.Core.Term
  ( Term(..), TickInfo, collectArgs, collectArgsTicks, mkApps, mkTmApps, mkTicks, patIds, Bind(..)
  , patVars, mkAbstraction, PrimInfo(..), WorkInfo(..), IsMultiPrim(..), PrimUnfolding(..), stripAllTicks)
import Clash.Core.TermInfo (isLocalVar, isVar, isPolyFun)
import Clash.Core.TyCon (TyConMap, tyConDataCons)
import Clash.Core.Type
  (LitTy(NumTy), Type(LitTy,VarTy), applyFunTy, splitTyConAppM, normalizeType
  , mkPolyFunTy, mkTyConApp)
import Clash.Core.TysPrim
import Clash.Core.Util (listToLets)
import Clash.Core.Var (Var(..), Id, TyVar, mkTyVar)
import Clash.Core.VarEnv
  ( InScopeSet, extendInScopeSet, extendInScopeSetList, lookupVarEnv
  , mkInScopeSet, mkVarSet, unionInScope, elemVarSet)
import qualified Clash.Data.UniqMap as UniqMap
import Clash.Debug (traceIf, traceM)
import Clash.Driver.Types (Binding(..), TransformationInfo(..), hasTransformationInfo)
import Clash.Netlist.Util (representableType)
import Clash.Rewrite.Combinators (topdownR)
import Clash.Rewrite.Types
  ( TransformContext(..), bindings, censor, curFun, customReprs, extra, tcCache
  , typeTranslator, workFreeBinders, debugOpts, topEntities, specializationLimit)
import Clash.Rewrite.Util
  ( mkBinderFor, mkDerivedName, mkFunction, mkTmBinderFor, setChanged, changed
  , normalizeTermTypes, normalizeId)
import Clash.Rewrite.WorkFree (isWorkFree)
import Clash.Normalize.Types
  ( NormRewrite, NormalizeSession, specialisationCache, specialisationHistory)
import Clash.Normalize.Util
  (constantSpecInfo, csrFoundConstant, csrNewBindings, csrNewTerm)
import Clash.Unique (Unique)
import Clash.Util (ClashException(..))

-- | Propagate arguments of application inwards; except for 'Lam' where the
-- argument becomes let-bound. 'appProp' tries to propagate as many arguments
-- as possible, down as many levels as possible; and should be called in a
-- top-down traversal.
--
-- The idea is that this reduces the number of traversals, which hopefully leads
-- to shorter compile times.
--
-- Note [AppProp no shadowing]
--
-- Case 1.
--
-- Imagine:
--
-- @
-- (case x of
--    D a b -> h a) (f x y)
-- @
--
-- rewriting this to:
--
-- @
-- let b = f x y
-- in  case x of
--       D a b -> h a b
-- @
--
-- is very bad because @b@ in @h a b@ is now bound by the pattern instead of the
-- newly introduced let-binding
--
-- instead we must deshadow w.r.t. the new variable and rewrite to:
--
-- @
-- let b = f x y
-- in  case x of
--       D a b1 -> h a b
-- @
--
-- Case 2.
--
-- Imagine
--
-- @
-- (\\x -> e) u
-- @
--
-- where @u@ has a free variable named @x@, rewriting this to:
--
-- @
-- let x = u
-- in  e
-- @
--
-- would be very bad, because the let-binding suddenly captures the free
-- variable in @u@. To prevent this from happening we over-approximate and check
-- whether @x@ is in the current InScopeSet, and deshadow if that's the case,
-- i.e. we then rewrite to:
--
-- @
-- let x1 = u
-- in  e [x:=x1]
-- @
--
-- Case 3.
--
-- The same for:
--
-- @
-- (let x = w in e) u
-- @
--
-- where @u@ again has a free variable @x@, rewriting this to:
--
-- @
-- let x = w in (e u)
-- @
--
-- would be bad because the let-binding now captures the free variable in @u@.
--
-- To prevent this from happening, we unconditionally deshadow the function part
-- of the application w.r.t. the free variables in the argument part of the
-- application. It is okay to over-approximate in this case and deshadow w.r.t
-- the current InScopeSet.
appProp :: HasCallStack => NormRewrite
appProp :: HasCallStack => NormRewrite
appProp ctx :: TransformContext
ctx@(TransformContext InScopeSet
is Context
_) = \case
  e :: Term
e@App {}
    | let (Term
fun,[Either Term Type]
args,[TickInfo]
ticks) = Term -> (Term, [Either Term Type], [TickInfo])
collectArgsTicks Term
e
    -> do (eN,hasChanged) <- RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState (Term, Any)
forall a.
RewriteMonad NormalizeState a
-> RewriteMonad NormalizeState (a, Any)
forall w (m :: Type -> Type) a. MonadWriter w m => m a -> m (a, w)
Writer.listen (InScopeSet
-> Term
-> [Either Term Type]
-> [TickInfo]
-> RewriteMonad NormalizeState Term
go InScopeSet
is (HasCallStack => InScopeSet -> Term -> Term
InScopeSet -> Term -> Term
deShadowTerm InScopeSet
is Term
fun) [Either Term Type]
args [TickInfo]
ticks)
          if Monoid.getAny hasChanged
            then return eN
            else return e
  e :: Term
e@TyApp {}
    | let (Term
fun,[Either Term Type]
args,[TickInfo]
ticks) = Term -> (Term, [Either Term Type], [TickInfo])
collectArgsTicks Term
e
    -> do (eN,hasChanged) <- RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState (Term, Any)
forall a.
RewriteMonad NormalizeState a
-> RewriteMonad NormalizeState (a, Any)
forall w (m :: Type -> Type) a. MonadWriter w m => m a -> m (a, w)
Writer.listen (InScopeSet
-> Term
-> [Either Term Type]
-> [TickInfo]
-> RewriteMonad NormalizeState Term
go InScopeSet
is (HasCallStack => InScopeSet -> Term -> Term
InScopeSet -> Term -> Term
deShadowTerm InScopeSet
is Term
fun) [Either Term Type]
args [TickInfo]
ticks)
          if Monoid.getAny hasChanged
            then return eN
            else return e
  Term
e          -> Term -> RewriteMonad NormalizeState Term
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
 where
  go :: InScopeSet -> Term -> [Either Term Type] -> [TickInfo] -> NormalizeSession Term
  go :: InScopeSet
-> Term
-> [Either Term Type]
-> [TickInfo]
-> RewriteMonad NormalizeState Term
go InScopeSet
is0 (Term -> (Term, [Either Term Type], [TickInfo])
collectArgsTicks -> (Term
fun,args0 :: [Either Term Type]
args0@(Either Term Type
_:[Either Term Type]
_),[TickInfo]
ticks0)) [Either Term Type]
args1 [TickInfo]
ticks1 =
    InScopeSet
-> Term
-> [Either Term Type]
-> [TickInfo]
-> RewriteMonad NormalizeState Term
go InScopeSet
is0 Term
fun ([Either Term Type]
args0 [Either Term Type] -> [Either Term Type] -> [Either Term Type]
forall a. [a] -> [a] -> [a]
++ [Either Term Type]
args1) ([TickInfo]
ticks0 [TickInfo] -> [TickInfo] -> [TickInfo]
forall a. [a] -> [a] -> [a]
++ [TickInfo]
ticks1)

  go InScopeSet
is0 (Lam Id
v Term
e) (Left Term
arg:[Either Term Type]
args) [TickInfo]
ticks = do
    RewriteMonad NormalizeState ()
forall extra. RewriteMonad extra ()
setChanged
    bndrs <- Getting
  (UniqMap (Binding Term))
  (RewriteState NormalizeState)
  (UniqMap (Binding Term))
-> RewriteMonad NormalizeState (UniqMap (Binding Term))
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting
  (UniqMap (Binding Term))
  (RewriteState NormalizeState)
  (UniqMap (Binding Term))
forall extra (f :: Type -> Type).
Functor f =>
(UniqMap (Binding Term) -> f (UniqMap (Binding Term)))
-> RewriteState extra -> f (RewriteState extra)
bindings
    orM [pure (isVar arg), isWorkFree workFreeBinders bndrs arg] >>= \case
      Bool
True ->
        let subst :: Subst
subst = Subst -> Id -> Term -> Subst
extendIdSubst (InScopeSet -> Subst
mkSubst InScopeSet
is0) Id
v Term
arg in
        (Term -> [TickInfo] -> Term
`mkTicks` [TickInfo]
ticks) (Term -> Term)
-> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> InScopeSet
-> Term
-> [Either Term Type]
-> [TickInfo]
-> RewriteMonad NormalizeState Term
go InScopeSet
is0 (HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm Doc ()
"appProp.AppLam" Subst
subst Term
e) [Either Term Type]
args []
      Bool
False ->
        let is1 :: InScopeSet
is1 = InScopeSet -> Id -> InScopeSet
forall a. InScopeSet -> Var a -> InScopeSet
extendInScopeSet InScopeSet
is0 Id
v in
        Bind Term -> Term -> Term
Let (Id -> Term -> Bind Term
forall a. Id -> a -> Bind a
NonRec Id
v Term
arg) (Term -> Term)
-> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> InScopeSet
-> Term
-> [Either Term Type]
-> [TickInfo]
-> RewriteMonad NormalizeState Term
go InScopeSet
is1 (HasCallStack => InScopeSet -> Term -> Term
InScopeSet -> Term -> Term
deShadowTerm InScopeSet
is1 Term
e) [Either Term Type]
args [TickInfo]
ticks

  go InScopeSet
is0 (Let (NonRec Id
i Term
x) Term
e) args :: [Either Term Type]
args@(Either Term Type
_:[Either Term Type]
_) [TickInfo]
ticks = do
    RewriteMonad NormalizeState ()
forall extra. RewriteMonad extra ()
setChanged
    let is1 :: InScopeSet
is1 = InScopeSet -> Id -> InScopeSet
forall a. InScopeSet -> Var a -> InScopeSet
extendInScopeSet InScopeSet
is0 Id
i
    -- XXX: binding should already be deshadowed w.r.t. 'is0'
    Bind Term -> Term -> Term
Let (Id -> Term -> Bind Term
forall a. Id -> a -> Bind a
NonRec Id
i Term
x) (Term -> Term)
-> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> InScopeSet
-> Term
-> [Either Term Type]
-> [TickInfo]
-> RewriteMonad NormalizeState Term
go InScopeSet
is1 Term
e [Either Term Type]
args [TickInfo]
ticks

  go InScopeSet
is0 (Let (Rec [(Id, Term)]
vs) Term
e) args :: [Either Term Type]
args@(Either Term Type
_:[Either Term Type]
_) [TickInfo]
ticks = do
    RewriteMonad NormalizeState ()
forall extra. RewriteMonad extra ()
setChanged
    let vbs :: [Id]
vbs  = ((Id, Term) -> Id) -> [(Id, Term)] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map (Id, Term) -> Id
forall a b. (a, b) -> a
fst [(Id, Term)]
vs
        is1 :: InScopeSet
is1  = InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
is0 [Id]
vbs
    -- XXX: 'vs' should already be deshadowed w.r.t. 'is0'
    Bind Term -> Term -> Term
Let ([(Id, Term)] -> Bind Term
forall a. [(Id, a)] -> Bind a
Rec [(Id, Term)]
vs) (Term -> Term)
-> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> InScopeSet
-> Term
-> [Either Term Type]
-> [TickInfo]
-> RewriteMonad NormalizeState Term
go InScopeSet
is1 Term
e [Either Term Type]
args [TickInfo]
ticks

  go InScopeSet
is0 (TyLam TyVar
tv Term
e) (Right Type
t:[Either Term Type]
args) [TickInfo]
ticks = do
    RewriteMonad NormalizeState ()
forall extra. RewriteMonad extra ()
setChanged
    let subst :: Subst
subst = Subst -> TyVar -> Type -> Subst
extendTvSubst (InScopeSet -> Subst
mkSubst InScopeSet
is0) TyVar
tv Type
t
    (Term -> [TickInfo] -> Term
`mkTicks` [TickInfo]
ticks) (Term -> Term)
-> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> InScopeSet
-> Term
-> [Either Term Type]
-> [TickInfo]
-> RewriteMonad NormalizeState Term
go InScopeSet
is0 (HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm Doc ()
"appProp.TyAppTyLam" Subst
subst Term
e) [Either Term Type]
args []

  go InScopeSet
is0 (Case Term
scrut Type
ty0 [Alt]
alts) args0 :: [Either Term Type]
args0@(Either Term Type
_:[Either Term Type]
_) [TickInfo]
ticks = do
    RewriteMonad NormalizeState ()
forall extra. RewriteMonad extra ()
setChanged
    let isA1 :: InScopeSet
isA1 = InScopeSet -> InScopeSet -> InScopeSet
unionInScope
                 InScopeSet
is0
                 ((UniqMap (Var Any) -> InScopeSet
mkInScopeSet (UniqMap (Var Any) -> InScopeSet)
-> ([Alt] -> UniqMap (Var Any)) -> [Alt] -> InScopeSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Var (ZonkAny 2)] -> UniqMap (Var Any)
forall a. [Var a] -> UniqMap (Var Any)
mkVarSet ([Var (ZonkAny 2)] -> UniqMap (Var Any))
-> ([Alt] -> [Var (ZonkAny 2)]) -> [Alt] -> UniqMap (Var Any)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Alt -> [Var (ZonkAny 2)]) -> [Alt] -> [Var (ZonkAny 2)]
forall (t :: Type -> Type) a b.
Foldable t =>
(a -> [b]) -> t a -> [b]
concatMap (Pat -> [Var (ZonkAny 2)]
forall a. Pat -> [Var a]
patVars (Pat -> [Var (ZonkAny 2)])
-> (Alt -> Pat) -> Alt -> [Var (ZonkAny 2)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Alt -> Pat
forall a b. (a, b) -> a
fst)) [Alt]
alts)
    (ty1,vs,args1) <- InScopeSet
-> Type
-> [(Id, Term)]
-> [Either Term Type]
-> RewriteMonad
     NormalizeState (Type, [(Id, Term)], [Either Term Type])
forall {m :: Type -> Type} {extra}.
(MonadState (RewriteState extra) m, MonadReader RewriteEnv m,
 MonadUnique m) =>
InScopeSet
-> Type
-> [(Id, Term)]
-> [Either Term Type]
-> m (Type, [(Id, Term)], [Either Term Type])
goCaseArg InScopeSet
isA1 Type
ty0 [] [Either Term Type]
args0
    case vs of
      [] -> (Term -> [TickInfo] -> Term
`mkTicks` [TickInfo]
ticks) (Term -> Term) -> ([Alt] -> Term) -> [Alt] -> Term
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Term -> Type -> [Alt] -> Term
Case Term
scrut Type
ty1 ([Alt] -> Term)
-> RewriteMonad NormalizeState [Alt]
-> RewriteMonad NormalizeState Term
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> (Alt -> RewriteMonad NormalizeState Alt)
-> [Alt] -> RewriteMonad NormalizeState [Alt]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: Type -> Type) a b.
Monad m =>
(a -> m b) -> [a] -> m [b]
mapM (InScopeSet
-> [Either Term Type] -> Alt -> RewriteMonad NormalizeState Alt
goAlt InScopeSet
is0 [Either Term Type]
args1) [Alt]
alts
      [(Id, Term)]
_  -> do
        let vbs :: [Id]
vbs   = ((Id, Term) -> Id) -> [(Id, Term)] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map (Id, Term) -> Id
forall a b. (a, b) -> a
fst [(Id, Term)]
vs
            is1 :: InScopeSet
is1   = InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
is0 [Id]
vbs
            alts1 :: [Alt]
alts1 = (Alt -> Alt) -> [Alt] -> [Alt]
forall a b. (a -> b) -> [a] -> [b]
map (HasCallStack => InScopeSet -> Alt -> Alt
InScopeSet -> Alt -> Alt
deShadowAlt InScopeSet
is1) [Alt]
alts
        -- TODO I should have a mkNonRecLets :: [LetBinding] -> Term -> Term
        -- function which makes a chain of non-recursive let expressions without
        -- needing to first take the SCCs of all the binders.
        [(Id, Term)] -> Term -> Term
listToLets [(Id, Term)]
vs (Term -> Term) -> ([Alt] -> Term) -> [Alt] -> Term
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Term -> [TickInfo] -> Term
`mkTicks` [TickInfo]
ticks) (Term -> Term) -> ([Alt] -> Term) -> [Alt] -> Term
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Term -> Type -> [Alt] -> Term
Case Term
scrut Type
ty1 ([Alt] -> Term)
-> RewriteMonad NormalizeState [Alt]
-> RewriteMonad NormalizeState Term
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> (Alt -> RewriteMonad NormalizeState Alt)
-> [Alt] -> RewriteMonad NormalizeState [Alt]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: Type -> Type) a b.
Monad m =>
(a -> m b) -> [a] -> m [b]
mapM (InScopeSet
-> [Either Term Type] -> Alt -> RewriteMonad NormalizeState Alt
goAlt InScopeSet
is1 [Either Term Type]
args1) [Alt]
alts1

  go InScopeSet
is0 (Tick TickInfo
sp Term
e) [Either Term Type]
args [TickInfo]
ticks = do
    RewriteMonad NormalizeState ()
forall extra. RewriteMonad extra ()
setChanged
    InScopeSet
-> Term
-> [Either Term Type]
-> [TickInfo]
-> RewriteMonad NormalizeState Term
go InScopeSet
is0 Term
e [Either Term Type]
args (TickInfo
spTickInfo -> [TickInfo] -> [TickInfo]
forall a. a -> [a] -> [a]
:[TickInfo]
ticks)

  go InScopeSet
_ Term
fun [Either Term Type]
args [TickInfo]
ticks = Term -> RewriteMonad NormalizeState Term
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Term -> [Either Term Type] -> Term
mkApps (Term -> [TickInfo] -> Term
mkTicks Term
fun [TickInfo]
ticks) [Either Term Type]
args)

  goAlt :: InScopeSet
-> [Either Term Type] -> Alt -> RewriteMonad NormalizeState Alt
goAlt InScopeSet
is0 [Either Term Type]
args0 (Pat
p,Term
e) = do
    let ([TyVar]
tvs,[Id]
ids) = Pat -> ([TyVar], [Id])
patIds Pat
p
        is1 :: InScopeSet
is1       = InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList (InScopeSet -> [TyVar] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
is0 [TyVar]
tvs) [Id]
ids
    (Pat
p,) (Term -> Alt)
-> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Alt
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> InScopeSet
-> Term
-> [Either Term Type]
-> [TickInfo]
-> RewriteMonad NormalizeState Term
go InScopeSet
is1 Term
e [Either Term Type]
args0 []

  goCaseArg :: InScopeSet
-> Type
-> [(Id, Term)]
-> [Either Term Type]
-> m (Type, [(Id, Term)], [Either Term Type])
goCaseArg InScopeSet
isA Type
ty0 [(Id, Term)]
ls0 (Right Type
t:[Either Term Type]
args0) = do
    tcm <- Getting TyConMap RewriteEnv TyConMap -> m TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
    let ty1 = HasCallStack => TyConMap -> Type -> Type -> Type
TyConMap -> Type -> Type -> Type
piResultTy TyConMap
tcm Type
ty0 Type
t
    (ty2,ls1,args1) <- goCaseArg isA ty1 ls0 args0
    return (ty2,ls1,Right t:args1)

  goCaseArg InScopeSet
isA0 Type
ty0 [(Id, Term)]
ls0 (Left Term
arg:[Either Term Type]
args0) = do
    tcm <- Getting TyConMap RewriteEnv TyConMap -> m TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
    bndrs <- Lens.use bindings
    let argTy = TyConMap -> Term -> Type
forall a. InferType a => TyConMap -> a -> Type
inferCoreTypeOf TyConMap
tcm Term
arg
        ty1   = TyConMap -> Type -> Type -> Type
applyFunTy TyConMap
tcm Type
ty0 Type
argTy
    orM [pure (isVar arg), isWorkFree workFreeBinders bndrs arg] >>= \case
      Bool
True -> do
        (ty2,ls1,args1) <- InScopeSet
-> Type
-> [(Id, Term)]
-> [Either Term Type]
-> m (Type, [(Id, Term)], [Either Term Type])
goCaseArg InScopeSet
isA0 Type
ty1 [(Id, Term)]
ls0 [Either Term Type]
args0
        return (ty2,ls1,Left arg:args1)
      Bool
False -> do
        boundArg <- InScopeSet -> TyConMap -> Name Term -> Term -> m Id
forall (m :: Type -> Type) a.
MonadUnique m =>
InScopeSet -> TyConMap -> Name a -> Term -> m Id
mkTmBinderFor InScopeSet
isA0 TyConMap
tcm (TransformContext -> OccName -> Name Term
mkDerivedName TransformContext
ctx OccName
"app_arg") Term
arg
        let isA1 = InScopeSet -> Id -> InScopeSet
forall a. InScopeSet -> Var a -> InScopeSet
extendInScopeSet InScopeSet
isA0 Id
boundArg
        (ty2,ls1,args1) <- goCaseArg isA1 ty1 ls0 args0
        return (ty2,(boundArg,arg):ls1,Left (Var boundArg):args1)

  goCaseArg InScopeSet
_ Type
ty [(Id, Term)]
ls [] = (Type, [(Id, Term)], [Either Term Type])
-> m (Type, [(Id, Term)], [Either Term Type])
forall a. a -> m a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Type
ty,[(Id, Term)]
ls,[])
{-# SCC appProp #-}

-- | Specialize functions on arguments which are constant, except when they
-- are clock, reset generators.
constantSpec :: HasCallStack => NormRewrite
constantSpec :: HasCallStack => NormRewrite
constantSpec ctx :: TransformContext
ctx@(TransformContext InScopeSet
is0 Context
tfCtx) e :: Term
e@(App Term
e1 Term
e2)
  | (Var {}, [Either Term Type]
args) <- Term -> (Term, [Either Term Type])
collectArgs Term
e1
  , ([Term]
_, []) <- [Either Term Type] -> ([Term], [Type])
forall a b. [Either a b] -> ([a], [b])
Either.partitionEithers [Either Term Type]
args
  , [TyVar] -> Bool
forall a. [a] -> Bool
forall (t :: Type -> Type) a. Foldable t => t a -> Bool
null ([TyVar] -> Bool) -> [TyVar] -> Bool
forall a b. (a -> b) -> a -> b
$ Getting (Endo [TyVar]) Term TyVar -> Term -> [TyVar]
forall a s. Getting (Endo [a]) s a -> s -> [a]
Lens.toListOf Getting (Endo [TyVar]) Term TyVar
Fold Term TyVar
termFreeTyVars Term
e2
  = do specInfo<- TransformContext
-> Term -> RewriteMonad NormalizeState ConstantSpecInfo
constantSpecInfo TransformContext
ctx Term
e2
       if csrFoundConstant specInfo then
         let newBindings = ConstantSpecInfo -> [(Id, Term)]
csrNewBindings ConstantSpecInfo
specInfo in
         if null newBindings then
           -- Whole of e2 is constant
           specialize ctx (App e1 e2)
         else do
           -- Parts of e2 are constant
           let is1 = InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
is0 ((Id, Term) -> Id
forall a b. (a, b) -> a
fst ((Id, Term) -> Id) -> [(Id, Term)] -> [Id]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> ConstantSpecInfo -> [(Id, Term)]
csrNewBindings ConstantSpecInfo
specInfo)
           (body, isSpec) <- Writer.listen $ specialize
             (TransformContext is1 tfCtx)
             (App e1 (csrNewTerm specInfo))

           if Monoid.getAny isSpec
             then changed (listToLets newBindings body)
             else return e
       else
        -- e2 has no constant parts
        return e
constantSpec TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC constantSpec #-}

-- | Specialize an application on its argument
specialize :: NormRewrite
specialize :: NormRewrite
specialize TransformContext
ctx Term
e = case Term
e of
  (TyApp Term
e1 Type
ty) -> TransformContext
-> Term
-> (Term, [Either Term Type], [TickInfo])
-> Either Term Type
-> RewriteMonad NormalizeState Term
specialize' TransformContext
ctx Term
e (Term -> (Term, [Either Term Type], [TickInfo])
collectArgsTicks Term
e1) (Type -> Either Term Type
forall a b. b -> Either a b
Right Type
ty)
  (App Term
e1 Term
e2)   -> TransformContext
-> Term
-> (Term, [Either Term Type], [TickInfo])
-> Either Term Type
-> RewriteMonad NormalizeState Term
specialize' TransformContext
ctx Term
e (Term -> (Term, [Either Term Type], [TickInfo])
collectArgsTicks Term
e1) (Term -> Either Term Type
forall a b. a -> Either a b
Left  Term
e2)
  Term
_             -> Term -> RewriteMonad NormalizeState Term
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e

{-
Note [ticks and specialization]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
As Clash now distinguishes between ticks in expressions when comparing for
alpha equality, this has a knock-on effect when accessing the specialization
cache. Consider these applications which differ only by ticks:

    f[GlobalId] (\x -> ... x[LocalId])
    f[GlobalId] <tick>(\x -> ... x[LocalId])
    f[GlobalId] (\x -> ... <tick>x[LocalId])

If one of these had been specialized, the other two would hit that term in the
specialization cache, saving Clash from having to re-do work which is in effect
the same. To preserve this behaviour, we use 'stripAllTicks' on the keys for
the specialization cache.

TODO While this preserves the old behaviour, the old behaviour is likely not
quite what we want. Using a value from the specialization cache may change the
ticks present, which can affect naming / debugging information in generated HDL.
We may also not want to look at ticks, as then the specialization cache will
miss on virtually every lookup which could add to normalization time.
-}

-- | Given two 'InlineSpec's, return the \"strongest\" one. I.e., the one that's
-- closest to @NoInline@ (or @Opaque@ for newer GHCs).
preferNoInline :: InlineSpec -> InlineSpec -> InlineSpec
preferNoInline :: InlineSpec -> InlineSpec -> InlineSpec
preferNoInline InlineSpec
is0 InlineSpec
is1
  | InlineSpec -> Int
enumInlineSpec InlineSpec
is0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= InlineSpec -> Int
enumInlineSpec InlineSpec
is1 = InlineSpec
is0
  | Bool
otherwise                                = InlineSpec
is1
 where
  enumInlineSpec :: InlineSpec -> Int
  enumInlineSpec :: InlineSpec -> Int
enumInlineSpec = \case
#if MIN_VERSION_ghc(9,2,0)
    NoUserInlinePrag {} -> -Int
1
#else
    NoUserInline {} -> -1
#endif
    Inline {} -> Int
0
    Inlinable {} -> Int
1
    NoInline {} -> Int
2
#if MIN_VERSION_ghc(9,4,0)
    Opaque {} -> Int
3
#endif

-- | Specialize an application on its argument
specialize'
  :: TransformContext
  -- ^ Transformation context
  -> Term
  -- ^ Original term
  -> (Term, [Either Term Type], [TickInfo])
  -- ^ Function part of the term, split into root and applied arguments
  -> Either Term Type
  -- ^ Argument to specialize on
  -> NormalizeSession Term
specialize' :: TransformContext
-> Term
-> (Term, [Either Term Type], [TickInfo])
-> Either Term Type
-> RewriteMonad NormalizeState Term
specialize' (TransformContext InScopeSet
is0 Context
_) Term
e (Var Id
f, [Either Term Type]
args, [TickInfo]
ticks) Either Term Type
specArgIn = do
  opts <- Getting DebugOpts RewriteEnv DebugOpts
-> RewriteMonad NormalizeState DebugOpts
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting DebugOpts RewriteEnv DebugOpts
Getter RewriteEnv DebugOpts
debugOpts
  tcm <- Lens.view tcCache

  -- Don't specialize TopEntities
  topEnts <- Lens.view topEntities
  if f `elemVarSet` topEnts
  then do
    case specArgIn of
      Left Term
_ -> do
        [Char] -> RewriteMonad NormalizeState ()
forall (f :: Type -> Type). Applicative f => [Char] -> f ()
traceM ([Char]
"Not specializing TopEntity: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Name Term -> [Char]
forall p. PrettyPrec p => p -> [Char]
showPpr (Id -> Name Term
forall a. Var a -> Name a
varName Id
f))
        Term -> RewriteMonad NormalizeState Term
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
      Right Type
tyArg ->
        Bool
-> [Char]
-> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall a. Bool -> [Char] -> a -> a
traceIf (TransformationInfo -> DebugOpts -> Bool
hasTransformationInfo TransformationInfo
AppliedTerm DebugOpts
opts) ([Char]
"Dropping type application on TopEntity: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Name Term -> [Char]
forall p. PrettyPrec p => p -> [Char]
showPpr (Id -> Name Term
forall a. Var a -> Name a
varName Id
f) [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"\ntype:\n" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Type -> [Char]
forall p. PrettyPrec p => p -> [Char]
showPpr Type
tyArg) (RewriteMonad NormalizeState Term
 -> RewriteMonad NormalizeState Term)
-> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall a b. (a -> b) -> a -> b
$
        -- TopEntities aren't allowed to be semantically polymorphic.
        -- But using type equality constraints they may be syntactically polymorphic.
        -- > topEntity :: forall dom . (dom ~ "System") => Signal dom Bool -> Signal dom Bool
        -- The TyLam's in the body will have been removed by 'Clash.Normalize.Util.substWithTyEq'.
        -- So we drop the TyApp ("specializing" on it) and change the varType to match.
        let newVarTy :: Type
newVarTy = HasCallStack => TyConMap -> Type -> Type -> Type
TyConMap -> Type -> Type -> Type
piResultTy TyConMap
tcm (Id -> Type
forall a. HasType a => a -> Type
coreTypeOf Id
f) Type
tyArg
        in  Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> [Either Term Type] -> Term
mkApps (Term -> [TickInfo] -> Term
mkTicks (Id -> Term
Var Id
f{varType = newVarTy}) [TickInfo]
ticks) [Either Term Type]
args)
  else do -- NondecreasingIndentation

  let specArg = (Term -> Term)
-> (Type -> Type) -> Either Term Type -> Either Term Type
forall a b c d. (a -> b) -> (c -> d) -> Either a c -> Either b d
forall (p :: Type -> Type -> Type) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap (TyConMap -> Term -> Term
normalizeTermTypes TyConMap
tcm) (TyConMap -> Type -> Type
normalizeType TyConMap
tcm) Either Term Type
specArgIn
      -- Create binders and variable references for free variables in 'specArg'
      -- (specBndrsIn,specVars) :: ([Either Id TyVar], [Either Term Type])
      (specBndrsIn,specVars) = specArgBndrsAndVars specArg
      argLen  = [Either Term Type] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [Either Term Type]
args
      specBndrs :: [Either Id TyVar]
      specBndrs = (Either Id TyVar -> Either Id TyVar)
-> [Either Id TyVar] -> [Either Id TyVar]
forall a b. (a -> b) -> [a] -> [b]
map (ASetter (Either Id TyVar) (Either Id TyVar) Id Id
-> (Id -> Id) -> Either Id TyVar -> Either Id TyVar
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
Lens.over ASetter (Either Id TyVar) (Either Id TyVar) Id Id
forall a c b (p :: Type -> Type -> Type) (f :: Type -> Type).
(Choice p, Applicative f) =>
p a (f b) -> p (Either a c) (f (Either b c))
Lens._Left (TyConMap -> Id -> Id
normalizeId TyConMap
tcm)) [Either Id TyVar]
specBndrsIn

      -- See Note [ticks and specialization]
      specAbs :: Either Term Type
      specAbs = (Term -> Either Term Type)
-> (Type -> Either Term Type)
-> Either Term Type
-> Either Term Type
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Term -> Either Term Type
forall a b. a -> Either a b
Left (Term -> Either Term Type)
-> (Term -> Term) -> Term -> Either Term Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Term -> Term
stripAllTicks (Term -> Term) -> (Term -> Term) -> Term -> Term
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Term -> [Either Id TyVar] -> Term
`mkAbstraction` [Either Id TyVar]
specBndrs)) (Type -> Either Term Type
forall a b. b -> Either a b
Right (Type -> Either Term Type)
-> (Type -> Type) -> Type -> Either Term Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Type
forall a. a -> a
id) Either Term Type
specArg
  -- Determine if 'f' has already been specialized on (a type-normalized) 'specArg'
  specM <- Map.lookup (f,argLen,specAbs) <$> Lens.use (extra.specialisationCache)
  case specM of
    -- Use previously specialized function
    Just Id
f' ->
      Bool
-> [Char]
-> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall a. Bool -> [Char] -> a -> a
traceIf (TransformationInfo -> DebugOpts -> Bool
hasTransformationInfo TransformationInfo
AppliedTerm DebugOpts
opts)
        ([Char]
"Using previous specialization of " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Name Term -> [Char]
forall p. PrettyPrec p => p -> [Char]
showPpr (Id -> Name Term
forall a. Var a -> Name a
varName Id
f) [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
" on " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++
          ((Term -> [Char]) -> (Type -> [Char]) -> Either Term Type -> [Char]
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either Term -> [Char]
forall p. PrettyPrec p => p -> [Char]
showPpr Type -> [Char]
forall p. PrettyPrec p => p -> [Char]
showPpr) Either Term Type
specAbs [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
": " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Name Term -> [Char]
forall p. PrettyPrec p => p -> [Char]
showPpr (Id -> Name Term
forall a. Var a -> Name a
varName Id
f')) (RewriteMonad NormalizeState Term
 -> RewriteMonad NormalizeState Term)
-> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall a b. (a -> b) -> a -> b
$
        Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> RewriteMonad NormalizeState Term)
-> Term -> RewriteMonad NormalizeState Term
forall a b. (a -> b) -> a -> b
$ Term -> [Either Term Type] -> Term
mkApps (Term -> [TickInfo] -> Term
mkTicks (Id -> Term
Var Id
f') [TickInfo]
ticks) ([Either Term Type]
args [Either Term Type] -> [Either Term Type] -> [Either Term Type]
forall a. [a] -> [a] -> [a]
++ [Either Term Type]
specVars)
    -- Create new specialized function
    Maybe Id
Nothing -> do
      -- Determine if we can specialize f
      bodyMaybe <- (UniqMap (Binding Term) -> Maybe (Binding Term))
-> RewriteMonad NormalizeState (UniqMap (Binding Term))
-> RewriteMonad NormalizeState (Maybe (Binding Term))
forall a b.
(a -> b)
-> RewriteMonad NormalizeState a -> RewriteMonad NormalizeState b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (Id -> UniqMap (Binding Term) -> Maybe (Binding Term)
forall a b. Uniquable a => a -> UniqMap b -> Maybe b
UniqMap.lookup Id
f) (RewriteMonad NormalizeState (UniqMap (Binding Term))
 -> RewriteMonad NormalizeState (Maybe (Binding Term)))
-> RewriteMonad NormalizeState (UniqMap (Binding Term))
-> RewriteMonad NormalizeState (Maybe (Binding Term))
forall a b. (a -> b) -> a -> b
$ Getting
  (UniqMap (Binding Term))
  (RewriteState NormalizeState)
  (UniqMap (Binding Term))
-> RewriteMonad NormalizeState (UniqMap (Binding Term))
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting
  (UniqMap (Binding Term))
  (RewriteState NormalizeState)
  (UniqMap (Binding Term))
forall extra (f :: Type -> Type).
Functor f =>
(UniqMap (Binding Term) -> f (UniqMap (Binding Term)))
-> RewriteState extra -> f (RewriteState extra)
bindings
      case bodyMaybe of
        Just (Binding Id
_ SrcSpan
sp InlineSpec
inl IsPrim
_ Term
bodyTm Bool
_) -> do
          -- Determine if we see a sequence of specializations on a growing argument
          specHistM <- Id -> UniqMap Int -> Maybe Int
forall a b. Uniquable a => a -> UniqMap b -> Maybe b
UniqMap.lookup Id
f (UniqMap Int -> Maybe Int)
-> RewriteMonad NormalizeState (UniqMap Int)
-> RewriteMonad NormalizeState (Maybe Int)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Getting (UniqMap Int) (RewriteState NormalizeState) (UniqMap Int)
-> RewriteMonad NormalizeState (UniqMap Int)
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use ((NormalizeState -> Const (UniqMap Int) NormalizeState)
-> RewriteState NormalizeState
-> Const (UniqMap Int) (RewriteState NormalizeState)
forall extra1 extra2 (f :: Type -> Type).
Functor f =>
(extra1 -> f extra2)
-> RewriteState extra1 -> f (RewriteState extra2)
extra((NormalizeState -> Const (UniqMap Int) NormalizeState)
 -> RewriteState NormalizeState
 -> Const (UniqMap Int) (RewriteState NormalizeState))
-> ((UniqMap Int -> Const (UniqMap Int) (UniqMap Int))
    -> NormalizeState -> Const (UniqMap Int) NormalizeState)
-> Getting
     (UniqMap Int) (RewriteState NormalizeState) (UniqMap Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(UniqMap Int -> Const (UniqMap Int) (UniqMap Int))
-> NormalizeState -> Const (UniqMap Int) NormalizeState
Lens' NormalizeState (UniqMap Int)
specialisationHistory)
          specLim   <- Lens.view specializationLimit
          if maybe False (> specLim) specHistM
            then throw (ClashException
                        sp
                        (unlines [ "Hit specialization limit " ++ show specLim ++ " on function `" ++ showPpr (varName f) ++ "'.\n"
                                 , "The function `" ++ showPpr f ++ "' is most likely recursive, and looks like it is being indefinitely specialized on a growing argument.\n"
                                 , "Body of `" ++ showPpr f ++ "':\n" ++ showPpr bodyTm ++ "\n"
                                 , "Argument (in position: " ++ show argLen ++ ") that triggered termination:\n" ++ (either showPpr showPpr) specArg
                                 , "Run with '-fclash-spec-limit=N' to increase the specialization limit to N."
                                 ])
                        Nothing)
            else do
              let existingNames = Term -> [Name a]
forall a. Term -> [Name a]
collectBndrsMinusApps Term
bodyTm
                  newNames      = [ OccName -> Unique -> Name a
forall a. OccName -> Unique -> Name a
mkUnsafeInternalName (OccName
"pTS" OccName -> OccName -> OccName
`Text.append` [Char] -> OccName
Text.pack (Unique -> [Char]
forall a. Show a => a -> [Char]
show Unique
n)) Unique
n
                                  | Unique
n <- [(Unique
0::Unique)..]
                                  ]
              -- Make new binders for existing arguments
              (boundArgs,argVars) <- fmap (unzip . map (either (Left &&& Left . Var) (Right &&& Right . VarTy))) $
                                     Monad.zipWithM
                                       (mkBinderFor is0 tcm)
                                       (existingNames ++ newNames)
                                       args
              -- Determine name the resulting specialized function, and the
              -- form of the specialized-on argument
              (fId,inl',specArg') <- case specArg of
                Left a :: Term
a@(Term -> (Term, [Either Term Type], [TickInfo])
collectArgsTicks -> (Var Id
g,[Either Term Type]
gArgs,[TickInfo]
_gTicks)) -> if TyConMap -> Term -> Bool
isPolyFun TyConMap
tcm Term
a
                    then do
                      -- In case we are specializing on an argument that is a
                      -- global function then we use that function's name as the
                      -- name of the specialized higher-order function.
                      -- Additionally, we will return the body of the global
                      -- function, instead of a variable reference to the
                      -- global function.
                      --
                      -- This will turn things like @mealy g k@ into a new
                      -- binding @g'@ where both the body of @mealy@ and @g@
                      -- are inlined, meaning the state-transition-function
                      -- and the memory element will be in a single function.
                      --
                      -- Finally, we must make sure we do not inline the bodies
                      -- of functions with a Synthesize annotation, as that would
                      -- duplicate Clash compiler work. See also issue #3024
                      gTmM <- (UniqMap (Binding Term) -> Maybe (Binding Term))
-> RewriteMonad NormalizeState (UniqMap (Binding Term))
-> RewriteMonad NormalizeState (Maybe (Binding Term))
forall a b.
(a -> b)
-> RewriteMonad NormalizeState a -> RewriteMonad NormalizeState b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (Id -> UniqMap (Binding Term) -> Maybe (Binding Term)
forall a b. Uniquable a => a -> UniqMap b -> Maybe b
UniqMap.lookup Id
g) (RewriteMonad NormalizeState (UniqMap (Binding Term))
 -> RewriteMonad NormalizeState (Maybe (Binding Term)))
-> RewriteMonad NormalizeState (UniqMap (Binding Term))
-> RewriteMonad NormalizeState (Maybe (Binding Term))
forall a b. (a -> b) -> a -> b
$ Getting
  (UniqMap (Binding Term))
  (RewriteState NormalizeState)
  (UniqMap (Binding Term))
-> RewriteMonad NormalizeState (UniqMap (Binding Term))
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting
  (UniqMap (Binding Term))
  (RewriteState NormalizeState)
  (UniqMap (Binding Term))
forall extra (f :: Type -> Type).
Functor f =>
(UniqMap (Binding Term) -> f (UniqMap (Binding Term)))
-> RewriteState extra -> f (RewriteState extra)
bindings
                      let gBody = if Id
g Id -> UniqMap (Var Any) -> Bool
forall a. Var a -> UniqMap (Var Any) -> Bool
`elemVarSet` UniqMap (Var Any)
topEnts then
                                    Maybe Term
forall a. Maybe a
Nothing
                                  else
                                    (Binding Term -> Term) -> Maybe (Binding Term) -> Maybe Term
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap Binding Term -> Term
forall a. Binding a -> a
bindingTerm Maybe (Binding Term)
gTmM
                      return
                        ( g
                        , preferNoInline inl (maybe noUserInline bindingSpec gTmM)
                        , maybe specArg (Left . (`mkApps` gArgs)) gBody
                        )
                    else (Id, InlineSpec, Either Term Type)
-> RewriteMonad NormalizeState (Id, InlineSpec, Either Term Type)
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Id
f,InlineSpec
inl,Either Term Type
specArg)
                Either Term Type
_ -> (Id, InlineSpec, Either Term Type)
-> RewriteMonad NormalizeState (Id, InlineSpec, Either Term Type)
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Id
f,InlineSpec
inl,Either Term Type
specArg)
              -- Create specialized functions
              let newBody = Term -> [Either Id TyVar] -> Term
mkAbstraction (Term -> [Either Term Type] -> Term
mkApps Term
bodyTm ([Either Term Type]
argVars [Either Term Type] -> [Either Term Type] -> [Either Term Type]
forall a. [a] -> [a] -> [a]
++ [Either Term Type
specArg'])) ([Either Id TyVar]
boundArgs [Either Id TyVar] -> [Either Id TyVar] -> [Either Id TyVar]
forall a. [a] -> [a] -> [a]
++ [Either Id TyVar]
specBndrs)
              newf <- mkFunction (varName fId) sp inl' newBody
              -- Remember specialization
              (extra.specialisationHistory) %= UniqMap.insertWith (+) f 1
              (extra.specialisationCache)  %= Map.insert (f,argLen,specAbs) newf
              -- use specialized function
              let newExpr = Term -> [Either Term Type] -> Term
mkApps (Term -> [TickInfo] -> Term
mkTicks (Id -> Term
Var Id
newf) [TickInfo]
ticks) ([Either Term Type]
args [Either Term Type] -> [Either Term Type] -> [Either Term Type]
forall a. [a] -> [a] -> [a]
++ [Either Term Type]
specVars)
              newf `deepseq` changed newExpr
        Maybe (Binding Term)
Nothing -> Term -> RewriteMonad NormalizeState Term
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
  where
    noUserInline :: InlineSpec
#if MIN_VERSION_ghc(9,2,0)
    noUserInline :: InlineSpec
noUserInline = InlineSpec
NoUserInlinePrag
#else
    noUserInline = NoUserInline
#endif

    collectBndrsMinusApps :: Term -> [Name a]
    collectBndrsMinusApps :: forall a. Term -> [Name a]
collectBndrsMinusApps = [Name a] -> [Name a]
forall a. [a] -> [a]
reverse ([Name a] -> [Name a]) -> (Term -> [Name a]) -> Term -> [Name a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Name a] -> Term -> [Name a]
forall {a}. Coercible a (Name Term) => [a] -> Term -> [a]
go []
      where
        go :: [a] -> Term -> [a]
go [a]
bs (Lam Id
v Term
e')    = [a] -> Term -> [a]
go (Name Term -> a
forall a b. Coercible a b => a -> b
coerce (Id -> Name Term
forall a. Var a -> Name a
varName Id
v)a -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
bs)  Term
e'
        go [a]
bs (TyLam TyVar
tv Term
e') = [a] -> Term -> [a]
go (Name Type -> a
forall a b. Coercible a b => a -> b
coerce (TyVar -> Name Type
forall a. Var a -> Name a
varName TyVar
tv)a -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
bs) Term
e'
        go [a]
bs (App Term
e' Term
_) = case [a] -> Term -> [a]
go [] Term
e' of
          []  -> [a]
bs
          [a]
bs' -> [a] -> [a]
forall a. HasCallStack => [a] -> [a]
init [a]
bs' [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
bs
        go [a]
bs (TyApp Term
e' Type
_) = case [a] -> Term -> [a]
go [] Term
e' of
          []  -> [a]
bs
          [a]
bs' -> [a] -> [a]
forall a. HasCallStack => [a] -> [a]
init [a]
bs' [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
bs
        go [a]
bs Term
_ = [a]
bs

-- Specializing non Var's is used by nonRepANF
specialize' TransformContext
_ctx Term
_ (Term
appE,[Either Term Type]
args,[TickInfo]
ticks) (Left Term
specArg) = do
  -- Create binders and variable references for free variables in 'specArg'
  let ([Either Id TyVar]
specBndrs,[Either Term Type]
specVars) = Either Term Type -> ([Either Id TyVar], [Either Term Type])
specArgBndrsAndVars (Term -> Either Term Type
forall a b. a -> Either a b
Left Term
specArg)
  -- Create specialized function
      newBody :: Term
newBody = Term -> [Either Id TyVar] -> Term
mkAbstraction Term
specArg [Either Id TyVar]
specBndrs
  -- See if there's an existing binder that's alpha-equivalent to the
  -- specialized function
  existing <- (Binding Term -> Bool)
-> UniqMap (Binding Term) -> UniqMap (Binding Term)
forall b. (b -> Bool) -> UniqMap b -> UniqMap b
UniqMap.filter ((Term -> Term -> Bool
`aeqTerm` Term
newBody) (Term -> Bool) -> (Binding Term -> Term) -> Binding Term -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Binding Term -> Term
forall a. Binding a -> a
bindingTerm) (UniqMap (Binding Term) -> UniqMap (Binding Term))
-> RewriteMonad NormalizeState (UniqMap (Binding Term))
-> RewriteMonad NormalizeState (UniqMap (Binding Term))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Getting
  (UniqMap (Binding Term))
  (RewriteState NormalizeState)
  (UniqMap (Binding Term))
-> RewriteMonad NormalizeState (UniqMap (Binding Term))
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting
  (UniqMap (Binding Term))
  (RewriteState NormalizeState)
  (UniqMap (Binding Term))
forall extra (f :: Type -> Type).
Functor f =>
(UniqMap (Binding Term) -> f (UniqMap (Binding Term)))
-> RewriteState extra -> f (RewriteState extra)
bindings
  -- Create a new function if an alpha-equivalent binder doesn't exist
  newf <- case UniqMap.elems existing of
    [] -> do (cf,sp) <- Getting (Id, SrcSpan) (RewriteState NormalizeState) (Id, SrcSpan)
-> RewriteMonad NormalizeState (Id, SrcSpan)
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting (Id, SrcSpan) (RewriteState NormalizeState) (Id, SrcSpan)
forall extra (f :: Type -> Type).
Functor f =>
((Id, SrcSpan) -> f (Id, SrcSpan))
-> RewriteState extra -> f (RewriteState extra)
curFun
#if MIN_VERSION_ghc(9,2,0)
             mkFunction (appendToName (varName cf) "_specF") sp NoUserInlinePrag newBody
#else
             mkFunction (appendToName (varName cf) "_specF") sp NoUserInline newBody
#endif
    (Binding Term
b:[Binding Term]
_) -> Id -> RewriteMonad NormalizeState Id
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Binding Term -> Id
forall a. Binding a -> Id
bindingId Binding Term
b)
  -- Create specialized argument
  let newArg  = Term -> Either Term b
forall a b. a -> Either a b
Left (Term -> Either Term b) -> Term -> Either Term b
forall a b. (a -> b) -> a -> b
$ Term -> [Either Term Type] -> Term
mkApps (Id -> Term
Var Id
newf) [Either Term Type]
specVars
  -- Use specialized argument
  let newExpr = Term -> [Either Term Type] -> Term
mkApps (Term -> [TickInfo] -> Term
mkTicks Term
appE [TickInfo]
ticks) ([Either Term Type]
args [Either Term Type] -> [Either Term Type] -> [Either Term Type]
forall a. [a] -> [a] -> [a]
++ [Either Term Type
forall {b}. Either Term b
newArg])
  changed newExpr

specialize' TransformContext
_ Term
e (Term, [Either Term Type], [TickInfo])
_ Either Term Type
_ = Term -> RewriteMonad NormalizeState Term
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e

-- Note [Collect free-variables in an insertion-ordered set]
--
-- In order for the specialization cache to work, 'specArgBndrsAndVars' should
-- yield (alpha equivalent) results for the same specialization. While collecting
-- free variables in a given term or type it should therefore keep a stable
-- ordering based on the order in which it finds free vars. To see why,
-- consider the following two pseudo-code calls to 'specialize':
--
--     specialize {f ('a', x[123], y[456])}
--     specialize {f ('b', x[456], y[123])}
--
-- Collecting the binders in a VarSet would yield the following (unique ordered)
-- sets:
--
--     {x[123], y[456]}
--     {y[123], x[456]}
--
-- ..and therefore breaking specializing caching. We now track them in insert-
-- ordered sets, yielding:
--
--     {x[123], y[456]}
--     {x[456], y[123]}
--

-- | Create binders and variable references for free variables in 'specArg'
specArgBndrsAndVars
  :: Either Term Type
  -> ([Either Id TyVar], [Either Term Type])
specArgBndrsAndVars :: Either Term Type -> ([Either Id TyVar], [Either Term Type])
specArgBndrsAndVars Either Term Type
specArg =
  -- See Note [Collect free-variables in an insertion-ordered set]
  let unitFV :: Var a -> Const (OSet.OLSet TyVar, OSet.OLSet Id) (Var a)
      unitFV :: forall a. Var a -> Const (OLSet TyVar, OLSet Id) (Var a)
unitFV v :: Var a
v@(Id {}) = (OLSet TyVar, OLSet Id) -> Const (OLSet TyVar, OLSet Id) (Var a)
forall {k} a (b :: k). a -> Const a b
Const (OLSet TyVar
forall a. Monoid a => a
mempty, OSet Id -> OLSet Id
forall a b. Coercible a b => a -> b
coerce (Id -> OSet Id
forall a. a -> OSet a
OSet.singleton (Var a -> Id
forall a b. Coercible a b => a -> b
coerce Var a
v)))
      unitFV v :: Var a
v@(TyVar {}) = (OLSet TyVar, OLSet Id) -> Const (OLSet TyVar, OLSet Id) (Var a)
forall {k} a (b :: k). a -> Const a b
Const (OSet TyVar -> OLSet TyVar
forall a b. Coercible a b => a -> b
coerce (TyVar -> OSet TyVar
forall a. a -> OSet a
OSet.singleton (Var a -> TyVar
forall a b. Coercible a b => a -> b
coerce Var a
v)), OLSet Id
forall a. Monoid a => a
mempty)

      ([TyVar]
specFTVs,[Id]
specFVs) = case Either Term Type
specArg of
        Left Term
tm  -> (OLSet TyVar -> [TyVar]
forall a. OLSet a -> [a]
OSet.toListL (OLSet TyVar -> [TyVar])
-> (OLSet Id -> [Id]) -> (OLSet TyVar, OLSet Id) -> ([TyVar], [Id])
forall b c b' c'. (b -> c) -> (b' -> c') -> (b, b') -> (c, c')
forall (a :: Type -> Type -> Type) b c b' c'.
Arrow a =>
a b c -> a b' c' -> a (b, b') (c, c')
*** OLSet Id -> [Id]
forall a. OLSet a -> [a]
OSet.toListL) ((OLSet TyVar, OLSet Id) -> ([TyVar], [Id]))
-> (Const (OLSet TyVar, OLSet Id) (Var (ZonkAny 0))
    -> (OLSet TyVar, OLSet Id))
-> Const (OLSet TyVar, OLSet Id) (Var (ZonkAny 0))
-> ([TyVar], [Id])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Const (OLSet TyVar, OLSet Id) (Var (ZonkAny 0))
-> (OLSet TyVar, OLSet Id)
forall {k} a (b :: k). Const a b -> a
getConst (Const (OLSet TyVar, OLSet Id) (Var (ZonkAny 0))
 -> ([TyVar], [Id]))
-> Const (OLSet TyVar, OLSet Id) (Var (ZonkAny 0))
-> ([TyVar], [Id])
forall a b. (a -> b) -> a -> b
$
                    Getting
  (Const (OLSet TyVar, OLSet Id) (Var (ZonkAny 0)))
  Term
  (Var (ZonkAny 0))
-> (Var (ZonkAny 0)
    -> Const (OLSet TyVar, OLSet Id) (Var (ZonkAny 0)))
-> Term
-> Const (OLSet TyVar, OLSet Id) (Var (ZonkAny 0))
forall r s a. Getting r s a -> (a -> r) -> s -> r
Lens.foldMapOf Getting
  (Const (OLSet TyVar, OLSet Id) (Var (ZonkAny 0)))
  Term
  (Var (ZonkAny 0))
forall a (f :: Type -> Type).
(Contravariant f, Applicative f) =>
(Var a -> f (Var a)) -> Term -> f Term
freeLocalVars Var (ZonkAny 0) -> Const (OLSet TyVar, OLSet Id) (Var (ZonkAny 0))
forall a. Var a -> Const (OLSet TyVar, OLSet Id) (Var a)
unitFV Term
tm
        Right Type
ty -> ( UniqMap TyVar -> [TyVar]
forall b. UniqMap b -> [b]
UniqMap.elems
                        (Getting (UniqMap TyVar) Type TyVar
-> (TyVar -> UniqMap TyVar) -> Type -> UniqMap TyVar
forall r s a. Getting r s a -> (a -> r) -> s -> r
Lens.foldMapOf Getting (UniqMap TyVar) Type TyVar
Fold Type TyVar
typeFreeVars (\TyVar
x -> TyVar -> UniqMap TyVar
forall a. Uniquable a => a -> UniqMap a
UniqMap.singletonUnique (TyVar -> TyVar
forall a b. Coercible a b => a -> b
coerce TyVar
x)) Type
ty)
                    , [] :: [Id])

      specTyBndrs :: [Either a TyVar]
specTyBndrs = (TyVar -> Either a TyVar) -> [TyVar] -> [Either a TyVar]
forall a b. (a -> b) -> [a] -> [b]
map TyVar -> Either a TyVar
forall a b. b -> Either a b
Right [TyVar]
specFTVs
      specTmBndrs :: [Either Id b]
specTmBndrs = (Id -> Either Id b) -> [Id] -> [Either Id b]
forall a b. (a -> b) -> [a] -> [b]
map Id -> Either Id b
forall a b. a -> Either a b
Left  [Id]
specFVs

      specTyVars :: [Either a Type]
specTyVars  = (TyVar -> Either a Type) -> [TyVar] -> [Either a Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Either a Type
forall a b. b -> Either a b
Right (Type -> Either a Type)
-> (TyVar -> Type) -> TyVar -> Either a Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyVar -> Type
VarTy) [TyVar]
specFTVs
      specTmVars :: [Either Term b]
specTmVars  = (Id -> Either Term b) -> [Id] -> [Either Term b]
forall a b. (a -> b) -> [a] -> [b]
map (Term -> Either Term b
forall a b. a -> Either a b
Left (Term -> Either Term b) -> (Id -> Term) -> Id -> Either Term b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Id -> Term
Var) [Id]
specFVs

  in  ([Either Id TyVar]
forall {a}. [Either a TyVar]
specTyBndrs [Either Id TyVar] -> [Either Id TyVar] -> [Either Id TyVar]
forall a. [a] -> [a] -> [a]
++ [Either Id TyVar]
forall {b}. [Either Id b]
specTmBndrs,[Either Term Type]
forall {a}. [Either a Type]
specTyVars [Either Term Type] -> [Either Term Type] -> [Either Term Type]
forall a. [a] -> [a] -> [a]
++ [Either Term Type]
forall {b}. [Either Term b]
specTmVars)

-- | Specialize functions on their non-representable argument
nonRepSpec :: HasCallStack => NormRewrite
nonRepSpec :: HasCallStack => NormRewrite
nonRepSpec TransformContext
ctx e :: Term
e@(App Term
e1 Term
e2)
  | (Var {}, [Either Term Type]
args) <- Term -> (Term, [Either Term Type])
collectArgs Term
e1
  , ([Term]
_, [])     <- [Either Term Type] -> ([Term], [Type])
forall a b. [Either a b] -> ([a], [b])
Either.partitionEithers [Either Term Type]
args
  , [TyVar] -> Bool
forall a. [a] -> Bool
forall (t :: Type -> Type) a. Foldable t => t a -> Bool
null ([TyVar] -> Bool) -> [TyVar] -> Bool
forall a b. (a -> b) -> a -> b
$ Getting (Endo [TyVar]) Term TyVar -> Term -> [TyVar]
forall a s. Getting (Endo [a]) s a -> s -> [a]
Lens.toListOf Getting (Endo [TyVar]) Term TyVar
Fold Term TyVar
termFreeTyVars Term
e2
  = do tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
       let e2Ty = TyConMap -> Term -> Type
forall a. InferType a => TyConMap -> a -> Type
inferCoreTypeOf TyConMap
tcm Term
e2
       let localVar = Term -> Bool
isLocalVar Term
e2
       nonRepE2 <- not <$> (representableType <$> Lens.view typeTranslator
                                              <*> Lens.view customReprs
                                              <*> pure False
                                              <*> Lens.view tcCache
                                              <*> pure e2Ty)
       if nonRepE2 && not localVar
         then do
           e2' <- inlineInternalSpecialisationArgument e2
           specialize ctx (App e1 e2')
         else return e
  where
    -- | If the argument on which we're specializing is an internal function,
    -- one created by the compiler, then inline that function before we
    -- specialize.
    --
    -- We need to do this because otherwise the specialization history won't
    -- recognize the new specialization argument as something the function has
    -- already been specialized on
    inlineInternalSpecialisationArgument
      :: Term
      -> NormalizeSession Term
    inlineInternalSpecialisationArgument :: Term -> RewriteMonad NormalizeState Term
inlineInternalSpecialisationArgument Term
app
      | (Var Id
f,[Either Term Type]
fArgs,[TickInfo]
ticks) <- Term -> (Term, [Either Term Type], [TickInfo])
collectArgsTicks Term
app
      = do
        fTmM <- Id -> UniqMap (Binding Term) -> Maybe (Binding Term)
forall b a. Var b -> VarEnv a -> Maybe a
lookupVarEnv Id
f (UniqMap (Binding Term) -> Maybe (Binding Term))
-> RewriteMonad NormalizeState (UniqMap (Binding Term))
-> RewriteMonad NormalizeState (Maybe (Binding Term))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Getting
  (UniqMap (Binding Term))
  (RewriteState NormalizeState)
  (UniqMap (Binding Term))
-> RewriteMonad NormalizeState (UniqMap (Binding Term))
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting
  (UniqMap (Binding Term))
  (RewriteState NormalizeState)
  (UniqMap (Binding Term))
forall extra (f :: Type -> Type).
Functor f =>
(UniqMap (Binding Term) -> f (UniqMap (Binding Term)))
-> RewriteState extra -> f (RewriteState extra)
bindings
        case fTmM of
          Just Binding Term
b
            | Name Term -> NameSort
forall a. Name a -> NameSort
nameSort (Id -> Name Term
forall a. Var a -> Name a
varName (Binding Term -> Id
forall a. Binding a -> Id
bindingId Binding Term
b)) NameSort -> NameSort -> Bool
forall a. Eq a => a -> a -> Bool
== NameSort
Internal
            -> (Any -> Any)
-> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall extra a.
(Any -> Any) -> RewriteMonad extra a -> RewriteMonad extra a
censor (Any -> Any -> Any
forall a b. a -> b -> a
const Any
forall a. Monoid a => a
mempty)
                      (NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m
topdownR HasCallStack => NormRewrite
NormRewrite
appProp TransformContext
ctx
                        (Term -> [Either Term Type] -> Term
mkApps (Term -> [TickInfo] -> Term
mkTicks (Binding Term -> Term
forall a. Binding a -> a
bindingTerm Binding Term
b) [TickInfo]
ticks) [Either Term Type]
fArgs))
          Maybe (Binding Term)
_ -> Term -> RewriteMonad NormalizeState Term
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
app
      | Bool
otherwise = Term -> RewriteMonad NormalizeState Term
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
app

nonRepSpec TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC nonRepSpec #-}

-- | Specialize functions on their type
typeSpec :: HasCallStack => NormRewrite
typeSpec :: HasCallStack => NormRewrite
typeSpec TransformContext
ctx e :: Term
e@(TyApp Term
e1 Type
ty)
  | (Var {},  [Either Term Type]
args) <- Term -> (Term, [Either Term Type])
collectArgs Term
e1
  , [TyVar] -> Bool
forall a. [a] -> Bool
forall (t :: Type -> Type) a. Foldable t => t a -> Bool
null ([TyVar] -> Bool) -> [TyVar] -> Bool
forall a b. (a -> b) -> a -> b
$ Getting (Endo [TyVar]) Type TyVar -> Type -> [TyVar]
forall a s. Getting (Endo [a]) s a -> s -> [a]
Lens.toListOf Getting (Endo [TyVar]) Type TyVar
Fold Type TyVar
typeFreeVars Type
ty
  , ([Term]
_, []) <- [Either Term Type] -> ([Term], [Type])
forall a b. [Either a b] -> ([a], [b])
Either.partitionEithers [Either Term Type]
args
  = NormRewrite
specialize TransformContext
ctx Term
e

typeSpec TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC typeSpec #-}

-- | Specialize functions on arguments which are zero-width. These arguments
-- can have only one possible value, and specializing on this value may create
-- additional opportunities for transformations to fire.
--
-- As we can't remove zero-width arguements (as transformations cannot change
-- the type of a term), we instead substitute all occurances of a lambda-bound
-- variable with a zero-width type with the only value of that type.
--
zeroWidthSpec :: HasCallStack => NormRewrite
zeroWidthSpec :: HasCallStack => NormRewrite
zeroWidthSpec (TransformContext InScopeSet
is Context
_) e :: Term
e@(Lam Id
i Term
x0) = do
  tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
  let bndrTy = TyConMap -> Type -> Type
normalizeType TyConMap
tcm (Id -> Type
forall a. HasType a => a -> Type
coreTypeOf Id
i)

  case zeroWidthTypeElem tcm bndrTy of
    Just Term
tm ->
      let subst :: Subst
subst = Subst -> Id -> Term -> Subst
extendIdSubst (InScopeSet -> Subst
mkSubst InScopeSet
is) Id
i Term
tm
          x1 :: Term
x1 = HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm Doc ()
"zeroWidthSpec" Subst
subst Term
x0
       in Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Id -> Term -> Term
Lam Id
i Term
x1)

    Maybe Term
Nothing ->
      Term -> RewriteMonad NormalizeState Term
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e

zeroWidthSpec TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC zeroWidthSpec #-}

-- Get the only element of a type, if it is zero-width.
--
zeroWidthTypeElem :: TyConMap -> Type -> Maybe Term
zeroWidthTypeElem :: TyConMap -> Type -> Maybe Term
zeroWidthTypeElem TyConMap
tcm Type
ty = do
  (tcNm, args) <- Type -> Maybe (TyConName, [Type])
splitTyConAppM Type
ty

  if | nameOcc tcNm == Text.showt ''BV.BitVector
     , [LitTy (NumTy 0)] <- args
     -> return (bitVectorZW tcNm args)

     | nameOcc tcNm == Text.showt ''I.Index
     , [LitTy (NumTy 1)] <- args
     -> return (indexZW tcNm args)

     | nameOcc tcNm == Text.showt ''S.Signed
     , [LitTy (NumTy 0)] <- args
     -> return (signedZW tcNm args)

     | nameOcc tcNm == Text.showt ''U.Unsigned
     , [LitTy (NumTy 0)] <- args
     -> return (unsignedZW tcNm args)

     -- Any other zero-width type should only have a single data constructor
     -- where all fields are also zero-width.
     | otherwise
     -> do
       tc <- UniqMap.lookup tcNm tcm

       case tyConDataCons tc of
         [DataCon
dc] -> do
           zwArgs <- (Type -> Maybe Term) -> [Type] -> Maybe [Term]
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: Type -> Type) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (TyConMap -> Type -> Maybe Term
zeroWidthTypeElem TyConMap
tcm) (DataCon -> [Type]
dcArgTys DataCon
dc)
           return (mkTmApps (Data dc) zwArgs)

         [DataCon]
_ ->
           Maybe Term
forall a. Maybe a
Nothing
 where
  nNm :: Name a
nNm = OccName -> Unique -> Name a
forall a. OccName -> Unique -> Name a
mkUnsafeSystemName OccName
"n" Unique
0
  nTv :: TyVar
nTv = Type -> Name Type -> TyVar
mkTyVar Type
typeNatKind Name Type
forall {a}. Name a
nNm

  mkBitVector :: TyConName -> PrimInfo
mkBitVector TyConName
tcNm =
    let prTy :: Type
prTy = Type -> [Either TyVar Type] -> Type
mkPolyFunTy (TyConName -> [Type] -> Type
mkTyConApp TyConName
tcNm [TyVar -> Type
VarTy TyVar
nTv])
                 [TyVar -> Either TyVar Type
forall a b. a -> Either a b
Left TyVar
nTv, Type -> Either TyVar Type
forall a b. b -> Either a b
Right Type
naturalPrimTy, Type -> Either TyVar Type
forall a b. b -> Either a b
Right Type
naturalPrimTy, Type -> Either TyVar Type
forall a b. b -> Either a b
Right Type
integerPrimTy]
     in OccName
-> Type -> WorkInfo -> IsMultiPrim -> PrimUnfolding -> PrimInfo
PrimInfo (Name -> OccName
forall a. Show a => a -> OccName
Text.showt 'BV.fromInteger#) Type
prTy WorkInfo
WorkNever IsMultiPrim
SingleResult PrimUnfolding
NoUnfolding

  bitVectorZW :: TyConName -> [Type] -> Term
bitVectorZW TyConName
tcNm [Type]
tyArgs =
    let pr :: PrimInfo
pr = TyConName -> PrimInfo
mkBitVector TyConName
tcNm
     in Term -> [Either Term Type] -> Term
mkApps (PrimInfo -> Term
Prim PrimInfo
pr) ([Either Term Type] -> Term) -> [Either Term Type] -> Term
forall a b. (a -> b) -> a -> b
$ (Type -> Either Term Type) -> [Type] -> [Either Term Type]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap Type -> Either Term Type
forall a b. b -> Either a b
Right [Type]
tyArgs [Either Term Type] -> [Either Term Type] -> [Either Term Type]
forall a. Semigroup a => a -> a -> a
<>
          [ Term -> Either Term Type
forall a b. a -> Either a b
Left (Literal -> Term
Literal (Integer -> Literal
NaturalLiteral Integer
0))
          , Term -> Either Term Type
forall a b. a -> Either a b
Left (Literal -> Term
Literal (Integer -> Literal
NaturalLiteral Integer
0))
          , Term -> Either Term Type
forall a b. a -> Either a b
Left (Literal -> Term
Literal (Integer -> Literal
IntegerLiteral Integer
0))
          ]

  mkSizedNum :: TyConName -> OccName -> PrimInfo
mkSizedNum TyConName
tcNm OccName
n =
    let prTy :: Type
prTy = Type -> [Either TyVar Type] -> Type
mkPolyFunTy (TyConName -> [Type] -> Type
mkTyConApp TyConName
tcNm [TyVar -> Type
VarTy TyVar
nTv])
                 [TyVar -> Either TyVar Type
forall a b. a -> Either a b
Left TyVar
nTv, Type -> Either TyVar Type
forall a b. b -> Either a b
Right Type
naturalPrimTy, Type -> Either TyVar Type
forall a b. b -> Either a b
Right Type
integerPrimTy]
     in OccName
-> Type -> WorkInfo -> IsMultiPrim -> PrimUnfolding -> PrimInfo
PrimInfo OccName
n Type
prTy WorkInfo
WorkNever IsMultiPrim
SingleResult PrimUnfolding
NoUnfolding

  indexZW :: TyConName -> [Type] -> Term
indexZW TyConName
tcNm [Type]
tyArgs =
    let pr :: PrimInfo
pr = TyConName -> OccName -> PrimInfo
mkSizedNum TyConName
tcNm (Name -> OccName
forall a. Show a => a -> OccName
Text.showt 'I.fromInteger#)
     in Term -> [Either Term Type] -> Term
mkApps (PrimInfo -> Term
Prim PrimInfo
pr) ([Either Term Type] -> Term) -> [Either Term Type] -> Term
forall a b. (a -> b) -> a -> b
$ (Type -> Either Term Type) -> [Type] -> [Either Term Type]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap Type -> Either Term Type
forall a b. b -> Either a b
Right [Type]
tyArgs [Either Term Type] -> [Either Term Type] -> [Either Term Type]
forall a. Semigroup a => a -> a -> a
<>
          [ Term -> Either Term Type
forall a b. a -> Either a b
Left (Literal -> Term
Literal (Integer -> Literal
NaturalLiteral Integer
1))
          , Term -> Either Term Type
forall a b. a -> Either a b
Left (Literal -> Term
Literal (Integer -> Literal
IntegerLiteral Integer
0))
          ]

  signedZW :: TyConName -> [Type] -> Term
signedZW TyConName
tcNm [Type]
tyArgs =
    let pr :: PrimInfo
pr = TyConName -> OccName -> PrimInfo
mkSizedNum TyConName
tcNm (Name -> OccName
forall a. Show a => a -> OccName
Text.showt 'S.fromInteger#)
     in Term -> [Either Term Type] -> Term
mkApps (PrimInfo -> Term
Prim PrimInfo
pr) ([Either Term Type] -> Term) -> [Either Term Type] -> Term
forall a b. (a -> b) -> a -> b
$ (Type -> Either Term Type) -> [Type] -> [Either Term Type]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap Type -> Either Term Type
forall a b. b -> Either a b
Right [Type]
tyArgs [Either Term Type] -> [Either Term Type] -> [Either Term Type]
forall a. Semigroup a => a -> a -> a
<>
          [ Term -> Either Term Type
forall a b. a -> Either a b
Left (Literal -> Term
Literal (Integer -> Literal
NaturalLiteral Integer
0))
          , Term -> Either Term Type
forall a b. a -> Either a b
Left (Literal -> Term
Literal (Integer -> Literal
IntegerLiteral Integer
0))
          ]

  unsignedZW :: TyConName -> [Type] -> Term
unsignedZW TyConName
tcNm [Type]
tyArgs =
    let pr :: PrimInfo
pr = TyConName -> OccName -> PrimInfo
mkSizedNum TyConName
tcNm (Name -> OccName
forall a. Show a => a -> OccName
Text.showt 'U.fromInteger#)
     in Term -> [Either Term Type] -> Term
mkApps (PrimInfo -> Term
Prim PrimInfo
pr) ([Either Term Type] -> Term) -> [Either Term Type] -> Term
forall a b. (a -> b) -> a -> b
$ (Type -> Either Term Type) -> [Type] -> [Either Term Type]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap Type -> Either Term Type
forall a b. b -> Either a b
Right [Type]
tyArgs [Either Term Type] -> [Either Term Type] -> [Either Term Type]
forall a. Semigroup a => a -> a -> a
<>
          [ Term -> Either Term Type
forall a b. a -> Either a b
Left (Literal -> Term
Literal (Integer -> Literal
NaturalLiteral Integer
0))
          , Term -> Either Term Type
forall a b. a -> Either a b
Left (Literal -> Term
Literal (Integer -> Literal
IntegerLiteral Integer
0))
          ]