Skip to content

Commit

Permalink
Move code around to more natural locations.
Browse files Browse the repository at this point in the history
  • Loading branch information
athas committed Nov 19, 2024
1 parent eea0560 commit f4afac4
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 88 deletions.
24 changes: 24 additions & 0 deletions src/Language/Futhark/Prop.hs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ module Language.Futhark.Prop
valBindBound,
funType,
stripExp,
subExps,
similarExps,
sameExp,

-- * Queries on patterns and params
patIdents,
Expand Down Expand Up @@ -1371,6 +1373,20 @@ stripExp (Attr _ e _) = stripExp e `mplus` Just e
stripExp (Ascript e _ _) = stripExp e `mplus` Just e
stripExp _ = Nothing

-- | All non-trivial subexpressions (as by stripExp) of some
-- expression, not including the expression itself.
subExps :: Exp -> [Exp]
subExps e
| Just e' <- stripExp e = subExps e'
| otherwise = astMap mapper e `execState` mempty
where
mapOnExp e'
| Just e'' <- stripExp e' = mapOnExp e''
| otherwise = do
modify (e' :)
astMap mapper e'
mapper = identityMapper {mapOnExp}

similarSlices :: Slice -> Slice -> Maybe [(Exp, Exp)]
similarSlices slice1 slice2
| length slice1 == length slice2 = do
Expand Down Expand Up @@ -1454,6 +1470,14 @@ similarExps (IndexSection slice1 _ _) (IndexSection slice2 _ _) =
similarSlices slice1 slice2
similarExps _ _ = Nothing

-- | Are these the same expression as per recursively invoking
-- 'similarExps'?
sameExp :: Exp -> Exp -> Bool
sameExp e1 e2
| Just es <- similarExps e1 e2 =
all (uncurry sameExp) es
| otherwise = False

-- | An identifier with type- and aliasing information.
type Ident = IdentBase Info VName

Expand Down
89 changes: 1 addition & 88 deletions src/Language/Futhark/TypeChecker/Terms.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ where

import Control.Monad
import Control.Monad.Except
import Control.Monad.Identity
import Control.Monad.Reader
import Control.Monad.State.Strict
import Data.Bifunctor
Expand All @@ -27,7 +26,7 @@ import Data.List.NonEmpty qualified as NE
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Set qualified as S
import Futhark.Util (mapAccumLM, nubOrd, topologicalSort)
import Futhark.Util (mapAccumLM, nubOrd)
import Futhark.Util.Pretty hiding (space)
import Language.Futhark
import Language.Futhark.Primitive (intByteSize)
Expand Down Expand Up @@ -230,92 +229,6 @@ checkCoerce loc te e = do
"a size coercion where the underlying expression size cannot be determined"
pure $ sizeFromName (qualName v) (srclocOf d)

sameExp :: Exp -> Exp -> Bool
sameExp e1 e2
| Just es <- similarExps e1 e2 =
all (uncurry sameExp) es
| otherwise = False

-- All non-trivial subexpressions (as by stripExp) of some expression,
-- not including the expression itself.
subExps :: Exp -> [Exp]
subExps e
| Just e' <- stripExp e = subExps e'
| otherwise = astMap mapper e `execState` mempty
where
mapOnExp e'
| Just e'' <- stripExp e' = mapOnExp e''
| otherwise = do
modify (e' :)
astMap mapper e'
mapper = identityMapper {mapOnExp}

-- Expressions witnessed by type, topologically sorted.
topWit :: TypeBase Exp u -> [Exp]
topWit = topologicalSort depends . witnessedExps
where
witnessedExps t = execState (traverseDims onDim t) mempty
where
onDim _ PosImmediate e = modify (e :)
onDim _ _ _ = pure ()
depends a b = any (sameExp b) $ subExps a

sizeFree ::
SrcLoc ->
(Exp -> Maybe VName) ->
TypeBase Size u ->
TermTypeM (TypeBase Size u, [VName])
sizeFree tloc expKiller orig_t = do
runReaderT (toBeReplaced orig_t $ onType orig_t) mempty `runStateT` mempty
where
lookReplacement e repl = snd <$> find (sameExp e . fst) repl
expReplace mapping e
| Just e' <- lookReplacement e mapping = e'
| otherwise = runIdentity $ astMap mapper e
where
mapper = identityMapper {mapOnExp = pure . expReplace mapping}

replacing e = do
e' <- asks (`expReplace` e)
case expKiller e' of
Nothing -> pure e'
Just cause -> do
vn <- lift $ lift $ newRigidDim tloc (RigidOutOfScope (locOf e) cause) "d"
modify (vn :)
pure $ sizeFromName (qualName vn) (srclocOf e)

toBeReplaced t m' = foldl f m' $ topWit t
where
f m e = do
e' <- replacing e
local ((e, e') :) m

onScalar (Record fs) =
Record <$> traverse onType fs
onScalar (Sum cs) =
Sum <$> (traverse . traverse) onType cs
onScalar (Arrow as pn d argT (RetType dims retT)) = do
argT' <- onType argT
old_bound <- get
retT' <- toBeReplaced retT $ onType retT
rl <- state $ partition (`notElem` old_bound)
let dims' = dims <> rl
pure $ Arrow as pn d argT' (RetType dims' retT')
onScalar (TypeVar u v args) =
TypeVar u v <$> mapM onTypeArg args
where
onTypeArg (TypeArgDim d) = TypeArgDim <$> replacing d
onTypeArg (TypeArgType ty) = TypeArgType <$> onType ty
onScalar (Prim pt) = pure $ Prim pt

onType ::
TypeBase Size u ->
ReaderT [(Exp, Exp)] (StateT [VName] TermTypeM) (TypeBase Size u)
onType (Array u shape scalar) =
Array u <$> traverse replacing shape <*> onScalar scalar
onType (Scalar ty) =
Scalar <$> onScalar ty

-- Used to remove unknown sizes from function body types before we
-- perform let-generalisation. This is because if a function is
-- inferred to return something of type '[x+y]t' where 'x' or 'y' are
Expand Down
73 changes: 73 additions & 0 deletions src/Language/Futhark/TypeChecker/Unify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ module Language.Futhark.TypeChecker.Unify
Rigidity (..),
RigidSource (..),
BreadCrumbs,
sizeFree,
noBreadCrumbs,
hasNoBreadCrumbs,
dimNotes,
Expand All @@ -30,14 +31,18 @@ where

import Control.Monad
import Control.Monad.Except
import Control.Monad.Identity
import Control.Monad.Reader
import Control.Monad.State
import Data.List qualified as L
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Set qualified as S
import Data.Text qualified as T
import Futhark.Util (topologicalSort)
import Futhark.Util.Pretty
import Language.Futhark
import Language.Futhark.Traversals
import Language.Futhark.TypeChecker.Monad hiding (BoundV)
import Language.Futhark.TypeChecker.Types

Expand Down Expand Up @@ -616,6 +621,74 @@ scopeCheck usage bcs vn max_lvl tp = do
<+> dquotes (prettyName v)
<+> "is rigidly bound in a deeper scope."

-- Expressions witnessed by type, topologically sorted.
topWit :: TypeBase Exp u -> [Exp]
topWit = topologicalSort depends . witnessedExps
where
witnessedExps t = execState (traverseDims onDim t) mempty
where
onDim _ PosImmediate e = modify (e :)
onDim _ _ _ = pure ()
depends a b = any (sameExp b) $ subExps a

sizeFree ::
(MonadUnify m) =>
SrcLoc ->
(Exp -> Maybe VName) ->
TypeBase Size u ->
m (TypeBase Size u, [VName])
sizeFree tloc expKiller orig_t = do
runReaderT (toBeReplaced orig_t $ onType orig_t) mempty `runStateT` mempty
where
lookReplacement e repl = snd <$> L.find (sameExp e . fst) repl
expReplace mapping e
| Just e' <- lookReplacement e mapping = e'
| otherwise = runIdentity $ astMap mapper e
where
mapper = identityMapper {mapOnExp = pure . expReplace mapping}

replacing e = do
e' <- asks (`expReplace` e)
case expKiller e' of
Nothing -> pure e'
Just cause -> do
vn <- lift $ lift $ newRigidDim tloc (RigidOutOfScope (locOf e) cause) "d"
modify (vn :)
pure $ sizeFromName (qualName vn) (srclocOf e)

toBeReplaced t m' = foldl f m' $ topWit t
where
f m e = do
e' <- replacing e
local ((e, e') :) m

onScalar (Record fs) =
Record <$> traverse onType fs
onScalar (Sum cs) =
Sum <$> (traverse . traverse) onType cs
onScalar (Arrow as pn d argT (RetType dims retT)) = do
argT' <- onType argT
old_bound <- get
retT' <- toBeReplaced retT $ onType retT
rl <- state $ L.partition (`notElem` old_bound)
let dims' = dims <> rl
pure $ Arrow as pn d argT' (RetType dims' retT')
onScalar (TypeVar u v args) =
TypeVar u v <$> mapM onTypeArg args
where
onTypeArg (TypeArgDim d) = TypeArgDim <$> replacing d
onTypeArg (TypeArgType ty) = TypeArgType <$> onType ty
onScalar (Prim pt) = pure $ Prim pt

onType ::
(MonadUnify m) =>
TypeBase Size u ->
ReaderT [(Exp, Exp)] (StateT [VName] m) (TypeBase Size u)
onType (Array u shape scalar) =
Array u <$> traverse replacing shape <*> onScalar scalar
onType (Scalar ty) =
Scalar <$> onScalar ty

linkVarToType ::
(MonadUnify m) =>
UnifySizes m ->
Expand Down

0 comments on commit f4afac4

Please sign in to comment.