Skip to content

Commit

Permalink
Rename rfromS to match the Ranked naming prefix
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Dec 9, 2023
1 parent d534fa6 commit a2645f9
Show file tree
Hide file tree
Showing 9 changed files with 27 additions and 27 deletions.
4 changes: 2 additions & 2 deletions bench/common/BenchProdTools.hs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ benchProd ~(_l, list, vec) =
let f :: DynamicExists OD.Array -> Flip OR.Array Double 0
f = (\(DynamicExists @r2 d) ->
gcastWith (unsafeCoerce Refl :: r2 :~: Double) $
tfromD d)
rfromD d)
in nf (V.map f . fst
. crevOnDomains @Double Nothing rankedVecDProd)
(V.map (DynamicExists . dfromR) vec)
Expand Down Expand Up @@ -133,7 +133,7 @@ rankedVecDProd :: forall r ranked.
=> Domains (DynamicOf ranked) -> ranked r 0
rankedVecDProd = V.foldl' (\acc (DynamicExists @r2 d) ->
gcastWith (unsafeCoerce Refl :: r2 :~: r) $
tfromD d * acc) 0
rfromD d * acc) 0

rankedNoShareListProd :: GoodScalar r
=> [ADVal (Flip OR.Array) r 0]
Expand Down
4 changes: 2 additions & 2 deletions src/HordeAd/Core/AstEnv.hs
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,13 @@ extendEnvDR :: forall ranked shaped s. ConvertTensor ranked shaped
-> AstEnv ranked shaped
extendEnvDR (AstDynamicVarName @_ @sh @r @y var, DynamicExists @r2 d) !env =
-- We don't need to manually pick a specialization for the existential
-- variable r2, because tfromD does not depend on r2.
-- variable r2, because rfromD does not depend on r2.
case testEquality (typeRep @r) (typeRep @r2) of
Just Refl ->
let n = length $ Sh.shapeT @sh
in case someNatVal $ toInteger n of
Just (SomeNat @n _) -> gcastWith (unsafeCoerce Refl :: n :~: y) $
extendEnvR var (tfromD d) env
extendEnvR var (rfromD d) env
Nothing -> error "extendEnvDR: impossible someNatVal error"
_ -> error "extendEnvDR: type mismatch"

Expand Down
2 changes: 1 addition & 1 deletion src/HordeAd/Core/AstInterpret.hs
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ interpretAst !env = \case
AstFromIntegral v ->
rfromIntegral $ rconstant $ interpretAstPrimalRuntimeSpecialized env v
AstConst a -> rconst a
AstSToR v -> tfromS $ interpretAstS env v
AstSToR v -> rfromS $ interpretAstS env v
AstConstant a -> rconstant $ interpretAstPrimal env a
AstPrimalPart a -> interpretAst env a
-- This is correct, because @s@ must be @PrimalSpan@ and so @ranked@ must
Expand Down
12 changes: 6 additions & 6 deletions src/HordeAd/Core/Delta.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1009,7 +1009,7 @@ buildFinMaps s0 deltaDt =
case sameShape @sh @sh2 of
Just Refl -> evalS s c d
_ -> error "buildFinMaps: different shapes in RToS(SToR)"
RToS d -> evalR s (tfromS c) d
RToS d -> evalR s (rfromS c) d

{-
-- The general case is given as the last one below,
Expand Down Expand Up @@ -1048,7 +1048,7 @@ buildFinMaps s0 deltaDt =
-> DeltaD (Clown (DynamicOf ranked)) ranked shaped r y
-> EvalState ranked shaped
evalD s !c = \case
RToD d -> evalR s (tfromD c) d
RToD d -> evalR s (rfromD c) d
SToD d -> evalS s (sfromD c) d

evalFromnMap :: EvalState ranked shaped -> EvalState ranked shaped
Expand All @@ -1060,7 +1060,7 @@ buildFinMaps s0 deltaDt =
DeltaBindingR @_ @r1 d -> case dMap EM.! n of
DynamicExists @r2 e ->
case testEquality (typeRep @r1) (typeRep @r2) of
Just Refl -> let c = tfromD e
Just Refl -> let c = rfromD e
in evalRRuntimeSpecialized s2 c d
_ -> error "buildFinMaps: type mismatch"
DeltaBindingS @_ @r1 d -> case dMap EM.! n of
Expand Down Expand Up @@ -1125,7 +1125,7 @@ buildDerivative dimR deltaDt params = do
then case params V.! i of
DynamicExists @r2 e ->
case testEquality (typeRep @r) (typeRep @r2) of
Just Refl -> return $! tfromD @ranked @shaped @r e
Just Refl -> return $! rfromD @ranked @shaped @r e
_ -> error "buildDerivative: type mismatch"
else error "buildDerivative': wrong index for an input"
ScaleR k d -> (* k) <$> evalR d
Expand All @@ -1138,7 +1138,7 @@ buildDerivative dimR deltaDt params = do
case dm EM.! n of
DynamicExists @r2 t ->
case testEquality (typeRep @r) (typeRep @r2) of
Just Refl -> return $! tfromD @ranked @shaped @r t
Just Refl -> return $! rfromD @ranked @shaped @r t
_ -> error "buildDerivative: type mismatch"
Nothing -> do
cRaw <- evalR d
Expand Down Expand Up @@ -1192,7 +1192,7 @@ buildDerivative dimR deltaDt params = do
Just Refl -> evalR (SToR d)
_ -> error "buildDerivative: different ranks in DToR(SToD)"
SToR (RToS d) -> evalR d -- no information lost, so no checks
SToR d -> tfromS <$> evalS d
SToR d -> rfromS <$> evalS d

evalS
:: forall sh r. (Sh.Shape sh, GoodScalar r)
Expand Down
8 changes: 4 additions & 4 deletions src/HordeAd/Core/TensorADVal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ dToR :: forall r ranked shaped n.
~ DeltaD (Clown (DynamicOf ranked)) ranked shaped
, KnownNat n, GoodScalar r )
=> ADVal (Clown (DynamicOf ranked)) r '() -> ADVal ranked r n
dToR (D l u u') = dDnotShared l (tfromD $ runClown u) (dDToR u')
dToR (D l u u') = dDnotShared l (rfromD $ runClown u) (dDToR u')
where
dDToR (RToD @n1 d) =
case sameNat (Proxy @n1) (Proxy @n) of
Expand Down Expand Up @@ -378,12 +378,12 @@ instance ( Dual ranked ~ DeltaR ranked shaped
~ DeltaD (Clown (DynamicOf ranked)) ranked shaped
, ConvertTensor ranked shaped )
=> ConvertTensor (ADVal ranked) (ADVal shaped) where
tfromD = dToR . runFlip
tfromS = sToR
rfromD = dToR . runFlip
rfromS = sToR
where
sToR :: forall r sh. (GoodScalar r, Sh.Shape sh)
=> ADVal shaped r sh -> ADVal ranked r (Sh.Rank sh)
sToR (D l u u') = dDnotShared l (tfromS u) (dSToR u')
sToR (D l u u') = dDnotShared l (rfromS u) (dSToR u')
where
dSToR (RToS d) = d -- no information lost, so no checks
dSToR d = SToR d
Expand Down
6 changes: 3 additions & 3 deletions src/HordeAd/Core/TensorAst.hs
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ instance ( GoodScalar r, KnownNat n
Just (DynamicExists @r2 a, rest) ->
if isTensorDummyAst a then Just (rzero (rshape aInit), rest) else
case testEquality (typeRep @r) (typeRep @r2) of
Just Refl -> let !t = tfromD @(AstRanked s) @(AstShaped s) @r a
Just Refl -> let !t = rfromD @(AstRanked s) @(AstShaped s) @r a
in Just (t, rest)
_ -> error $ "fromDomains: type mismatch: "
++ show (typeRep @r) ++ " " ++ show (typeRep @r2)
Expand Down Expand Up @@ -641,8 +641,8 @@ astBuild1VectorizeS f =
-- * ConvertTensor and DomainsTensor instances

instance AstSpan s => ConvertTensor (AstRanked s) (AstShaped s) where
tfromD = astFromDynamic
tfromS = astSToR
rfromD = astFromDynamic
rfromS = astSToR
dfromR = AstRToD
dfromS = AstSToD
sfromR = astRToS
Expand Down
10 changes: 5 additions & 5 deletions src/HordeAd/Core/TensorClass.hs
Original file line number Diff line number Diff line change
Expand Up @@ -575,9 +575,9 @@ class ( DynamicOf ranked ~ DynamicOf shaped
=> ConvertTensor (ranked :: RankedTensorKind)
(shaped :: ShapedTensorKind)
| ranked -> shaped, shaped -> ranked where
tfromD :: (GoodScalar r, KnownNat n)
rfromD :: (GoodScalar r, KnownNat n)
=> DynamicOf ranked r -> ranked r n
tfromS :: (GoodScalar r, Sh.Shape sh)
rfromS :: (GoodScalar r, Sh.Shape sh)
=> shaped r sh -> ranked r (Sh.Rank sh)
dfromR :: (GoodScalar r, KnownNat n)
=> ranked r n -> DynamicOf ranked r
Expand Down Expand Up @@ -767,7 +767,7 @@ instance (GoodScalar r, KnownNat n)
Just (DynamicExists @r2 a, rest) ->
if isTensorDummyD a then Just (rzero (rshape aInit), rest) else
case testEquality (typeRep @r) (typeRep @r2) of
Just Refl -> let !aR = tfromD @(Flip OR.Array) @(Flip OS.Array) @r a
Just Refl -> let !aR = rfromD @(Flip OR.Array) @(Flip OS.Array) @r a
in Just (aR, rest)
_ -> error $ "fromDomains: type mismatch: "
++ show (typeRep @r) ++ " " ++ show (typeRep @r2)
Expand Down Expand Up @@ -925,8 +925,8 @@ instance {-# OVERLAPS #-} {-# OVERLAPPING #-}
-- The DomainsTensor instance requires ADVal instance, so it's given elsewhere.

instance ConvertTensor (Flip OR.Array) (Flip OS.Array) where
tfromD = Flip . Data.Array.Convert.convert
tfromS = Flip . Data.Array.Convert.convert . runFlip
rfromD = Flip . Data.Array.Convert.convert
rfromS = Flip . Data.Array.Convert.convert . runFlip
dfromR = Data.Array.Convert.convert . runFlip
dfromS = Data.Array.Convert.convert . runFlip
sfromR = Flip . Data.Array.Convert.convert . runFlip
Expand Down
6 changes: 3 additions & 3 deletions src/HordeAd/External/CommonShapedOps.hs
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ sminIndexN :: ( ADReadyS shaped, GoodScalar r
sminIndexN t =
ShapedList.fromLinearIdx
(sshape t)
(ShapedList.shapedNat $ tfromS $ sprimalPart $ sminIndex (sflatten t))
(ShapedList.shapedNat $ rfromS $ sprimalPart $ sminIndex (sflatten t))

smaxIndexN :: ( ADReadyS shaped, GoodScalar r
, Sh.Shape sh, KnownNat (Sh.Size sh) )
=> shaped r sh -> IndexSh shaped sh
smaxIndexN t =
ShapedList.fromLinearIdx
(sshape t)
(ShapedList.shapedNat $ tfromS $ sprimalPart $ smaxIndex (sflatten t))
(ShapedList.shapedNat $ rfromS $ sprimalPart $ smaxIndex (sflatten t))

sminimum :: forall r sh shaped.
(ADReadyS shaped, GoodScalar r, Sh.Shape sh, KnownNat (Sh.Size sh))
Expand All @@ -69,7 +69,7 @@ sletIx :: forall r sh n shaped.
=> IndexOf shaped n -> (IndexOf shaped n -> shaped r sh) -> shaped r sh
sletIx ix0 f = slet (sfromR @(RankedOf shaped) @shaped @Int64 @'[n]
$ rint64FromIndex1 ix0) $ \ixT ->
f $ rint64ToIndex1 $ tfromS ixT
f $ rint64ToIndex1 $ rfromS ixT

scaleS :: forall shaped r sh.
(Sh.Shape sh, ADReadyS shaped, GoodScalar r)
Expand Down
2 changes: 1 addition & 1 deletion test/simplified/TestAdaptorSimplified.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1584,7 +1584,7 @@ testBarReluMax3FwdFrom =
assertEqualUpToEpsilon 1e-10
(Flip $ OS.fromList @'[2, 1, 2] [0.45309153191767404,0.9060427799711201,-2.8186426018387007,40.02498898648793])
(fwd @Double @'[2, 1, 2]
(sfromR . barReluMax . tfromS)
(sfromR . barReluMax . rfromS)
(Flip $ OS.fromList @'[2, 1, 2] [1.1, 2, 3, 4.2])
(Flip $ OS.fromList @'[2, 1, 2] [0.1, 0.2, 0.3, 0.42]))

Expand Down

0 comments on commit a2645f9

Please sign in to comment.