Skip to content

Commit

Permalink
Re-add type equalities GHC claims spurious but are needed for typing
Browse files Browse the repository at this point in the history
Mikolaj committed Feb 28, 2024
1 parent 1bd188c commit fc41217
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions src/HordeAd/Core/Engine.hs
Original file line number Diff line number Diff line change
@@ -235,19 +235,20 @@ fwdArtifactAdapt f vals =

-- * Old gradient adaptors, with constant and fixed inputs and dt

-- We are inlining these function, in part, because they take function
-- arguments and calling unknown functions is expensive and, in part,
-- because we can't easily specialize them due to
-- https://gitlab.haskell.org/ghc/ghc/-/issues/23798
-- However, they are called in many places, so we break the inline chain
-- at crevOnHVector to avoid binary blowup.
-- The equality `RankedOf f ~ Flip OR.Array` is needed for type-checking
-- later on, even though GHC 9.6.4 reports it as redundant.
--
-- We are inlining these functions because they take function arguments
-- and are not too large. However, becausethey are called in many places,
-- we break the inline chain at crevOnHVector, to avoid exe blowup.
-- | The old versions that use the fixed input and @dt@ to compute gradient
-- only at these values, both transposing and evaluating at the same time.
--
-- These work for @f@ both ranked and shaped.
crev
:: forall r y f vals advals.
( AdaptableHVector (ADVal (Flip OR.Array)) advals
( RankedOf f ~ Flip OR.Array
, AdaptableHVector (ADVal (Flip OR.Array)) advals
, AdaptableHVector (ADVal (Flip OR.Array)) (ADVal f r y)
, AdaptableHVector (Flip OR.Array) vals
, AdaptableHVector (Flip OR.Array) (f r y)
@@ -259,7 +260,8 @@ crev f vals = crevDtMaybe f vals Nothing
-- | This version additionally takes the sensitivity parameter.
crevDt
:: forall r y f vals advals.
( AdaptableHVector (ADVal (Flip OR.Array)) advals
( RankedOf f ~ Flip OR.Array
, AdaptableHVector (ADVal (Flip OR.Array)) advals
, AdaptableHVector (ADVal (Flip OR.Array)) (ADVal f r y)
, AdaptableHVector (Flip OR.Array) vals
, AdaptableHVector (Flip OR.Array) (f r y)
@@ -270,7 +272,8 @@ crevDt f vals dt = crevDtMaybe f vals (Just dt)

crevDtMaybe
:: forall r y f vals advals.
( AdaptableHVector (ADVal (Flip OR.Array)) advals
( RankedOf f ~ Flip OR.Array
, AdaptableHVector (ADVal (Flip OR.Array)) advals
, AdaptableHVector (ADVal (Flip OR.Array)) (ADVal f r y)
, AdaptableHVector (Flip OR.Array) vals
, AdaptableHVector (Flip OR.Array) (f r y)
@@ -302,7 +305,8 @@ crevDtMaybe f vals mdt =
-- | This takes the sensitivity parameter, by convention.
cfwd
:: forall r y f vals advals.
( AdaptableHVector (ADVal (Flip OR.Array)) advals
( RankedOf f ~ Flip OR.Array
, AdaptableHVector (ADVal (Flip OR.Array)) advals
, AdaptableHVector (ADVal (Flip OR.Array)) (ADVal f r y)
, AdaptableHVector (Flip OR.Array) vals
, AdaptableHVector (Flip OR.Array) (f r y)

0 comments on commit fc41217

Please sign in to comment.