Skip to content

Commit

Permalink
Reword things in the light of Type :: Type
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Dec 27, 2023
1 parent b489e66 commit 8275e46
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 41 deletions.
2 changes: 1 addition & 1 deletion src/HordeAd/Core/Ast.hs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ varNameToAstVarId :: AstVarName f r y -> AstVarId
varNameToAstVarId (AstVarName varId) = varId

-- This can't be replaced by AstVarId. because in some places it's used
-- to record the kind, scalar and shape of arguments in a domain.
-- to record the type, scalar and shape of arguments in a domain.
--
-- A lot of the variables are existential, but there's no nesting,
-- so no special care about picking specializations at runtime is needed.
Expand Down
10 changes: 5 additions & 5 deletions src/HordeAd/Core/AstEnv.hs
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,14 @@ extendEnvD (AstDynamicVarName @ty @r @sh varId, d) !env
DynamicRanked @r2 @n2 t -> case matchingRank @sh @n2 of
Just Refl -> case testEquality (typeRep @r2) (typeRep @r) of
Just Refl -> extendEnvR (AstVarName varId) t env
_ -> error "extendEnvD: type mismatch"
_ -> error "extendEnvD: scalar mismatch"
_ -> error "extendEnvD: rank mismatch"
DynamicShaped{} -> error "extendEnvD: ranked from shaped"
DynamicRankedDummy @r2 @sh2 _ _ -> case sameShape @sh2 @sh of
Just Refl -> case testEquality (typeRep @r2) (typeRep @r) of
Just Refl -> withListShape (Sh.shapeT @sh2) $ \sh3 ->
extendEnvR @_ @_ @r (AstVarName varId) (rzero sh3) env
_ -> error "extendEnvD: type mismatch"
_ -> error "extendEnvD: scalar mismatch"
_ -> error "extendEnvD: rank mismatch"
DynamicShapedDummy{} -> error "extendEnvD: ranked from shaped"
extendEnvD (AstDynamicVarName @ty @r @sh varId, d) env
Expand All @@ -97,15 +97,15 @@ extendEnvD (AstDynamicVarName @ty @r @sh varId, d) env
DynamicShaped @r2 @sh2 t -> case sameShape @sh2 @sh of
Just Refl -> case testEquality (typeRep @r2) (typeRep @r) of
Just Refl -> extendEnvS (AstVarName varId) t env
_ -> error "extendEnvD: type mismatch"
_ -> error "extendEnvD: scalar mismatch"
_ -> error "extendEnvD: shape mismatch"
DynamicRankedDummy{} -> error "extendEnvD: shaped from ranked"
DynamicShapedDummy @r2 @sh2 _ _ -> case sameShape @sh2 @sh of
Just Refl -> case testEquality (typeRep @r2) (typeRep @r) of
Just Refl -> extendEnvS @_ @_ @r @sh (AstVarName varId) 0 env
_ -> error "extendEnvD: type mismatch"
_ -> error "extendEnvD: scalar mismatch"
_ -> error "extendEnvD: shape mismatch"
extendEnvD _ _ = error "extendEnvD: unexpected kind"
extendEnvD _ _ = error "extendEnvD: unexpected type"

extendEnvI :: ( RankedTensor ranked
, RankedOf (PrimalOf ranked) ~ PrimalOf ranked )
Expand Down
16 changes: 8 additions & 8 deletions src/HordeAd/Core/AstInterpret.hs
Original file line number Diff line number Diff line change
Expand Up @@ -124,17 +124,17 @@ interpretAst !env = \case
Just Refl -> case testEquality (typeRep @r) (typeRep @r2) of
Just Refl -> assert (rshape t == sh
`blame` (sh, rshape t, varId, t, env)) t
_ -> error "interpretAst: type mismatch"
_ -> error "interpretAst: scalar mismatch"
_ -> error "interpretAst: wrong shape in environment"
-- To impose such checks, we'd need to switch from OD tensors
-- to existential OR/OS tensors so that we can inspect
-- which it is and then seed Delta evaluation maps with that.
-- Just{} -> error "interpretAst: wrong tensor kind in environment"
-- Just{} -> error "interpretAst: wrong tensor type in environment"
Just (AstEnvElemS @sh2 @r2 t) -> case shapeToList sh == Sh.shapeT @sh2 of
True -> case matchingRank @sh2 @n of
Just Refl -> case testEquality (typeRep @r) (typeRep @r2) of
Just Refl -> rfromS @_ @r2 @sh2 t
_ -> error "interpretAst: type mismatch"
_ -> error "interpretAst: scalar mismatch"
_ -> error "interpretAst: wrong rank"
False -> error $ "interpretAst: wrong shape in environment"
`showFailure`
Expand Down Expand Up @@ -461,7 +461,7 @@ interpretAst !env = \case
, Just Refl <- sameShape @sh3 @sh2
, Just Refl <- testEquality (typeRep @r2) (typeRep @r3) ->
extendEnvS @ranked @shaped @r2 @sh2 (AstVarName varId) 0
_ -> error "interpretAst: impossible kind"
_ -> error "interpretAst: impossible type"
env2 lw = foldr f env (zip vars (V.toList lw))
in rletDomainsIn lt0 lt (\lw -> interpretAst (env2 lw) v)
AstSToR v -> rfromS $ interpretAstS env v
Expand Down Expand Up @@ -676,20 +676,20 @@ interpretAstS !env = \case
Just (AstEnvElemS @sh2 @r2 t) -> case sameShape @sh2 @sh of
Just Refl -> case testEquality (typeRep @r) (typeRep @r2) of
Just Refl -> t
_ -> error "interpretAstS: type mismatch"
_ -> error "interpretAstS: scalar mismatch"
Nothing -> error $ "interpretAstS: wrong shape in environment"
`showFailure`
(Sh.shapeT @sh, Sh.shapeT @sh2, varId, t, env)
-- To impose such checks, we'd need to switch from OD tensors
-- to existential OR/OS tensors so that we can inspect
-- which it is and then seed Delta evaluation maps with that.
-- Just{} -> error "interpretAstS: wrong tensor kind in environment"
-- Just{} -> error "interpretAstS: wrong tensor type in environment"
Just (AstEnvElemR @n2 @r2 t) -> case matchingRank @sh @n2 of
Just Refl -> case testEquality (typeRep @r) (typeRep @r2) of
Just Refl -> assert (Sh.shapeT @sh == shapeToList (rshape t)
`blame` (Sh.shapeT @sh, rshape t, varId, t, env))
$ sfromR @_ @r2 @sh t
_ -> error "interpretAstS: type mismatch"
_ -> error "interpretAstS: scalar mismatch"
_ -> error "interpretAstS: wrong shape in environment"
Nothing -> error $ "interpretAstS: unknown variable " ++ show varId
AstLetS var u v ->
Expand Down Expand Up @@ -998,7 +998,7 @@ interpretAstS !env = \case
, Just Refl <- sameShape @sh3 @sh2
, Just Refl <- testEquality (typeRep @r2) (typeRep @r3) ->
extendEnvS @ranked @shaped @r2 @sh2 (AstVarName varId) 0
_ -> error "interpretAstS: impossible kind"
_ -> error "interpretAstS: impossible type"
env2 lw = foldr f env (zip vars (V.toList lw))
in sletDomainsIn lt0 lt (\lw -> interpretAstS (env2 lw) v)
AstRToS v -> sfromR $ interpretAst env v
Expand Down
16 changes: 8 additions & 8 deletions src/HordeAd/Core/AstSimplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -875,7 +875,7 @@ astLet var u v@(Ast.AstVar _ var2) =
Just Refl -> case sameNat (Proxy @n) (Proxy @m) of
Just Refl -> case testEquality (typeRep @r) (typeRep @r2) of
Just Refl -> u
_ -> error "astLet: type mismatch"
_ -> error "astLet: scalar mismatch"
_ -> error "astLet: rank mismatch"
_ -> error "astLet: span mismatch"
else v
Expand All @@ -885,7 +885,7 @@ astLet var u v@(Ast.AstConstant (Ast.AstVar _ var2)) = -- a common noop
Just Refl -> case sameNat (Proxy @n) (Proxy @m) of
Just Refl -> case testEquality (typeRep @r) (typeRep @r2) of
Just Refl -> Ast.AstConstant u
_ -> error "astLet: type mismatch"
_ -> error "astLet: scalar mismatch"
_ -> error "astLet: rank mismatch"
_ -> error "astLet: span mismatch"
else v
Expand All @@ -895,7 +895,7 @@ astLet var u v@(Ast.AstPrimalPart (Ast.AstVar _ var2)) = -- a common noop
Just Refl -> case sameNat (Proxy @n) (Proxy @m) of
Just Refl -> case testEquality (typeRep @r) (typeRep @r2) of
Just Refl -> astPrimalPart @r2 u
_ -> error "astLet: type mismatch"
_ -> error "astLet: scalar mismatch"
_ -> error "astLet: rank mismatch"
_ -> error "astLet: span mismatch"
else v
Expand All @@ -905,7 +905,7 @@ astLet var u v@(Ast.AstDualPart (Ast.AstVar _ var2)) = -- a noop
Just Refl -> case sameNat (Proxy @n) (Proxy @m) of
Just Refl -> case testEquality (typeRep @r) (typeRep @r2) of
Just Refl -> astDualPart @r2 u
_ -> error "astLet: type mismatch"
_ -> error "astLet: scalar mismatch"
_ -> error "astLet: rank mismatch"
_ -> error "astLet: span mismatch"
else v
Expand Down Expand Up @@ -938,7 +938,7 @@ astLetS var u v@(Ast.AstVarS var2) =
Just Refl -> case sameShape @sh1 @sh2 of
Just Refl -> case testEquality (typeRep @r) (typeRep @r2) of
Just Refl -> u
_ -> error "astLetS: type mismatch"
_ -> error "astLetS: scalar mismatch"
_ -> error "astLetS: shape mismatch"
_ -> error "astLetS: span mismatch"
else v
Expand All @@ -948,7 +948,7 @@ astLetS var u v@(Ast.AstConstantS (Ast.AstVarS var2)) = -- a common noop
Just Refl -> case sameShape @sh1 @sh2 of
Just Refl -> case testEquality (typeRep @r) (typeRep @r2) of
Just Refl -> Ast.AstConstantS u
_ -> error "astLetS: type mismatch"
_ -> error "astLetS: scalar mismatch"
_ -> error "astLetS: shape mismatch"
_ -> error "astLetS: span mismatch"
else v
Expand All @@ -958,7 +958,7 @@ astLetS var u v@(Ast.AstPrimalPartS (Ast.AstVarS var2)) = -- a common noop
Just Refl -> case sameShape @sh1 @sh2 of
Just Refl -> case testEquality (typeRep @r) (typeRep @r2) of
Just Refl -> astPrimalPartS @r2 u
_ -> error "astLetS: type mismatch"
_ -> error "astLetS: scalar mismatch"
_ -> error "astLetS: shape mismatch"
_ -> error "astLetS: span mismatch"
else v
Expand All @@ -968,7 +968,7 @@ astLetS var u v@(Ast.AstDualPartS (Ast.AstVarS var2)) = -- a noop
Just Refl -> case sameShape @sh1 @sh2 of
Just Refl -> case testEquality (typeRep @r) (typeRep @r2) of
Just Refl -> astDualPartS @r2 u
_ -> error "astLetS: type mismatch"
_ -> error "astLetS: scalar mismatch"
_ -> error "astLetS: shape mismatch"
_ -> error "astLetS: span mismatch"
else v
Expand Down
12 changes: 6 additions & 6 deletions src/HordeAd/Core/Delta.hs
Original file line number Diff line number Diff line change
Expand Up @@ -993,7 +993,7 @@ buildFinMaps s0 deltaDt =
Just Refl -> case testEquality (typeRep @r1)
(typeRep @r2) of
Just Refl -> evalRRuntimeSpecialized s2 c d
_ -> error "buildFinMaps: type mismatch"
_ -> error "buildFinMaps: scalar mismatch"
_ -> error "buildFinMaps: rank mismatch"
DynamicShaped{} ->
error "evalFromnMap: DynamicShaped"
Expand All @@ -1008,7 +1008,7 @@ buildFinMaps s0 deltaDt =
Just Refl -> case testEquality (typeRep @r1)
(typeRep @r2) of
Just Refl -> evalSRuntimeSpecialized s2 c d
_ -> error "buildFinMaps: type mismatch"
_ -> error "buildFinMaps: scalar mismatch"
_ -> error "buildFinMaps: shape mismatch"
DynamicRankedDummy{} ->
error "evalFromnMap: DynamicRankedDummy"
Expand Down Expand Up @@ -1069,7 +1069,7 @@ buildDerivative dimR deltaDt params = do
DynamicRanked @r2 @n2 e -> case sameNat (Proxy @n2) (Proxy @n) of
Just Refl -> case testEquality (typeRep @r) (typeRep @r2) of
Just Refl -> return e
_ -> error "buildDerivative: type mismatch"
_ -> error "buildDerivative: scalar mismatch"
_ -> error "buildDerivative: rank mismatch"
DynamicShaped{} -> error "buildDerivative: DynamicShaped"
DynamicRankedDummy{} -> error "buildDerivative: DynamicRankedDummy"
Expand All @@ -1087,7 +1087,7 @@ buildDerivative dimR deltaDt params = do
(Proxy @n) of
Just Refl -> case testEquality (typeRep @r) (typeRep @r2) of
Just Refl -> return e
_ -> error "buildDerivative: type mismatch"
_ -> error "buildDerivative: scalar mismatch"
_ -> error "buildDerivative: rank mismatch"
DynamicShaped{} -> error "buildDerivative: DynamicShaped"
DynamicRankedDummy{} ->
Expand Down Expand Up @@ -1159,7 +1159,7 @@ buildDerivative dimR deltaDt params = do
DynamicShaped @r2 @sh2 e -> case sameShape @sh2 @sh of
Just Refl -> case testEquality (typeRep @r) (typeRep @r2) of
Just Refl -> return e
_ -> error "buildDerivative: type mismatch"
_ -> error "buildDerivative: scalar mismatch"
_ -> error "buildDerivative: shape mismatch"
DynamicRankedDummy{} -> error "buildDerivative: DynamicRankedDummy"
DynamicShapedDummy{} -> error "buildDerivative: DynamicShapedDummy"
Expand All @@ -1176,7 +1176,7 @@ buildDerivative dimR deltaDt params = do
DynamicShaped @r2 @sh2 e -> case sameShape @sh2 @sh of
Just Refl -> case testEquality (typeRep @r) (typeRep @r2) of
Just Refl -> return e
_ -> error "buildDerivative: type mismatch"
_ -> error "buildDerivative: scalar mismatch"
_ -> error "buildDerivative: shape mismatch"
DynamicRankedDummy{} ->
error "buildDerivative: DynamicRankedDummy"
Expand Down
22 changes: 11 additions & 11 deletions src/HordeAd/Core/TensorClass.hs
Original file line number Diff line number Diff line change
Expand Up @@ -296,13 +296,13 @@ raddDynamic r (DynamicRanked @r2 @n2 t) = case sameNat (Proxy @n2)
(Proxy @n) of
Just Refl -> case testEquality (typeRep @r) (typeRep @r2) of
Just Refl -> DynamicRanked @r $ r + t
_ -> error "raddDynamic: type mismatch"
_ -> error "raddDynamic: scalar mismatch"
_ -> error "raddDynamic: rank mismatch"
raddDynamic _ DynamicShaped{} = error "raddDynamic: DynamicShaped"
raddDynamic r (DynamicRankedDummy @r2 @sh2 _ _) = case matchingRank @sh2 @n of
Just Refl -> case testEquality (typeRep @r2) (typeRep @r) of
Just Refl -> DynamicRanked (r :: ranked r2 (Sh.Rank sh2))
_ -> error "raddDynamic: type mismatch"
_ -> error "raddDynamic: scalar mismatch"
_ -> error "raddDynamic: rank mismatch"
raddDynamic _ DynamicShapedDummy{} = error "raddDynamic: DynamicShapedDummy"

Expand All @@ -315,13 +315,13 @@ saddDynamic _ DynamicRanked{} = error "saddDynamic: DynamicRanked"
saddDynamic r (DynamicShaped @r2 @sh2 t) = case sameShape @sh2 @sh of
Just Refl -> case testEquality (typeRep @r2) (typeRep @r) of
Just Refl -> DynamicShaped @r $ r + t
_ -> error "saddDynamic: type mismatch"
_ -> error "saddDynamic: scalar mismatch"
_ -> error "saddDynamic: shape mismatch"
saddDynamic _ DynamicRankedDummy{} = error "saddDynamic: DynamicRankedDummy"
saddDynamic r (DynamicShapedDummy @r2 @sh2 _ _) = case sameShape @sh2 @sh of
Just Refl -> case testEquality (typeRep @r2) (typeRep @r) of
Just Refl -> DynamicShaped (r :: shaped r2 sh2)
_ -> error "saddDynamic: type mismatch"
_ -> error "saddDynamic: scalar mismatch"
_ -> error "saddDynamic: shape mismatch"


Expand Down Expand Up @@ -607,13 +607,13 @@ rfromD :: forall ranked r n.
rfromD (DynamicRanked @r2 @n2 t) = case sameNat (Proxy @n2) (Proxy @n) of
Just Refl -> case testEquality (typeRep @r) (typeRep @r2) of
Just Refl -> t
_ -> error "rfromD: type mismatch"
_ -> error "rfromD: scalar mismatch"
_ -> error "rfromD: rank mismatch"
rfromD DynamicShaped{} = error "rfromD: unexpected DynamicShaped"
rfromD (DynamicRankedDummy @r2 @sh2 _ _) = case matchingRank @sh2 @n of
Just Refl -> case testEquality (typeRep @r) (typeRep @r2) of
Just Refl -> rfromS @_ @r2 @sh2 0
_ -> error "rfromD: type mismatch"
_ -> error "rfromD: scalar mismatch"
_ -> error "rfromD: rank mismatch"
rfromD DynamicShapedDummy{} = error "rfromD: unexpected DynamicShapedDummy"

Expand All @@ -626,7 +626,7 @@ sfromD DynamicRanked{} = error "sfromD: unexpected DynamicRanked"
sfromD (DynamicShaped @r2 @sh2 t) = case sameShape @sh2 @sh of
Just Refl -> case testEquality (typeRep @r) (typeRep @r2) of
Just Refl -> t
_ -> error "sfromD: type mismatch"
_ -> error "sfromD: scalar mismatch"
_ -> error "sfromD: shape mismatch"
sfromD DynamicRankedDummy{} = error "sfromD: unexpected DynamicRankedDummy"
sfromD DynamicShapedDummy{} = 0
Expand Down Expand Up @@ -813,7 +813,7 @@ fromDomainsR params = case V.uncons params of
(Proxy @n) of
Just Refl -> case testEquality (typeRep @r2) (typeRep @r) of
Just Refl -> Just (t, rest)
_ -> error $ "fromDomainsR: type mismatch in "
_ -> error $ "fromDomainsR: scalar mismatch in "
++ show (typeRep @r2, typeRep @r)
_ -> error "fromDomainsR: rank mismatch"
Just (DynamicShaped{}, _) -> error "fromDomainsR: ranked from shaped"
Expand All @@ -822,7 +822,7 @@ fromDomainsR params = case V.uncons params of
Just Refl ->
let sh2 = listShapeToShape (Sh.shapeT @sh2)
in Just (rzero sh2 :: ranked r2 (Sh.Rank sh2), rest)
_ -> error "fromDomainsR: type mismatch"
_ -> error "fromDomainsR: scalar mismatch"
_ -> error "fromDomainsR: shape mismatch"
Just (DynamicShapedDummy{}, _) -> error "fromDomainsR: ranked from shaped"
Nothing -> Nothing
Expand All @@ -837,15 +837,15 @@ fromDomainsS params = case V.uncons params of
Just (DynamicShaped @r2 @sh2 t, rest) -> case sameShape @sh2 @sh of
Just Refl -> case testEquality (typeRep @r2) (typeRep @r) of
Just Refl -> Just (t, rest)
_ -> error "fromDomainsS: type mismatch"
_ -> error "fromDomainsS: scalar mismatch"
_ -> error "fromDomainsS: shape mismatch"
Just (DynamicRankedDummy{}, _) -> error "fromDomainsS: shaped from ranked"
Just (DynamicShapedDummy @r2 @sh2 _ _, rest) -> case sameShape @sh2 @sh of
Just Refl -> case testEquality (typeRep @r2) (typeRep @r) of
Just Refl ->
-- The dummy gets removed, so we verify its types before it does.
Just (0 :: shaped r2 sh2, rest)
_ -> error "fromDomainsS: type mismatch"
_ -> error "fromDomainsS: scalar mismatch"
_ -> error "fromDomainsS: shape mismatch"
Nothing -> Nothing

Expand Down
4 changes: 2 additions & 2 deletions src/HordeAd/Core/Types.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{-# LANGUAGE QuantifiedConstraints, UndecidableInstances #-}
-- | Some fundamental kinds, type families and types.
-- | Some fundamental type families and types.
module HordeAd.Core.Types
( -- * Kinds of the functors that determine the structure of a tensor type
TensorType, RankedTensorType, ShapedTensorType, TensorToken(..)
Expand Down Expand Up @@ -44,7 +44,7 @@ import HordeAd.Internal.TensorFFI
import HordeAd.Util.ShapedList (ShapedList, ShapedNat)
import HordeAd.Util.SizedIndex

-- * Kinds of the functors that determine the structure of a tensor type
-- * Types of types of tensors

type TensorType ty = Type -> ty -> Type

Expand Down

0 comments on commit 8275e46

Please sign in to comment.