Skip to content

Commit

Permalink
Remove rsharePrimal & Co
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Mar 30, 2024
1 parent 3bd283a commit b3da1c6
Show file tree
Hide file tree
Showing 5 changed files with 0 additions and 61 deletions.
24 changes: 0 additions & 24 deletions src/HordeAd/Core/AstFreshId.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
-- with @unsafePerformIO@ outside, so some of it escapes.
module HordeAd.Core.AstFreshId
( unRawHVector, rawHVector
, astRegisterADShare, astRegisterADShareS
, funToAstIOR, funToAstR, funToAstIOS, funToAstS
, fun1RToAst, fun1SToAst, fun1XToAst
, fun1DToAst, fun1HToAst, fun1LToAst
Expand All @@ -29,7 +28,6 @@ import GHC.TypeLits (KnownNat, Nat)
import System.IO.Unsafe (unsafePerformIO)

import HordeAd.Core.Ast
import HordeAd.Core.AstTools
import HordeAd.Core.HVector
import HordeAd.Core.Types
import qualified HordeAd.Util.ShapedList as ShapedList
Expand Down Expand Up @@ -71,28 +69,6 @@ unsafeGetFreshAstVarName :: IO (AstVarName f r y)
unsafeGetFreshAstVarName =
AstVarName . intToAstVarId <$> atomicAddCounter_ unsafeAstVarCounter 1

astRegisterADShare :: (GoodScalar r, KnownNat n)
=> AstRaw PrimalSpan r n -> ADShare
-> (ADShare, AstRaw PrimalSpan r n)
{-# NOINLINE astRegisterADShare #-}
astRegisterADShare !r !l | astIsSmall True (unAstRaw r) = (l, r)
astRegisterADShare (AstRaw r) l = unsafePerformIO $ do
freshId <- unsafeGetFreshAstVarId
let !l2 = insertADShare freshId (AstBindingsSimple $ DynamicRanked r) l
!r2 = AstVar (shapeAst r) $ AstVarName freshId
return (l2, AstRaw r2)

astRegisterADShareS :: (GoodScalar r, Sh.Shape sh)
=> AstRawS PrimalSpan r sh -> ADShare
-> (ADShare, AstRawS PrimalSpan r sh)
{-# NOINLINE astRegisterADShareS #-}
astRegisterADShareS !r !l | astIsSmallS True (unAstRawS r) = (l, r)
astRegisterADShareS (AstRawS r) l = unsafePerformIO $ do
freshId <- unsafeGetFreshAstVarId
let !l2 = insertADShare freshId (AstBindingsSimple $ DynamicShaped r) l
!r2 = AstVarS $ AstVarName freshId
return (l2, AstRawS r2)

funToAstIOR :: forall n m s r r2. GoodScalar r
=> ShapeInt n -> (AstRanked s r n -> AstRanked s r2 m)
-> IO ( AstVarName (AstRanked s) r n
Expand Down
1 change: 0 additions & 1 deletion src/HordeAd/Core/TensorADVal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,6 @@ instance ADReadyBoth ranked shaped
sletInHVector (D l u u') f =
let !var2 = sshare u
in f (dDnotShared l var2 u')
dsharePrimal d l = (l, d)
dbuild1 k f =
ravelHVector $ map (f . fromIntegral) [0 .. (sNatValue k :: Int) - 1]
rrev :: (GoodScalar r, KnownNat n)
Expand Down
28 changes: 0 additions & 28 deletions src/HordeAd/Core/TensorAst.hs
Original file line number Diff line number Diff line change
Expand Up @@ -541,15 +541,6 @@ instance forall s. AstSpan s => HVectorTensor (AstRanked s) (AstShaped s) where
-- These and many similar bangs are necessary to ensure variable IDs
-- are generated in the expected order, resulting in nesting of lets
-- occuring in the correct order and so no scoping errors.
dsharePrimal !r !l | Just Refl <- sameAstSpan @s @PrimalSpan =
fun1DToAst (shapeAstHVector r) $ \ !vars !asts -> case vars of
[] -> (l, V.empty)
!var : _ -> -- vars are fresh, so var uniquely represent vars
( insertADShare (dynamicVarNameToAstVarId var)
(AstBindingsHVector vars r)
l
, asts )
dsharePrimal _ _ = error "dsharePrimal: wrong span"
dshare a@(AstShareHVector{}) = a
dshare a =
let shs = shapeAstHVector a
Expand Down Expand Up @@ -871,10 +862,6 @@ instance AstSpan s => RankedTensor (AstRaw s) where
_ -> (emptyADShare, u)
-- For convenience and simplicity we define this for all spans,
-- but it can only ever be used for PrimalSpan.
rsharePrimal =
case sameAstSpan @s @PrimalSpan of
Just Refl -> astRegisterADShare
_ -> error "rsharePrimal: used not at PrimalSpan"
rshare a@(AstRaw (AstShare{})) = a
rshare a | astIsSmall True (unAstRaw a) = a
rshare a = AstRaw $ fun1RToAst $ \ !var -> AstShare var (unAstRaw a)
Expand Down Expand Up @@ -1015,10 +1002,6 @@ instance AstSpan s => ShapedTensor (AstRawS s) where
AstLetADShareS l t -> (l, AstRawS t)
AstConstantS (AstLetADShareS l t) -> (l, AstRawS $ AstConstantS t)
_ -> (emptyADShare, u)
ssharePrimal =
case sameAstSpan @s @PrimalSpan of
Just Refl -> astRegisterADShareS
_ -> error "ssharePrimal: used not at PrimalSpan"
sshare a@(AstRawS (AstShareS{})) = a
sshare a | astIsSmallS True (unAstRawS a) = a
sshare a = AstRawS $ fun1SToAst $ \ !var -> AstShareS var (unAstRawS a)
Expand Down Expand Up @@ -1064,15 +1047,6 @@ instance AstSpan s => HVectorTensor (AstRaw s) (AstRawS s) where
Just Refl -> \l t ->
AstRawWrap $ unletAstHVector6 l $ unAstRawWrap t
_ -> error "dunlet: used not at PrimalSpan"
dsharePrimal !(AstRawWrap r) !l | Just Refl <- sameAstSpan @s @PrimalSpan =
fun1DToAst (shapeAstHVector r) $ \ !vars !asts -> case vars of
[] -> (l, V.empty)
!var : _ -> -- vars are fresh, so var uniquely represent vars
( insertADShare (dynamicVarNameToAstVarId var)
(AstBindingsHVector vars r)
l
, rawHVector asts )
dsharePrimal _ _ = error "dsharePrimal: wrong span"
dshare a@(AstRawWrap (AstShareHVector{})) = a
dshare (AstRawWrap a) =
let shs = shapeAstHVector a
Expand Down Expand Up @@ -1214,7 +1188,6 @@ instance AstSpan s => HVectorTensor (AstNoVectorize s) (AstNoVectorizeS s) where
AstNoVectorizeWrap
$ sletInHVector (unAstNoVectorizeS u)
(unAstNoVectorizeWrap . f . AstNoVectorizeS)
dsharePrimal = error "dsharePrimal for AstNoVectorize"
dbuild1 k f =
AstNoVectorizeWrap
$ AstBuildHVector1 k $ funToAstI (unAstNoVectorizeWrap . f . AstNoVectorize)
Expand Down Expand Up @@ -1391,7 +1364,6 @@ instance AstSpan s => HVectorTensor (AstNoSimplify s) (AstNoSimplifyS s) where
AstNoSimplifyWrap
$ astLetInHVectorFunRawS (unAstNoSimplifyS u)
(unAstNoSimplifyWrap . f . AstNoSimplifyS)
dsharePrimal = error "dsharePrimal for AstNoSimplify"
dbuild1 k f = AstNoSimplifyWrap
$ astBuildHVector1Vectorize
k (unAstNoSimplifyWrap . f . AstNoSimplify)
Expand Down
7 changes: 0 additions & 7 deletions src/HordeAd/Core/TensorClass.hs
Original file line number Diff line number Diff line change
Expand Up @@ -315,9 +315,6 @@ class ( Integral (IntOf ranked), CRanked ranked Num
rletWrap _l u = u
rletUnwrap :: ranked r n -> (ADShare, ranked r n)
rletUnwrap u = (emptyADShare, u)
rsharePrimal :: (GoodScalar r, KnownNat n)
=> ranked r n -> ADShare -> (ADShare, ranked r n)
rsharePrimal r l = (l, r)
rshare :: KnownNat n => ranked r n -> ranked r n
rshare = id

Expand Down Expand Up @@ -695,9 +692,6 @@ class ( Integral (IntOf shaped), CShaped shaped Num
sletWrap _l u = u
sletUnwrap :: shaped r sh -> (ADShare, shaped r sh)
sletUnwrap u = (emptyADShare, u)
ssharePrimal :: (GoodScalar r, Sh.Shape sh)
=> shaped r sh -> ADShare -> (ADShare, shaped r sh)
ssharePrimal r l = (l, r)
sshare :: Sh.Shape sh => shaped r sh -> shaped r sh
sshare = id

Expand Down Expand Up @@ -756,7 +750,6 @@ class HVectorTensor (ranked :: RankedTensorType)
-> HVectorOf ranked
dunlet :: ADShare -> HVectorOf ranked -> HVectorOf ranked
dunlet l = assert (nullADShare l)
dsharePrimal :: HVectorOf ranked -> ADShare -> (ADShare, HVector ranked)
dshare :: HVectorOf ranked -> HVectorOf ranked
dshare = id
dbuild1 :: SNat k
Expand Down
1 change: 0 additions & 1 deletion src/HordeAd/Core/TensorConcrete.hs
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,6 @@ instance HVectorTensor (Flip OR.Array) (Flip OS.Array) where
dletHFunInHVector = (&)
rletInHVector = (&)
sletInHVector = (&)
dsharePrimal d l = (l, d)
dbuild1 k f =
ravelHVector $ map (f . fromIntegral) [0 .. (sNatValue k :: Int) - 1]
rrev :: (GoodScalar r, KnownNat n)
Expand Down

0 comments on commit b3da1c6

Please sign in to comment.