Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve generated code for derived instances #3189

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/Juvix/Compiler/Core/Data/IdentDependencyInfo.hs
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,20 @@ recursiveIdentsClosure tab =
chlds = fromJust $ HashMap.lookup sym graph

-- | Complement of recursiveIdentsClosure
nonRecursiveReachableIdents' :: InfoTable -> HashSet Symbol
nonRecursiveReachableIdents' tab =
HashSet.difference
(HashSet.fromList (HashMap.keys (tab ^. infoIdentifiers)))
(recursiveIdentsClosure tab)

nonRecursiveReachableIdents :: Module -> HashSet Symbol
nonRecursiveReachableIdents = nonRecursiveReachableIdents' . computeCombinedInfoTable

nonRecursiveIdents' :: InfoTable -> HashSet Symbol
nonRecursiveIdents' tab =
HashSet.difference
(HashSet.fromList (HashMap.keys (tab ^. infoIdentifiers)))
(recursiveIdentsClosure tab)
(recursiveIdents' tab)

nonRecursiveIdents :: Module -> HashSet Symbol
nonRecursiveIdents = nonRecursiveIdents' . computeCombinedInfoTable
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,14 @@ constantFolding' opts nonRecSyms tab md =
-- zero-order. For example, `3 + 4` is evaluated to `7`, and `id 3` is evaluated
-- to `3`, but `id id` is not evaluated because the target type is not
-- zero-order (it's a function type). This optimization is only applied to
-- non-recursive symbols.
-- symbols from which no recursive symbols can be reached.
--
-- References:
-- - https://github.com/anoma/juvix/pull/2450
-- - https://github.com/anoma/juvix/issues/2154
constantFolding :: (Member (Reader CoreOptions) r) => Module -> Sem r Module
constantFolding md = do
opts <- ask
return $ constantFolding' opts (nonRecursiveIdents' tab) tab md
return $ constantFolding' opts (nonRecursiveReachableIdents' tab) tab md
where
tab = computeCombinedInfoTable md
9 changes: 7 additions & 2 deletions src/Juvix/Compiler/Core/Transformation/Optimize/Inlining.hs
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,7 @@ convertNode inlineDepth nonRecSyms md = dmapL go
node
Nothing
| HashSet.member _identSymbol nonRecSyms
&& isConstructorApp def
&& checkDepth md bl inlineDepth def ->
&& checkLambdaConstructorApp (length args) bl def ->
NCase cs {_caseValue = mkApps def args}
_ ->
node
Expand All @@ -93,6 +92,12 @@ convertNode inlineDepth nonRecSyms md = dmapL go
_ ->
node

checkLambdaConstructorApp :: Int -> BinderList Binder -> Node -> Bool
checkLambdaConstructorApp argsNum bl node =
argsNum >= lamsNum && isConstructorApp body && checkDepth md bl inlineDepth body
where
(lamsNum, body) = unfoldLambdas' node

inlining' :: Int -> HashSet Symbol -> Module -> Module
inlining' inliningDepth nonRecSyms md = mapT (const (convertNode inliningDepth nonRecSyms md)) md

Expand Down
11 changes: 7 additions & 4 deletions src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ optimize' opts@CoreOptions {..} md =
. compose
(6 * _optOptimizationLevel)
( doConstantFolding
. doSimplification 2
. doInlining
. doSimplification 1
. specializeArgs
. doSimplification 2
. doInlining
)
. doConstantFolding
. letFolding
Expand All @@ -36,13 +36,16 @@ optimize' opts@CoreOptions {..} md =
nonRecs :: HashSet Symbol
nonRecs = nonRecursiveIdents' tab

nonRecsReachable :: HashSet Symbol
nonRecsReachable = nonRecursiveReachableIdents' tab

doConstantFolding :: Module -> Module
doConstantFolding md' = constantFolding' opts nonRecs' tab' md'
where
tab' = computeCombinedInfoTable md'
nonRecs'
| _optOptimizationLevel > 1 = nonRecursiveIdents' tab'
| otherwise = nonRecs
| _optOptimizationLevel > 1 = nonRecursiveReachableIdents' tab'
| otherwise = nonRecsReachable

doInlining :: Module -> Module
doInlining md' = inlining' _optInliningDepth nonRecs' md'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ convertNode = dmapLRM go
fun = reLambdas lams' body''
letitem =
mkLetItem
(ii ^. identifierName)
("spec_" <> ii ^. identifierName)
-- the type is not in the scope of the binder
(shift (-1) ty')
fun
Expand Down
66 changes: 52 additions & 14 deletions src/Juvix/Compiler/Internal/Translation/FromConcrete.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ module Juvix.Compiler.Internal.Translation.FromConcrete
( module Juvix.Compiler.Internal.Translation.FromConcrete.Data.Context,
fromConcrete,
DefaultArgsStack,
goTopModule,
fromConcreteExpression,
fromConcreteImport,
)
Expand Down Expand Up @@ -553,13 +552,34 @@ deriveEq ::
DerivingArgs ->
Sem r Internal.FunctionDef
deriveEq DerivingArgs {..} = do
arg <- getArg
indInfo <- getIndInfo
let argty = getArgType indInfo
argsInfo <- goArgsInfo _derivingInstanceName
lam <- eqLambda arg
lamName <- Internal.freshFunVar (getLoc _derivingInstanceName) ("eq__" <> _derivingInstanceName ^. Internal.nameText)
let lam = Internal.ExpressionIden (Internal.IdenFunction lamName)
lamFun <- eqLambda lam indInfo argty
lamTy <- Internal.ExpressionHole <$> Internal.freshHole (getLoc _derivingInstanceName)
let lamDef =
Internal.FunctionDef
{ _funDefTerminating = False,
_funDefIsInstanceCoercion = Nothing,
_funDefPragmas = mempty,
_funDefArgsInfo = [],
_funDefDocComment = Nothing,
_funDefName = lamName,
_funDefType = lamTy,
_funDefBody = lamFun,
_funDefBuiltin = Nothing
}
mkEq <- getBuiltin (getLoc eqName) BuiltinMkEq
let body = mkEq Internal.@@ lam
ty = Internal.foldFunType _derivingParameters ret
pragmas' <- goPragmas _derivingPragmas
let body =
Internal.ExpressionLet
Internal.Let
{ _letClauses = pure (Internal.LetMutualBlock (Internal.MutualBlockLet (pure lamDef))),
_letExpression = mkEq Internal.@@ lam
}
ty = Internal.foldFunType _derivingParameters ret
return
Internal.FunctionDef
{ _funDefTerminating = False,
Expand All @@ -580,14 +600,23 @@ deriveEq DerivingArgs {..} = do
args :: [Internal.ApplicationArg]
(eqName, args) = _derivingReturnType

getArg :: Sem r Internal.InductiveInfo
getArg = runFailDefaultM (throwDerivingWrongForm ret) $ do
getIndInfo :: Sem r Internal.InductiveInfo
getIndInfo = runFailDefaultM (throwDerivingWrongForm ret) $ do
[Internal.ApplicationArg Explicit a] <- return args
Internal.ExpressionIden (Internal.IdenInductive ind) <- return (fst (Internal.unfoldExpressionApp a))
getDefinedInductive ind

eqLambda :: Internal.InductiveInfo -> Sem r Internal.Expression
eqLambda d = do
getArgType :: Internal.InductiveInfo -> Internal.Expression
getArgType indInfo =
Internal.foldApplication
(Internal.toExpression (indInfo ^. Internal.inductiveInfoName))
(map toAppArg (indInfo ^. Internal.inductiveInfoParameters))
where
toAppArg :: Internal.InductiveParameter -> Internal.ApplicationArg
toAppArg p = Internal.ApplicationArg Explicit (Internal.toExpression (p ^. Internal.inductiveParamName))

eqLambda :: Internal.Expression -> Internal.InductiveInfo -> Internal.Expression -> Sem r Internal.Expression
eqLambda lam d argty = do
let loc = getLoc eqName
band <- getBuiltin loc BuiltinBoolAnd
btrue <- getBuiltin loc BuiltinBoolTrue
Expand Down Expand Up @@ -627,6 +656,7 @@ deriveEq DerivingArgs {..} = do
Internal.ConstructorName ->
Sem r Internal.LambdaClause
lambdaClause band btrue bisEqual c = do
argsRecursive :: [Bool] <- getRecursiveArgs
numArgs :: [IsImplicit] <- getNumArgs
let loc = getLoc _derivingInstanceName
mkpat :: Sem r ([Internal.VarName], Internal.PatternArg)
Expand All @@ -641,22 +671,24 @@ deriveEq DerivingArgs {..} = do
return
Internal.LambdaClause
{ _lambdaPatterns = p1 :| [p2],
_lambdaBody = allEq (zipExact v1 v2)
_lambdaBody = allEq (zip3Exact v1 v2 argsRecursive)
}
where
allEq :: (Internal.IsExpression expr) => [(expr, expr)] -> Internal.Expression
allEq :: (Internal.IsExpression expr) => [(expr, expr, Bool)] -> Internal.Expression
allEq k = case nonEmpty k of
Nothing -> Internal.toExpression btrue
Just l -> mkAnds (fmap (uncurry mkEq) l)
Just l -> mkAnds (fmap (uncurry3 mkEq) l)

mkAnds :: (Internal.IsExpression expr) => NonEmpty expr -> Internal.Expression
mkAnds = foldl1 mkAnd . fmap Internal.toExpression

mkAnd :: (Internal.IsExpression expr) => expr -> expr -> Internal.Expression
mkAnd a b = band Internal.@@ a Internal.@@ b

mkEq :: (Internal.IsExpression expr) => expr -> expr -> Internal.Expression
mkEq a b = bisEqual Internal.@@ a Internal.@@ b
mkEq :: (Internal.IsExpression expr) => expr -> expr -> Bool -> Internal.Expression
mkEq a b isRec
| isRec = lam Internal.@@ a Internal.@@ b
| otherwise = bisEqual Internal.@@ a Internal.@@ b

getNumArgs :: Sem r [IsImplicit]
getNumArgs = do
Expand All @@ -668,6 +700,12 @@ deriveEq DerivingArgs {..} = do
. each
. Internal.paramImplicit

getRecursiveArgs :: Sem r [Bool]
getRecursiveArgs = do
def <- getDefinedConstructor c
let argTypes = map (^. Internal.paramType) $ Internal.constructorArgs (def ^. Internal.constructorInfoType)
return $ map (== argty) argTypes

goFunctionDef ::
forall r.
(Members '[Reader DefaultArgsStack, Reader Pragmas, Error ScoperError, NameIdGen, Reader S.InfoTable] r) =>
Expand Down
1 change: 1 addition & 0 deletions tests/Compilation/positive/out/test085.out
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ true
true
false
false
true
11 changes: 9 additions & 2 deletions tests/Compilation/positive/test085.juvix
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,17 @@ import Stdlib.System.IO open;

syntax alias isEqual := Eq.eq;

type NTree :=
| NLeaf
| NNode NTree Nat NTree;

deriving instance
ntreeEq : Eq NTree;

type Tree A :=
| Leaf
| Node (Tree A) A (Tree A);

{-# inline: case #-}
deriving instance
treeEqI {A} {{Eq A}} : Eq (Tree A);

Expand All @@ -29,4 +35,5 @@ main : IO :=
>>> printLn (not (isEqual (false, true) (false, false)))
>>> printLn (isEqual t1 t1)
>>> printLn (isEqual t1 t2)
>>> printLn (isEqual t1 t3);
>>> printLn (isEqual t1 t3)
>>> printLn (isEqual (NNode NLeaf 0 NLeaf) (NNode NLeaf 0 NLeaf));
Loading