From a2645f92a059ba912dbae4176bb57c3767e65ab9 Mon Sep 17 00:00:00 2001 From: Mikolaj Konarski Date: Sat, 9 Dec 2023 23:30:28 +0100 Subject: [PATCH] Rename rfromS to match the Ranked naming prefix --- bench/common/BenchProdTools.hs | 4 ++-- src/HordeAd/Core/AstEnv.hs | 4 ++-- src/HordeAd/Core/AstInterpret.hs | 2 +- src/HordeAd/Core/Delta.hs | 12 ++++++------ src/HordeAd/Core/TensorADVal.hs | 8 ++++---- src/HordeAd/Core/TensorAst.hs | 6 +++--- src/HordeAd/Core/TensorClass.hs | 10 +++++----- src/HordeAd/External/CommonShapedOps.hs | 6 +++--- test/simplified/TestAdaptorSimplified.hs | 2 +- 9 files changed, 27 insertions(+), 27 deletions(-) diff --git a/bench/common/BenchProdTools.hs b/bench/common/BenchProdTools.hs index 2607d366e..91621ab8f 100644 --- a/bench/common/BenchProdTools.hs +++ b/bench/common/BenchProdTools.hs @@ -85,7 +85,7 @@ benchProd ~(_l, list, vec) = let f :: DynamicExists OD.Array -> Flip OR.Array Double 0 f = (\(DynamicExists @r2 d) -> gcastWith (unsafeCoerce Refl :: r2 :~: Double) $ - tfromD d) + rfromD d) in nf (V.map f . fst . crevOnDomains @Double Nothing rankedVecDProd) (V.map (DynamicExists . dfromR) vec) @@ -133,7 +133,7 @@ rankedVecDProd :: forall r ranked. => Domains (DynamicOf ranked) -> ranked r 0 rankedVecDProd = V.foldl' (\acc (DynamicExists @r2 d) -> gcastWith (unsafeCoerce Refl :: r2 :~: r) $ - tfromD d * acc) 0 + rfromD d * acc) 0 rankedNoShareListProd :: GoodScalar r => [ADVal (Flip OR.Array) r 0] diff --git a/src/HordeAd/Core/AstEnv.hs b/src/HordeAd/Core/AstEnv.hs index 512b4e00f..d86a6c9d3 100644 --- a/src/HordeAd/Core/AstEnv.hs +++ b/src/HordeAd/Core/AstEnv.hs @@ -73,13 +73,13 @@ extendEnvDR :: forall ranked shaped s. ConvertTensor ranked shaped -> AstEnv ranked shaped extendEnvDR (AstDynamicVarName @_ @sh @r @y var, DynamicExists @r2 d) !env = -- We don't need to manually pick a specialization for the existential - -- variable r2, because tfromD does not depend on r2. + -- variable r2, because rfromD does not depend on r2. case testEquality (typeRep @r) (typeRep @r2) of Just Refl -> let n = length $ Sh.shapeT @sh in case someNatVal $ toInteger n of Just (SomeNat @n _) -> gcastWith (unsafeCoerce Refl :: n :~: y) $ - extendEnvR var (tfromD d) env + extendEnvR var (rfromD d) env Nothing -> error "extendEnvDR: impossible someNatVal error" _ -> error "extendEnvDR: type mismatch" diff --git a/src/HordeAd/Core/AstInterpret.hs b/src/HordeAd/Core/AstInterpret.hs index 3a414f3aa..9b4ec802e 100644 --- a/src/HordeAd/Core/AstInterpret.hs +++ b/src/HordeAd/Core/AstInterpret.hs @@ -420,7 +420,7 @@ interpretAst !env = \case AstFromIntegral v -> rfromIntegral $ rconstant $ interpretAstPrimalRuntimeSpecialized env v AstConst a -> rconst a - AstSToR v -> tfromS $ interpretAstS env v + AstSToR v -> rfromS $ interpretAstS env v AstConstant a -> rconstant $ interpretAstPrimal env a AstPrimalPart a -> interpretAst env a -- This is correct, because @s@ must be @PrimalSpan@ and so @ranked@ must diff --git a/src/HordeAd/Core/Delta.hs b/src/HordeAd/Core/Delta.hs index a248e646f..d3f19236c 100644 --- a/src/HordeAd/Core/Delta.hs +++ b/src/HordeAd/Core/Delta.hs @@ -1009,7 +1009,7 @@ buildFinMaps s0 deltaDt = case sameShape @sh @sh2 of Just Refl -> evalS s c d _ -> error "buildFinMaps: different shapes in RToS(SToR)" - RToS d -> evalR s (tfromS c) d + RToS d -> evalR s (rfromS c) d {- -- The general case is given as the last one below, @@ -1048,7 +1048,7 @@ buildFinMaps s0 deltaDt = -> DeltaD (Clown (DynamicOf ranked)) ranked shaped r y -> EvalState ranked shaped evalD s !c = \case - RToD d -> evalR s (tfromD c) d + RToD d -> evalR s (rfromD c) d SToD d -> evalS s (sfromD c) d evalFromnMap :: EvalState ranked shaped -> EvalState ranked shaped @@ -1060,7 +1060,7 @@ buildFinMaps s0 deltaDt = DeltaBindingR @_ @r1 d -> case dMap EM.! n of DynamicExists @r2 e -> case testEquality (typeRep @r1) (typeRep @r2) of - Just Refl -> let c = tfromD e + Just Refl -> let c = rfromD e in evalRRuntimeSpecialized s2 c d _ -> error "buildFinMaps: type mismatch" DeltaBindingS @_ @r1 d -> case dMap EM.! n of @@ -1125,7 +1125,7 @@ buildDerivative dimR deltaDt params = do then case params V.! i of DynamicExists @r2 e -> case testEquality (typeRep @r) (typeRep @r2) of - Just Refl -> return $! tfromD @ranked @shaped @r e + Just Refl -> return $! rfromD @ranked @shaped @r e _ -> error "buildDerivative: type mismatch" else error "buildDerivative': wrong index for an input" ScaleR k d -> (* k) <$> evalR d @@ -1138,7 +1138,7 @@ buildDerivative dimR deltaDt params = do case dm EM.! n of DynamicExists @r2 t -> case testEquality (typeRep @r) (typeRep @r2) of - Just Refl -> return $! tfromD @ranked @shaped @r t + Just Refl -> return $! rfromD @ranked @shaped @r t _ -> error "buildDerivative: type mismatch" Nothing -> do cRaw <- evalR d @@ -1192,7 +1192,7 @@ buildDerivative dimR deltaDt params = do Just Refl -> evalR (SToR d) _ -> error "buildDerivative: different ranks in DToR(SToD)" SToR (RToS d) -> evalR d -- no information lost, so no checks - SToR d -> tfromS <$> evalS d + SToR d -> rfromS <$> evalS d evalS :: forall sh r. (Sh.Shape sh, GoodScalar r) diff --git a/src/HordeAd/Core/TensorADVal.hs b/src/HordeAd/Core/TensorADVal.hs index 155d4e3af..59591d6f5 100644 --- a/src/HordeAd/Core/TensorADVal.hs +++ b/src/HordeAd/Core/TensorADVal.hs @@ -117,7 +117,7 @@ dToR :: forall r ranked shaped n. ~ DeltaD (Clown (DynamicOf ranked)) ranked shaped , KnownNat n, GoodScalar r ) => ADVal (Clown (DynamicOf ranked)) r '() -> ADVal ranked r n -dToR (D l u u') = dDnotShared l (tfromD $ runClown u) (dDToR u') +dToR (D l u u') = dDnotShared l (rfromD $ runClown u) (dDToR u') where dDToR (RToD @n1 d) = case sameNat (Proxy @n1) (Proxy @n) of @@ -378,12 +378,12 @@ instance ( Dual ranked ~ DeltaR ranked shaped ~ DeltaD (Clown (DynamicOf ranked)) ranked shaped , ConvertTensor ranked shaped ) => ConvertTensor (ADVal ranked) (ADVal shaped) where - tfromD = dToR . runFlip - tfromS = sToR + rfromD = dToR . runFlip + rfromS = sToR where sToR :: forall r sh. (GoodScalar r, Sh.Shape sh) => ADVal shaped r sh -> ADVal ranked r (Sh.Rank sh) - sToR (D l u u') = dDnotShared l (tfromS u) (dSToR u') + sToR (D l u u') = dDnotShared l (rfromS u) (dSToR u') where dSToR (RToS d) = d -- no information lost, so no checks dSToR d = SToR d diff --git a/src/HordeAd/Core/TensorAst.hs b/src/HordeAd/Core/TensorAst.hs index 22031120c..1c4bc1da2 100644 --- a/src/HordeAd/Core/TensorAst.hs +++ b/src/HordeAd/Core/TensorAst.hs @@ -415,7 +415,7 @@ instance ( GoodScalar r, KnownNat n Just (DynamicExists @r2 a, rest) -> if isTensorDummyAst a then Just (rzero (rshape aInit), rest) else case testEquality (typeRep @r) (typeRep @r2) of - Just Refl -> let !t = tfromD @(AstRanked s) @(AstShaped s) @r a + Just Refl -> let !t = rfromD @(AstRanked s) @(AstShaped s) @r a in Just (t, rest) _ -> error $ "fromDomains: type mismatch: " ++ show (typeRep @r) ++ " " ++ show (typeRep @r2) @@ -641,8 +641,8 @@ astBuild1VectorizeS f = -- * ConvertTensor and DomainsTensor instances instance AstSpan s => ConvertTensor (AstRanked s) (AstShaped s) where - tfromD = astFromDynamic - tfromS = astSToR + rfromD = astFromDynamic + rfromS = astSToR dfromR = AstRToD dfromS = AstSToD sfromR = astRToS diff --git a/src/HordeAd/Core/TensorClass.hs b/src/HordeAd/Core/TensorClass.hs index 80ab5b59d..5f0cce30a 100644 --- a/src/HordeAd/Core/TensorClass.hs +++ b/src/HordeAd/Core/TensorClass.hs @@ -575,9 +575,9 @@ class ( DynamicOf ranked ~ DynamicOf shaped => ConvertTensor (ranked :: RankedTensorKind) (shaped :: ShapedTensorKind) | ranked -> shaped, shaped -> ranked where - tfromD :: (GoodScalar r, KnownNat n) + rfromD :: (GoodScalar r, KnownNat n) => DynamicOf ranked r -> ranked r n - tfromS :: (GoodScalar r, Sh.Shape sh) + rfromS :: (GoodScalar r, Sh.Shape sh) => shaped r sh -> ranked r (Sh.Rank sh) dfromR :: (GoodScalar r, KnownNat n) => ranked r n -> DynamicOf ranked r @@ -767,7 +767,7 @@ instance (GoodScalar r, KnownNat n) Just (DynamicExists @r2 a, rest) -> if isTensorDummyD a then Just (rzero (rshape aInit), rest) else case testEquality (typeRep @r) (typeRep @r2) of - Just Refl -> let !aR = tfromD @(Flip OR.Array) @(Flip OS.Array) @r a + Just Refl -> let !aR = rfromD @(Flip OR.Array) @(Flip OS.Array) @r a in Just (aR, rest) _ -> error $ "fromDomains: type mismatch: " ++ show (typeRep @r) ++ " " ++ show (typeRep @r2) @@ -925,8 +925,8 @@ instance {-# OVERLAPS #-} {-# OVERLAPPING #-} -- The DomainsTensor instance requires ADVal instance, so it's given elsewhere. instance ConvertTensor (Flip OR.Array) (Flip OS.Array) where - tfromD = Flip . Data.Array.Convert.convert - tfromS = Flip . Data.Array.Convert.convert . runFlip + rfromD = Flip . Data.Array.Convert.convert + rfromS = Flip . Data.Array.Convert.convert . runFlip dfromR = Data.Array.Convert.convert . runFlip dfromS = Data.Array.Convert.convert . runFlip sfromR = Flip . Data.Array.Convert.convert . runFlip diff --git a/src/HordeAd/External/CommonShapedOps.hs b/src/HordeAd/External/CommonShapedOps.hs index a574c611d..96b941e51 100644 --- a/src/HordeAd/External/CommonShapedOps.hs +++ b/src/HordeAd/External/CommonShapedOps.hs @@ -34,7 +34,7 @@ sminIndexN :: ( ADReadyS shaped, GoodScalar r sminIndexN t = ShapedList.fromLinearIdx (sshape t) - (ShapedList.shapedNat $ tfromS $ sprimalPart $ sminIndex (sflatten t)) + (ShapedList.shapedNat $ rfromS $ sprimalPart $ sminIndex (sflatten t)) smaxIndexN :: ( ADReadyS shaped, GoodScalar r , Sh.Shape sh, KnownNat (Sh.Size sh) ) @@ -42,7 +42,7 @@ smaxIndexN :: ( ADReadyS shaped, GoodScalar r smaxIndexN t = ShapedList.fromLinearIdx (sshape t) - (ShapedList.shapedNat $ tfromS $ sprimalPart $ smaxIndex (sflatten t)) + (ShapedList.shapedNat $ rfromS $ sprimalPart $ smaxIndex (sflatten t)) sminimum :: forall r sh shaped. (ADReadyS shaped, GoodScalar r, Sh.Shape sh, KnownNat (Sh.Size sh)) @@ -69,7 +69,7 @@ sletIx :: forall r sh n shaped. => IndexOf shaped n -> (IndexOf shaped n -> shaped r sh) -> shaped r sh sletIx ix0 f = slet (sfromR @(RankedOf shaped) @shaped @Int64 @'[n] $ rint64FromIndex1 ix0) $ \ixT -> - f $ rint64ToIndex1 $ tfromS ixT + f $ rint64ToIndex1 $ rfromS ixT scaleS :: forall shaped r sh. (Sh.Shape sh, ADReadyS shaped, GoodScalar r) diff --git a/test/simplified/TestAdaptorSimplified.hs b/test/simplified/TestAdaptorSimplified.hs index e505e54fb..80c7433c5 100644 --- a/test/simplified/TestAdaptorSimplified.hs +++ b/test/simplified/TestAdaptorSimplified.hs @@ -1584,7 +1584,7 @@ testBarReluMax3FwdFrom = assertEqualUpToEpsilon 1e-10 (Flip $ OS.fromList @'[2, 1, 2] [0.45309153191767404,0.9060427799711201,-2.8186426018387007,40.02498898648793]) (fwd @Double @'[2, 1, 2] - (sfromR . barReluMax . tfromS) + (sfromR . barReluMax . rfromS) (Flip $ OS.fromList @'[2, 1, 2] [1.1, 2, 3, 4.2]) (Flip $ OS.fromList @'[2, 1, 2] [0.1, 0.2, 0.3, 0.42]))