Skip to content

Commit

Permalink
Switch indexes to be TKScalar (WIP)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Nov 28, 2024
1 parent 42dd3f6 commit 8c76466
Show file tree
Hide file tree
Showing 20 changed files with 373 additions and 326 deletions.
9 changes: 4 additions & 5 deletions src/HordeAd/Core/Ast.hs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ import Data.Array.Nested
, ListR
, ListS (..)
, Rank
, ShS (..)
, type (++)
)
import Data.Array.Nested qualified as Nested
Expand Down Expand Up @@ -197,19 +196,19 @@ type instance PrimalOf (AstTensor ms s) = AstTensor ms PrimalSpan
-- | This is the (arbitrarily) chosen representation of terms denoting
-- integers in the indexes of tensor operations.
type AstInt ms = IntOf (AstTensor ms FullSpan)
-- ~ AstTensor ms PrimalSpan (TKS '[] Int64)
-- ~ AstTensor ms PrimalSpan (TKScalar Int64)

-- TODO: type IntVarNameF = AstVarName PrimalSpan Int64
type IntVarName = AstVarName PrimalSpan (TKS '[] Int64)
type IntVarName = AstVarName PrimalSpan (TKScalar Int64)

pattern AstIntVar :: IntVarName -> AstInt ms
pattern AstIntVar var = AstVar (FTKS ZSS FTKScalar) var
pattern AstIntVar var = AstVar FTKScalar var

isTensorInt :: forall s y ms. (AstSpan s, TensorKind y)
=> AstTensor ms s y
-> Maybe (AstTensor ms s y :~: AstInt ms)
isTensorInt _ = case ( sameAstSpan @s @PrimalSpan
, sameTensorKind @y @(TKS '[] Int64) ) of
, sameTensorKind @y @(TKScalar Int64) ) of
(Just Refl, Just Refl) -> Just Refl
_ -> Nothing

Expand Down
2 changes: 1 addition & 1 deletion src/HordeAd/Core/AstEnv.hs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ extendEnvD vd@(AstDynamicVarName @ty @r @sh varId, d) !env = case d of
extendEnvI :: BaseTensor target
=> IntVarName -> IntOf target -> AstEnv target
-> AstEnv target
extendEnvI var !i !env = extendEnv var (sfromPrimal i) env
extendEnvI var !i !env = extendEnv var (tfromPrimal (STKScalar typeRep) i) env

extendEnvVars :: forall target m. BaseTensor target
=> AstVarList m -> IxROf target m
Expand Down
14 changes: 6 additions & 8 deletions src/HordeAd/Core/AstInline.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ import Data.Vector.Generic qualified as V
import GHC.Exts (IsList (..))
import GHC.TypeLits (fromSNat)

import Data.Array.Nested (KnownShS (..))

import HordeAd.Core.Ast (AstBool, AstTensor)
import HordeAd.Core.Ast hiding (AstBool (..), AstTensor (..))
import HordeAd.Core.Ast qualified as Ast
Expand Down Expand Up @@ -430,9 +428,9 @@ unshareAstTensor tShare =
-- into more than one index element, with the share containing
-- the gather/scatter/build variables corresponding to the index.
unshareAstScoped
:: forall sh s r. (GoodScalar r, KnownShS sh, AstSpan s)
=> [IntVarName] -> AstBindings -> AstTensor AstMethodShare s (TKS sh r)
-> (AstBindings, AstTensor AstMethodLet s (TKS sh r))
:: forall s r. (GoodScalar r, AstSpan s)
=> [IntVarName] -> AstBindings -> AstTensor AstMethodShare s (TKScalar r)
-> (AstBindings, AstTensor AstMethodLet s (TKScalar r))
unshareAstScoped vars0 memo0 v0 =
let (memo1, v1) = unshareAst memo0 v0
memoDiff = DMap.difference memo1 memo0
Expand Down Expand Up @@ -486,11 +484,11 @@ unshareAst memo = \case
in (memo3, Ast.AstCond b1 t2 t3)
Ast.AstReplicate k v -> second (Ast.AstReplicate k) (unshareAst memo v)
Ast.AstBuild1 @y2 snat (var, v) -> case stensorKind @y2 of
STKScalar{} -> error "WIP"
STKR SNat STKScalar{} -> error "WIP"
STKS sh STKScalar{} -> withKnownShS sh $
STKScalar{} ->
let (memo1, v2) = unshareAstScoped [var] memo v
in (memo1, Ast.AstBuild1 snat (var, v2))
STKR SNat STKScalar{} -> error "WIP"
STKS sh STKScalar{} -> withKnownShS sh $ error "WIP"
STKX sh STKScalar{} -> withKnownShX sh $ error "WIP"
STKProduct{} -> error "WIP"
STKUntyped -> error "WIP"
Expand Down
12 changes: 8 additions & 4 deletions src/HordeAd/Core/AstInterpret.hs
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ interpretAst !env = \case
let args2 = interpretAst env <$> args
in foldr1 (+) args2 -- avoid @fromInteger 0@ in @sum@
AstIndex AstIota (i :.: ZIR) ->
rfromIntegral . rfromPrimal . rfromS $ interpretAstPrimal env i
rfromIntegral . rfromPrimal . runRepScalar $ interpretAstPrimal env i
AstIndex v ix ->
let v2 = interpretAst env v
ix3 = interpretAstPrimal env <$> ix
Expand Down Expand Up @@ -555,7 +555,7 @@ interpretAst !env = \case
AstReshape sh v -> rreshape sh (interpretAst env v)
AstGather sh AstIota (vars, i :.: ZIR) ->
rbuild sh (interpretLambdaIndex interpretAst env
(vars, fromPrimal @s $ AstFromIntegral $ AstRFromS i))
(vars, fromPrimal @s $ AstFromIntegral $ AstScalar i))
AstGather sh v (vars, ix) ->
let t1 = interpretAst env v
f2 = interpretLambdaIndexToIndex interpretAstPrimal env (vars, ix)
Expand Down Expand Up @@ -677,7 +677,7 @@ interpretAst !env = \case
in foldl1 (+) (srepl 0 : args2) -- backward compat vs @sum@
-- TODO: in foldr1 (+) args2 -- avoid @fromInteger 0@ in @sum@
AstIndexS AstIotaS (i :.$ ZIS) ->
sfromIntegral . sfromPrimal $ interpretAstPrimal env i
sfromIntegral . sfromPrimal . sfromR . runRepScalar $ interpretAstPrimal env i
AstIndexS @sh1 @_ @_ @r v ix ->
let v2 = interpretAst env v
ix3 = interpretAstPrimal env <$> ix
Expand Down Expand Up @@ -814,7 +814,7 @@ interpretAst !env = \case
$ sbuild @target @r @(Rank sh2)
(interpretLambdaIndexS
interpretAst env
(vars, fromPrimal @s $ AstFromIntegralS i))
(vars, fromPrimal @s $ AstFromIntegralS $ AstSFromR $ AstScalar i))
AstGatherS v (vars, ix) ->
let t1 = interpretAst env v
f2 = interpretLambdaIndexToIndexS interpretAstPrimal env (vars, ix)
Expand Down Expand Up @@ -960,6 +960,10 @@ interpretAstBool !env = \case
AstBoolConst a -> if a then true else false
AstRel @y3 opCodeRel arg1 arg2 ->
case stensorKind @y3 of
STKScalar{} ->
let r1 = interpretAstPrimal env arg1
r2 = interpretAstPrimal env arg2
in interpretAstRelOp opCodeRel r1 r2
STKR SNat STKScalar{} ->
let r1 = interpretAstPrimalRuntimeSpecialized env arg1
r2 = interpretAstPrimalRuntimeSpecialized env arg2
Expand Down
4 changes: 3 additions & 1 deletion src/HordeAd/Core/AstPrettyPrint.hs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ printAst cfgOld d t =
then case isTensorInt t of
Just Refl -> case t of
AstVar _ var -> printAstIntVar cfgOld var
AstConcrete _ i -> shows $ Nested.sunScalar $ unRepN i
AstConcrete _ i -> shows $ unRepN i
_ -> printAstAux cfgOld d t
_ -> let cfg = cfgOld {representsIntIndex = False}
in printAstAux cfg d t
Expand All @@ -151,6 +151,8 @@ printAst cfgOld d t =
printAstAux :: forall s y ms. (TensorKind y, AstSpan s)
=> PrintConfig -> Int -> AstTensor ms s y -> ShowS
printAstAux cfg d = \case
AstScalar t -> printAstAux cfg d t -- TODO
AstUnScalar t -> printAstAux cfg d t -- TODO
AstPair t1 t2 ->
showParen (d > 10)
$ showString "tpair (" -- TODO
Expand Down
Loading

0 comments on commit 8c76466

Please sign in to comment.