Skip to content

Commit

Permalink
Add kfloor, kcast and kfromIntegral ops for scalars
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Nov 30, 2024
1 parent ddb9c3b commit 7228e66
Show file tree
Hide file tree
Showing 12 changed files with 152 additions and 34 deletions.
8 changes: 8 additions & 0 deletions src/HordeAd/Core/Ast.hs
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,14 @@ data AstTensor :: AstMethodOfSharing -> AstSpanType -> TensorKindType
-> AstTensor ms s (TKScalar r)
AstSumOfList :: GoodScalar r
=> [AstTensor ms s (TKScalar r)] -> AstTensor ms s (TKScalar r)
AstFloor :: (GoodScalar r, RealFrac r, GoodScalar r2, Integral r2)
=> AstTensor ms PrimalSpan (TKScalar r)
-> AstTensor ms PrimalSpan (TKScalar r2)
AstCast :: (GoodScalar r1, RealFrac r1, RealFrac r2, GoodScalar r2)
=> AstTensor ms s (TKScalar r1) -> AstTensor ms s (TKScalar r2)
AstFromIntegral :: (GoodScalar r1, Integral r1, GoodScalar r2)
=> AstTensor ms PrimalSpan (TKScalar r1)
-> AstTensor ms PrimalSpan (TKScalar r2)

-- Here starts the ranked part.
AstN1R :: (GoodScalar r, KnownNat n)
Expand Down
6 changes: 6 additions & 0 deletions src/HordeAd/Core/AstInline.hs
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ inlineAst memo v0 = case v0 of
Ast.AstSumOfList args ->
let (memo2, args2) = mapAccumR inlineAst memo args
in (memo2, Ast.AstSumOfList args2)
Ast.AstFloor a -> second Ast.AstFloor $ inlineAst memo a
Ast.AstCast a -> second Ast.AstCast $ inlineAst memo a
Ast.AstFromIntegral a -> second Ast.AstFromIntegral $ inlineAst memo a
Ast.AstN1R opCode u ->
let (memo2, u2) = inlineAst memo u
in (memo2, Ast.AstN1R opCode u2)
Expand Down Expand Up @@ -524,6 +527,9 @@ unshareAst memo = \case
Ast.AstSumOfList args ->
let (memo2, args2) = mapAccumR unshareAst memo args
in (memo2, Ast.AstSumOfList args2)
Ast.AstFloor a -> second Ast.AstFloor $ unshareAst memo a
Ast.AstCast v -> second Ast.AstCast $ unshareAst memo v
Ast.AstFromIntegral v -> second Ast.AstFromIntegral $ unshareAst memo v
Ast.AstN1R opCode u ->
let (memo2, u2) = unshareAst memo u
in (memo2, Ast.AstN1R opCode u2)
Expand Down
5 changes: 5 additions & 0 deletions src/HordeAd/Core/AstInterpret.hs
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,11 @@ interpretAst !env = \case
AstSumOfList args ->
let args2 = interpretAst env <$> args
in foldr1 (+) args2 -- avoid @fromInteger 0@ in @sum@
AstFloor v ->
kfloor $ tfromPrimal (STKScalar typeRep) $ interpretAstPrimal env v
AstCast v -> kcast $ interpretAst env v
AstFromIntegral v ->
kfromIntegral $ tfromPrimal (STKScalar typeRep) $ interpretAstPrimal env v
{- TODO: revise when we handle GPUs. For now, this is done in TensorOps
instead and that's fine, because for one-element carriers,
reshape and replicate are very cheap. OTOH, this was introducing
Expand Down
6 changes: 6 additions & 0 deletions src/HordeAd/Core/AstPrettyPrint.hs
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,12 @@ printAstAux cfg d = \case
in showParen (d > 6)
$ printAst cfg 7 left
. foldr (.) id rs
AstFloor v ->
printPrefixOp printAst cfg d "kfloor" [v]
AstCast v ->
printPrefixOp printAst cfg d "kcast" [v]
AstFromIntegral v ->
printPrefixOp printAst cfg d "kfromIntegral" [v]
AstN1R opCode u -> printAstN1R printAst cfg d opCode u
AstN2R opCode u v -> printAstN2R printAst cfg d opCode u v
AstR1R opCode u -> printAstR1R printAst cfg d opCode u
Expand Down
84 changes: 57 additions & 27 deletions src/HordeAd/Core/AstSimplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ module HordeAd.Core.AstSimplify
, astReplicate, astAppend, astAppendS, astSlice, astSliceS
, astReverse, astReverseS
, astTranspose, astTransposeS, astReshape, astReshapeS
, astCast, astCastS, astFromIntegral, astFromIntegralS
, astCast, astCastR, astCastS
, astFromIntegral, astFromIntegralR, astFromIntegralS
, astProject1, astProject2, astProjectR, astProjectS, astNestS, astUnNestS
, astRFromS, astSFromR, astSFromX, astXFromS
, astPrimalPart, astDualPart
Expand Down Expand Up @@ -310,6 +311,9 @@ astNonIndexStep t = case t of
case isTensorInt t of
Just Refl -> foldr1 contractAstPlusOp args
_ -> t
Ast.AstFloor{} -> t
Ast.AstCast v -> astCast v
Ast.AstFromIntegral v -> astFromIntegral v
AstN1R{} -> t
AstN2R{} -> t
Ast.AstR1R{} -> t
Expand All @@ -328,8 +332,8 @@ astNonIndexStep t = case t of
Ast.AstGather _ v0 (ZR, ix) -> Ast.AstIndex v0 ix
Ast.AstGather sh v0 (_, ZIR) -> astReplicateN sh v0
Ast.AstGather{} -> t -- this is "index" enough
Ast.AstCastR v -> astCast v
Ast.AstFromIntegralR v -> astFromIntegral v
Ast.AstCastR v -> astCastR v
Ast.AstFromIntegralR v -> astFromIntegralR v
Ast.AstProjectR l p -> astProjectR l p
Ast.AstLetHVectorIn vars u v -> astLetHVectorIn vars u v
Ast.AstRFromS v -> astRFromS v
Expand Down Expand Up @@ -479,7 +483,7 @@ astIndexKnobsR knobs v0 ix@(i1 :.: (rest1 :: AstIndex AstMethodLet m1)) =
Ast.AstMaxIndexR v -> Ast.AstMaxIndexR $ astIndexKnobsR knobs v ix
Ast.AstFloorR v -> Ast.AstFloorR $ astIndexKnobsR knobs v ix
Ast.AstIotaR | AstConcrete _ (RepN i) <- i1 -> case sameNat (Proxy @n) (Proxy @0) of
Just Refl -> astFromIntegral $ AstConcrete (FTKR ZSR FTKScalar) $ RepN $ Nested.rscalar i
Just Refl -> astFromIntegralR $ AstConcrete (FTKR ZSR FTKScalar) $ RepN $ Nested.rscalar i
_ -> error "astIndexKnobsR: rank not 0"
-- TODO: AstIndex AstIotaR (i :.: ZIR) ->
-- rfromIntegral . rfromPrimal . rfromScalar $ interpretAstPrimal env i
Expand Down Expand Up @@ -559,8 +563,8 @@ astIndexKnobsR knobs v0 ix@(i1 :.: (rest1 :: AstIndex AstMethodLet m1)) =
in astLet var2 i1 $ astIndex w rest1
Ast.AstGather{} ->
error "astIndex: AstGather: impossible pattern needlessly required"
Ast.AstCastR t -> astCast $ astIndexKnobsR knobs t ix
Ast.AstFromIntegralR v -> astFromIntegral $ astIndexKnobsR knobs v ix
Ast.AstCastR t -> astCastR $ astIndexKnobsR knobs t ix
Ast.AstFromIntegralR v -> astFromIntegralR $ astIndexKnobsR knobs v ix
AstConcrete _ t ->
let unConst :: AstInt AstMethodLet -> Maybe [Int64]
-> Maybe [Int64]
Expand Down Expand Up @@ -1032,7 +1036,7 @@ astGatherKnobsR knobs sh0 v0 (vars0, ix0) =
Ast.AstFloorR
$ astGatherKnobsR knobs sh4 v (vars4, ix4)
Ast.AstIotaR | AstConcrete _ (RepN i) <- i4 -> case sameNat (Proxy @p') (Proxy @1) of
Just Refl -> astFromIntegral $ astReplicate0NT sh4 $ AstConcrete (FTKR ZSR FTKScalar) $ RepN $ Nested.rscalar i
Just Refl -> astFromIntegralR $ astReplicate0NT sh4 $ AstConcrete (FTKR ZSR FTKScalar) $ RepN $ Nested.rscalar i
_ -> error "astGather: AstIota: impossible pattern needlessly required"
{- TODO: is this beneficial?
AstGather sh AstIotaR (vars, i :.: ZIR) ->
Expand Down Expand Up @@ -1167,8 +1171,8 @@ astGatherKnobsR knobs sh0 v0 (vars0, ix0) =
LTI -> composedGather
EQI -> assimilatedGather
GTI -> gcastWith (flipCompare @p' @m2) assimilatedGather
Ast.AstCastR v -> astCast $ astGather sh4 v (vars4, ix4)
Ast.AstFromIntegralR v -> astFromIntegral $ astGather sh4 v (vars4, ix4)
Ast.AstCastR v -> astCastR $ astGather sh4 v (vars4, ix4)
Ast.AstFromIntegralR v -> astFromIntegralR $ astGather sh4 v (vars4, ix4)
AstConcrete{} -> -- free variables possible, so can't compute the tensor
Ast.AstGather sh4 v4 (vars4, ix4)
Ast.AstProjectR{} -> -- TODO, but most likely reduced before it gets here
Expand Down Expand Up @@ -2057,13 +2061,21 @@ astReshapeS = \case
Just Refl -> v
_ -> Ast.AstReshapeS v

astCast :: (KnownNat n, GoodScalar r1, GoodScalar r2, RealFrac r1, RealFrac r2)
=> AstTensor AstMethodLet s (TKR n r1) -> AstTensor AstMethodLet s (TKR n r2)
astCast (AstConcrete (FTKR sh FTKScalar) t) = AstConcrete (FTKR sh FTKScalar) $ RepN $ tcastR (unRepN t)
astCast :: (GoodScalar r1, GoodScalar r2, RealFrac r1, RealFrac r2)
=> AstTensor ms s (TKScalar r1) -> AstTensor ms s (TKScalar r2)
astCast (AstConcrete FTKScalar t) = AstConcrete FTKScalar $ RepN $ realToFrac (unRepN t)
astCast (Ast.AstFromPrimal v) = Ast.AstFromPrimal $ astCast v
astCast (Ast.AstCastR v) = astCast v
astCast (Ast.AstFromIntegralR v) = astFromIntegral v
astCast v = Ast.AstCastR v
astCast (Ast.AstCast v) = astCast v
astCast (Ast.AstFromIntegral v) = astFromIntegral v
astCast v = Ast.AstCast v

astCastR :: (KnownNat n, GoodScalar r1, GoodScalar r2, RealFrac r1, RealFrac r2)
=> AstTensor AstMethodLet s (TKR n r1) -> AstTensor AstMethodLet s (TKR n r2)
astCastR (AstConcrete (FTKR sh FTKScalar) t) = AstConcrete (FTKR sh FTKScalar) $ RepN $ tcastR (unRepN t)
astCastR (Ast.AstFromPrimal v) = Ast.AstFromPrimal $ astCastR v
astCastR (Ast.AstCastR v) = astCastR v
astCastR (Ast.AstFromIntegralR v) = astFromIntegralR v
astCastR v = Ast.AstCastR v

astCastS :: ( KnownShS sh, GoodScalar r1, GoodScalar r2, RealFrac r1
, RealFrac r2 )
Expand All @@ -2074,12 +2086,19 @@ astCastS (Ast.AstCastS v) = astCastS v
astCastS (Ast.AstFromIntegralS v) = astFromIntegralS v
astCastS v = Ast.AstCastS v

astFromIntegral :: (KnownNat n, GoodScalar r1, GoodScalar r2, Integral r1)
astFromIntegral :: (GoodScalar r1, GoodScalar r2, Integral r1)
=> AstTensor ms PrimalSpan (TKScalar r1)
-> AstTensor ms PrimalSpan (TKScalar r2)
astFromIntegral (AstConcrete FTKScalar t) = AstConcrete FTKScalar $ RepN $ fromIntegral (unRepN t)
astFromIntegral (Ast.AstFromIntegral v) = astFromIntegral v
astFromIntegral v = Ast.AstFromIntegral v

astFromIntegralR :: (KnownNat n, GoodScalar r1, GoodScalar r2, Integral r1)
=> AstTensor AstMethodLet PrimalSpan (TKR n r1)
-> AstTensor AstMethodLet PrimalSpan (TKR n r2)
astFromIntegral (AstConcrete (FTKR sh FTKScalar) t) = AstConcrete (FTKR sh FTKScalar) $ RepN $ tfromIntegralR (unRepN t)
astFromIntegral (Ast.AstFromIntegralR v) = astFromIntegral v
astFromIntegral v = Ast.AstFromIntegralR v
astFromIntegralR (AstConcrete (FTKR sh FTKScalar) t) = AstConcrete (FTKR sh FTKScalar) $ RepN $ tfromIntegralR (unRepN t)
astFromIntegralR (Ast.AstFromIntegralR v) = astFromIntegralR v
astFromIntegralR v = Ast.AstFromIntegralR v

astFromIntegralS :: (KnownShS sh, GoodScalar r1, GoodScalar r2, Integral r1)
=> AstTensor AstMethodLet PrimalSpan (TKS sh r1)
Expand Down Expand Up @@ -2240,6 +2259,7 @@ astPrimalPart t = case t of
Ast.AstR2 opCode u v -> Ast.AstR2 opCode (astPrimalPart u) (astPrimalPart v)
Ast.AstI2 opCode u v -> Ast.AstI2 opCode (astPrimalPart u) (astPrimalPart v)
AstSumOfList args -> astSumOfList (map astPrimalPart args)
Ast.AstCast v -> astCast $ astPrimalPart v
AstN1R opCode u -> AstN1R opCode (astPrimalPart u)
AstN2R opCode u v -> AstN2R opCode (astPrimalPart u) (astPrimalPart v)
Ast.AstR1R opCode u -> Ast.AstR1R opCode (astPrimalPart u)
Expand All @@ -2256,7 +2276,7 @@ astPrimalPart t = case t of
Ast.AstTranspose perm v -> astTranspose perm (astPrimalPart v)
Ast.AstReshape sh v -> astReshape sh (astPrimalPart v)
Ast.AstGather sh v (vars, ix) -> astGatherR sh (astPrimalPart v) (vars, ix)
Ast.AstCastR v -> astCast $ astPrimalPart v
Ast.AstCastR v -> astCastR $ astPrimalPart v
Ast.AstProjectR l p -> astProjectR (astPrimalPart l) p
Ast.AstLetHVectorIn vars l v -> astLetHVectorIn vars l (astPrimalPart v)
Ast.AstRFromS v -> astRFromS $ astPrimalPart v
Expand Down Expand Up @@ -2324,6 +2344,7 @@ astDualPart t = case t of
Ast.AstR2{} -> Ast.AstDualPart t
Ast.AstI2{} -> Ast.AstDualPart t
AstSumOfList args -> astSumOfList (map astDualPart args)
Ast.AstCast v -> astCast $ astDualPart v
AstN1R{} -> Ast.AstDualPart t -- stuck; the ops are not defined on dual part
AstN2R{} -> Ast.AstDualPart t -- stuck; the ops are not defined on dual part
Ast.AstR1R{} -> Ast.AstDualPart t
Expand All @@ -2340,7 +2361,7 @@ astDualPart t = case t of
Ast.AstTranspose perm v -> astTranspose perm (astDualPart v)
Ast.AstReshape sh v -> astReshape sh (astDualPart v)
Ast.AstGather sh v (vars, ix) -> astGatherR sh (astDualPart v) (vars, ix)
Ast.AstCastR v -> astCast $ astDualPart v
Ast.AstCastR v -> astCastR $ astDualPart v
Ast.AstProjectR l p -> astProjectR (astDualPart l) p
Ast.AstLetHVectorIn vars l v -> astLetHVectorIn vars l (astDualPart v)
Ast.AstRFromS v -> astRFromS $ astDualPart v
Expand Down Expand Up @@ -2601,6 +2622,9 @@ simplifyAst t = case t of
case isTensorInt t of
Just Refl -> foldr1 contractAstPlusOp (map simplifyAst args)
_ -> astSumOfList (map simplifyAst args)
Ast.AstFloor a -> Ast.AstFloor (simplifyAst a)
Ast.AstCast v -> astCast $ simplifyAst v
Ast.AstFromIntegral v -> astFromIntegral $ simplifyAst v
AstN1R opCode u -> AstN1R opCode (simplifyAst u)
AstN2R opCode u v -> AstN2R opCode (simplifyAst u) (simplifyAst v)
Ast.AstR1R opCode u -> Ast.AstR1R opCode (simplifyAst u)
Expand All @@ -2620,8 +2644,8 @@ simplifyAst t = case t of
Ast.AstReshape sh v -> astReshape sh (simplifyAst v)
Ast.AstGather sh v (vars, ix) ->
astGatherR sh (simplifyAst v) (vars, simplifyAstIndex ix)
Ast.AstCastR v -> astCast $ simplifyAst v
Ast.AstFromIntegralR v -> astFromIntegral $ simplifyAst v
Ast.AstCastR v -> astCastR $ simplifyAst v
Ast.AstFromIntegralR v -> astFromIntegralR $ simplifyAst v
Ast.AstProjectR l p -> astProjectR (simplifyAst l) p
Ast.AstLetHVectorIn vars l v ->
astLetHVectorIn vars (simplifyAst l) (simplifyAst v)
Expand Down Expand Up @@ -2779,6 +2803,9 @@ expandAst t = case t of
case isTensorInt t of
Just Refl -> foldr1 contractAstPlusOp (map expandAst args)
_ -> astSumOfList (map expandAst args)
Ast.AstFloor a -> Ast.AstFloor (expandAst a)
Ast.AstCast v -> astCast $ expandAst v
Ast.AstFromIntegral v -> astFromIntegral $ expandAst v
AstN1R opCode u -> AstN1R opCode (expandAst u)
AstN2R opCode u v -> AstN2R opCode (expandAst u) (expandAst v)
Ast.AstR1R opCode u -> Ast.AstR1R opCode (expandAst u)
Expand Down Expand Up @@ -2843,8 +2870,8 @@ expandAst t = case t of
Ast.AstGather sh v (vars, ix) ->
astGatherKnobsR (defaultKnobs {knobExpand = True})
sh (expandAst v) (vars, expandAstIndex ix)
Ast.AstCastR v -> astCast $ expandAst v
Ast.AstFromIntegralR v -> astFromIntegral $ expandAst v
Ast.AstCastR v -> astCastR $ expandAst v
Ast.AstFromIntegralR v -> astFromIntegralR $ expandAst v
Ast.AstProjectR l p -> astProjectR (expandAst l) p
Ast.AstLetHVectorIn vars l v ->
astLetHVectorIn vars (expandAst l) (expandAst v)
Expand Down Expand Up @@ -3310,6 +3337,9 @@ substitute1Ast i var v1 = case v1 of
Just Refl -> foldr1 contractAstPlusOp $ zipWith fromMaybe args margs
_ -> astSumOfList $ zipWith fromMaybe args margs
else Nothing
Ast.AstFloor a -> Ast.AstFloor <$> substitute1Ast i var a
Ast.AstCast v -> astCast <$> substitute1Ast i var v
Ast.AstFromIntegral v -> astFromIntegral <$> substitute1Ast i var v
Ast.AstN1R opCode u -> Ast.AstN1R opCode <$> substitute1Ast i var u
Ast.AstN2R opCode u v ->
let mu = substitute1Ast i var u
Expand Down Expand Up @@ -3363,8 +3393,8 @@ substitute1Ast i var v1 = case v1 of
(Nothing, Nothing) -> Nothing
(mv, mix) -> Just $ astGatherR sh (fromMaybe v mv)
(vars, fromMaybe ix mix)
Ast.AstCastR v -> astCast <$> substitute1Ast i var v
Ast.AstFromIntegralR v -> astFromIntegral <$> substitute1Ast i var v
Ast.AstCastR v -> astCastR <$> substitute1Ast i var v
Ast.AstFromIntegralR v -> astFromIntegralR <$> substitute1Ast i var v
Ast.AstConcrete{} -> Nothing
Ast.AstProjectR l p ->
case substitute1Ast i var l of
Expand Down
8 changes: 7 additions & 1 deletion src/HordeAd/Core/AstTools.hs
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,13 @@ ftkAst t = case t of
AstR2{} -> FTKScalar
AstI2{} -> FTKScalar
AstSumOfList{} -> FTKScalar
AstFloor{} -> FTKScalar
AstCast{} -> FTKScalar
AstFromIntegral{} -> FTKScalar

AstMinIndexR a -> FTKR (initShape $ shapeAst a) FTKScalar
AstMaxIndexR a -> FTKR (initShape $ shapeAst a) FTKScalar
AstFloorR a -> FTKR (shapeAst a) FTKScalar
AstFloorR a -> FTKR (shapeAst a) FTKScalar
AstIotaR -> FTKR (singletonShape (maxBound :: Int)) FTKScalar -- ought to be enough
AstN1R _opCode v -> ftkAst v
AstN2R _opCode v _ -> ftkAst v
Expand Down Expand Up @@ -219,6 +222,9 @@ varInAst var = \case
AstR1 _ t -> varInAst var t
AstR2 _ t u -> varInAst var t || varInAst var u
AstI2 _ t u -> varInAst var t || varInAst var u
AstFloor a -> varInAst var a
AstCast t -> varInAst var t
AstFromIntegral t -> varInAst var t
AstSumOfList l -> any (varInAst var) l
AstN1R _ t -> varInAst var t
AstN2R _ t u -> varInAst var t || varInAst var u
Expand Down
10 changes: 8 additions & 2 deletions src/HordeAd/Core/AstVectorize.hs
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,12 @@ build1V snat@SNat (var, v0) =
(build1VOccurenceUnknown snat (var, v))
Ast.AstSumOfList args -> traceRule $
astSumOfList $ map (\v -> build1VOccurenceUnknown snat (var, v)) args
Ast.AstFloor v -> traceRule $
Ast.AstFloor $ build1V snat (var, v)
Ast.AstCast v -> traceRule $
astCast $ build1V snat (var, v)
Ast.AstFromIntegral v -> traceRule $
astFromIntegral $ build1V snat (var, v)

Ast.AstN1R opCode u -> traceRule $
Ast.AstN1R opCode (build1V snat (var, u))
Expand Down Expand Up @@ -361,9 +367,9 @@ build1V snat@SNat (var, v0) =
(build1VOccurenceUnknown snat (var, v))
(varFresh ::: vars, astVarFresh :.: ix2)
Ast.AstCastR v -> traceRule $
astCast $ build1V snat (var, v)
astCastR $ build1V snat (var, v)
Ast.AstFromIntegralR v -> traceRule $
astFromIntegral $ build1V snat (var, v)
astFromIntegralR $ build1V snat (var, v)
Ast.AstConcrete{} ->
error "build1V: AstConcrete can't have free index variables"

Expand Down
11 changes: 11 additions & 0 deletions src/HordeAd/Core/Delta.hs
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,8 @@ deriving instance ( (forall y7. TensorKind y7 => Show (target y7))

type role Delta nominal nominal
data Delta :: Target -> TensorKindType -> Type where
Cast :: (GoodScalar r1, RealFrac r1, GoodScalar r2, RealFrac r2)
=> Delta target (TKScalar r1) -> Delta target (TKScalar r2)
FromScalarG :: GoodScalar r
=> Delta target (TKScalar r) -> Delta target (TKS '[] r)
ToScalarG :: GoodScalar r
Expand Down Expand Up @@ -649,6 +651,7 @@ deriving instance ( TensorKind y
shapeDeltaFull :: forall target y. TensorKind y
=> Delta target y -> FullTensorKind y
shapeDeltaFull = \case
Cast{} -> FTKScalar
FromScalarG{} -> FTKS ZSS FTKScalar
ToScalarG{} -> FTKScalar
PairG t1 t2 -> FTKProduct (shapeDeltaFull t1) (shapeDeltaFull t2)
Expand Down Expand Up @@ -1098,6 +1101,9 @@ evalSame !s !c = \case
-- (and the InputG constructor and the vector space constructors)
-- can be handled here, where the extra
-- constraint makes it easier.
Cast @r1 d ->
evalR s (toADTensorKindShared (stensorKind @(TKScalar r1))
$ kcast c) d
FromScalarG d -> evalSame s (stoScalar c) d
ToScalarG d -> evalSame s (sfromScalar c) d
InputG _ftk i ->
Expand Down Expand Up @@ -1503,6 +1509,11 @@ fwdSame
=> IMap target -> ADMap target -> Delta target y
-> (ADMap target, target (ADTensorKind y))
fwdSame params s = \case
d0@(Cast @r1 d)
| Dict <- lemTensorKindOfAD (stensorKind @(TKScalar r1)) ->
case sameTensorKind @(TKScalar r1) @(ADTensorKind (TKScalar r1)) of
Just Refl -> second kcast $ fwdSame params s d
_ -> (s, repConstant 0 $ aDTensorKind $ shapeDeltaFull d0)
FromScalarG d -> let (s2, t) = fwdSame params s d
in (s2, sfromScalar t)
ToScalarG d -> let (s2, t) = fwdSame params s d
Expand Down
8 changes: 8 additions & 0 deletions src/HordeAd/Core/OpsADVal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,14 @@ instance (ADReadyNoLet target, ShareTensor target, ShareTensor (PrimalOf target)
sD t d = dD t d
sScale k = ScaleG k

kfloor (D u _) =
let v = kfloor u
in fromPrimalADVal v
kcast (D u u') = dD (kcast u) (Cast u')
kfromIntegral (D u _) =
let v = kfromIntegral u
in fromPrimalADVal v

tpair (D u u') (D v v') = dDnotShared (tpair u v) (PairG u' v')
tproject1 (D u u') = dDnotShared (tproject1 u) (fst $ unPairGUnshared u')
tproject2 (D u u') = dDnotShared (tproject2 u) (snd $ unPairGUnshared u')
Expand Down
Loading

0 comments on commit 7228e66

Please sign in to comment.