Skip to content

Commit

Permalink
Generalize FullTensorKind to nested arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Nov 18, 2024
1 parent ced91eb commit 1f4a275
Show file tree
Hide file tree
Showing 23 changed files with 278 additions and 276 deletions.
2 changes: 1 addition & 1 deletion src/HordeAd/Core/Ast.hs
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ type AstInt ms = AstTensor ms PrimalSpan (TKS '[] Int64)
type IntVarName = AstVarName PrimalSpan (TKS '[] Int64)

pattern AstIntVar :: IntVarName -> AstInt ms
pattern AstIntVar var = AstVar (FTKS ZSS) var
pattern AstIntVar var = AstVar (FTKS ZSS FTKScalar) var

isTensorInt :: forall s y ms. (AstSpan s, TensorKind y)
=> AstTensor ms s y
Expand Down
9 changes: 5 additions & 4 deletions src/HordeAd/Core/AstFreshId.hs
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,13 @@ funToAstIO sh f = do
!x = f (AstVar FTKScalar varName)
dynVar = AstDynamicVarName @Nat @r @'[] freshId
in (varName, dynVar, x)
FTKR @_ @r shr ->
FTKR shr (FTKScalar @r) ->
return $! withShapeP (shapeToList shr) $ \(Proxy @p_sh) ->
let varName = mkAstVarName freshId
!x = f (AstVar sh varName)
dynVar = AstDynamicVarName @Nat @r @p_sh freshId
in (varName, dynVar, x)
FTKS @sh @r shs -> do
FTKS @sh shs (FTKScalar @r) -> do
let varName = mkAstVarName freshId
!x = f (AstVar sh varName)
dynVar = withKnownShS shs $ AstDynamicVarName @[Nat] @r @sh freshId
Expand All @@ -104,6 +104,7 @@ funToAstIO sh f = do
let varName = mkAstVarName freshId
!x = f (AstVar sh varName)
return (varName, undefined, x)
_ -> error "TODO"

funToAst :: TensorKind y
=> FullTensorKind y
Expand Down Expand Up @@ -149,14 +150,14 @@ dynamicToVar (DynamicRankedDummy @r2 @sh2 _ _) = do
return $! withListSh (Proxy @sh2) $ \sh4 ->
let !varE = AstDynamicVarName @Nat @r2 @sh2 freshId
dynE :: AstDynamic ms s
!dynE = DynamicRanked @r2 (AstVar (FTKR sh4) (mkAstVarName freshId))
!dynE = DynamicRanked @r2 (AstVar (FTKR sh4 FTKScalar) (mkAstVarName freshId))
in (varE, dynE)
dynamicToVar (DynamicShapedDummy @r2 @sh2 _ _) = do
freshId <- unsafeGetFreshAstVarId
return $!
let !varE = AstDynamicVarName @[Nat] @r2 @sh2 freshId
dynE :: AstDynamic ms s
!dynE = DynamicShaped @r2 @sh2 (AstVar (FTKS knownShS) (mkAstVarName freshId))
!dynE = DynamicShaped @r2 @sh2 (AstVar (FTKS knownShS FTKScalar) (mkAstVarName freshId))
in (varE, dynE)

funToAstRevIO :: forall x. FullTensorKind x
Expand Down
5 changes: 3 additions & 2 deletions src/HordeAd/Core/AstInterpret.hs
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,8 @@ interpretAst !env = \case
-> target (BuildTensorKind n z)
emptyFromStk ftk = case ftk of
FTKScalar -> rfromList0N (0 :$: ZSR) []
FTKR sh | SNat <- shrRank sh -> rfromList0N (0 :$: sh) []
FTKS sh -> withKnownShS sh $ sfromList0N []
FTKR sh FTKScalar | SNat <- shrRank sh -> rfromList0N (0 :$: sh) []
FTKS sh FTKScalar -> withKnownShS sh $ sfromList0N []
FTKX{} -> error "TODO"
FTKProduct @z1 @z2 ftk1 ftk2
| Dict <- lemTensorKindOfF ftk1
Expand All @@ -256,6 +256,7 @@ interpretAst !env = \case
tpair (emptyFromStk ftk1) (emptyFromStk ftk2)
FTKUntyped ssh -> dmkHVector $ replicate1HVector @target (SNat @0)
$ V.map dynamicFromVoid ssh
_ -> error "TODO"
in emptyFromStk (ftkAst v)
-- The following can't be, in general, so partially evaluated, because v
-- may contain variables that the evironment sends to terms,
Expand Down
5 changes: 4 additions & 1 deletion src/HordeAd/Core/AstSimplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2337,9 +2337,10 @@ astLetHVectorIn vars l v = case v of
STKScalar _ -> Ast.AstUnScalar $ astProjectR l i
STKR SNat STKScalar{} -> astProjectR l i
STKS sh STKScalar{} -> withKnownShS sh $ astProjectS l i
STKX sh STKScalar{}-> withKnownShX sh $ error "TODO"
STKX sh STKScalar{}-> withKnownShX sh $ error "TODO"
STKProduct{} -> error "astLetHVectorIn: STKProduct"
STKUntyped -> error "astLetHVectorIn: STKUntyped"
_ -> error "TODO"
_ -> v
Ast.AstPrimalPart (Ast.AstVar _ var2) ->
case elemIndex (varNameToAstVarId var2)
Expand All @@ -2351,6 +2352,7 @@ astLetHVectorIn vars l v = case v of
STKX sh STKScalar{} -> withKnownShX sh $ error "TODO"
STKProduct{} -> error "astLetHVectorIn: STKProduct"
STKUntyped -> error "astLetHVectorIn: STKUntyped"
_ -> error "TODO"
_ -> v
Ast.AstDualPart (Ast.AstVar _ var2) ->
case elemIndex (varNameToAstVarId var2)
Expand All @@ -2362,6 +2364,7 @@ astLetHVectorIn vars l v = case v of
STKX sh STKScalar{} -> withKnownShX sh $ error "TODO"
STKProduct{} -> error "astLetHVectorIn: STKProduct"
STKUntyped -> error "astLetHVectorIn: STKUntyped"
_ -> error "TODO"
_ -> v
_ -> case l of
Ast.AstMkHVector l3 ->
Expand Down
94 changes: 47 additions & 47 deletions src/HordeAd/Core/AstTools.hs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ import HordeAd.Util.SizedList
ftkAst :: forall s y ms. TensorKind y
=> AstTensor ms s y -> FullTensorKind y
ftkAst t = case t of
AstScalar{} -> FTKR ZSR
AstScalar{} -> FTKR ZSR FTKScalar
AstUnScalar{} -> FTKScalar
AstPair t1 t2 -> FTKProduct (ftkAst t1) (ftkAst t2)
AstProject1 v -> case ftkAst v of
Expand All @@ -62,10 +62,10 @@ ftkAst t = case t of
AstLet _ _ v -> ftkAst v
AstShare _ v -> ftkAst v
AstToShare v -> ftkAst v
AstMinIndex a -> FTKR $ initShape $ shapeAst a
AstMaxIndex a -> FTKR $ initShape $ shapeAst a
AstFloor a -> FTKR $ shapeAst a
AstIota -> FTKR $ singletonShape (maxBound :: Int) -- ought to be enough
AstMinIndex a -> FTKR (initShape $ shapeAst a) FTKScalar
AstMaxIndex a -> FTKR (initShape $ shapeAst a) FTKScalar
AstFloor a -> FTKR (shapeAst a) FTKScalar
AstIota -> FTKR (singletonShape (maxBound :: Int)) FTKScalar -- ought to be enough
AstN1 _opCode v -> ftkAst v
AstN2 _opCode v _ -> ftkAst v
AstR1 _opCode v -> ftkAst v
Expand All @@ -74,66 +74,66 @@ ftkAst t = case t of
AstSumOfList args -> case args of
[] -> error "ftkAst: AstSumOfList with no arguments"
v : _ -> ftkAst v
AstIndex v _is -> FTKR $ dropShape $ shapeAst v
AstSum v -> FTKR $ tailShape $ shapeAst v
AstScatter sh _ _ -> FTKR sh
AstIndex v _is -> FTKR (dropShape $ shapeAst v) FTKScalar
AstSum v -> FTKR (tailShape $ shapeAst v) FTKScalar
AstScatter sh _ _ -> FTKR sh FTKScalar
AstFromVector l -> case V.toList l of
[] -> case stensorKind @y of
STKR @n SNat _ -> case sameNat (Proxy @n) (Proxy @1) of
Just Refl -> FTKR $ singletonShape 0
Just Refl -> FTKR (0 :$: ZSR) FTKScalar
Nothing -> error "ftkAst: AstFromVector with no arguments"
v : _ -> FTKR $ V.length l :$: shapeAst v
v : _ -> FTKR (V.length l :$: shapeAst v) FTKScalar
AstAppend x y -> case shapeAst x of
ZSR -> error "ftkAst: impossible pattern needlessly required"
xi :$: xsh -> case shapeAst y of
ZSR -> error "ftkAst: impossible pattern needlessly required"
yi :$: _ -> FTKR $ xi + yi :$: xsh
AstSlice _i n v -> FTKR $ n :$: tailShape (shapeAst v)
yi :$: _ -> FTKR (xi + yi :$: xsh) FTKScalar
AstSlice _i n v -> FTKR (n :$: tailShape (shapeAst v)) FTKScalar
AstReverse v -> ftkAst v
AstTranspose perm v ->
FTKR $ Nested.Internal.Shape.shrPermutePrefix perm $ shapeAst v
AstReshape sh _v -> FTKR sh
AstGather sh _v (_vars, _ix) -> FTKR sh
AstCast v -> FTKR $ shapeAst v
AstFromIntegral a -> FTKR $ shapeAst a
AstConcrete a -> FTKR $ Nested.rshape a
FTKR (Nested.Internal.Shape.shrPermutePrefix perm $ shapeAst v) FTKScalar
AstReshape sh _v -> FTKR sh FTKScalar
AstGather sh _v (_vars, _ix) -> FTKR sh FTKScalar
AstCast v -> FTKR (shapeAst v) FTKScalar
AstFromIntegral a -> FTKR (shapeAst a) FTKScalar
AstConcrete a -> FTKR (Nested.rshape a) FTKScalar
AstProjectR l p -> case shapeAstHVector l V.! p of
DynamicRankedDummy @_ @sh _ _ -> FTKR $ listToShape $ shapeT @sh
DynamicRankedDummy @_ @sh _ _ -> FTKR (listToShape $ shapeT @sh) FTKScalar
DynamicShapedDummy{} -> error "ftkAst: DynamicShapedDummy"
AstLetHVectorIn _ _ v -> ftkAst v
AstRFromS @sh _ | Dict <- lemKnownNatRank (knownShS @sh) ->
FTKR $ listToShape $ shapeT @sh

AstMinIndexS{} -> FTKS knownShS
AstMaxIndexS{} -> FTKS knownShS
AstFloorS{} -> FTKS knownShS
AstIotaS{} -> FTKS knownShS
AstN1S{} -> FTKS knownShS
AstN2S{} -> FTKS knownShS
AstR1S{} -> FTKS knownShS
AstR2S{} -> FTKS knownShS
AstI2S{} -> FTKS knownShS
AstSumOfListS{} -> FTKS knownShS
AstIndexS{} -> FTKS knownShS
AstSumS{} -> FTKS knownShS
AstScatterS{} -> FTKS knownShS
AstFromVectorS{} -> FTKS knownShS
AstAppendS{} -> FTKS knownShS
AstSliceS{} -> FTKS knownShS
AstReverseS{} -> FTKS knownShS
FTKR (listToShape $ shapeT @sh) FTKScalar

AstMinIndexS{} -> FTKS knownShS FTKScalar
AstMaxIndexS{} -> FTKS knownShS FTKScalar
AstFloorS{} -> FTKS knownShS FTKScalar
AstIotaS{} -> FTKS knownShS FTKScalar
AstN1S{} -> FTKS knownShS FTKScalar
AstN2S{} -> FTKS knownShS FTKScalar
AstR1S{} -> FTKS knownShS FTKScalar
AstR2S{} -> FTKS knownShS FTKScalar
AstI2S{} -> FTKS knownShS FTKScalar
AstSumOfListS{} -> FTKS knownShS FTKScalar
AstIndexS{} -> FTKS knownShS FTKScalar
AstSumS{} -> FTKS knownShS FTKScalar
AstScatterS{} -> FTKS knownShS FTKScalar
AstFromVectorS{} -> FTKS knownShS FTKScalar
AstAppendS{} -> FTKS knownShS FTKScalar
AstSliceS{} -> FTKS knownShS FTKScalar
AstReverseS{} -> FTKS knownShS FTKScalar
AstTransposeS @perm @sh2 perm _v ->
withShapeP
(backpermutePrefixList (Permutation.permToList' perm)
(shapeT @sh2)) $ \(Proxy @sh2Perm) ->
gcastWith (unsafeCoerce Refl :: sh2Perm :~: Permutation.PermutePrefix perm sh2) $
FTKS knownShS
AstReshapeS{} -> FTKS knownShS
AstGatherS{} -> FTKS knownShS
AstCastS{} -> FTKS knownShS
AstFromIntegralS{} -> FTKS knownShS
AstConcreteS{} -> FTKS knownShS
AstProjectS{} -> FTKS knownShS
AstSFromR{} -> FTKS knownShS
FTKS knownShS FTKScalar
AstReshapeS{} -> FTKS knownShS FTKScalar
AstGatherS{} -> FTKS knownShS FTKScalar
AstCastS{} -> FTKS knownShS FTKScalar
AstFromIntegralS{} -> FTKS knownShS FTKScalar
AstConcreteS{} -> FTKS knownShS FTKScalar
AstProjectS{} -> FTKS knownShS FTKScalar
AstSFromR{} -> FTKS knownShS FTKScalar

AstMkHVector v ->
FTKUntyped
Expand All @@ -158,7 +158,7 @@ ftkAst t = case t of
shapeAst :: forall n s r ms. (KnownNat n, GoodScalar r)
=> AstTensor ms s (TKR n r) -> IShR n
shapeAst t = case ftkAst t of
FTKR sh -> sh
FTKR sh _ -> sh

-- Length of the outermost dimension.
lengthAst :: (KnownNat n, GoodScalar r) => AstTensor ms s (TKR (1 + n) r) -> Int
Expand Down
15 changes: 8 additions & 7 deletions src/HordeAd/Core/AstVectorize.hs
Original file line number Diff line number Diff line change
Expand Up @@ -705,11 +705,11 @@ substProjRep snat@SNat var ftk2 var1 v
projection prVar = \case
FTKScalar ->
Ast.AstUnScalar $ Ast.AstIndex prVar (Ast.AstIntVar var :.: ZIR)
FTKR sh | SNat <- shrRank sh ->
FTKR sh FTKScalar | SNat <- shrRank sh ->
Ast.AstIndex prVar (Ast.AstIntVar var :.: ZIR)
FTKS sh -> withKnownShS sh
FTKS sh FTKScalar -> withKnownShS sh
$ Ast.AstIndexS prVar (Ast.AstIntVar var :.$ ZIS)
FTKX sh -> withKnownShX (ssxFromShape sh)
FTKX sh FTKScalar -> withKnownShX (ssxFromShape sh)
$ Ast.AstIndexX prVar (Ast.AstIntVar var :.% ZIX)
FTKProduct @z1 @z2 ftk41 ftk42
| Dict <- lemTensorKindOfF ftk41
Expand All @@ -729,16 +729,17 @@ substProjRep snat@SNat var ftk2 var1 v
(DynamicRankedDummy @_ @sh3 _ _)
| Just Refl <- matchingRank @(k ': sh3) @n2 =
withListSh (Proxy @sh3) $ \sh1 ->
DynamicRanked $ projection t (FTKR sh1)
DynamicRanked $ projection t (FTKR sh1 FTKScalar)
projDyn (DynamicShaped @_ @sh2 t)
(DynamicShapedDummy @_ @sh3 _ _)
| Just Refl <- sameShape @sh2 @(k ': sh3) =
DynamicShaped $ projection t (FTKS @sh3 knownShS)
DynamicShaped $ projection t (FTKS @sh3 knownShS FTKScalar)
projDyn _ _ = error "projDyn: impossible DynamicTensor cases"
in astLetHVectorIn
vars
prVar
(Ast.AstMkHVector $ V.zipWith projDyn asts shs0)
_ -> error "TODO"
v2 = substituteAst
(projection astVar3 ftk2)
var1 v
Expand All @@ -753,7 +754,7 @@ substProjRanked :: forall n1 r1 s1 s y.
substProjRanked k var sh1 var1 =
let var2 = mkAstVarName @s1 @(TKR (1 + n1) r1) (varNameToAstVarId var1) -- changed shape; TODO: shall we rename?
projection =
Ast.AstIndex (Ast.AstVar (FTKR $ k :$: sh1) var2)
Ast.AstIndex (Ast.AstVar (FTKR (k :$: sh1) FTKScalar) var2)
(Ast.AstIntVar var :.: ZIR)
in substituteAst
projection var1
Expand All @@ -769,7 +770,7 @@ substProjShaped :: forall k sh1 r1 s1 s y.
substProjShaped var var1 =
let var2 = mkAstVarName @s1 @(TKS (k : sh1) r1) (varNameToAstVarId var1)
projection =
Ast.AstIndexS (Ast.AstVar @(TKS (k ': sh1) r1) (FTKS knownShS) var2)
Ast.AstIndexS (Ast.AstVar @(TKS (k ': sh1) r1) (FTKS knownShS FTKScalar) var2)
(Ast.AstIntVar var :.$ ZIS)
in substituteAst
projection var1
Expand Down
19 changes: 10 additions & 9 deletions src/HordeAd/Core/CarriersADVal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,9 @@ generateDeltaInputs =
let gen :: Int -> FullTensorKind y -> (Delta target y, Int)
gen j ftk = case ftk of
FTKScalar -> (InputG ftk (toInputId j), j + 1)
FTKR sh | SNat <- shrRank sh -> (InputG ftk (toInputId j), j + 1)
FTKS sh -> withKnownShS sh $ (InputG ftk (toInputId j), j + 1)
FTKX sh -> withKnownShX (ssxFromShape sh)
FTKR sh FTKScalar | SNat <- shrRank sh -> (InputG ftk (toInputId j), j + 1)
FTKS sh FTKScalar -> withKnownShS sh $ (InputG ftk (toInputId j), j + 1)
FTKX sh FTKScalar -> withKnownShX (ssxFromShape sh)
$ (InputG ftk (toInputId j), j + 1)
FTKProduct ftk1 ftk2 | Dict <- lemTensorKindOfF ftk1
, Dict <- lemTensorKindOfF ftk2 ->
Expand All @@ -225,11 +225,12 @@ generateDeltaInputs =
let f :: (Int, DynamicTensor VoidTensor) -> DynamicTensor (Delta target)
f (i, DynamicRankedDummy @r @sh _ _) =
withListSh (Proxy @sh) $ \sh ->
DynamicRanked $ InputG (FTKR @_ @r sh) (toInputId i)
DynamicRanked $ InputG (FTKR sh (FTKScalar @r)) (toInputId i)
f (i, DynamicShapedDummy @r @sh _ _) =
DynamicShaped $ InputG (FTKS @sh @r knownShS) (toInputId i)
DynamicShaped $ InputG (FTKS @sh knownShS (FTKScalar @r)) (toInputId i)
len = V.length shs
in (HToH $ V.map f $ V.zip (V.enumFromN j len) shs, j + len)
_ -> error "TODO"
in fst . gen 0
{- TODO: this causes a cyclic dependency:
{-# SPECIALIZE generateDeltaInputs
Expand Down Expand Up @@ -265,10 +266,10 @@ rFromH (HToH hv) i = case hv V.! i of
DynamicRankedDummy @r2 @sh _ _
| Just Refl <- matchingRank @sh @n
, Just Refl <- testEquality (typeRep @r) (typeRep @r2) ->
ZeroG (FTKR $ fromList $ shapeT @sh)
ZeroG $ FTKR (fromList $ shapeT @sh) FTKScalar
_ -> error "rFromH: impossible case"
rFromH (ZeroG (FTKUntyped shs)) i = case shs V.! i of
DynamicRankedDummy @_ @sh _ _ -> ZeroG (FTKR $ fromList $ shapeT @sh)
DynamicRankedDummy @_ @sh _ _ -> ZeroG $ FTKR (fromList $ shapeT @sh) FTKScalar
DynamicShapedDummy{} -> error "rFromH: DynamicShapedDummy"
rFromH d i = RFromH d i

Expand All @@ -281,11 +282,11 @@ sFromH (HToH hv) i = case hv V.! i of
DynamicShapedDummy @r2 @sh3 _ _
| Just Refl <- sameShape @sh @sh3
, Just Refl <- testEquality (typeRep @r) (typeRep @r2) ->
ZeroG (FTKS $ fromList $ shapeT @sh3)
ZeroG $ FTKS (fromList $ shapeT @sh3) FTKScalar
_ -> error "sFromH: impossible case"
sFromH (ZeroG (FTKUntyped shs)) i = case shs V.! i of
DynamicRankedDummy{} -> error "sFromH: DynamicRankedDummy"
DynamicShapedDummy @_ @sh3 _ _ -> ZeroG (FTKS $ fromList $ shapeT @sh3)
DynamicShapedDummy @_ @sh3 _ _ -> ZeroG $ FTKS (fromList $ shapeT @sh3) FTKScalar
sFromH d i = SFromH d i


Expand Down
Loading

0 comments on commit 1f4a275

Please sign in to comment.