From 37edef97d0d6a58b6238da11bf64737bb39e6d7f Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 5 Feb 2024 22:28:17 +0100 Subject: [PATCH 001/258] WIP on rewriting type checker. --- futhark.cabal | 1 + src/Language/Futhark/TypeChecker.hs | 3 + src/Language/Futhark/TypeChecker/Terms2.hs | 1072 ++++++++++++++++++++ 3 files changed, 1076 insertions(+) create mode 100644 src/Language/Futhark/TypeChecker/Terms2.hs diff --git a/futhark.cabal b/futhark.cabal index 62190324f8..d9918a6554 100644 --- a/futhark.cabal +++ b/futhark.cabal @@ -414,6 +414,7 @@ library Language.Futhark.TypeChecker.Modules Language.Futhark.TypeChecker.Monad Language.Futhark.TypeChecker.Terms + Language.Futhark.TypeChecker.Terms2 Language.Futhark.TypeChecker.Terms.Loop Language.Futhark.TypeChecker.Terms.Monad Language.Futhark.TypeChecker.Terms.Pat diff --git a/src/Language/Futhark/TypeChecker.hs b/src/Language/Futhark/TypeChecker.hs index c82c0a70c8..2bded54dd9 100644 --- a/src/Language/Futhark/TypeChecker.hs +++ b/src/Language/Futhark/TypeChecker.hs @@ -34,6 +34,7 @@ import Language.Futhark.TypeChecker.Modules import Language.Futhark.TypeChecker.Monad import Language.Futhark.TypeChecker.Names import Language.Futhark.TypeChecker.Terms +import Language.Futhark.TypeChecker.Terms2 qualified as Terms2 import Language.Futhark.TypeChecker.Types import Prelude hiding (abs, mod) @@ -695,6 +696,8 @@ checkValBind vb = do attrs' <- mapM checkAttr attrs + void $ Terms2.checkValDef (fname, maybe_tdecl, tparams, params, body, loc) + (tparams', params', maybe_tdecl', rettype, body') <- checkFunDef (fname, maybe_tdecl, tparams, params, body, loc) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs new file mode 100644 index 0000000000..6f12b8fcad --- /dev/null +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -0,0 +1,1072 @@ +-- | A very WIP reimplementation of type checking of terms. +-- +-- The strategy is to split type checking into two (main) passes: +-- +-- 1) A size-agnostic pass that generates constraints (type Ct) which +-- are then solved offline to find a solution. This produces an AST +-- where most of the type annotations are just references to type +-- variables. Further, all the size-specific annotations (e.g. +-- existential sizes) just contain dummy values, such as empty lists. +-- The constraints use a type representation where all dimensions are +-- the same. However, we do try to take to store the sizes resulting +-- from explicit type ascriptions - these cannot refer to inferred +-- existentials, so it is safe to resolve them here. We don't do +-- anything with this information, however. +-- +-- 2) Pass (1) has given us a program where we know the types of +-- everything, but the sizes of nothing. Pass (2) then does +-- essentially size inference, much like the current/old type checker, +-- but of course with the massive benefit of already knowing the full +-- type of everything. This can be implemented using online constraint +-- solving (as before), or perhaps a completely syntax-driven +-- approach. +-- +-- As of this writing, only the constraint generation part of pass (1) +-- has been implemented, and it is very likely that some of the +-- constraints are actually wrong. Next step is to imlement the +-- solver. Currently all we do is dump the constraints to the +-- terminal. +-- +-- Also, no thought whatsoever has been put into quality of type +-- errors yet. However, I think an approach based on tacking source +-- information onto constraints should work well, as all constraints +-- ultimately originate from some bit of program syntax. +-- +-- Also no thought has been put into how to handle the liftedness +-- stuff. Since it does not really affect choices made during +-- inference, perhaps we can do it in a post-inference check. +module Language.Futhark.TypeChecker.Terms2 + ( checkValDef, + ) +where + +import Control.Monad.Except +import Control.Monad.Reader +import Control.Monad.State +import Data.Bifunctor +import Data.Char (isAscii) +import Data.List qualified as L +import Data.List.NonEmpty qualified as NE +import Data.Map qualified as M +import Data.Maybe +import Data.Text qualified as T +import Debug.Trace +import Futhark.FreshNames qualified as FreshNames +import Futhark.MonadFreshNames hiding (newName) +import Futhark.Util (mapAccumLM) +import Futhark.Util.Pretty +import Language.Futhark +import Language.Futhark.TypeChecker.Monad hiding (BoundV, lookupMod) +import Language.Futhark.TypeChecker.Monad qualified as TypeM +import Language.Futhark.TypeChecker.Types +import Language.Futhark.TypeChecker.Unify (Level, mkUsage) +import Prelude hiding (mod) + +data Inferred t + = NoneInferred + | Ascribed t + +instance Functor Inferred where + fmap _ NoneInferred = NoneInferred + fmap f (Ascribed t) = Ascribed (f t) + +data ValBinding + = BoundV [TypeParam] StructType + | OverloadedF [PrimType] [Maybe PrimType] (Maybe PrimType) + | EqualityF + deriving (Show) + +type Type = TypeBase () NoUniqueness + +toType :: TypeBase d u -> Type +toType = bimap (const ()) (const NoUniqueness) + +expType :: Exp -> Type +expType = toType . typeOf + +data Ct + = CtEq Type Type + | CtOneOf Type [PrimType] + | CtHasConstr Type Name [Type] + | CtHasField Type Name Type + deriving (Show) + +instance Pretty Ct where + pretty (CtEq t1 t2) = pretty t1 <+> "~" <+> pretty t2 + pretty (CtOneOf t1 ts) = pretty t1 <+> "∈" <+> pretty ts + pretty (CtHasConstr t1 k ts) = + pretty t1 <+> "~" <+> "... | " <+> hsep ("#" <> pretty k : map pretty ts) <+> " | ..." + pretty (CtHasField t1 k t) = + pretty t1 <+> "~" <+> braces ("..." <+> pretty k <> ":" <+> pretty t <+> "...") + +type Constraints = [Ct] + +-- | The substitution (or other information) known about a type +-- variable. +data TyVarSub + = -- | No substitution known yet; can be substituted with anything. + TyVarFree + | -- | This substitution has been found. + TyVarSub Type + deriving (Show) + +instance Pretty TyVarSub where + pretty TyVarFree = "free" + pretty (TyVarSub t) = "=" <> pretty t + +type TyVar = VName + +-- | If a VName is not in this map, it is assumed to be rigid. +type TyVars = M.Map TyVar TyVarSub + +data TermScope = TermScope + { scopeVtable :: M.Map VName ValBinding, + scopeTypeTable :: M.Map VName TypeBinding, + scopeModTable :: M.Map VName Mod + } + deriving (Show) + +instance Semigroup TermScope where + TermScope vt1 tt1 mt1 <> TermScope vt2 tt2 mt2 = + TermScope (vt2 `M.union` vt1) (tt2 `M.union` tt1) (mt1 `M.union` mt2) + +-- | Type checking happens with access to this environment. The +-- 'TermScope' will be extended during type-checking as bindings come into +-- scope. +data TermEnv = TermEnv + { termScope :: TermScope, + termLevel :: Level, + termOuterEnv :: Env, + termImportName :: ImportName + } + +-- | The state is a set of constraints and a counter for generating +-- type names. This is distinct from the usual counter we use for +-- generating unique names, as these will be user-visible. +data TermState = TermState + { termConstraints :: Constraints, + termTyVars :: TyVars, + termCounter :: !Int, + termWarnings :: Warnings, + termNameSource :: VNameSource + } + +newtype TermM a + = TermM + ( ReaderT + TermEnv + (StateT TermState (Except (Warnings, TypeError))) + a + ) + deriving + ( Monad, + Functor, + Applicative, + MonadReader TermEnv, + MonadState TermState + ) + +envToTermScope :: Env -> TermScope +envToTermScope env = + TermScope + { scopeVtable = vtable, + scopeTypeTable = envTypeTable env, + scopeModTable = envModTable env + } + where + vtable = M.map valBinding $ envVtable env + valBinding (TypeM.BoundV tps v) = BoundV tps v + +initialTermScope :: TermScope +initialTermScope = + TermScope + { scopeVtable = initialVtable, + scopeTypeTable = mempty, + scopeModTable = mempty + } + where + initialVtable = M.fromList $ mapMaybe addIntrinsicF $ M.toList intrinsics + + prim = Scalar . Prim + arrow x y = Scalar $ Arrow mempty Unnamed Observe x y + + addIntrinsicF (name, IntrinsicMonoFun pts t) = + Just (name, BoundV [] $ arrow pts' $ RetType [] $ prim t) + where + pts' = case pts of + [pt] -> prim pt + _ -> Scalar $ tupleRecord $ map prim pts + addIntrinsicF (name, IntrinsicOverloadedFun ts pts rts) = + Just (name, OverloadedF ts pts rts) + addIntrinsicF (name, IntrinsicPolyFun tvs pts rt) = + Just + ( name, + BoundV tvs $ foldFunType pts rt + ) + addIntrinsicF (name, IntrinsicEquality) = + Just (name, EqualityF) + addIntrinsicF _ = Nothing + +runTermM :: TermM a -> TypeM a +runTermM (TermM m) = do + initial_scope <- (initialTermScope <>) . envToTermScope <$> askEnv + name <- askImportName + outer_env <- askEnv + src <- gets stateNameSource + let initial_env = + TermEnv + { termScope = initial_scope, + termLevel = 0, + termImportName = name, + termOuterEnv = outer_env + } + initial_state = + TermState + { termConstraints = mempty, + termTyVars = mempty, + termWarnings = mempty, + termNameSource = src, + termCounter = 0 + } + case runExcept (runStateT (runReaderT m initial_env) initial_state) of + Left (ws, e) -> do + warnings ws + throwError e + Right (a, TermState {termNameSource, termWarnings}) -> do + warnings termWarnings + modify $ \s -> s {stateNameSource = termNameSource} + pure a + +incLevel :: TermM a -> TermM a +incLevel = local $ \env -> env {termLevel = termLevel env + 1} + +incCounter :: TermM Int +incCounter = do + s <- get + put s {termCounter = termCounter s + 1} + pure $ termCounter s + +tyVarType :: (Monoid u) => TyVar -> TypeBase dim u +tyVarType v = Scalar $ TypeVar mempty (qualName v) [] + +newTyVar :: a -> Name -> TermM TyVar +newTyVar loc desc = do + i <- incCounter + v <- newID $ mkTypeVarName desc i + modify $ \s -> s {termTyVars = M.insert v TyVarFree $ termTyVars s} + pure v + +newType :: (Monoid u) => a -> Name -> TermM (TypeBase dim u) +newType loc desc = tyVarType <$> newTyVar loc desc + +addCt :: Ct -> TermM () +addCt ct = modify $ \s -> s {termConstraints = ct : termConstraints s} + +ctEq :: TypeBase d1 u1 -> TypeBase d2 u2 -> TermM () +ctEq t1 t2 = addCt $ CtEq (toType t1) (toType t2) + +ctHasConstr :: TypeBase d1 u1 -> Name -> [TypeBase d2 u2] -> TermM () +ctHasConstr t1 k t2 = addCt $ CtHasConstr (toType t1) k (map toType t2) + +ctHasField :: TypeBase d1 u1 -> Name -> TypeBase d2 u2 -> TermM () +ctHasField t1 k t = addCt $ CtHasField (toType t1) k (toType t) + +ctOneOf :: TypeBase d1 u1 -> [PrimType] -> TermM () +ctOneOf t ts = addCt $ CtOneOf (toType t) ts + +localScope :: (TermScope -> TermScope) -> TermM a -> TermM a +localScope f = local $ \tenv -> tenv {termScope = f $ termScope tenv} + +withEnv :: TermEnv -> Env -> TermEnv +withEnv tenv env = tenv {termScope = termScope tenv <> envToTermScope env} + +lookupQualNameEnv :: QualName VName -> TermM TermScope +lookupQualNameEnv (QualName [q] _) + | baseTag q <= maxIntrinsicTag = asks termScope -- Magical intrinsic module. +lookupQualNameEnv qn@(QualName quals _) = do + scope <- asks termScope + descend scope quals + where + descend scope [] = pure scope + descend scope (q : qs) + | Just (ModEnv q_scope) <- M.lookup q $ scopeModTable scope = + descend (envToTermScope q_scope) qs + | otherwise = + error $ "lookupQualNameEnv " <> show qn + +instance MonadError TypeError TermM where + throwError e = TermM $ do + ws <- gets termWarnings + throwError (ws, e) + + catchError (TermM m) f = + TermM $ m `catchError` f' + where + f' (_, e) = let TermM m' = f e in m' + +instance MonadTypeChecker TermM where + checkExpForSize = checkExp + + warnings ws = modify $ \s -> s {termWarnings = termWarnings s <> ws} + + warn loc problem = warnings $ singleWarning (locOf loc) problem + + newName v = do + s <- get + let (v', src') = FreshNames.newName (termNameSource s) v + put $ s {termNameSource = src'} + pure v' + + newID s = newName $ VName s 0 + + newTypeName name = do + i <- incCounter + newID $ mkTypeVarName name i + + bindVal v (TypeM.BoundV tps t) = localScope $ \scope -> + scope {scopeVtable = M.insert v (BoundV tps t) $ scopeVtable scope} + + lookupType qn = do + outer_env <- asks termOuterEnv + scope <- lookupQualNameEnv qn + case M.lookup (qualLeaf qn) $ scopeTypeTable scope of + Nothing -> error $ "lookupType: " <> show qn + Just (TypeAbbr l ps (RetType dims def)) -> + pure + ( ps, + RetType dims $ qualifyTypeVars outer_env (map typeParamName ps) (qualQuals qn) def, + l + ) + + typeError loc notes s = + throwError $ TypeError (locOf loc) notes s + +--- All the general machinery goes above. + +require :: T.Text -> [PrimType] -> Exp -> TermM Exp +require why ts e = do + ctOneOf (typeOf e) ts + pure e + +-- | Create a new type name and insert it (unconstrained) in the set +-- of type variables. +instTypeParam :: + (Monoid as) => + QualName VName -> + SrcLoc -> + TypeParam -> + TermM (VName, Subst (RetTypeBase dim as)) +instTypeParam qn loc tparam = do + i <- incCounter + let name = nameFromString (takeWhile isAscii (baseString (typeParamName tparam))) + v <- newID $ mkTypeVarName name i + case tparam of + TypeParamType {} -> do + modify $ \s -> s {termTyVars = M.insert v TyVarFree $ termTyVars s} + pure (v, Subst [] $ RetType [] $ Scalar $ TypeVar mempty (qualName v) []) + TypeParamDim {} -> + pure (v, ExpSubst $ sizeFromName (qualName v) loc) + +-- | Instantiate a type scheme with fresh type variables for its type +-- parameters. Returns the names of the fresh type variables, the +-- instance list, and the instantiated type. +instTypeScheme :: + QualName VName -> + SrcLoc -> + [TypeParam] -> + StructType -> + TermM ([VName], StructType) +instTypeScheme qn loc tparams t = do + (names, substs) <- fmap (unzip . catMaybes) $ forM tparams $ \tparam -> do + case tparam of + TypeParamType x _ _ -> do + i <- incCounter + let name = nameFromString (takeWhile isAscii (baseString (typeParamName tparam))) + v <- newID $ mkTypeVarName name i + pure $ Just (v, (typeParamName tparam, Subst [] $ RetType [] $ Scalar $ TypeVar mempty (qualName v) [])) + TypeParamDim {} -> + pure Nothing + let t' = applySubst (`lookup` substs) t + pure (names, t') + +lookupMod :: QualName VName -> TermM Mod +lookupMod qn@(QualName _ name) = do + scope <- lookupQualNameEnv qn + case M.lookup name $ scopeModTable scope of + Nothing -> error $ "lookupMod: " <> show qn + Just m -> pure m + +lookupVar :: SrcLoc -> QualName VName -> TermM StructType +lookupVar loc qn@(QualName qs name) = do + scope <- lookupQualNameEnv qn + case M.lookup name $ scopeVtable scope of + Nothing -> + error $ "lookupVar: " <> show qn + Just (BoundV tparams t) -> do + if null tparams && null qs + then pure t + else do + (tnames, t') <- instTypeScheme qn loc tparams t + outer_env <- asks termOuterEnv + pure $ qualifyTypeVars outer_env tnames qs t' + Just EqualityF -> do + argtype <- newType loc "t" + pure $ + Scalar . Arrow mempty Unnamed Observe argtype . RetType [] $ + Scalar $ + Arrow mempty Unnamed Observe argtype $ + RetType [] $ + Scalar $ + Prim Bool + Just (OverloadedF ts pts rt) -> do + argtype <- newType loc "t" + ctOneOf argtype ts + let (pts', rt') = instOverloaded (argtype :: StructType) pts rt + pure $ foldFunType (map (toParam Observe) pts') $ RetType [] $ toRes Nonunique rt' + where + instOverloaded argtype pts rt = + ( map (maybe argtype (Scalar . Prim)) pts, + maybe argtype (Scalar . Prim) rt + ) + +bind :: + [Ident StructType] -> + TermM a -> + TermM a +bind idents = localScope (`bindVars` idents) + where + bindVars = foldl bindVar + + bindVar scope (Ident name (Info tp) _) = + scope + { scopeVtable = M.insert name (BoundV [] tp) $ scopeVtable scope + } + +-- All this complexity is just so we can handle un-suffixed numeric +-- literals in patterns. +patLitMkType :: PatLit -> SrcLoc -> TermM ParamType +patLitMkType (PatLitInt _) loc = do + t <- newType loc "t" + ctOneOf t anyNumberType + pure t +patLitMkType (PatLitFloat _) loc = do + t <- newType loc "t" + ctOneOf t anyFloatType + pure t +patLitMkType (PatLitPrim v) _ = + pure $ Scalar $ Prim $ primValueType v + +checkPat' :: + PatBase NoInfo VName ParamType -> + Inferred ParamType -> + TermM (Pat ParamType) +checkPat' (PatParens p loc) t = + PatParens <$> checkPat' p t <*> pure loc +checkPat' (PatAttr attr p loc) t = + PatAttr <$> checkAttr attr <*> checkPat' p t <*> pure loc +checkPat' (Id name NoInfo loc) (Ascribed t) = + pure $ Id name (Info t) loc +checkPat' (Id name NoInfo loc) NoneInferred = do + t <- newType loc "t" + pure $ Id name (Info t) loc +checkPat' (Wildcard _ loc) (Ascribed t) = + pure $ Wildcard (Info t) loc +checkPat' (Wildcard NoInfo loc) NoneInferred = do + t <- newType loc "t" + pure $ Wildcard (Info t) loc +checkPat' p@(TuplePat ps loc) (Ascribed t) + | Just ts <- isTupleRecord t, + length ts == length ps = + TuplePat + <$> zipWithM checkPat' ps (map Ascribed ts) + <*> pure loc + | otherwise = do + ps_t :: [ParamType] <- replicateM (length ps) (newType loc "t") + ctEq (Scalar (tupleRecord ps_t)) t + checkPat' p $ Ascribed $ toParam Observe $ Scalar $ tupleRecord ps_t +checkPat' (TuplePat ps loc) NoneInferred = + TuplePat <$> mapM (`checkPat'` NoneInferred) ps <*> pure loc +checkPat' p@(RecordPat p_fs loc) (Ascribed t) + | Scalar (Record t_fs) <- t, + L.sort (map fst p_fs) == L.sort (M.keys t_fs) = + RecordPat . M.toList <$> check t_fs <*> pure loc + | otherwise = do + p_fs' <- traverse (const $ newType loc "t") $ M.fromList p_fs + ctEq (Scalar (Record p_fs') :: ParamType) t + checkPat' p $ Ascribed $ toParam Observe $ Scalar (Record p_fs') + where + check t_fs = + traverse (uncurry checkPat') $ + M.intersectionWith (,) (M.fromList p_fs) (fmap Ascribed t_fs) +checkPat' (RecordPat fs loc) NoneInferred = + RecordPat . M.toList + <$> traverse (`checkPat'` NoneInferred) (M.fromList fs) + <*> pure loc +checkPat' (PatAscription p t loc) maybe_outer_t = do + (t', _, RetType dims st, _) <- checkTypeExp t + + case maybe_outer_t of + Ascribed outer_t -> do + ctEq st outer_t + PatAscription + <$> checkPat' p (Ascribed (resToParam st)) + <*> pure t' + <*> pure loc + NoneInferred -> + PatAscription + <$> checkPat' p (Ascribed (resToParam st)) + <*> pure t' + <*> pure loc +checkPat' (PatLit l NoInfo loc) (Ascribed t) = do + t' <- patLitMkType l loc + addCt $ CtEq (toType t') (toType t) + pure $ PatLit l (Info t') loc +checkPat' (PatLit l NoInfo loc) NoneInferred = do + t' <- patLitMkType l loc + pure $ PatLit l (Info t') loc +checkPat' (PatConstr n NoInfo ps loc) (Ascribed (Scalar (Sum cs))) + | Just ts <- M.lookup n cs = do + when (length ps /= length ts) $ + typeError loc mempty $ + "Pattern #" + <> pretty n + <> " expects" + <+> pretty (length ps) + <+> "constructor arguments, but type provides" + <+> pretty (length ts) + <+> "arguments." + ps' <- zipWithM checkPat' ps $ map Ascribed ts + pure $ PatConstr n (Info (Scalar (Sum cs))) ps' loc +checkPat' (PatConstr n NoInfo ps loc) (Ascribed t) = do + t' <- newType loc "t" + ps' <- forM ps $ \p -> do + p_t <- newType (srclocOf p) "t" + checkPat' p $ Ascribed p_t + ctHasConstr (t' :: ParamType) n $ map patternStructType ps' + pure $ PatConstr n (Info t) ps' loc +checkPat' (PatConstr n NoInfo ps loc) NoneInferred = do + ps' <- mapM (`checkPat'` NoneInferred) ps + t <- newType loc "t" + ctHasConstr t n $ map patternStructType ps' + pure $ PatConstr n (Info t) ps' loc + +checkPat :: + PatBase NoInfo VName (TypeBase Size u) -> + Inferred StructType -> + (Pat ParamType -> TermM a) -> + TermM a +checkPat p t m = + m =<< checkPat' (fmap (toParam Observe) p) (fmap (toParam Observe) t) + +-- | Bind @let@-bound sizes. This is usually followed by 'bindletPat' +-- immediately afterwards. +bindSizes :: [SizeBinder VName] -> TermM a -> TermM a +bindSizes [] m = m -- Minor optimisation. +bindSizes sizes m = bind (map sizeWithType sizes) m + where + sizeWithType size = + Ident (sizeName size) (Info (Scalar (Prim (Signed Int64)))) (srclocOf size) + +bindLetPat :: + PatBase NoInfo VName (TypeBase Size u) -> + StructType -> + (Pat ParamType -> TermM a) -> + TermM a +bindLetPat p t m = do + checkPat p (Ascribed t) $ \p' -> + bind (patIdents (fmap toStruct p')) $ + m p' + +typeParamIdent :: TypeParam -> Maybe (Ident StructType) +typeParamIdent (TypeParamDim v loc) = + Just $ Ident v (Info $ Scalar $ Prim $ Signed Int64) loc +typeParamIdent _ = Nothing + +bindTypes :: + [(VName, TypeBinding)] -> + TermM a -> + TermM a +bindTypes tbinds = localScope extend + where + extend scope = + scope + { scopeTypeTable = M.fromList tbinds <> scopeTypeTable scope + } + +bindTypeParams :: [TypeParam] -> TermM a -> TermM a +bindTypeParams tparams = + bind (mapMaybe typeParamIdent tparams) + . bindTypes (mapMaybe typeParamType tparams) + where + typeParamType (TypeParamType l v _) = + Just (v, TypeAbbr l [] $ RetType [] $ Scalar (TypeVar mempty (qualName v) [])) + typeParamType TypeParamDim {} = + Nothing + +bindParams :: + [TypeParam] -> + [PatBase NoInfo VName ParamType] -> + ([Pat ParamType] -> TermM a) -> + TermM a +bindParams tps orig_ps m = bindTypeParams tps $ do + let descend ps' (p : ps) = + checkPat p NoneInferred $ \p' -> + bind (patIdents $ fmap toStruct p') $ incLevel $ descend (p' : ps') ps + descend ps' [] = m $ reverse ps' + + incLevel $ descend [] orig_ps + +checkApply :: SrcLoc -> (Maybe (QualName VName), Int) -> Type -> Exp -> TermM TyVar +checkApply loc (fname, _) ftype arg = do + a <- newType loc "a" + b <- newTyVar loc "b" + ctEq ftype $ Scalar $ Arrow NoUniqueness Unnamed Observe a $ RetType [] (tyVarType b) + ctEq a (expType arg) + pure b + +checkSlice :: SliceBase NoInfo VName -> TermM [DimIndex] +checkSlice = mapM checkDimIndex + where + checkDimIndex (DimFix i) = + DimFix <$> check i + checkDimIndex (DimSlice i j s) = + DimSlice <$> traverse check i <*> traverse check j <*> traverse check s + + check = require "use as index" anySignedType <=< checkExp + +isSlice :: DimIndexBase f vn -> Bool +isSlice DimSlice {} = True +isSlice DimFix {} = False + +-- Add constraints saying that the first type has a (potentially +-- nested) field containing the second type. +mustHaveFields :: + SrcLoc -> + TypeBase d1 u1 -> + [Name] -> + TypeBase d2 u2 -> + TermM () +mustHaveFields loc t [] ve_t = ctEq t ve_t +mustHaveFields loc t (f : fs) ve_t = do + f_t :: Type <- newType loc "ft" + ctHasField t f f_t + mustHaveFields loc f_t fs ve_t + +checkCase :: + StructType -> + CaseBase NoInfo VName -> + TermM (CaseBase Info VName, StructType) +checkCase mt (CasePat p e loc) = + bindLetPat p mt $ \p' -> do + e' <- checkExp e + pure (CasePat (fmap toStruct p') e' loc, typeOf e') + +checkCases :: + StructType -> + NE.NonEmpty (CaseBase NoInfo VName) -> + TermM (NE.NonEmpty (CaseBase Info VName), StructType) +checkCases mt rest_cs = + case NE.uncons rest_cs of + (c, Nothing) -> do + (c', t) <- checkCase mt c + pure (NE.singleton c', t) + (c, Just cs) -> do + (c', c_t) <- checkCase mt c + (cs', cs_t) <- checkCases mt cs + ctEq c_t cs_t + pure (NE.cons c' cs', c_t) + +-- | An unmatched pattern. Used in in the generation of +-- unmatched pattern warnings by the type checker. +data Unmatched p + = UnmatchedNum p [PatLit] + | UnmatchedBool p + | UnmatchedConstr p + | Unmatched p + deriving (Functor, Show) + +instance Pretty (Unmatched (Pat StructType)) where + pretty um = case um of + (UnmatchedNum p nums) -> pretty' p <+> "where p is not one of" <+> pretty nums + (UnmatchedBool p) -> pretty' p + (UnmatchedConstr p) -> pretty' p + (Unmatched p) -> pretty' p + where + pretty' (PatAscription p t _) = pretty p <> ":" <+> pretty t + pretty' (PatParens p _) = parens $ pretty' p + pretty' (PatAttr _ p _) = parens $ pretty' p + pretty' (Id v _ _) = prettyName v + pretty' (TuplePat pats _) = parens $ commasep $ map pretty' pats + pretty' (RecordPat fs _) = braces $ commasep $ map ppField fs + where + ppField (name, t) = pretty (nameToString name) <> equals <> pretty' t + pretty' Wildcard {} = "_" + pretty' (PatLit e _ _) = pretty e + pretty' (PatConstr n _ ps _) = "#" <> pretty n <+> sep (map pretty' ps) + +checkExp :: ExpBase NoInfo VName -> TermM (ExpBase Info VName) +-- +checkExp (Var qn _ loc) = do + t <- lookupVar loc qn + pure $ Var qn (Info t) loc +checkExp (OpSection op _ loc) = do + ftype <- lookupVar loc op + pure $ OpSection op (Info ftype) loc +checkExp (Negate arg loc) = do + arg' <- require "numeric negation" anyNumberType =<< checkExp arg + pure $ Negate arg' loc +checkExp (Not arg loc) = do + arg' <- require "logical negation" (Bool : anyIntType) =<< checkExp arg + pure $ Not arg' loc +checkExp (Hole NoInfo loc) = + Hole <$> (Info <$> newType loc "hole") <*> pure loc +checkExp (Parens e loc) = + Parens <$> checkExp e <*> pure loc +checkExp (TupLit es loc) = + TupLit <$> mapM checkExp es <*> pure loc +checkExp (QualParens (modname, modnameloc) e loc) = do + mod <- lookupMod modname + case mod of + ModEnv env -> local (`withEnv` env) $ do + e' <- checkExp e + pure $ QualParens (modname, modnameloc) e' loc + ModFun {} -> + typeError loc mempty . withIndexLink "module-is-parametric" $ + "Module" <+> pretty modname <+> " is a parametric module." +-- +checkExp (IntLit x NoInfo loc) = do + t <- newType loc "num" + ctOneOf t anyNumberType + pure $ IntLit x (Info t) loc +checkExp (FloatLit x NoInfo loc) = do + t <- newType loc "float" + ctOneOf t anyFloatType + pure $ FloatLit x (Info t) loc +checkExp (Literal v loc) = + pure $ Literal v loc +checkExp (StringLit vs loc) = + pure $ StringLit vs loc +checkExp (ArrayLit es _ loc) = do + -- TODO: this will produce an enormous number of constraints and + -- type variables for pathologically large arrays with + -- type-unsuffixed integers. Add some special case that handles that + -- more efficiently. + et <- newType loc "et" + es' <- forM es $ \e -> do + e' <- checkExp e + ctEq (typeOf e') et + pure e' + let arr_t = arrayOf (Shape [sizeFromInteger (L.genericLength es) loc]) et + pure $ ArrayLit es' (Info arr_t) loc +checkExp (RecordLit fs loc) = + RecordLit <$> evalStateT (mapM checkField fs) mempty <*> pure loc + where + checkField (RecordFieldExplicit f e rloc) = do + errIfAlreadySet f rloc + modify $ M.insert f rloc + RecordFieldExplicit f <$> lift (checkExp e) <*> pure rloc + checkField (RecordFieldImplicit name NoInfo rloc) = do + errIfAlreadySet (baseName name) rloc + t <- lift $ lookupVar rloc $ qualName name + modify $ M.insert (baseName name) rloc + pure $ RecordFieldImplicit name (Info t) rloc + + errIfAlreadySet f rloc = do + maybe_sloc <- gets $ M.lookup f + case maybe_sloc of + Just sloc -> + lift . typeError rloc mempty $ + "Field" + <+> dquotes (pretty f) + <+> "previously defined at" + <+> pretty (locStrRel rloc sloc) + <> "." + Nothing -> pure () + +-- +checkExp (Attr info e loc) = + Attr <$> checkAttr info <*> checkExp e <*> pure loc +checkExp (Assert e1 e2 NoInfo loc) = do + e1' <- require "being asserted" [Bool] =<< checkExp e1 + e2' <- checkExp e2 + pure $ Assert e1' e2' (Info (prettyText e1)) loc +-- +checkExp (Constr name es NoInfo loc) = do + t <- newType loc "t" + es' <- mapM checkExp es + ctHasConstr t name $ map typeOf es' + pure $ Constr name es' (Info t) loc +-- +checkExp (AppExp (Apply fe args loc) NoInfo) = do + fe' <- checkExp fe + ((_, rt), args') <- mapAccumLM onArg (0, typeOf fe') args + + pure $ AppExp (Apply fe' args' loc) $ Info $ AppRes rt [] + where + fname = + case fe of + Var v _ _ -> Just v + _ -> Nothing + + onArg (i, f_t) (_, arg) = do + arg' <- checkExp arg + rt <- checkApply loc (fname, i) (toType f_t) arg' + pure + ( (i + 1, tyVarType rt), + (Info Nothing, arg') + ) +-- +checkExp (AppExp (BinOp (op, oploc) NoInfo (e1, _) (e2, _) loc) NoInfo) = do + ftype <- lookupVar oploc op + e1' <- checkExp e1 + e2' <- checkExp e2 + + rt1 <- checkApply loc (Just op, 0) (toType ftype) e1' + rt2 <- checkApply loc (Just op, 1) (tyVarType rt1) e2' + + pure $ + AppExp + (BinOp (op, oploc) (Info ftype) (e1', Info Nothing) (e2', Info Nothing) loc) + (Info (AppRes (tyVarType rt2) [])) +-- +checkExp (OpSectionLeft op _ e _ _ loc) = do + optype <- lookupVar loc op + e' <- checkExp e + rt <- checkApply loc (Just op, 0) (toType optype) e' + pure $ + OpSectionLeft + op + (Info optype) + e' + -- Dummy types. + ( Info (Unnamed, Scalar $ Prim Bool, Nothing), + Info (Unnamed, Scalar $ Prim Bool) + ) + (Info (RetType [] (tyVarType rt)), Info []) + loc +-- +checkExp (OpSectionRight op _ e _ NoInfo loc) = do + optype <- lookupVar loc op + e' <- checkExp e + t1 <- newType loc "t" + rt <- newType loc "rt" + ctEq optype $ foldFunType [t1, toParam Observe $ typeOf e'] $ RetType [] rt + pure $ + OpSectionRight + op + (Info optype) + e' + -- Dummy types. + ( Info (Unnamed, Scalar $ Prim Bool), + Info (Unnamed, Scalar $ Prim Bool, Nothing) + ) + (Info $ RetType [] rt) + loc +-- +checkExp (ProjectSection fields NoInfo loc) = do + a <- newType loc "a" + b <- newType loc "b" + mustHaveFields loc a fields b + let ft = Scalar $ Arrow mempty Unnamed Observe a $ RetType [] b + pure $ ProjectSection fields (Info ft) loc +-- +checkExp (Lambda params body rettype NoInfo loc) = do + bindParams [] params $ \params' -> do + body' <- checkExp body + rettype_te' <- case rettype of + Just rettype_te -> do + (rettype_te', _, RetType _ st, _) <- checkTypeExp rettype_te + ctEq (typeOf body') st + pure $ Just rettype_te' + Nothing -> pure Nothing + let ret = RetType [] $ toRes Nonunique $ typeOf body' + pure $ Lambda params' body' rettype_te' (Info ret) loc +-- +checkExp (AppExp (LetPat sizes pat e body loc) _) = do + e' <- checkExp e + + bindSizes sizes . incLevel . bindLetPat pat (typeOf e') $ \pat' -> do + body' <- incLevel $ checkExp body + pure $ + AppExp + (LetPat sizes (fmap toStruct pat') e' body' loc) + (Info $ AppRes (typeOf body') []) +-- +checkExp (AppExp (LetFun name (tparams, params, maybe_retdecl, NoInfo, e) body loc) _) = do + (tparams', params', maybe_retdecl', rettype, e') <- + bindParams tparams params $ \params' -> do + e' <- checkExp e + let ret = RetType [] $ toRes Nonunique $ typeOf e' + pure (tparams, params', undefined, ret, e') + + let entry = BoundV tparams' $ funType params' rettype + bindF scope = + scope + { scopeVtable = M.insert name entry $ scopeVtable scope + } + body' <- localScope bindF $ checkExp body + + pure $ + AppExp + ( LetFun + name + (tparams', params', maybe_retdecl', Info rettype, e') + body' + loc + ) + (Info $ AppRes (typeOf body') []) +-- +checkExp (AppExp (Range start maybe_step end loc) _) = do + start' <- checkExp' start + maybe_step' <- traverse checkExp' maybe_step + end' <- traverse checkExp' end + range_t <- newType loc "range" + ctEq range_t $ arrayOf (Shape [()]) (toType (typeOf start')) + pure $ AppExp (Range start' maybe_step' end' loc) $ Info $ AppRes range_t [] + where + checkExp' = require "use in range expression" anyIntType <=< checkExp +-- +checkExp (Project k e NoInfo loc) = do + e' <- checkExp e + kt <- newType loc "t" + ctHasField (typeOf e') k kt + pure $ Project k e' (Info kt) loc +-- +checkExp (RecordUpdate src fields ve NoInfo loc) = do + src' <- checkExp src + ve' <- checkExp ve + mustHaveFields loc (typeOf src') fields (typeOf ve') + pure $ RecordUpdate src' fields ve' (Info (typeOf src')) loc +-- +checkExp (IndexSection slice NoInfo loc) = do + slice' <- checkSlice slice + index_arg_t <- newType loc "index" + index_elem_t <- newType loc "index_elem" + index_res_t <- newType loc "index_res" + let num_slices = length $ filter isSlice slice + ctEq index_arg_t $ arrayOf (Shape (replicate num_slices ())) index_elem_t + ctEq index_res_t $ arrayOf (Shape (replicate (length slice) ())) index_elem_t + let ft = Scalar $ Arrow mempty Unnamed Observe index_arg_t $ RetType [] index_res_t + pure $ IndexSection slice' (Info ft) loc +-- +checkExp (AppExp (Index e slice loc) _) = do + e' <- checkExp e + slice' <- checkSlice slice + index_t <- newType loc "index" + index_elem_t <- newType loc "index_elem" + let num_slices = length $ filter isSlice slice + ctEq index_t $ arrayOf (Shape (replicate num_slices ())) index_elem_t + ctEq (typeOf e') $ arrayOf (Shape (replicate (length slice) ())) index_elem_t + pure $ AppExp (Index e' slice' loc) (Info $ AppRes index_t []) +-- +checkExp (Update src slice ve loc) = do + src' <- checkExp src + slice' <- checkSlice slice + ve' <- checkExp ve + let num_slices = length $ filter isSlice slice + update_elem_t <- newType loc "update_elem" + ctEq (typeOf src') $ arrayOf (Shape (replicate (length slice) ())) update_elem_t + ctEq (typeOf ve') $ arrayOf (Shape (replicate num_slices ())) update_elem_t + pure $ Update src' slice' ve' loc +-- +checkExp (AppExp (LetWith dest src slice ve body loc) _) = do + src_t <- lookupVar (srclocOf src) $ qualName $ identName src + let src' = src {identType = Info src_t} + dest' = dest {identType = Info src_t} + slice' <- checkSlice slice + ve' <- checkExp ve + let num_slices = length $ filter isSlice slice + update_elem_t <- newType loc "update_elem" + ctEq src_t $ arrayOf (Shape (replicate (length slice) ())) update_elem_t + ctEq (typeOf ve') $ arrayOf (Shape (replicate num_slices ())) update_elem_t + bind [dest'] $ do + body' <- checkExp body + pure $ AppExp (LetWith dest' src' slice' ve' body' loc) (Info $ AppRes (typeOf body') []) +-- +checkExp (AppExp (If e1 e2 e3 loc) _) = do + e1' <- checkExp e1 + e2' <- checkExp e2 + e3' <- checkExp e3 + + ctEq (typeOf e1') (Scalar (Prim Bool) :: Type) + ctEq (typeOf e2') (typeOf e3') + + pure $ AppExp (If e1' e2' e3' loc) (Info $ AppRes (typeOf e1') []) +-- +checkExp (AppExp (Match e cs loc) _) = do + e' <- checkExp e + (cs', t) <- checkCases (typeOf e') cs + pure $ AppExp (Match e' cs' loc) (Info $ AppRes t []) +-- +checkExp (AppExp (Loop _ pat arg form body loc) _) = do + arg' <- checkExp arg + bindLetPat pat (typeOf arg') $ \pat' -> do + (form', body') <- + case form of + For (Ident i _ iloc) bound -> do + bound' <- require "loop bound" anyIntType =<< checkExp bound + let i' = Ident i (Info (typeOf bound')) iloc + bind [i'] $ do + body' <- checkExp body + ctEq (typeOf arg') (typeOf body') + pure (For i' bound', body') + While cond -> do + cond' <- checkExp cond + body' <- checkExp body + ctEq (typeOf arg') (typeOf body') + pure (While cond', body') + ForIn elemp arr -> do + arr' <- checkExp arr + elem_t <- newType elemp "elem" + ctEq (typeOf arr') $ arrayOf (Shape [()]) (toType elem_t) + bindLetPat elemp elem_t $ \elemp' -> do + body' <- checkExp body + pure (ForIn (toStruct <$> elemp') arr', body') + pure $ + AppExp + (Loop [] pat' arg' form' body' loc) + (Info (AppRes (patternStructType pat') [])) +-- +checkExp (Ascript e te loc) = do + e' <- checkExp e + (te', _, RetType _ st, _) <- checkTypeExp te + ctEq (typeOf e') st + pure $ Ascript e' te' loc +checkExp (Coerce e te NoInfo loc) = do + e' <- checkExp e + (te', _, RetType _ st, _) <- checkTypeExp te + ctEq (typeOf e') st + pure $ Coerce e' te' (Info (toStruct st)) loc + +-- +-- + +checkValDef :: + ( VName, + Maybe (TypeExp NoInfo VName), + [TypeParam], + [PatBase NoInfo VName ParamType], + ExpBase NoInfo VName, + SrcLoc + ) -> + TypeM + ( [TypeParam], + [Pat ParamType], + Maybe (TypeExp Info VName), + ResRetType, + Exp + ) +checkValDef (fname, maybe_retdecl, tparams, params, body, loc) = runTermM $ do + bindParams tparams params $ \params' -> do + body' <- checkExp body + cts <- gets termConstraints + tyvars <- gets termTyVars + traceM $ + unlines + [ "function " <> prettyNameString fname, + "constraints:", + prettyString cts, + "tyvars:", + prettyString $ map (first prettyNameString) $ M.toList tyvars + ] + pure (undefined, params', undefined, undefined, body') From ab529bb4031019053101eef8cf76cc628e529437 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 6 Feb 2024 11:26:11 +0100 Subject: [PATCH 002/258] Do not think of overloading as constraints. --- src/Language/Futhark/TypeChecker/Terms2.hs | 97 +++++++++++----------- 1 file changed, 50 insertions(+), 47 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 6f12b8fcad..52106ad28d 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -256,24 +256,33 @@ newTyVar loc desc = do modify $ \s -> s {termTyVars = M.insert v TyVarFree $ termTyVars s} pure v -newType :: (Monoid u) => a -> Name -> TermM (TypeBase dim u) +newType :: (Located loc, Monoid u) => loc -> Name -> TermM (TypeBase dim u) newType loc desc = tyVarType <$> newTyVar loc desc +newTypeWithField :: (Monoid u) => SrcLoc -> Name -> Name -> TypeBase dim u -> TermM (TypeBase dim u) +newTypeWithField loc desc k t = do + rt <- newType loc desc + addCt $ CtHasField (toType rt) k (toType t) + pure rt + +newTypeWithConstr :: (Monoid u) => SrcLoc -> Name -> Name -> [Type] -> TermM (TypeBase dim u) +newTypeWithConstr loc desc k ts = do + t <- newType loc desc + addCt $ CtHasConstr (toType t) k ts + pure t + +newTypeOverloaded :: (Monoid u) => SrcLoc -> Name -> [PrimType] -> TermM (TypeBase dim u) +newTypeOverloaded loc name pts = do + t <- newType loc name + addCt $ CtOneOf (toType t) pts + pure t + addCt :: Ct -> TermM () addCt ct = modify $ \s -> s {termConstraints = ct : termConstraints s} ctEq :: TypeBase d1 u1 -> TypeBase d2 u2 -> TermM () ctEq t1 t2 = addCt $ CtEq (toType t1) (toType t2) -ctHasConstr :: TypeBase d1 u1 -> Name -> [TypeBase d2 u2] -> TermM () -ctHasConstr t1 k t2 = addCt $ CtHasConstr (toType t1) k (map toType t2) - -ctHasField :: TypeBase d1 u1 -> Name -> TypeBase d2 u2 -> TermM () -ctHasField t1 k t = addCt $ CtHasField (toType t1) k (toType t) - -ctOneOf :: TypeBase d1 u1 -> [PrimType] -> TermM () -ctOneOf t ts = addCt $ CtOneOf (toType t) ts - localScope :: (TermScope -> TermScope) -> TermM a -> TermM a localScope f = local $ \tenv -> tenv {termScope = f $ termScope tenv} @@ -344,8 +353,9 @@ instance MonadTypeChecker TermM where --- All the general machinery goes above. require :: T.Text -> [PrimType] -> Exp -> TermM Exp -require why ts e = do - ctOneOf (typeOf e) ts +require why pts e = do + t :: Type <- newTypeOverloaded (srclocOf e) "t" pts + ctEq t $ expType e pure e -- | Create a new type name and insert it (unconstrained) in the set @@ -419,8 +429,7 @@ lookupVar loc qn@(QualName qs name) = do Scalar $ Prim Bool Just (OverloadedF ts pts rt) -> do - argtype <- newType loc "t" - ctOneOf argtype ts + argtype <- newTypeOverloaded loc "t" ts let (pts', rt') = instOverloaded (argtype :: StructType) pts rt pure $ foldFunType (map (toParam Observe) pts') $ RetType [] $ toRes Nonunique rt' where @@ -445,14 +454,10 @@ bind idents = localScope (`bindVars` idents) -- All this complexity is just so we can handle un-suffixed numeric -- literals in patterns. patLitMkType :: PatLit -> SrcLoc -> TermM ParamType -patLitMkType (PatLitInt _) loc = do - t <- newType loc "t" - ctOneOf t anyNumberType - pure t -patLitMkType (PatLitFloat _) loc = do - t <- newType loc "t" - ctOneOf t anyFloatType - pure t +patLitMkType (PatLitInt _) loc = + newTypeOverloaded loc "t" anyNumberType +patLitMkType (PatLitFloat _) loc = + newTypeOverloaded loc "t" anyFloatType patLitMkType (PatLitPrim v) _ = pure $ Scalar $ Prim $ primValueType v @@ -538,16 +543,15 @@ checkPat' (PatConstr n NoInfo ps loc) (Ascribed (Scalar (Sum cs))) ps' <- zipWithM checkPat' ps $ map Ascribed ts pure $ PatConstr n (Info (Scalar (Sum cs))) ps' loc checkPat' (PatConstr n NoInfo ps loc) (Ascribed t) = do - t' <- newType loc "t" ps' <- forM ps $ \p -> do p_t <- newType (srclocOf p) "t" checkPat' p $ Ascribed p_t - ctHasConstr (t' :: ParamType) n $ map patternStructType ps' - pure $ PatConstr n (Info t) ps' loc + t' <- newTypeWithConstr loc "t" n $ map (toType . patternType) ps' + ctEq t' t + pure $ PatConstr n (Info t') ps' loc checkPat' (PatConstr n NoInfo ps loc) NoneInferred = do ps' <- mapM (`checkPat'` NoneInferred) ps - t <- newType loc "t" - ctHasConstr t n $ map patternStructType ps' + t <- newTypeWithConstr loc "t" n $ map (toType . patternType) ps' pure $ PatConstr n (Info t) ps' loc checkPat :: @@ -640,17 +644,18 @@ isSlice DimFix {} = False -- Add constraints saying that the first type has a (potentially -- nested) field containing the second type. -mustHaveFields :: - SrcLoc -> - TypeBase d1 u1 -> - [Name] -> - TypeBase d2 u2 -> - TermM () -mustHaveFields loc t [] ve_t = ctEq t ve_t +mustHaveFields :: SrcLoc -> Type -> [Name] -> Type -> TermM () +mustHaveFields _ t [] ve_t = + -- This case is probably never reached. + ctEq t ve_t +mustHaveFields loc t [f] ve_t = do + rt :: Type <- newTypeWithField loc "ft" f ve_t + ctEq t rt mustHaveFields loc t (f : fs) ve_t = do - f_t :: Type <- newType loc "ft" - ctHasField t f f_t - mustHaveFields loc f_t fs ve_t + ft :: Type <- newType loc "ft" + rt <- newTypeWithField loc "rt" f ft + mustHaveFields loc ft fs ve_t + ctEq t rt checkCase :: StructType -> @@ -735,12 +740,10 @@ checkExp (QualParens (modname, modnameloc) e loc) = do "Module" <+> pretty modname <+> " is a parametric module." -- checkExp (IntLit x NoInfo loc) = do - t <- newType loc "num" - ctOneOf t anyNumberType + t <- newTypeOverloaded loc "num" anyNumberType pure $ IntLit x (Info t) loc checkExp (FloatLit x NoInfo loc) = do - t <- newType loc "float" - ctOneOf t anyFloatType + t <- newTypeOverloaded loc "float" anyFloatType pure $ FloatLit x (Info t) loc checkExp (Literal v loc) = pure $ Literal v loc @@ -792,9 +795,8 @@ checkExp (Assert e1 e2 NoInfo loc) = do pure $ Assert e1' e2' (Info (prettyText e1)) loc -- checkExp (Constr name es NoInfo loc) = do - t <- newType loc "t" es' <- mapM checkExp es - ctHasConstr t name $ map typeOf es' + t <- newTypeWithConstr loc "t" name $ map expType es' pure $ Constr name es' (Info t) loc -- checkExp (AppExp (Apply fe args loc) NoInfo) = do @@ -866,7 +868,7 @@ checkExp (OpSectionRight op _ e _ NoInfo loc) = do checkExp (ProjectSection fields NoInfo loc) = do a <- newType loc "a" b <- newType loc "b" - mustHaveFields loc a fields b + mustHaveFields loc (toType a) fields (toType b) let ft = Scalar $ Arrow mempty Unnamed Observe a $ RetType [] b pure $ ProjectSection fields (Info ft) loc -- @@ -928,14 +930,15 @@ checkExp (AppExp (Range start maybe_step end loc) _) = do -- checkExp (Project k e NoInfo loc) = do e' <- checkExp e - kt <- newType loc "t" - ctHasField (typeOf e') k kt + kt <- newType loc "kt" + t <- newTypeWithField loc "t" k kt + ctEq (typeOf e') t pure $ Project k e' (Info kt) loc -- checkExp (RecordUpdate src fields ve NoInfo loc) = do src' <- checkExp src ve' <- checkExp ve - mustHaveFields loc (typeOf src') fields (typeOf ve') + mustHaveFields loc (expType src') fields (expType ve') pure $ RecordUpdate src' fields ve' (Info (typeOf src')) loc -- checkExp (IndexSection slice NoInfo loc) = do From d36f15ede541972667ebcb56a5ad65b159fd4082 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 6 Feb 2024 11:37:38 +0100 Subject: [PATCH 003/258] Move overloading from constraints to tyvars. --- src/Language/Futhark/TypeChecker/Terms2.hs | 70 ++++++++++------------ 1 file changed, 30 insertions(+), 40 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 52106ad28d..036338bc1d 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -59,7 +59,7 @@ import Language.Futhark import Language.Futhark.TypeChecker.Monad hiding (BoundV, lookupMod) import Language.Futhark.TypeChecker.Monad qualified as TypeM import Language.Futhark.TypeChecker.Types -import Language.Futhark.TypeChecker.Unify (Level, mkUsage) +import Language.Futhark.TypeChecker.Unify (Level) import Prelude hiding (mod) data Inferred t @@ -84,40 +84,36 @@ toType = bimap (const ()) (const NoUniqueness) expType :: Exp -> Type expType = toType . typeOf -data Ct - = CtEq Type Type - | CtOneOf Type [PrimType] - | CtHasConstr Type Name [Type] - | CtHasField Type Name Type +data Ct = CtEq Type Type deriving (Show) instance Pretty Ct where pretty (CtEq t1 t2) = pretty t1 <+> "~" <+> pretty t2 - pretty (CtOneOf t1 ts) = pretty t1 <+> "∈" <+> pretty ts - pretty (CtHasConstr t1 k ts) = - pretty t1 <+> "~" <+> "... | " <+> hsep ("#" <> pretty k : map pretty ts) <+> " | ..." - pretty (CtHasField t1 k t) = - pretty t1 <+> "~" <+> braces ("..." <+> pretty k <> ":" <+> pretty t <+> "...") type Constraints = [Ct] --- | The substitution (or other information) known about a type --- variable. -data TyVarSub - = -- | No substitution known yet; can be substituted with anything. +-- | Information about a type variable. +data TyVarInfo + = -- | Can be substituted with anything. TyVarFree - | -- | This substitution has been found. - TyVarSub Type + | -- | Can only be substituted with these primitive types. + TyVarPrim [PrimType] + | -- | Must be a record with these fields. + TyVarRecord (M.Map Name Type) + | -- | Must be a sum type with these fields. + TyVarSum (M.Map Name [Type]) deriving (Show) -instance Pretty TyVarSub where +instance Pretty TyVarInfo where pretty TyVarFree = "free" - pretty (TyVarSub t) = "=" <> pretty t + pretty (TyVarPrim pts) = "∈" <+> pretty pts + pretty (TyVarRecord fs) = pretty $ Scalar $ Record fs + pretty (TyVarSum cs) = pretty $ Scalar $ Sum cs type TyVar = VName -- | If a VName is not in this map, it is assumed to be rigid. -type TyVars = M.Map TyVar TyVarSub +type TyVars = M.Map TyVar TyVarInfo data TermScope = TermScope { scopeVtable :: M.Map VName ValBinding, @@ -249,33 +245,30 @@ incCounter = do tyVarType :: (Monoid u) => TyVar -> TypeBase dim u tyVarType v = Scalar $ TypeVar mempty (qualName v) [] -newTyVar :: a -> Name -> TermM TyVar -newTyVar loc desc = do +newTyVarWith :: a -> Name -> TyVarInfo -> TermM TyVar +newTyVarWith loc desc info = do i <- incCounter v <- newID $ mkTypeVarName desc i - modify $ \s -> s {termTyVars = M.insert v TyVarFree $ termTyVars s} + modify $ \s -> s {termTyVars = M.insert v info $ termTyVars s} pure v +newTyVar :: a -> Name -> TermM TyVar +newTyVar loc desc = newTyVarWith loc desc TyVarFree + newType :: (Located loc, Monoid u) => loc -> Name -> TermM (TypeBase dim u) newType loc desc = tyVarType <$> newTyVar loc desc newTypeWithField :: (Monoid u) => SrcLoc -> Name -> Name -> TypeBase dim u -> TermM (TypeBase dim u) -newTypeWithField loc desc k t = do - rt <- newType loc desc - addCt $ CtHasField (toType rt) k (toType t) - pure rt +newTypeWithField loc desc k t = + tyVarType <$> newTyVarWith loc desc (TyVarRecord $ M.singleton k $ toType t) newTypeWithConstr :: (Monoid u) => SrcLoc -> Name -> Name -> [Type] -> TermM (TypeBase dim u) -newTypeWithConstr loc desc k ts = do - t <- newType loc desc - addCt $ CtHasConstr (toType t) k ts - pure t +newTypeWithConstr loc desc k ts = + tyVarType <$> newTyVarWith loc desc (TyVarSum $ M.singleton k ts) newTypeOverloaded :: (Monoid u) => SrcLoc -> Name -> [PrimType] -> TermM (TypeBase dim u) -newTypeOverloaded loc name pts = do - t <- newType loc name - addCt $ CtOneOf (toType t) pts - pure t +newTypeOverloaded loc name pts = + tyVarType <$> newTyVarWith loc name (TyVarPrim pts) addCt :: Ct -> TermM () addCt ct = modify $ \s -> s {termConstraints = ct : termConstraints s} @@ -1041,9 +1034,6 @@ checkExp (Coerce e te NoInfo loc) = do ctEq (typeOf e') st pure $ Coerce e' te' (Info (toStruct st)) loc --- --- - checkValDef :: ( VName, Maybe (TypeExp NoInfo VName), @@ -1068,8 +1058,8 @@ checkValDef (fname, maybe_retdecl, tparams, params, body, loc) = runTermM $ do unlines [ "function " <> prettyNameString fname, "constraints:", - prettyString cts, + unlines $ map prettyString cts, "tyvars:", - prettyString $ map (first prettyNameString) $ M.toList tyvars + unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars ] pure (undefined, params', undefined, undefined, body') From c032bcff60489780acdf258d1c2a3049cc95d2f4 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 6 Feb 2024 13:10:47 +0100 Subject: [PATCH 004/258] Make some room for a solver. --- futhark.cabal | 1 + .../Futhark/TypeChecker/Constraints.hs | 78 +++++++++++++++++ src/Language/Futhark/TypeChecker/Terms2.hs | 83 ++++++++----------- 3 files changed, 114 insertions(+), 48 deletions(-) create mode 100644 src/Language/Futhark/TypeChecker/Constraints.hs diff --git a/futhark.cabal b/futhark.cabal index d9918a6554..c89eacb214 100644 --- a/futhark.cabal +++ b/futhark.cabal @@ -409,6 +409,7 @@ library Language.Futhark.Tuple Language.Futhark.TypeChecker Language.Futhark.TypeChecker.Consumption + Language.Futhark.TypeChecker.Constraints Language.Futhark.TypeChecker.Names Language.Futhark.TypeChecker.Match Language.Futhark.TypeChecker.Modules diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs new file mode 100644 index 0000000000..e33623704d --- /dev/null +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -0,0 +1,78 @@ +module Language.Futhark.TypeChecker.Constraints + ( Type, + Ct (..), + Constraints, + TyVarInfo (..), + TyVar, + TyVars, + solve, + ) +where + +import Control.Monad.Except +import Control.Monad.State +import Data.Bifunctor +import Data.Map qualified as M +import Data.Text qualified as T +import Futhark.Util.Pretty +import Language.Futhark + +type Type = TypeBase () NoUniqueness + +data Ct = CtEq Type Type + deriving (Show) + +instance Pretty Ct where + pretty (CtEq t1 t2) = pretty t1 <+> "~" <+> pretty t2 + +type Constraints = [Ct] + +-- | Information about a type variable. +data TyVarInfo + = -- | Can be substituted with anything. + TyVarFree + | -- | Can only be substituted with these primitive types. + TyVarPrim [PrimType] + | -- | Must be a record with these fields. + TyVarRecord (M.Map Name Type) + | -- | Must be a sum type with these fields. + TyVarSum (M.Map Name [Type]) + deriving (Show) + +instance Pretty TyVarInfo where + pretty TyVarFree = "free" + pretty (TyVarPrim pts) = "∈" <+> pretty pts + pretty (TyVarRecord fs) = pretty $ Scalar $ Record fs + pretty (TyVarSum cs) = pretty $ Scalar $ Sum cs + +type TyVar = VName + +-- | If a VName is not in this map, it is assumed to be rigid. +type TyVars = M.Map TyVar TyVarInfo + +data TyVarSol + = -- | Has been substituted with this. + TyVarSol Type + | -- | Not substituted yet; has this constraint. + TyVarUnsol TyVarInfo + deriving (Show) + +newtype SolverState = SolverState {solverTyVars :: M.Map TyVar TyVarSol} + +initialState :: TyVars -> SolverState +initialState tyvars = SolverState $ M.map TyVarUnsol tyvars + +solution :: SolverState -> M.Map TyVar Type +solution = undefined + +newtype SolveM a = SolveM {runSolveM :: StateT SolverState (Except T.Text) a} + deriving (Functor, Applicative, Monad, MonadState SolverState, MonadError T.Text) + +solve :: Constraints -> TyVars -> Either T.Text (M.Map TyVar Type) +solve constraints tyvars = + second solution + . runExcept + . flip execStateT (initialState tyvars) + . runSolveM + $ throwError "cannot solve" +{-# NOINLINE solve #-} diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 036338bc1d..eddc209b1a 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -56,6 +56,7 @@ import Futhark.MonadFreshNames hiding (newName) import Futhark.Util (mapAccumLM) import Futhark.Util.Pretty import Language.Futhark +import Language.Futhark.TypeChecker.Constraints import Language.Futhark.TypeChecker.Monad hiding (BoundV, lookupMod) import Language.Futhark.TypeChecker.Monad qualified as TypeM import Language.Futhark.TypeChecker.Types @@ -76,45 +77,12 @@ data ValBinding | EqualityF deriving (Show) -type Type = TypeBase () NoUniqueness - toType :: TypeBase d u -> Type toType = bimap (const ()) (const NoUniqueness) expType :: Exp -> Type expType = toType . typeOf -data Ct = CtEq Type Type - deriving (Show) - -instance Pretty Ct where - pretty (CtEq t1 t2) = pretty t1 <+> "~" <+> pretty t2 - -type Constraints = [Ct] - --- | Information about a type variable. -data TyVarInfo - = -- | Can be substituted with anything. - TyVarFree - | -- | Can only be substituted with these primitive types. - TyVarPrim [PrimType] - | -- | Must be a record with these fields. - TyVarRecord (M.Map Name Type) - | -- | Must be a sum type with these fields. - TyVarSum (M.Map Name [Type]) - deriving (Show) - -instance Pretty TyVarInfo where - pretty TyVarFree = "free" - pretty (TyVarPrim pts) = "∈" <+> pretty pts - pretty (TyVarRecord fs) = pretty $ Scalar $ Record fs - pretty (TyVarSum cs) = pretty $ Scalar $ Sum cs - -type TyVar = VName - --- | If a VName is not in this map, it is assumed to be rigid. -type TyVars = M.Map TyVar TyVarInfo - data TermScope = TermScope { scopeVtable :: M.Map VName ValBinding, scopeTypeTable :: M.Map VName TypeBinding, @@ -270,6 +238,18 @@ newTypeOverloaded :: (Monoid u) => SrcLoc -> Name -> [PrimType] -> TermM (TypeBa newTypeOverloaded loc name pts = tyVarType <$> newTyVarWith loc name (TyVarPrim pts) +asStructType :: (Monoid u) => SrcLoc -> TypeBase d u -> TermM (TypeBase Size u) +asStructType _ (Scalar (Prim pt)) = pure $ Scalar $ Prim pt +asStructType _ (Scalar (TypeVar u v [])) = pure $ Scalar $ TypeVar u v [] +asStructType loc (Scalar (Arrow u pname d t1 (RetType ext t2))) = do + t1' <- asStructType loc t1 + t2' <- asStructType loc t2 + pure $ Scalar $ Arrow u pname d t1' $ RetType ext t2' +asStructType loc t = do + t' <- newType loc "artificial" + ctEq t' t + pure t' + addCt :: Ct -> TermM () addCt ct = modify $ \s -> s {termConstraints = ct : termConstraints s} @@ -613,13 +593,16 @@ bindParams tps orig_ps m = bindTypeParams tps $ do incLevel $ descend [] orig_ps -checkApply :: SrcLoc -> (Maybe (QualName VName), Int) -> Type -> Exp -> TermM TyVar +checkApply :: SrcLoc -> (Maybe (QualName VName), Int) -> Type -> Exp -> TermM Type +checkApply loc (fname, _) (Scalar (Arrow _ _ _ a (RetType _ b))) arg = do + ctEq a $ expType arg + pure $ toType b checkApply loc (fname, _) ftype arg = do - a <- newType loc "a" - b <- newTyVar loc "b" + a <- newType loc "arg" + b <- newTyVar loc "res" ctEq ftype $ Scalar $ Arrow NoUniqueness Unnamed Observe a $ RetType [] (tyVarType b) ctEq a (expType arg) - pure b + pure $ tyVarType b checkSlice :: SliceBase NoInfo VName -> TermM [DimIndex] checkSlice = mapM checkDimIndex @@ -794,9 +777,9 @@ checkExp (Constr name es NoInfo loc) = do -- checkExp (AppExp (Apply fe args loc) NoInfo) = do fe' <- checkExp fe - ((_, rt), args') <- mapAccumLM onArg (0, typeOf fe') args - - pure $ AppExp (Apply fe' args' loc) $ Info $ AppRes rt [] + ((_, rt), args') <- mapAccumLM onArg (0, expType fe') args + rt' <- asStructType loc rt + pure $ AppExp (Apply fe' args' loc) $ Info $ AppRes rt' [] where fname = case fe of @@ -807,7 +790,7 @@ checkExp (AppExp (Apply fe args loc) NoInfo) = do arg' <- checkExp arg rt <- checkApply loc (fname, i) (toType f_t) arg' pure - ( (i + 1, tyVarType rt), + ( (i + 1, rt), (Info Nothing, arg') ) -- @@ -817,17 +800,19 @@ checkExp (AppExp (BinOp (op, oploc) NoInfo (e1, _) (e2, _) loc) NoInfo) = do e2' <- checkExp e2 rt1 <- checkApply loc (Just op, 0) (toType ftype) e1' - rt2 <- checkApply loc (Just op, 1) (tyVarType rt1) e2' + rt2 <- checkApply loc (Just op, 1) rt1 e2' + rt2' <- asStructType loc rt2 pure $ AppExp (BinOp (op, oploc) (Info ftype) (e1', Info Nothing) (e2', Info Nothing) loc) - (Info (AppRes (tyVarType rt2) [])) + (Info (AppRes rt2' [])) -- checkExp (OpSectionLeft op _ e _ _ loc) = do optype <- lookupVar loc op e' <- checkExp e rt <- checkApply loc (Just op, 0) (toType optype) e' + rt' <- asStructType loc rt pure $ OpSectionLeft op @@ -837,7 +822,7 @@ checkExp (OpSectionLeft op _ e _ _ loc) = do ( Info (Unnamed, Scalar $ Prim Bool, Nothing), Info (Unnamed, Scalar $ Prim Bool) ) - (Info (RetType [] (tyVarType rt)), Info []) + (Info (RetType [] $ toRes Nonunique rt'), Info []) loc -- checkExp (OpSectionRight op _ e _ NoInfo loc) = do @@ -1056,10 +1041,12 @@ checkValDef (fname, maybe_retdecl, tparams, params, body, loc) = runTermM $ do tyvars <- gets termTyVars traceM $ unlines - [ "function " <> prettyNameString fname, - "constraints:", + [ "# function " <> prettyNameString fname, + "## constraints:", unlines $ map prettyString cts, - "tyvars:", - unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars + "## tyvars:", + unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars, + "## solution:", + either T.unpack (unlines . map (prettyString . first prettyNameString) . M.toList) $ solve cts tyvars ] pure (undefined, params', undefined, undefined, body') From 2b08f83959bd47f32f8452104a76cfa6dde137b6 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 7 Feb 2024 13:52:03 +0100 Subject: [PATCH 005/258] Comment for Robert. --- src/Language/Futhark/TypeChecker/Constraints.hs | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index e33623704d..dfcb2d3241 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -17,7 +17,15 @@ import Data.Text qualified as T import Futhark.Util.Pretty import Language.Futhark -type Type = TypeBase () NoUniqueness +-- | A shape component is currently just unit. The rank of an array is +-- then just the number of shape components it contains in its shape +-- list. When we add AUTOMAP, these components will also allow shape +-- variables. The list of components should then be understood as +-- concatenation of shapes (meaning you can't just take the length to +-- determine the rank of the array). +type SComp = () + +type Type = TypeBase SComp NoUniqueness data Ct = CtEq Type Type deriving (Show) From 0399a6d5ee9b7cbfc8e40481bb3473de43dbebe8 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 7 Feb 2024 14:46:00 +0100 Subject: [PATCH 006/258] WIP in solving. --- .../Futhark/TypeChecker/Constraints.hs | 69 ++++++++++++++++++- src/Language/Futhark/TypeChecker/Types.hs | 2 +- 2 files changed, 68 insertions(+), 3 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index dfcb2d3241..f50e570e8f 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -61,6 +61,8 @@ type TyVars = M.Map TyVar TyVarInfo data TyVarSol = -- | Has been substituted with this. TyVarSol Type + | -- | Replaced by this other type variable. + TyVarLink VName | -- | Not substituted yet; has this constraint. TyVarUnsol TyVarInfo deriving (Show) @@ -70,17 +72,80 @@ newtype SolverState = SolverState {solverTyVars :: M.Map TyVar TyVarSol} initialState :: TyVars -> SolverState initialState tyvars = SolverState $ M.map TyVarUnsol tyvars +substTyVars :: (Monoid u) => M.Map TyVar TyVarSol -> TypeBase SComp u -> TypeBase SComp u +substTyVars m t@(Scalar (TypeVar u (QualName qs v) args)) = + case M.lookup v m of + Just (TyVarLink v') -> + substTyVars m $ Scalar $ TypeVar u (QualName qs v') args + Just (TyVarSol t') -> second (const mempty) t' + Just (TyVarUnsol _) -> t + Nothing -> t +substTyVars _ (Scalar (Prim pt)) = Scalar $ Prim pt +substTyVars m (Scalar (Record fs)) = Scalar $ Record $ M.map (substTyVars m) fs +substTyVars m (Scalar (Sum cs)) = Scalar $ Sum $ M.map (map $ substTyVars m) cs +substTyVars m (Scalar (Arrow u pname d t1 (RetType ext t2))) = + Scalar $ Arrow u pname d (substTyVars m t1) $ RetType ext $ substTyVars m t2 +substTyVars m (Array u shape elemt) = + arrayOfWithAliases u shape $ substTyVars m $ Scalar elemt + solution :: SolverState -> M.Map TyVar Type -solution = undefined +solution s = M.mapMaybe f $ solverTyVars s + where + f (TyVarSol t) = Just $ substTyVars (solverTyVars s) t + f (TyVarLink v) = f =<< M.lookup v (solverTyVars s) + f (TyVarUnsol _) = Nothing newtype SolveM a = SolveM {runSolveM :: StateT SolverState (Except T.Text) a} deriving (Functor, Applicative, Monad, MonadState SolverState, MonadError T.Text) +subTyVar :: VName -> Type -> SolveM () +subTyVar v t = + modify $ \s -> s {solverTyVars = M.insert v (TyVarSol t) $ solverTyVars s} + +linkTyVar :: VName -> VName -> SolveM () +linkTyVar v t = + modify $ \s -> s {solverTyVars = M.insert v (TyVarLink t) $ solverTyVars s} + +unify :: Type -> Type -> Maybe [(Type, Type)] +unify (Scalar (Prim pt1)) (Scalar (Prim pt2)) + | pt1 == pt2 = Just [] +unify _ _ = Nothing + +solveCt :: Ct -> SolveM () +solveCt ct = do + let CtEq t1 t2 = ct + solveCt' (t1, t2) + where + bad = throwError $ "Unsolvable: " <> prettyText ct + solveCt' (t1, t2) = do + tyvars <- gets solverTyVars + let flexible v = case M.lookup v tyvars of + Just (TyVarLink v') -> flexible v' + Just (TyVarUnsol _) -> True + Just (TyVarSol _) -> False + Nothing -> False + case (t1, t2) of + ( Scalar (TypeVar _ (QualName [] v1) []), + Scalar (TypeVar _ (QualName [] v2) []) + ) -> + case (flexible v1, flexible v2) of + (False, False) -> bad + (True, False) -> subTyVar v1 t2 + (False, True) -> subTyVar v2 t1 + (True, True) -> linkTyVar v1 v2 + (Scalar (TypeVar _ (QualName [] v1) []), _) -> + if flexible v1 then subTyVar v1 t2 else bad + (_, Scalar (TypeVar _ (QualName [] v2) [])) -> + if flexible v2 then subTyVar v2 t1 else bad + _ -> case unify t1 t2 of + Nothing -> bad + Just eqs -> mapM_ solveCt' eqs + solve :: Constraints -> TyVars -> Either T.Text (M.Map TyVar Type) solve constraints tyvars = second solution . runExcept . flip execStateT (initialState tyvars) . runSolveM - $ throwError "cannot solve" + $ mapM solveCt constraints {-# NOINLINE solve #-} diff --git a/src/Language/Futhark/TypeChecker/Types.hs b/src/Language/Futhark/TypeChecker/Types.hs index 5c8673114b..5a364464b0 100644 --- a/src/Language/Futhark/TypeChecker/Types.hs +++ b/src/Language/Futhark/TypeChecker/Types.hs @@ -472,7 +472,7 @@ substTypesRet lookupSubst ot = onType (Array u shape et) = arrayOfWithAliases u (applySubst lookupSubst' shape) - <$> onType (second (const mempty) $ Scalar et) + <$> onType (Scalar et) onType (Scalar (Prim t)) = pure $ Scalar $ Prim t onType (Scalar (TypeVar u v targs)) = do targs' <- mapM subsTypeArg targs From c03d00d73d673a92f9db06c8f219ee8e08323d06 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 7 Feb 2024 16:55:56 +0100 Subject: [PATCH 007/258] Basically functional solver. --- src/Language/Futhark/Pretty.hs | 7 +- .../Futhark/TypeChecker/Constraints.hs | 59 +++++++++++---- src/Language/Futhark/TypeChecker/Terms2.hs | 72 ++++++++++--------- 3 files changed, 89 insertions(+), 49 deletions(-) diff --git a/src/Language/Futhark/Pretty.hs b/src/Language/Futhark/Pretty.hs index 4c91f277f7..c5134a4acd 100644 --- a/src/Language/Futhark/Pretty.hs +++ b/src/Language/Futhark/Pretty.hs @@ -7,6 +7,7 @@ module Language.Futhark.Pretty prettyTuple, leadingOperator, IsName (..), + prettyNameText, prettyNameString, Annot (..), ) @@ -55,9 +56,13 @@ instance IsName Name where prettyName = pretty toName = id +-- | Prettyprint name as text. +prettyNameText :: (IsName v) => v -> T.Text +prettyNameText = docText . prettyName + -- | Prettyprint name as string. Only use this for debugging. prettyNameString :: (IsName v) => v -> String -prettyNameString = T.unpack . docText . prettyName +prettyNameString = T.unpack . prettyNameText -- | Class for type constructors that represent annotations. Used in -- the prettyprinter to either print the original AST, or the computed diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index f50e570e8f..74a2dd2fd0 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -1,5 +1,6 @@ module Language.Futhark.TypeChecker.Constraints ( Type, + toType, Ct (..), Constraints, TyVarInfo (..), @@ -25,8 +26,13 @@ import Language.Futhark -- determine the rank of the array). type SComp = () +-- | The type representation used by the constraint solver. Agnostic +-- to sizes. type Type = TypeBase SComp NoUniqueness +toType :: TypeBase d u -> Type +toType = bimap (const ()) (const NoUniqueness) + data Ct = CtEq Type Type deriving (Show) @@ -106,9 +112,26 @@ linkTyVar :: VName -> VName -> SolveM () linkTyVar v t = modify $ \s -> s {solverTyVars = M.insert v (TyVarLink t) $ solverTyVars s} +-- Unify at the root, emitting new equalities that must hold. unify :: Type -> Type -> Maybe [(Type, Type)] unify (Scalar (Prim pt1)) (Scalar (Prim pt2)) | pt1 == pt2 = Just [] +unify (Scalar (Arrow _ _ _ t1a (RetType _ t1r))) (Scalar (Arrow _ _ _ t2a (RetType _ t2r))) = + Just [(t1a, t2a), (toType t1r, toType t2r)] +unify (Scalar (Record fs1)) (Scalar (Record fs2)) + | M.keys fs1 == M.keys fs2 = + Just $ M.elems $ M.intersectionWith (,) fs1 fs2 +unify (Scalar (Sum cs1)) (Scalar (Sum cs2)) + | M.keys cs1 == M.keys cs2 = do + fmap concat + . forM (M.elems $ M.intersectionWith (,) cs1 cs2) + $ \(ts1, ts2) -> do + guard $ length ts1 == length ts2 + Just $ zip ts1 ts2 +unify t1 t2 + | Just t1' <- peelArray 1 t1, + Just t2' <- peelArray 1 t2 = + Just [(t1', t2')] unify _ _ = Nothing solveCt :: Ct -> SolveM () @@ -124,20 +147,28 @@ solveCt ct = do Just (TyVarUnsol _) -> True Just (TyVarSol _) -> False Nothing -> False - case (t1, t2) of - ( Scalar (TypeVar _ (QualName [] v1) []), - Scalar (TypeVar _ (QualName [] v2) []) - ) -> - case (flexible v1, flexible v2) of - (False, False) -> bad - (True, False) -> subTyVar v1 t2 - (False, True) -> subTyVar v2 t1 - (True, True) -> linkTyVar v1 v2 - (Scalar (TypeVar _ (QualName [] v1) []), _) -> - if flexible v1 then subTyVar v1 t2 else bad - (_, Scalar (TypeVar _ (QualName [] v2) [])) -> - if flexible v2 then subTyVar v2 t1 else bad - _ -> case unify t1 t2 of + sub t@(Scalar (TypeVar u (QualName [] v) [])) = + case M.lookup v tyvars of + Just (TyVarLink v') -> sub $ Scalar (TypeVar u (QualName [] v') []) + Just (TyVarSol t') -> sub t' + _ -> t + sub t = t + case (sub t1, sub t2) of + ( t1'@(Scalar (TypeVar _ (QualName [] v1) [])), + t2'@(Scalar (TypeVar _ (QualName [] v2) [])) + ) + | v1 == v2 -> pure () + | otherwise -> + case (flexible v1, flexible v2) of + (False, False) -> bad + (True, False) -> subTyVar v1 t2' + (False, True) -> subTyVar v2 t1' + (True, True) -> linkTyVar v1 v2 + (Scalar (TypeVar _ (QualName [] v1) []), t2') -> + if flexible v1 then subTyVar v1 t2' else bad + (t1', Scalar (TypeVar _ (QualName [] v2) [])) -> + if flexible v2 then subTyVar v2 t1' else bad + (t1', t2') -> case unify t1' t2' of Nothing -> bad Just eqs -> mapM_ solveCt' eqs diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index eddc209b1a..b1b83f6f2b 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -77,9 +77,6 @@ data ValBinding | EqualityF deriving (Show) -toType :: TypeBase d u -> Type -toType = bimap (const ()) (const NoUniqueness) - expType :: Exp -> Type expType = toType . typeOf @@ -254,7 +251,13 @@ addCt :: Ct -> TermM () addCt ct = modify $ \s -> s {termConstraints = ct : termConstraints s} ctEq :: TypeBase d1 u1 -> TypeBase d2 u2 -> TermM () -ctEq t1 t2 = addCt $ CtEq (toType t1) (toType t2) +ctEq t1 t2 = + -- As a minor optimisation, do not add constraint if the types are + -- equal. + unless (t1' == t2') $ addCt $ CtEq t1' t2' + where + t1' = toType t1 + t2' = toType t2 localScope :: (TermScope -> TermScope) -> TermM a -> TermM a localScope f = local $ \tenv -> tenv {termScope = f $ termScope tenv} @@ -287,7 +290,7 @@ instance MonadError TypeError TermM where f' (_, e) = let TermM m' = f e in m' instance MonadTypeChecker TermM where - checkExpForSize = checkExp + checkExpForSize = require "use as size" [Signed Int64] <=< checkExp warnings ws = modify $ \s -> s {termWarnings = termWarnings s <> ws} @@ -325,6 +328,10 @@ instance MonadTypeChecker TermM where --- All the general machinery goes above. +arrayOfRank :: Int -> Type -> Type +arrayOfRank 0 t = t +arrayOfRank n t = arrayOf (Shape $ replicate n ()) t + require :: T.Text -> [PrimType] -> Exp -> TermM Exp require why pts e = do t :: Type <- newTypeOverloaded (srclocOf e) "t" pts @@ -365,8 +372,8 @@ instTypeScheme qn loc tparams t = do TypeParamType x _ _ -> do i <- incCounter let name = nameFromString (takeWhile isAscii (baseString (typeParamName tparam))) - v <- newID $ mkTypeVarName name i - pure $ Just (v, (typeParamName tparam, Subst [] $ RetType [] $ Scalar $ TypeVar mempty (qualName v) [])) + v <- newTyVar loc name + pure $ Just (v, (typeParamName tparam, Subst [] $ RetType [] $ tyVarType v)) TypeParamDim {} -> pure Nothing let t' = applySubst (`lookup` substs) t @@ -394,13 +401,7 @@ lookupVar loc qn@(QualName qs name) = do pure $ qualifyTypeVars outer_env tnames qs t' Just EqualityF -> do argtype <- newType loc "t" - pure $ - Scalar . Arrow mempty Unnamed Observe argtype . RetType [] $ - Scalar $ - Arrow mempty Unnamed Observe argtype $ - RetType [] $ - Scalar $ - Prim Bool + pure $ foldFunType [argtype, argtype] $ RetType [] $ Scalar $ Prim Bool Just (OverloadedF ts pts rt) -> do argtype <- newTypeOverloaded loc "t" ts let (pts', rt') = instOverloaded (argtype :: StructType) pts rt @@ -811,34 +812,37 @@ checkExp (AppExp (BinOp (op, oploc) NoInfo (e1, _) (e2, _) loc) NoInfo) = do checkExp (OpSectionLeft op _ e _ _ loc) = do optype <- lookupVar loc op e' <- checkExp e - rt <- checkApply loc (Just op, 0) (toType optype) e' - rt' <- asStructType loc rt + void $ checkApply loc (Just op, 0) (toType optype) e' + let t1 = typeOf e' + t2 <- newType loc "t2" + rt <- newType loc "rt" + ctEq optype $ foldFunType [toParam Observe t1, t2] $ RetType [] rt pure $ OpSectionLeft op (Info optype) e' - -- Dummy types. - ( Info (Unnamed, Scalar $ Prim Bool, Nothing), - Info (Unnamed, Scalar $ Prim Bool) + ( Info (Unnamed, toParam Observe t1, Nothing), + Info (Unnamed, t2) ) - (Info (RetType [] $ toRes Nonunique rt'), Info []) + (Info (RetType [] rt), Info []) loc -- checkExp (OpSectionRight op _ e _ NoInfo loc) = do optype <- lookupVar loc op e' <- checkExp e t1 <- newType loc "t" + let t2 = typeOf e' rt <- newType loc "rt" - ctEq optype $ foldFunType [t1, toParam Observe $ typeOf e'] $ RetType [] rt + ctEq optype $ foldFunType [t1, toParam Observe t2] $ RetType [] rt pure $ OpSectionRight op (Info optype) e' -- Dummy types. - ( Info (Unnamed, Scalar $ Prim Bool), - Info (Unnamed, Scalar $ Prim Bool, Nothing) + ( Info (Unnamed, toParam Observe t1), + Info (Unnamed, toParam Observe t2, Nothing) ) (Info $ RetType [] rt) loc @@ -901,7 +905,7 @@ checkExp (AppExp (Range start maybe_step end loc) _) = do maybe_step' <- traverse checkExp' maybe_step end' <- traverse checkExp' end range_t <- newType loc "range" - ctEq range_t $ arrayOf (Shape [()]) (toType (typeOf start')) + ctEq range_t $ arrayOfRank 1 (toType (typeOf start')) pure $ AppExp (Range start' maybe_step' end' loc) $ Info $ AppRes range_t [] where checkExp' = require "use in range expression" anyIntType <=< checkExp @@ -925,8 +929,8 @@ checkExp (IndexSection slice NoInfo loc) = do index_elem_t <- newType loc "index_elem" index_res_t <- newType loc "index_res" let num_slices = length $ filter isSlice slice - ctEq index_arg_t $ arrayOf (Shape (replicate num_slices ())) index_elem_t - ctEq index_res_t $ arrayOf (Shape (replicate (length slice) ())) index_elem_t + ctEq index_arg_t $ arrayOfRank num_slices index_elem_t + ctEq index_res_t $ arrayOfRank (length slice) index_elem_t let ft = Scalar $ Arrow mempty Unnamed Observe index_arg_t $ RetType [] index_res_t pure $ IndexSection slice' (Info ft) loc -- @@ -936,8 +940,8 @@ checkExp (AppExp (Index e slice loc) _) = do index_t <- newType loc "index" index_elem_t <- newType loc "index_elem" let num_slices = length $ filter isSlice slice - ctEq index_t $ arrayOf (Shape (replicate num_slices ())) index_elem_t - ctEq (typeOf e') $ arrayOf (Shape (replicate (length slice) ())) index_elem_t + ctEq index_t $ arrayOfRank num_slices index_elem_t + ctEq (typeOf e') $ arrayOfRank (length slice) index_elem_t pure $ AppExp (Index e' slice' loc) (Info $ AppRes index_t []) -- checkExp (Update src slice ve loc) = do @@ -946,8 +950,8 @@ checkExp (Update src slice ve loc) = do ve' <- checkExp ve let num_slices = length $ filter isSlice slice update_elem_t <- newType loc "update_elem" - ctEq (typeOf src') $ arrayOf (Shape (replicate (length slice) ())) update_elem_t - ctEq (typeOf ve') $ arrayOf (Shape (replicate num_slices ())) update_elem_t + ctEq (typeOf src') $ arrayOfRank (length slice) update_elem_t + ctEq (typeOf ve') $ arrayOfRank num_slices update_elem_t pure $ Update src' slice' ve' loc -- checkExp (AppExp (LetWith dest src slice ve body loc) _) = do @@ -958,8 +962,8 @@ checkExp (AppExp (LetWith dest src slice ve body loc) _) = do ve' <- checkExp ve let num_slices = length $ filter isSlice slice update_elem_t <- newType loc "update_elem" - ctEq src_t $ arrayOf (Shape (replicate (length slice) ())) update_elem_t - ctEq (typeOf ve') $ arrayOf (Shape (replicate num_slices ())) update_elem_t + ctEq src_t $ arrayOfRank (length slice) update_elem_t + ctEq (typeOf ve') $ arrayOfRank num_slices update_elem_t bind [dest'] $ do body' <- checkExp body pure $ AppExp (LetWith dest' src' slice' ve' body' loc) (Info $ AppRes (typeOf body') []) @@ -972,7 +976,7 @@ checkExp (AppExp (If e1 e2 e3 loc) _) = do ctEq (typeOf e1') (Scalar (Prim Bool) :: Type) ctEq (typeOf e2') (typeOf e3') - pure $ AppExp (If e1' e2' e3' loc) (Info $ AppRes (typeOf e1') []) + pure $ AppExp (If e1' e2' e3' loc) (Info $ AppRes (typeOf e2') []) -- checkExp (AppExp (Match e cs loc) _) = do e' <- checkExp e @@ -999,7 +1003,7 @@ checkExp (AppExp (Loop _ pat arg form body loc) _) = do ForIn elemp arr -> do arr' <- checkExp arr elem_t <- newType elemp "elem" - ctEq (typeOf arr') $ arrayOf (Shape [()]) (toType elem_t) + ctEq (typeOf arr') $ arrayOfRank 1 (toType elem_t) bindLetPat elemp elem_t $ \elemp' -> do body' <- checkExp body pure (ForIn (toStruct <$> elemp') arr', body') From 0cd1562bf8192500669df9f36b70a3f76acbcf4d Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 9 Feb 2024 15:49:09 +0100 Subject: [PATCH 008/258] Check return ascriptions. --- src/Language/Futhark/TypeChecker/Terms2.hs | 39 +++++++++++++--------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index b1b83f6f2b..49e2b9ae99 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -370,7 +370,6 @@ instTypeScheme qn loc tparams t = do (names, substs) <- fmap (unzip . catMaybes) $ forM tparams $ \tparam -> do case tparam of TypeParamType x _ _ -> do - i <- incCounter let name = nameFromString (takeWhile isAscii (baseString (typeParamName tparam))) v <- newTyVar loc name pure $ Just (v, (typeParamName tparam, Subst [] $ RetType [] $ tyVarType v)) @@ -686,6 +685,13 @@ instance Pretty (Unmatched (Pat StructType)) where pretty' (PatLit e _ _) = pretty e pretty' (PatConstr n _ ps _) = "#" <> pretty n <+> sep (map pretty' ps) +checkRetDecl :: Exp -> Maybe (TypeExp NoInfo VName) -> TermM (Maybe (TypeExp Info VName)) +checkRetDecl _ Nothing = pure Nothing +checkRetDecl body (Just te) = do + (te', _, RetType _ st, _) <- checkTypeExp te + ctEq (typeOf body) st + pure $ Just te' + checkExp :: ExpBase NoInfo VName -> TermM (ExpBase Info VName) -- checkExp (Var qn _ loc) = do @@ -854,17 +860,12 @@ checkExp (ProjectSection fields NoInfo loc) = do let ft = Scalar $ Arrow mempty Unnamed Observe a $ RetType [] b pure $ ProjectSection fields (Info ft) loc -- -checkExp (Lambda params body rettype NoInfo loc) = do +checkExp (Lambda params body retdecl NoInfo loc) = do bindParams [] params $ \params' -> do body' <- checkExp body - rettype_te' <- case rettype of - Just rettype_te -> do - (rettype_te', _, RetType _ st, _) <- checkTypeExp rettype_te - ctEq (typeOf body') st - pure $ Just rettype_te' - Nothing -> pure Nothing + retdecl' <- checkRetDecl body' retdecl let ret = RetType [] $ toRes Nonunique $ typeOf body' - pure $ Lambda params' body' rettype_te' (Info ret) loc + pure $ Lambda params' body' retdecl' (Info ret) loc -- checkExp (AppExp (LetPat sizes pat e body loc) _) = do e' <- checkExp e @@ -876,12 +877,13 @@ checkExp (AppExp (LetPat sizes pat e body loc) _) = do (LetPat sizes (fmap toStruct pat') e' body' loc) (Info $ AppRes (typeOf body') []) -- -checkExp (AppExp (LetFun name (tparams, params, maybe_retdecl, NoInfo, e) body loc) _) = do - (tparams', params', maybe_retdecl', rettype, e') <- +checkExp (AppExp (LetFun name (tparams, params, retdecl, NoInfo, e) body loc) _) = do + (tparams', params', retdecl', rettype, e') <- bindParams tparams params $ \params' -> do e' <- checkExp e let ret = RetType [] $ toRes Nonunique $ typeOf e' - pure (tparams, params', undefined, ret, e') + retdecl' <- checkRetDecl e' retdecl + pure (tparams, params', retdecl', ret, e') let entry = BoundV tparams' $ funType params' rettype bindF scope = @@ -894,7 +896,7 @@ checkExp (AppExp (LetFun name (tparams, params, maybe_retdecl, NoInfo, e) body l AppExp ( LetFun name - (tparams', params', maybe_retdecl', Info rettype, e') + (tparams', params', retdecl', Info rettype, e') body' loc ) @@ -1035,14 +1037,19 @@ checkValDef :: ( [TypeParam], [Pat ParamType], Maybe (TypeExp Info VName), - ResRetType, Exp ) -checkValDef (fname, maybe_retdecl, tparams, params, body, loc) = runTermM $ do +checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do bindParams tparams params $ \params' -> do body' <- checkExp body + + retdecl' <- checkRetDecl body' retdecl + cts <- gets termConstraints tyvars <- gets termTyVars + + let solution = solve cts tyvars + traceM $ unlines [ "# function " <> prettyNameString fname, @@ -1053,4 +1060,4 @@ checkValDef (fname, maybe_retdecl, tparams, params, body, loc) = runTermM $ do "## solution:", either T.unpack (unlines . map (prettyString . first prettyNameString) . M.toList) $ solve cts tyvars ] - pure (undefined, params', undefined, undefined, body') + pure (undefined, params', retdecl', body') From 7de2678e7c7bf8d3fdd1ee8ec76724db757ac828 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 9 Feb 2024 15:53:28 +0100 Subject: [PATCH 009/258] Clean up things. --- src/Language/Futhark/TypeChecker/Terms2.hs | 25 +++------------------- 1 file changed, 3 insertions(+), 22 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 49e2b9ae99..ee1b04da01 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -217,7 +217,7 @@ newTyVarWith loc desc info = do modify $ \s -> s {termTyVars = M.insert v info $ termTyVars s} pure v -newTyVar :: a -> Name -> TermM TyVar +newTyVar :: (Located loc) => loc -> Name -> TermM TyVar newTyVar loc desc = newTyVarWith loc desc TyVarFree newType :: (Located loc, Monoid u) => loc -> Name -> TermM (TypeBase dim u) @@ -338,25 +338,6 @@ require why pts e = do ctEq t $ expType e pure e --- | Create a new type name and insert it (unconstrained) in the set --- of type variables. -instTypeParam :: - (Monoid as) => - QualName VName -> - SrcLoc -> - TypeParam -> - TermM (VName, Subst (RetTypeBase dim as)) -instTypeParam qn loc tparam = do - i <- incCounter - let name = nameFromString (takeWhile isAscii (baseString (typeParamName tparam))) - v <- newID $ mkTypeVarName name i - case tparam of - TypeParamType {} -> do - modify $ \s -> s {termTyVars = M.insert v TyVarFree $ termTyVars s} - pure (v, Subst [] $ RetType [] $ Scalar $ TypeVar mempty (qualName v) []) - TypeParamDim {} -> - pure (v, ExpSubst $ sizeFromName (qualName v) loc) - -- | Instantiate a type scheme with fresh type variables for its type -- parameters. Returns the names of the fresh type variables, the -- instance list, and the instantiated type. @@ -481,7 +462,7 @@ checkPat' (RecordPat fs loc) NoneInferred = <$> traverse (`checkPat'` NoneInferred) (M.fromList fs) <*> pure loc checkPat' (PatAscription p t loc) maybe_outer_t = do - (t', _, RetType dims st, _) <- checkTypeExp t + (t', _, RetType _ st, _) <- checkTypeExp t case maybe_outer_t of Ascribed outer_t -> do @@ -1058,6 +1039,6 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do "## tyvars:", unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars, "## solution:", - either T.unpack (unlines . map (prettyString . first prettyNameString) . M.toList) $ solve cts tyvars + either T.unpack (unlines . map (prettyString . first prettyNameString) . M.toList) solution ] pure (undefined, params', retdecl', body') From 51eafc66331a76de9eac11831ccc367a9695524b Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 9 Feb 2024 15:56:07 +0100 Subject: [PATCH 010/258] Remove most warnings. --- src/Language/Futhark/TypeChecker/Terms2.hs | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index ee1b04da01..ee62ed478f 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -210,8 +210,8 @@ incCounter = do tyVarType :: (Monoid u) => TyVar -> TypeBase dim u tyVarType v = Scalar $ TypeVar mempty (qualName v) [] -newTyVarWith :: a -> Name -> TyVarInfo -> TermM TyVar -newTyVarWith loc desc info = do +newTyVarWith :: (Located loc) => loc -> Name -> TyVarInfo -> TermM TyVar +newTyVarWith _loc desc info = do i <- incCounter v <- newID $ mkTypeVarName desc i modify $ \s -> s {termTyVars = M.insert v info $ termTyVars s} @@ -333,7 +333,7 @@ arrayOfRank 0 t = t arrayOfRank n t = arrayOf (Shape $ replicate n ()) t require :: T.Text -> [PrimType] -> Exp -> TermM Exp -require why pts e = do +require _why pts e = do t :: Type <- newTypeOverloaded (srclocOf e) "t" pts ctEq t $ expType e pure e @@ -347,13 +347,12 @@ instTypeScheme :: [TypeParam] -> StructType -> TermM ([VName], StructType) -instTypeScheme qn loc tparams t = do +instTypeScheme _qn loc tparams t = do (names, substs) <- fmap (unzip . catMaybes) $ forM tparams $ \tparam -> do case tparam of - TypeParamType x _ _ -> do - let name = nameFromString (takeWhile isAscii (baseString (typeParamName tparam))) - v <- newTyVar loc name - pure $ Just (v, (typeParamName tparam, Subst [] $ RetType [] $ tyVarType v)) + TypeParamType _ v _ -> do + v' <- newTyVar loc $ nameFromString $ takeWhile isAscii $ baseString v + pure $ Just (v, (typeParamName tparam, Subst [] $ RetType [] $ tyVarType v')) TypeParamDim {} -> pure Nothing let t' = applySubst (`lookup` substs) t @@ -575,10 +574,10 @@ bindParams tps orig_ps m = bindTypeParams tps $ do incLevel $ descend [] orig_ps checkApply :: SrcLoc -> (Maybe (QualName VName), Int) -> Type -> Exp -> TermM Type -checkApply loc (fname, _) (Scalar (Arrow _ _ _ a (RetType _ b))) arg = do +checkApply _ _ (Scalar (Arrow _ _ _ a (RetType _ b))) arg = do ctEq a $ expType arg pure $ toType b -checkApply loc (fname, _) ftype arg = do +checkApply loc _ ftype arg = do a <- newType loc "arg" b <- newTyVar loc "res" ctEq ftype $ Scalar $ Arrow NoUniqueness Unnamed Observe a $ RetType [] (tyVarType b) @@ -1020,7 +1019,7 @@ checkValDef :: Maybe (TypeExp Info VName), Exp ) -checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do +checkValDef (fname, retdecl, tparams, params, body, _loc) = runTermM $ do bindParams tparams params $ \params' -> do body' <- checkExp body From 7b78c864a0cae2a4798f37d9fdbf6181584e7e0b Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 9 Feb 2024 18:27:26 +0100 Subject: [PATCH 011/258] Defective type checker work. --- src/Language/Futhark/TypeChecker.hs | 3 - src/Language/Futhark/TypeChecker/Terms.hs | 266 ++++++++---------- .../Futhark/TypeChecker/Terms/Loop.hs | 26 +- .../Futhark/TypeChecker/Terms/Monad.hs | 24 +- src/Language/Futhark/TypeChecker/Terms/Pat.hs | 85 +----- src/Language/Futhark/TypeChecker/Terms2.hs | 13 +- 6 files changed, 167 insertions(+), 250 deletions(-) diff --git a/src/Language/Futhark/TypeChecker.hs b/src/Language/Futhark/TypeChecker.hs index 23e02b79d7..9bc29be7d4 100644 --- a/src/Language/Futhark/TypeChecker.hs +++ b/src/Language/Futhark/TypeChecker.hs @@ -34,7 +34,6 @@ import Language.Futhark.TypeChecker.Modules import Language.Futhark.TypeChecker.Monad import Language.Futhark.TypeChecker.Names import Language.Futhark.TypeChecker.Terms -import Language.Futhark.TypeChecker.Terms2 qualified as Terms2 import Language.Futhark.TypeChecker.Types import Prelude hiding (abs, mod) @@ -696,8 +695,6 @@ checkValBind vb = do attrs' <- mapM checkAttr attrs - void $ Terms2.checkValDef (fname, maybe_tdecl, tparams, params, body, loc) - (tparams', params', maybe_tdecl', rettype, body') <- checkFunDef (fname, maybe_tdecl, tparams, params, body, loc) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index d933287ab3..3e6ecbff19 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -38,6 +38,7 @@ import Language.Futhark.TypeChecker.Monad hiding (BoundV, lookupMod) import Language.Futhark.TypeChecker.Terms.Loop import Language.Futhark.TypeChecker.Terms.Monad import Language.Futhark.TypeChecker.Terms.Pat +import Language.Futhark.TypeChecker.Terms2 qualified as Terms2 import Language.Futhark.TypeChecker.Types import Language.Futhark.TypeChecker.Unify import Prelude hiding (mod) @@ -182,8 +183,8 @@ sliceShape _ _ t = pure (t, []) checkAscript :: SrcLoc -> - TypeExp (ExpBase NoInfo VName) VName -> - ExpBase NoInfo VName -> + TypeExp Exp VName -> + Exp -> TermTypeM (TypeExp Exp VName, Exp) checkAscript loc te e = do (te', decl_t, _) <- checkTypeExpNonrigid te @@ -197,8 +198,8 @@ checkAscript loc te e = do checkCoerce :: SrcLoc -> - TypeExp (ExpBase NoInfo VName) VName -> - ExpBase NoInfo VName -> + TypeExp Exp VName -> + Exp -> TermTypeM (TypeExp Exp VName, StructType, Exp) checkCoerce loc te e = do (te', te_t, ext) <- checkTypeExpNonrigid te @@ -347,48 +348,26 @@ unscopeType :: unscopeType tloc unscoped = sizeFree tloc $ find (`elem` unscoped) . fvVars . freeInExp -checkExp :: ExpBase NoInfo VName -> TermTypeM Exp +checkExp :: Exp -> TermTypeM Exp checkExp (Literal val loc) = pure $ Literal val loc -checkExp (Hole _ loc) = do - t <- newTypeVar loc "t" - pure $ Hole (Info t) loc +checkExp (Hole info loc) = + pure $ Hole info loc checkExp (StringLit vs loc) = pure $ StringLit vs loc -checkExp (IntLit val NoInfo loc) = do - t <- newTypeVar loc "t" - mustBeOneOf anyNumberType (mkUsage loc "integer literal") t - pure $ IntLit val (Info t) loc -checkExp (FloatLit val NoInfo loc) = do - t <- newTypeVar loc "t" - mustBeOneOf anyFloatType (mkUsage loc "float literal") t - pure $ FloatLit val (Info t) loc +checkExp (IntLit val info loc) = + pure $ IntLit val info loc +checkExp (FloatLit val info loc) = + pure $ FloatLit val info loc checkExp (TupLit es loc) = TupLit <$> mapM checkExp es <*> pure loc checkExp (RecordLit fs loc) = - RecordLit <$> evalStateT (mapM checkField fs) mempty <*> pure loc + RecordLit <$> mapM checkField fs <*> pure loc where - checkField (RecordFieldExplicit f e rloc) = do - errIfAlreadySet f rloc - modify $ M.insert f rloc - RecordFieldExplicit f <$> lift (checkExp e) <*> pure rloc - checkField (RecordFieldImplicit name NoInfo rloc) = do - errIfAlreadySet (baseName name) rloc - t <- lift $ lookupVar rloc $ qualName name - modify $ M.insert (baseName name) rloc - pure $ RecordFieldImplicit name (Info t) rloc - - errIfAlreadySet f rloc = do - maybe_sloc <- gets $ M.lookup f - case maybe_sloc of - Just sloc -> - lift . typeError rloc mempty $ - "Field" - <+> dquotes (pretty f) - <+> "previously defined at" - <+> pretty (locStrRel rloc sloc) - <> "." - Nothing -> pure () + checkField (RecordFieldExplicit f e rloc) = + RecordFieldExplicit f <$> checkExp e <*> pure rloc + checkField (RecordFieldImplicit name info rloc) = + pure $ RecordFieldImplicit name info rloc checkExp (ArrayLit all_es _ loc) = -- Construct the result type and unify all elements with it. We -- only create a type variable for empty arrays; otherwise we use @@ -484,12 +463,12 @@ checkExp (AppExp (Range start maybe_step end loc) _) = do checkExp (Ascript e te loc) = do (te', e') <- checkAscript loc te e pure $ Ascript e' te' loc -checkExp (Coerce e te NoInfo loc) = do +checkExp (Coerce e te _ loc) = do (te', te_t, e') <- checkCoerce loc te e t <- expTypeFully e' t' <- matchDims (const . const pure) t te_t pure $ Coerce e' te' (Info t') loc -checkExp (AppExp (BinOp (op, oploc) NoInfo (e1, _) (e2, _) loc) NoInfo) = do +checkExp (AppExp (BinOp (op, oploc) _ (e1, _) (e2, _) loc) _) = do ftype <- lookupVar oploc op e1' <- checkExp e1 e2' <- checkExp e2 @@ -509,7 +488,7 @@ checkExp (AppExp (BinOp (op, oploc) NoInfo (e1, _) (e2, _) loc) NoInfo) = do loc ) (Info (AppRes rt' retext)) -checkExp (Project k e NoInfo loc) = do +checkExp (Project k e _ loc) = do e' <- checkExp e t <- expType e' kt <- mustHaveField (mkUsage loc $ docText $ "projection of field " <> dquotes (pretty k)) k t @@ -543,7 +522,7 @@ checkExp (QualParens (modname, modnameloc) e loc) = do ModFun {} -> typeError loc mempty . withIndexLink "module-is-parametric" $ "Module" <+> pretty modname <+> " is a parametric module." -checkExp (Var qn NoInfo loc) = do +checkExp (Var qn _ loc) = do t <- lookupVar loc qn pure $ Var qn (Info t) loc checkExp (Negate arg loc) = do @@ -552,7 +531,7 @@ checkExp (Negate arg loc) = do checkExp (Not arg loc) = do arg' <- require "logical negation" (Bool : anyIntType) =<< checkExp arg pure $ Not arg' loc -checkExp (AppExp (Apply fe args loc) NoInfo) = do +checkExp (AppExp (Apply fe args loc) _) = do fe' <- checkExp fe args' <- mapM (checkExp . snd) args t <- expType fe' @@ -598,7 +577,7 @@ checkExp (AppExp (LetPat sizes pat e body loc) _) = do AppExp (LetPat sizes (fmap toStruct pat') e' body' loc) (Info $ AppRes body_t' retext) -checkExp (AppExp (LetFun name (tparams, params, maybe_retdecl, NoInfo, e) body loc) _) = do +checkExp (AppExp (LetFun name (tparams, params, maybe_retdecl, _, e) body loc) _) = do (tparams', params', maybe_retdecl', rettype, e') <- checkBinding (name, maybe_retdecl, tparams, params, e, loc) @@ -621,19 +600,18 @@ checkExp (AppExp (LetFun name (tparams, params, maybe_retdecl, NoInfo, e) body l ) (Info $ AppRes body_t ext) checkExp (AppExp (LetWith dest src slice ve body loc) _) = do - src' <- checkIdent src slice' <- checkSlice slice (t, _) <- newArrayType (mkUsage src "type of source array") "src" $ sliceDims slice' - unify (mkUsage loc "type of target array") t $ unInfo $ identType src' + unify (mkUsage loc "type of target array") t $ unInfo $ identType src (elemt, _) <- sliceShape (Just (loc, Nonrigid)) slice' =<< normTypeFully t ve' <- unifies "type of target array" elemt =<< checkExp ve - bindingIdent dest (unInfo (identType src')) $ \dest' -> do + bindingIdent dest $ do body' <- checkExp body - (body_t, ext) <- unscopeType loc [identName dest'] =<< expTypeFully body' - pure $ AppExp (LetWith dest' src' slice' ve' body' loc) (Info $ AppRes body_t ext) + (body_t, ext) <- unscopeType loc [identName dest] =<< expTypeFully body' + pure $ AppExp (LetWith dest src slice' ve' body' loc) (Info $ AppRes body_t ext) checkExp (Update src slice ve loc) = do slice' <- checkSlice slice (t, _) <- newArrayType (mkUsage' src) "src" $ sliceDims slice' @@ -645,11 +623,9 @@ checkExp (Update src slice ve loc) = do -- Record updates are a bit hacky, because we do not have row typing -- (yet?). For now, we only permit record updates where we know the -- full type up to the field we are updating. -checkExp (RecordUpdate src fields ve NoInfo loc) = do +checkExp (RecordUpdate src fields ve _ loc) = do src' <- checkExp src ve' <- checkExp ve - a <- expTypeFully src' - foldM_ (flip $ mustHaveField usage) a fields ve_t <- expType ve' updated_t <- updateField fields ve_t =<< expTypeFully src' pure $ RecordUpdate src' fields ve' (Info updated_t) loc @@ -681,11 +657,11 @@ checkExp (AppExp (Index e slice loc) _) = do =<< expTypeFully e' pure $ AppExp (Index e' slice' loc) (Info $ AppRes t' retext) -checkExp (Assert e1 e2 NoInfo loc) = do +checkExp (Assert e1 e2 _ loc) = do e1' <- require "being asserted" [Bool] =<< checkExp e1 e2' <- checkExp e2 pure $ Assert e1' e2' (Info (prettyText e1)) loc -checkExp (Lambda params body rettype_te NoInfo loc) = do +checkExp (Lambda params body rettype_te _ loc) = do (params', body', rettype', RetType dims ty) <- incLevel . bindingParams [] params $ \params' -> do rettype_checked <- traverse checkTypeExpNonrigid rettype_te @@ -757,7 +733,7 @@ checkExp (OpSectionLeft op _ e _ _ loc) = do _ -> typeError loc mempty $ "Operator section with invalid operator of type" <+> pretty ftype -checkExp (OpSectionRight op _ e _ NoInfo loc) = do +checkExp (OpSectionRight op _ e _ _ loc) = do ftype <- lookupVar loc op e' <- checkExp e case ftype of @@ -782,13 +758,13 @@ checkExp (OpSectionRight op _ e _ NoInfo loc) = do _ -> typeError loc mempty $ "Operator section with invalid operator of type" <+> pretty ftype -checkExp (ProjectSection fields NoInfo loc) = do +checkExp (ProjectSection fields _ loc) = do a <- newTypeVar loc "a" let usage = mkUsage loc "projection at" b <- foldM (flip $ mustHaveField usage) a fields let ft = Scalar $ Arrow mempty Unnamed Observe a $ RetType [] $ toRes Nonunique b pure $ ProjectSection fields (Info ft) loc -checkExp (IndexSection slice NoInfo loc) = do +checkExp (IndexSection slice _ loc) = do slice' <- checkSlice slice (t, _) <- newArrayType (mkUsage' loc) "e" $ sliceDims slice' (t', retext) <- sliceShape Nothing slice' t @@ -801,11 +777,9 @@ checkExp (AppExp (Loop _ mergepat mergeexp form loopbody loc) _) = do AppExp (Loop sparams mergepat' mergeexp' form' loopbody' loc) (Info appres) -checkExp (Constr name es NoInfo loc) = do +checkExp (Constr name es _ loc) = do t <- newTypeVar loc "t" es' <- mapM checkExp es - ets <- mapM expType es' - mustHaveConstr (mkUsage loc "use of constructor") name t ets pure $ Constr name es' (Info t) loc checkExp (AppExp (Match e cs loc) _) = do e' <- checkExp e @@ -819,9 +793,20 @@ checkExp (AppExp (Match e cs loc) _) = do checkExp (Attr info e loc) = Attr <$> checkAttr info <*> checkExp e <*> pure loc +checkCase :: + StructType -> + CaseBase Info VName -> + TermTypeM (CaseBase Info VName, StructType, [VName]) +checkCase mt (CasePat p e loc) = + bindingPat [] p mt $ \p' -> do + e' <- checkExp e + e_t <- expTypeFully e' + (e_t', retext) <- unscopeType loc (patNames p') e_t + pure (CasePat (fmap toStruct p') e' loc, e_t', retext) + checkCases :: StructType -> - NE.NonEmpty (CaseBase NoInfo VName) -> + NE.NonEmpty (CaseBase Info VName) -> TermTypeM (NE.NonEmpty (CaseBase Info VName), StructType, [VName]) checkCases mt rest_cs = case NE.uncons rest_cs of @@ -834,17 +819,6 @@ checkCases mt rest_cs = (brancht, retext) <- unifyBranchTypes (srclocOf c) c_t cs_t pure (NE.cons c' cs', brancht, retext) -checkCase :: - StructType -> - CaseBase NoInfo VName -> - TermTypeM (CaseBase Info VName, StructType, [VName]) -checkCase mt (CasePat p e loc) = - bindingPat [] p mt $ \p' -> do - e' <- checkExp e - e_t <- expTypeFully e' - (e_t', retext) <- unscopeType loc (patNames p') e_t - pure (CasePat (fmap toStruct p') e' loc, e_t', retext) - -- | An unmatched pattern. Used in in the generation of -- unmatched pattern warnings by the type checker. data Unmatched p @@ -873,12 +847,7 @@ instance Pretty (Unmatched (Pat StructType)) where pretty' (PatLit e _ _) = pretty e pretty' (PatConstr n _ ps _) = "#" <> pretty n <+> sep (map pretty' ps) -checkIdent :: IdentBase NoInfo VName StructType -> TermTypeM (Ident StructType) -checkIdent (Ident name _ loc) = do - vt <- lookupVar loc $ qualName name - pure $ Ident name (Info vt) loc - -checkSlice :: SliceBase NoInfo VName -> TermTypeM [DimIndex] +checkSlice :: SliceBase Info VName -> TermTypeM [DimIndex] checkSlice = mapM checkDimIndex where checkDimIndex (DimFix i) = do @@ -1039,8 +1008,8 @@ checkApply loc (fname, prev_applied) ftype argexp = do -- turn out to be polymorphic, in which case the list of type -- parameters will be non-empty. checkOneExp :: ExpBase NoInfo VName -> TypeM ([TypeParam], Exp) -checkOneExp e = runTermTypeM checkExp $ do - e' <- checkExp e +checkOneExp e = runTermTypeM (checkExp . undefined) $ do + e' <- checkExp $ undefined e let t = typeOf e' (tparams, _, _) <- letGeneralise (nameFromString "") (srclocOf e) [] [] $ toRes Nonunique t @@ -1053,8 +1022,8 @@ checkOneExp e = runTermTypeM checkExp $ do -- | Type-check a single size expression in isolation. This expression may -- turn out to be polymorphic, in which case it is unified with i64. checkSizeExp :: ExpBase NoInfo VName -> TypeM Exp -checkSizeExp e = runTermTypeM checkExp $ do - e' <- checkExp e +checkSizeExp e = runTermTypeM (checkExp . undefined) $ do + e' <- checkExp $ undefined e let t = typeOf e' when (hasBinding e') $ typeError (srclocOf e') mempty . withIndexLink "size-expression-bind" $ @@ -1286,61 +1255,6 @@ localChecks = void . check <> pretty ty <> "." --- | Type-check a top-level (or module-level) function definition. --- Despite the name, this is also used for checking constant --- definitions, by treating them as 0-ary functions. -checkFunDef :: - ( VName, - Maybe (TypeExp (ExpBase NoInfo VName) VName), - [TypeParam], - [PatBase NoInfo VName ParamType], - ExpBase NoInfo VName, - SrcLoc - ) -> - TypeM - ( [TypeParam], - [Pat ParamType], - Maybe (TypeExp Exp VName), - ResRetType, - Exp - ) -checkFunDef (fname, maybe_retdecl, tparams, params, body, loc) = - runTermTypeM checkExp $ do - (tparams', params', maybe_retdecl', RetType dims rettype', body') <- - checkBinding (fname, maybe_retdecl, tparams, params, body, loc) - - -- Since this is a top-level function, we also resolve overloaded - -- types, using either defaults or complaining about ambiguities. - fixOverloadedTypes $ - typeVars rettype' <> foldMap (typeVars . patternType) params' - - -- Then replace all inferred types in the body and parameters. - body'' <- updateTypes body' - params'' <- updateTypes params' - maybe_retdecl'' <- traverse updateTypes maybe_retdecl' - rettype'' <- normTypeFully rettype' - - -- Check if the function body can actually be evaluated. - causalityCheck body'' - - -- Check for various problems. - mapM_ (mustBeIrrefutable . fmap toStruct) params' - localChecks body'' - - let ((body''', updated_ret), errors) = - Consumption.checkValDef - ( fname, - params'', - body'', - RetType dims rettype'', - maybe_retdecl'', - loc - ) - - mapM_ throwError errors - - pure (tparams', params'', maybe_retdecl'', updated_ret, body''') - -- | This is "fixing" as in "setting them", not "correcting them". We -- only make very conservative fixing. fixOverloadedTypes :: Names -> TermTypeM () @@ -1418,10 +1332,10 @@ inferredReturnType loc params t = do checkBinding :: ( VName, - Maybe (TypeExp (ExpBase NoInfo VName) VName), + Maybe (TypeExp Exp VName), [TypeParam], - [PatBase NoInfo VName ParamType], - ExpBase NoInfo VName, + [PatBase Info VName ParamType], + ExpBase Info VName, SrcLoc ) -> TermTypeM @@ -1670,7 +1584,7 @@ letGeneralise defname defloc tparams params restype = checkFunBody :: [Pat ParamType] -> - ExpBase NoInfo VName -> + Exp -> Maybe ResType -> SrcLoc -> TermTypeM Exp @@ -1705,3 +1619,73 @@ arrayOfM :: arrayOfM loc t shape = do arrayElemType (mkUsage loc "use as array element") "type used in array" t pure $ arrayOf shape t + +addInitialConstraints :: M.Map VName (TypeBase () NoUniqueness) -> TermTypeM () +addInitialConstraints = mapM_ f . M.toList + where + addConstraint v c = modifyConstraints $ M.insert v (0, c) + usage = mkUsage (mempty :: Loc) "trust me bro" + f (v, t) = do + (t', _) <- allDimsFreshInType usage Nonrigid "dv" t + addConstraint v $ Constraint (RetType [] t') usage + +-- | Type-check a top-level (or module-level) function definition. +-- Despite the name, this is also used for checking constant +-- definitions, by treating them as 0-ary functions. +checkFunDef :: + ( VName, + Maybe (TypeExp (ExpBase NoInfo VName) VName), + [TypeParam], + [PatBase NoInfo VName ParamType], + ExpBase NoInfo VName, + SrcLoc + ) -> + TypeM + ( [TypeParam], + [Pat ParamType], + Maybe (TypeExp Exp VName), + ResRetType, + Exp + ) +checkFunDef (fname, retdecl, tparams, params, body, loc) = do + (maybe_tysubsts, params', retdecl', body') <- + Terms2.checkValDef (fname, retdecl, tparams, params, body, loc) + case maybe_tysubsts of + Left err -> typeError loc mempty $ pretty err + Right tysubsts -> runTermTypeM (checkExp . undefined) $ do + addInitialConstraints tysubsts + + (tparams', params'', retdecl'', RetType dims rettype', body'') <- + checkBinding (fname, retdecl', tparams, params', body', loc) + + -- Since this is a top-level function, we also resolve overloaded + -- types, using either defaults or complaining about ambiguities. + fixOverloadedTypes $ + typeVars rettype' <> foldMap (typeVars . patternType) params'' + + -- Then replace all inferred types in the body and parameters. + body''' <- updateTypes body'' + params''' <- updateTypes params'' + retdecl''' <- traverse updateTypes retdecl'' + rettype'' <- normTypeFully rettype' + + -- Check if the function body can actually be evaluated. + causalityCheck body''' + + -- Check for various problems. + mapM_ (mustBeIrrefutable . fmap toStruct) params'' + localChecks body''' + + let ((body'''', updated_ret), errors) = + Consumption.checkValDef + ( fname, + params''', + body''', + RetType dims rettype'', + retdecl''', + loc + ) + + mapM_ throwError errors + + pure (tparams', params''', retdecl''', updated_ret, body'''') diff --git a/src/Language/Futhark/TypeChecker/Terms/Loop.hs b/src/Language/Futhark/TypeChecker/Terms/Loop.hs index c5e3619ac7..349b105823 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Loop.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Loop.hs @@ -102,7 +102,7 @@ wellTypedLoopArg src sparams pat arg = do -- | An un-checked loop. type UncheckedLoop = - (PatBase NoInfo VName ParamType, ExpBase NoInfo VName, LoopFormBase NoInfo VName, ExpBase NoInfo VName) + (Pat ParamType, Exp, LoopFormBase Info VName, Exp) -- | A loop that has been type-checked. type CheckedLoop = @@ -111,7 +111,7 @@ type CheckedLoop = -- | Type-check a @loop@ expression, passing in a function for -- type-checking subexpressions. checkLoop :: - (ExpBase NoInfo VName -> TermTypeM Exp) -> + (Exp -> TermTypeM Exp) -> UncheckedLoop -> SrcLoc -> TermTypeM (CheckedLoop, AppRes) @@ -223,18 +223,16 @@ checkLoop checkExp (mergepat, mergeexp, form, loopbody) loc = do uboundexp' <- require "being the bound in a 'for' loop" anySignedType =<< checkExp uboundexp - bound_t <- expTypeFully uboundexp' - bindingIdent i bound_t $ \i' -> - bindingPat [] mergepat merge_t $ - \mergepat' -> incLevel $ do - loopbody' <- checkExp loopbody - (sparams, mergepat'') <- checkLoopReturnSize mergepat' loopbody' - pure - ( sparams, - mergepat'', - For i' uboundexp', - loopbody' - ) + bindingIdent i . bindingPat [] mergepat merge_t $ + \mergepat' -> incLevel $ do + loopbody' <- checkExp loopbody + (sparams, mergepat'') <- checkLoopReturnSize mergepat' loopbody' + pure + ( sparams, + mergepat'', + For i uboundexp', + loopbody' + ) ForIn xpat e -> do (arr_t, _) <- newArrayType (mkUsage' (srclocOf e)) "e" 1 e' <- unifies "being iterated in a 'for-in' loop" arr_t =<< checkExp e diff --git a/src/Language/Futhark/TypeChecker/Terms/Monad.hs b/src/Language/Futhark/TypeChecker/Terms/Monad.hs index efa052fc7d..de3b03472d 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Monad.hs @@ -94,7 +94,7 @@ data Checking | CheckingAscription StructType StructType | CheckingLetGeneralise Name | CheckingParams (Maybe Name) - | CheckingPat (PatBase NoInfo VName StructType) (Inferred StructType) + | CheckingPat (PatBase Info VName StructType) (Inferred StructType) | CheckingLoopBody StructType StructType | CheckingLoopInitial StructType StructType | CheckingRecordUpdate [Name] StructType StructType @@ -544,8 +544,8 @@ allDimsFreshInType :: Usage -> Rigidity -> Name -> - TypeBase Size als -> - TermTypeM (TypeBase Size als, M.Map VName Size) + TypeBase d als -> + TermTypeM (TypeBase Size als, M.Map VName d) allDimsFreshInType usage r desc t = runStateT (bitraverse onDim pure t) mempty where @@ -581,25 +581,15 @@ require why ts e = do mustBeOneOf ts (mkUsage (srclocOf e) why) . toStruct =<< expType e pure e -termCheckTypeExp :: - TypeExp (ExpBase NoInfo VName) VName -> - TermTypeM (TypeExp Exp VName, [VName], ResRetType) -termCheckTypeExp te = do - (te', svars, rettype, _l) <- checkTypeExp te +checkTypeExpNonrigid :: TypeExp Exp VName -> TermTypeM (TypeExp Exp VName, ResType, [VName]) +checkTypeExpNonrigid te = do + (te', svars, rettype, _l) <- checkTypeExp $ undefined te -- No guarantee that the locally bound sizes in rettype are globally -- unique, but we want to turn them into size variables, so let's - -- give them some unique names. Maybe this should be done below, - -- where we actually turn these into size variables? + -- give them some unique names. RetType dims st <- renameRetType rettype - pure (te', svars, RetType dims st) - -checkTypeExpNonrigid :: - TypeExp (ExpBase NoInfo VName) VName -> - TermTypeM (TypeExp Exp VName, ResType, [VName]) -checkTypeExpNonrigid te = do - (te', svars, RetType dims st) <- termCheckTypeExp te forM_ (svars ++ dims) $ \v -> constrain v $ Size Nothing $ mkUsage (srclocOf te) "anonymous size in type expression" pure (te', st, svars ++ dims) diff --git a/src/Language/Futhark/TypeChecker/Terms/Pat.hs b/src/Language/Futhark/TypeChecker/Terms/Pat.hs index 980f278326..ad4ea0aa0f 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Pat.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Pat.hs @@ -110,47 +110,21 @@ bindingSizes sizes m = binding (map sizeWithType sizes) m Ident (sizeName size) (Info (Scalar (Prim (Signed Int64)))) (srclocOf size) -- | Bind a single term-level identifier. -bindingIdent :: - IdentBase NoInfo VName StructType -> - StructType -> - (Ident StructType -> TermTypeM a) -> - TermTypeM a -bindingIdent (Ident v NoInfo vloc) t m = do - let ident = Ident v (Info t) vloc - binding [ident] $ m ident - --- All this complexity is just so we can handle un-suffixed numeric --- literals in patterns. -patLitMkType :: PatLit -> SrcLoc -> TermTypeM ParamType -patLitMkType (PatLitInt _) loc = do - t <- newTypeVar loc "t" - mustBeOneOf anyNumberType (mkUsage loc "integer literal") (toStruct t) - pure t -patLitMkType (PatLitFloat _) loc = do - t <- newTypeVar loc "t" - mustBeOneOf anyFloatType (mkUsage loc "float literal") (toStruct t) - pure t -patLitMkType (PatLitPrim v) _ = - pure $ Scalar $ Prim $ primValueType v +bindingIdent :: Ident StructType -> TermTypeM a -> TermTypeM a +bindingIdent ident = binding [ident] checkPat' :: [(SizeBinder VName, QualName VName)] -> - PatBase NoInfo VName ParamType -> + Pat ParamType -> Inferred ParamType -> TermTypeM (Pat ParamType) checkPat' sizes (PatParens p loc) t = PatParens <$> checkPat' sizes p t <*> pure loc checkPat' sizes (PatAttr attr p loc) t = PatAttr <$> checkAttr attr <*> checkPat' sizes p t <*> pure loc -checkPat' _ (Id name NoInfo loc) (Ascribed t) = - pure $ Id name (Info t) loc -checkPat' _ (Id name NoInfo loc) NoneInferred = do - t <- newTypeVar loc "t" +checkPat' _ (Id name (Info t) loc) _ = pure $ Id name (Info t) loc -checkPat' _ (Wildcard _ loc) (Ascribed t) = - pure $ Wildcard (Info t) loc -checkPat' _ (Wildcard NoInfo loc) NoneInferred = do - t <- newTypeVar loc "t" +checkPat' _ (Wildcard (Info t) loc) _ = pure $ Wildcard (Info t) loc checkPat' sizes p@(TuplePat ps loc) (Ascribed t) | Just ts <- isTupleRecord t, @@ -192,6 +166,9 @@ checkPat' sizes (RecordPat fs loc) NoneInferred = RecordPat . M.toList <$> traverse (\p -> checkPat' sizes p NoneInferred) (M.fromList fs) <*> pure loc +checkPat' sizes (PatAscription p t loc) _ = + -- FIXME + PatAscription <$> checkPat' sizes p NoneInferred <*> pure t <*> pure loc checkPat' sizes (PatAscription p t loc) maybe_outer_t = do (t', st, _) <- checkTypeExpNonrigid t @@ -209,47 +186,15 @@ checkPat' sizes (PatAscription p t loc) maybe_outer_t = do <$> checkPat' sizes p (Ascribed (resToParam st)) <*> pure t' <*> pure loc -checkPat' _ (PatLit l NoInfo loc) (Ascribed t) = do - t' <- patLitMkType l loc - unify (mkUsage loc "matching against literal") (toStruct t') (toStruct t) - pure $ PatLit l (Info t') loc -checkPat' _ (PatLit l NoInfo loc) NoneInferred = do - t' <- patLitMkType l loc - pure $ PatLit l (Info t') loc -checkPat' sizes (PatConstr n NoInfo ps loc) (Ascribed (Scalar (Sum cs))) - | Just ts <- M.lookup n cs = do - when (length ps /= length ts) $ - typeError loc mempty $ - "Pattern #" - <> pretty n - <> " expects" - <+> pretty (length ps) - <+> "constructor arguments, but type provides" - <+> pretty (length ts) - <+> "arguments." - ps' <- zipWithM (checkPat' sizes) ps $ map Ascribed ts - pure $ PatConstr n (Info (Scalar (Sum cs))) ps' loc -checkPat' sizes (PatConstr n NoInfo ps loc) (Ascribed t) = do - t' <- newTypeVar loc "t" - ps' <- forM ps $ \p -> do - p_t <- newTypeVar (srclocOf p) "t" - checkPat' sizes p $ Ascribed p_t - mustHaveConstr usage n (toStruct t') (patternStructType <$> ps') - unify usage t' (toStruct t) - pure $ PatConstr n (Info t) ps' loc - where - usage = mkUsage loc "matching against constructor" -checkPat' sizes (PatConstr n NoInfo ps loc) NoneInferred = do +checkPat' _ (PatLit l info loc) _ = + pure $ PatLit l info loc +checkPat' sizes (PatConstr n info ps loc) _ = do ps' <- mapM (\p -> checkPat' sizes p NoneInferred) ps - t <- newTypeVar loc "t" - mustHaveConstr usage n (toStruct t) (patternStructType <$> ps') - pure $ PatConstr n (Info t) ps' loc - where - usage = mkUsage loc "matching against constructor" + pure $ PatConstr n info ps' loc checkPat :: [(SizeBinder VName, QualName VName)] -> - PatBase NoInfo VName (TypeBase Size u) -> + Pat (TypeBase Size u) -> Inferred StructType -> (Pat ParamType -> TermTypeM a) -> TermTypeM a @@ -272,7 +217,7 @@ checkPat sizes p t m = do -- | Check and bind a @let@-pattern. bindingPat :: [SizeBinder VName] -> - PatBase NoInfo VName (TypeBase Size u) -> + Pat (TypeBase Size u) -> StructType -> (Pat ParamType -> TermTypeM a) -> TermTypeM a @@ -292,7 +237,7 @@ bindingPat sizes p t m = do -- | Check and bind type and value parameters. bindingParams :: [TypeParam] -> - [PatBase NoInfo VName ParamType] -> + [Pat ParamType] -> ([Pat ParamType] -> TermTypeM a) -> TermTypeM a bindingParams tps orig_ps m = bindingTypeParams tps $ do diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index ee62ed478f..3258e5db7b 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -665,7 +665,10 @@ instance Pretty (Unmatched (Pat StructType)) where pretty' (PatLit e _ _) = pretty e pretty' (PatConstr n _ ps _) = "#" <> pretty n <+> sep (map pretty' ps) -checkRetDecl :: Exp -> Maybe (TypeExp NoInfo VName) -> TermM (Maybe (TypeExp Info VName)) +checkRetDecl :: + Exp -> + Maybe (TypeExp (ExpBase NoInfo VName) VName) -> + TermM (Maybe (TypeExp Exp VName)) checkRetDecl _ Nothing = pure Nothing checkRetDecl body (Just te) = do (te', _, RetType _ st, _) <- checkTypeExp te @@ -1007,16 +1010,16 @@ checkExp (Coerce e te NoInfo loc) = do checkValDef :: ( VName, - Maybe (TypeExp NoInfo VName), + Maybe (TypeExp (ExpBase NoInfo VName) VName), [TypeParam], [PatBase NoInfo VName ParamType], ExpBase NoInfo VName, SrcLoc ) -> TypeM - ( [TypeParam], + ( Either T.Text (M.Map TyVar Type), [Pat ParamType], - Maybe (TypeExp Info VName), + Maybe (TypeExp Exp VName), Exp ) checkValDef (fname, retdecl, tparams, params, body, _loc) = runTermM $ do @@ -1040,4 +1043,4 @@ checkValDef (fname, retdecl, tparams, params, body, _loc) = runTermM $ do "## solution:", either T.unpack (unlines . map (prettyString . first prettyNameString) . M.toList) solution ] - pure (undefined, params', retdecl', body') + pure (solution, params', retdecl', body') From 8639b4a66fa1e6de507e6bfbc0bb239e093e3e06 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 9 Feb 2024 18:53:36 +0100 Subject: [PATCH 012/258] No more undefined. --- src/Language/Futhark/TypeChecker/Terms.hs | 6 +++--- src/Language/Futhark/TypeChecker/Terms/Monad.hs | 12 ++++++------ src/Language/Futhark/TypeChecker/Terms/Pat.hs | 3 --- 3 files changed, 9 insertions(+), 12 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 3e6ecbff19..c27ec06f5f 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -1008,7 +1008,7 @@ checkApply loc (fname, prev_applied) ftype argexp = do -- turn out to be polymorphic, in which case the list of type -- parameters will be non-empty. checkOneExp :: ExpBase NoInfo VName -> TypeM ([TypeParam], Exp) -checkOneExp e = runTermTypeM (checkExp . undefined) $ do +checkOneExp e = runTermTypeM checkExp $ do e' <- checkExp $ undefined e let t = typeOf e' (tparams, _, _) <- @@ -1022,7 +1022,7 @@ checkOneExp e = runTermTypeM (checkExp . undefined) $ do -- | Type-check a single size expression in isolation. This expression may -- turn out to be polymorphic, in which case it is unified with i64. checkSizeExp :: ExpBase NoInfo VName -> TypeM Exp -checkSizeExp e = runTermTypeM (checkExp . undefined) $ do +checkSizeExp e = runTermTypeM checkExp $ do e' <- checkExp $ undefined e let t = typeOf e' when (hasBinding e') $ @@ -1652,7 +1652,7 @@ checkFunDef (fname, retdecl, tparams, params, body, loc) = do Terms2.checkValDef (fname, retdecl, tparams, params, body, loc) case maybe_tysubsts of Left err -> typeError loc mempty $ pretty err - Right tysubsts -> runTermTypeM (checkExp . undefined) $ do + Right tysubsts -> runTermTypeM checkExp $ do addInitialConstraints tysubsts (tparams', params'', retdecl'', RetType dims rettype', body'') <- diff --git a/src/Language/Futhark/TypeChecker/Terms/Monad.hs b/src/Language/Futhark/TypeChecker/Terms/Monad.hs index 1b06cdcdd7..49c5226d29 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Monad.hs @@ -196,7 +196,7 @@ data TermEnv = TermEnv { termScope :: TermScope, termChecking :: Maybe Checking, termLevel :: Level, - termChecker :: ExpBase NoInfo VName -> TermTypeM Exp, + termCheckExp :: ExpBase Info VName -> TermTypeM Exp, termOuterEnv :: Env, termImportName :: ImportName } @@ -574,9 +574,9 @@ require why ts e = do mustBeOneOf ts (mkUsage (srclocOf e) why) . toStruct =<< expType e pure e -checkExpForSize :: ExpBase NoInfo VName -> TermTypeM Exp +checkExpForSize :: ExpBase Info VName -> TermTypeM Exp checkExpForSize e = do - checker <- asks termChecker + checker <- asks termCheckExp e' <- checker e let t = toStruct $ typeOf e' unify (mkUsage (locOf e') "Size expression") t (Scalar (Prim (Signed Int64))) @@ -584,7 +584,7 @@ checkExpForSize e = do checkTypeExpNonrigid :: TypeExp Exp VName -> TermTypeM (TypeExp Exp VName, ResType, [VName]) checkTypeExpNonrigid te = do - (te', svars, rettype, _l) <- checkTypeExp checkExpForSize $ undefined te + (te', svars, rettype, _l) <- checkTypeExp checkExpForSize te -- No guarantee that the locally bound sizes in rettype are globally -- unique, but we want to turn them into size variables, so let's @@ -636,7 +636,7 @@ initialTermScope = Just (name, EqualityF) addIntrinsicF _ = Nothing -runTermTypeM :: (ExpBase NoInfo VName -> TermTypeM Exp) -> TermTypeM a -> TypeM a +runTermTypeM :: (ExpBase Info VName -> TermTypeM Exp) -> TermTypeM a -> TypeM a runTermTypeM checker (TermTypeM m) = do initial_scope <- (initialTermScope <>) . envToTermScope <$> askEnv name <- askImportName @@ -647,7 +647,7 @@ runTermTypeM checker (TermTypeM m) = do { termScope = initial_scope, termChecking = Nothing, termLevel = 0, - termChecker = checker, + termCheckExp = checker, termImportName = name, termOuterEnv = outer_env } diff --git a/src/Language/Futhark/TypeChecker/Terms/Pat.hs b/src/Language/Futhark/TypeChecker/Terms/Pat.hs index ad4ea0aa0f..16ad00f710 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Pat.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Pat.hs @@ -166,9 +166,6 @@ checkPat' sizes (RecordPat fs loc) NoneInferred = RecordPat . M.toList <$> traverse (\p -> checkPat' sizes p NoneInferred) (M.fromList fs) <*> pure loc -checkPat' sizes (PatAscription p t loc) _ = - -- FIXME - PatAscription <$> checkPat' sizes p NoneInferred <*> pure t <*> pure loc checkPat' sizes (PatAscription p t loc) maybe_outer_t = do (t', st, _) <- checkTypeExpNonrigid t From 3f09449594aca2f077b59a151357cde64f1558f1 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Fri, 9 Feb 2024 16:05:57 -0800 Subject: [PATCH 013/258] Add ILP/LP solving stuff. --- futhark.cabal | 4 + src/Futhark/Solve/BranchAndBound.hs | 75 ++++++ src/Futhark/Solve/LP.hs | 306 +++++++++++++++++++++++++ src/Futhark/Solve/Matrix.hs | 341 ++++++++++++++++++++++++++++ src/Futhark/Solve/Simplex.hs | 238 +++++++++++++++++++ 5 files changed, 964 insertions(+) create mode 100644 src/Futhark/Solve/BranchAndBound.hs create mode 100644 src/Futhark/Solve/LP.hs create mode 100644 src/Futhark/Solve/Matrix.hs create mode 100644 src/Futhark/Solve/Simplex.hs diff --git a/futhark.cabal b/futhark.cabal index c89eacb214..1c0531a83b 100644 --- a/futhark.cabal +++ b/futhark.cabal @@ -370,6 +370,10 @@ library Futhark.Pkg.Types Futhark.Profile Futhark.Script + Futhark.Solve.LP + Futhark.Solve.Matrix + Futhark.Solve.Simplex + Futhark.Solve.BranchAndBound Futhark.Test Futhark.Test.Spec Futhark.Test.Values diff --git a/src/Futhark/Solve/BranchAndBound.hs b/src/Futhark/Solve/BranchAndBound.hs new file mode 100644 index 0000000000..846ae4a59a --- /dev/null +++ b/src/Futhark/Solve/BranchAndBound.hs @@ -0,0 +1,75 @@ +module Futhark.Solve.BranchAndBound (branchAndBound) where + +import Data.Map qualified as M +import Data.Maybe +import Data.Set qualified as S +import Data.Vector.Unboxed (Unbox, Vector) +import Data.Vector.Unboxed qualified as V +import Debug.Trace +import Futhark.Solve.LP (LP (..)) +import Futhark.Solve.Matrix +import Futhark.Solve.Simplex + +newtype Bound a = Bound (Maybe a, Maybe a) + deriving (Eq, Ord, Show) + +instance (Ord a) => Semigroup (Bound a) where + Bound (mlb1, mub1) <> Bound (mlb2, mub2) = + Bound (combine max mlb1 mlb2, combine min mub1 mub2) + where + combine _ Nothing b2 = b2 + combine _ b1 Nothing = b1 + combine c (Just b1) (Just b2) = Just $ c b1 b2 + +-- | Solves an LP with the additional constraint that all solutions +-- must be integral. Returns 'Nothing' if infeasible or unbounded. +branchAndBound :: + (Read a, Unbox a, RealFrac a, Show a) => + LP a -> + Maybe (a, Vector Int) +branchAndBound prob@(LP _ a d) = (zopt,) <$> mopt + where + (zopt, mopt) = step (S.singleton mempty) (negate $ read "Infinity") Nothing + step todo zlow opt + | S.null todo = (zlow, opt) + | otherwise = + let (next, rest) = S.deleteFindMin todo + in case simplexLP (mkProblem next) of + Nothing -> step rest zlow opt + Just (z, sol) + | z <= zlow -> step rest zlow opt + | V.all isInt sol -> + step rest z (Just $ V.map round sol) + | otherwise -> + let (idx, frac) = + V.head $ V.filter (not . isInt . snd) $ V.zip (V.generate (V.length sol) id) sol + new_todo = + S.fromList $ + filter + (/= next) + [ M.insertWith (<>) idx (Bound (Nothing, Just $ fromInteger $ floor frac)) next, + M.insertWith (<>) idx (Bound (Just $ fromInteger $ ceiling frac, Nothing)) next + ] + in step (new_todo <> rest) zlow opt + + -- TODO: use isInt x = x == round x + -- requires a better 'rowEchelon' implementation for matrices + isInt x = (abs (fromIntegral (round x) - x)) <= 10 ^^ (-10) + mkProblem = + M.foldrWithKey + ( \idx bound acc -> addBound acc idx bound + ) + prob + + addBound lp idx (Bound (mlb, mub)) = + lp + { lpA = a `addRows` new_rows, + lpd = d V.++ V.fromList new_ds + } + where + (new_rows, new_ds) = + unzip $ + catMaybes + [ (V.generate (ncols a) (\i -> if i == idx then (-1) else 0),) <$> (negate <$> mlb), + (V.generate (ncols a) (\i -> if i == idx then 1 else 0),) <$> mub + ] diff --git a/src/Futhark/Solve/LP.hs b/src/Futhark/Solve/LP.hs new file mode 100644 index 0000000000..11e943a1b1 --- /dev/null +++ b/src/Futhark/Solve/LP.hs @@ -0,0 +1,306 @@ +module Futhark.Solve.LP + ( LP (..), + LPE (..), + convert, + normalize, + var, + constant, + cval, + bin, + or, + oneIsZero, + (~+~), + (~-~), + (~*~), + (!), + neg, + linearProgToLP, + linearProgToLPE, + LSum (..), + LinearProg (..), + OptType (..), + Constraint (..), + (==), + (<=), + (>=), + rowEchelonLPE, + ) +where + +import Data.List qualified as L +import Data.Map (Map) +import Data.Map qualified as Map +import Data.Maybe +import Data.Vector.Unboxed (Unbox, Vector) +import Data.Vector.Unboxed qualified as V +import Debug.Trace +import Futhark.Solve.Matrix (Matrix (..)) +import Futhark.Solve.Matrix qualified as M +import Prelude hiding (or, (<=), (==), (>=)) +import Prelude qualified + +-- | A linear program. 'LP c a d' represents the program +-- +-- > maximize c^T * a +-- > subject to a * x <= d +-- > x >= 0 +-- +-- The matrix 'a' is assumed to have linearly-independent rows. +data LP a = LP + { lpc :: Vector a, + lpA :: Matrix a, + lpd :: Vector a + } + deriving (Eq, Show) + +-- | Equational form of a linear program. 'LPE c a d' represents the +-- program +-- +-- > maximize c^T * a +-- > subject to a * x = d +-- > x >= 0 +-- +-- The matrix 'a' is assumed to have linearly-independent rows. +data LPE a = LPE + { pc :: Vector a, + pA :: Matrix a, + pd :: Vector a + } + deriving (Eq, Show) + +rowEchelonLPE :: (Show a, Unbox a, Fractional a, Ord a) => LPE a -> LPE a +rowEchelonLPE (LPE c a d) = + LPE c (M.sliceCols (V.generate (ncols a) id) ad) (M.getCol (ncols a) ad) + where + ad = + M.filterRows (V.any (Prelude./= 0)) $ + (M.rowEchelon $ a M.<|> M.fromColVector d) + +-- | Converts an 'LP' into an equivalent 'LPE' by introducing slack +-- variables. +convert :: (Show a, Num a, Unbox a) => LP a -> LPE a +convert (LP c a d) = LPE c' a' d + where + a' = a M.<|> M.diagonal (V.replicate (M.nrows a) 1) + c' = c V.++ V.replicate (M.nrows a) 0 + +-- | Linear sum of variables. +newtype LSum v a = LSum {lsum :: (Map (Maybe v) a)} + deriving (Eq) + +instance (Show v, Show a) => Show (LSum v a) where + show (LSum m) = + L.intercalate + " + " + $ map + ( \(k, a) -> + case k of + Nothing -> show a + Just k' -> show a <> "*" <> show k' + ) + $ Map.toList m + +instance Functor (LSum v) where + fmap f (LSum m) = LSum $ fmap f m + +-- | Type of constraint +data CType = Equal | LessEq + deriving (Eq) + +instance Show CType where + show (Equal) = "=" + show (LessEq) = "<=" + +-- | A constraint for a linear program. +data Constraint v a + = Constraint CType (LSum v a) (LSum v a) + deriving (Eq) + +instance (Show a, Show v) => Show (Constraint v a) where + show (Constraint t l r) = + show l <> " " <> show t <> " " <> show r + +data OptType = Maximize | Minimize + deriving (Show, Eq) + +-- | A linear program. +data LinearProg v a = LinearProg + { optType :: OptType, + objective :: LSum v a, + constraints :: [Constraint v a] + } + deriving (Eq) + +instance (Show v, Show a) => Show (LinearProg v a) where + show (LinearProg opt obj cs) = + unlines $ + [ show opt, + show obj, + "subject to:" + ] + ++ map show cs + +bigM :: (Num a) => a +bigM = 10 ^ 3 + +oneIsZero :: (Eq a, Num a, Ord v) => v -> v -> v -> v -> [Constraint v a] +oneIsZero b1 b2 x1 x2 = + mkC b1 x1 + <> mkC b2 x2 + <> [(var b1 ~+~ var b2) <= constant 1] + where + mkC b x = + [ var x <= bigM ~*~ var b + ] + +or :: (Eq a, Num a, Ord v) => v -> v -> Constraint v a -> Constraint v a -> [Constraint v a] +or b1 b2 c1 c2 = + mkC b1 c1 + <> mkC b2 c2 + <> [var b1 ~+~ var b2 <= constant 1] + where + mkC b (Constraint Equal l r) = + [ l <= r ~+~ bigM ~*~ (constant 1 ~-~ var b), + l >= r ~-~ bigM ~*~ (constant 1 ~-~ var b) + ] + mkC b (Constraint LessEq l r) = + [ l <= r ~+~ bigM ~*~ (constant 1 ~-~ var b) + ] + +bin :: (Num a, Ord v) => v -> Constraint v a +bin v = Constraint LessEq (var v) (constant 1) + +(==) :: (Num a, Ord v) => LSum v a -> LSum v a -> Constraint v a +l == r = Constraint Equal l r + +infix 4 == + +(<=) :: (Num a, Ord v) => LSum v a -> LSum v a -> Constraint v a +l <= r = Constraint LessEq l r + +infix 4 <= + +(>=) :: (Num a, Ord v) => LSum v a -> LSum v a -> Constraint v a +l >= r = Constraint LessEq (neg l) (neg r) + +infix 4 >= + +normalize :: (Eq a, Num a) => LSum v a -> LSum v a +normalize = LSum . Map.filter (/= 0) . lsum + +var :: (Num a) => v -> LSum v a +var v = LSum $ Map.singleton (Just v) (fromInteger 1) + +constant :: a -> LSum v a +constant = LSum . Map.singleton Nothing + +cval :: (Num a, Ord v) => LSum v a -> a +cval = (! Nothing) + +(~+~) :: (Eq a, Num a, Ord v) => LSum v a -> LSum v a -> LSum v a +(LSum x) ~+~ (LSum y) = normalize $ LSum $ Map.unionWith (+) x y + +infixl 6 ~+~ + +(~-~) :: (Eq a, Num a, Ord v) => LSum v a -> LSum v a -> LSum v a +x ~-~ y = x ~+~ (neg y) + +infixl 6 ~-~ + +(~*~) :: (Eq a, Num a, Ord v) => a -> LSum v a -> LSum v a +a ~*~ s = normalize $ fmap (a *) s + +infixl 7 ~*~ + +(!) :: (Num a, Ord v) => LSum v a -> Maybe v -> a +(LSum m) ! v = + case m Map.!? v of + Nothing -> 0 + Just a -> a + +neg :: (Num a, Ord v) => LSum v a -> LSum v a +neg (LSum x) = LSum $ fmap negate x + +-- | Converts a linear program given with a list of constraints +-- into the standard form. +linearProgToLP :: + forall v a. + (Unbox a, Num a, Ord v, Eq a) => + LinearProg v a -> + (LP a, Map Int v) +linearProgToLP (LinearProg otype obj cs) = + (LP c a d, idxMap) + where + cs' = foldMap (convertEqCType . splitConstraint) cs + idxMap = + Map.fromList $ + zip [0 ..] $ + catMaybes $ + Map.keys $ + mconcat $ + map (lsum . fst) cs' + mkRow s = V.generate (Map.size idxMap) $ \i -> s ! Just (idxMap Map.! i) + c = mkRow $ convertObj otype obj + a = M.fromVectors $ map (mkRow . fst) cs' + d = V.fromList $ map snd cs' + + splitConstraint :: Constraint v a -> (CType, LSum v a, a) + splitConstraint (Constraint ctype l r) = + let c = negate $ cval (l ~-~ r) + in (ctype, l ~-~ r ~-~ constant c, c) + + convertEqCType :: (CType, LSum v a, a) -> [(LSum v a, a)] + convertEqCType (Equal, s, a) = [(s, a), (neg s, negate a)] + convertEqCType (LessEq, s, a) = [(s, a)] + + convertObj :: OptType -> LSum v a -> LSum v a + convertObj Maximize s = s + convertObj Minimize s = neg s + +-- | Converts a linear program given with a list of constraints +-- into the equational form. Assumes no <= constraints. +linearProgToLPE :: + forall v a. + (Unbox a, Num a, Ord v, Eq a) => + LinearProg v a -> + (LPE a, Map Int v) +linearProgToLPE (LinearProg otype obj cs) = + (LPE c a d, idxMap) + where + cs' = map (checkOnlyEqType . splitConstraint) cs + idxMap = + Map.fromList $ + zip [0 ..] $ + catMaybes $ + Map.keys $ + mconcat $ + map (lsum . fst) cs' + mkRow s = V.generate (Map.size idxMap) $ \i -> s ! Just (idxMap Map.! i) + c = mkRow $ convertObj otype obj + a = M.fromVectors $ map (mkRow . fst) cs' + d = V.fromList $ map snd cs' + + splitConstraint :: Constraint v a -> (CType, LSum v a, a) + splitConstraint (Constraint ctype l r) = + let c = negate $ cval (l ~-~ r) + in (ctype, l ~-~ r ~-~ constant c, c) + + checkOnlyEqType :: (CType, LSum v a, a) -> (LSum v a, a) + checkOnlyEqType (Equal, s, a) = (s, a) + checkOnlyEqType (ctype, _, _) = error $ show ctype + + convertObj :: OptType -> LSum v a -> LSum v a + convertObj Maximize s = s + convertObj Minimize s = neg s + +test1 :: LPE Double +test1 = + LPE + { pc = V.fromList [5.5, 2.1], + pA = + M.fromLists + [ [-1, 1], + [8, 2] + ], + pd = V.fromList [2, 17] + } diff --git a/src/Futhark/Solve/Matrix.hs b/src/Futhark/Solve/Matrix.hs new file mode 100644 index 0000000000..90e1a3e126 --- /dev/null +++ b/src/Futhark/Solve/Matrix.hs @@ -0,0 +1,341 @@ +module Futhark.Solve.Matrix + ( Matrix (..), + toList, + toLists, + fromRowVector, + fromColVector, + fromVectors, + fromLists, + (@), + (!), + sliceCols, + getColM, + getCol, + setCol, + sliceRows, + getRowM, + getRow, + (<|>), + (<->), + addRow, + addRows, + imap, + generate, + identity, + diagonal, + (<.>), + (.*), + (*.), + (.+.), + (.-.), + rowEchelon, + filterRows, + deleteRow, + deleteCol, + ) +where + +import Data.List qualified as L +import Data.Map (Map) +import Data.Map qualified as M +import Data.Vector.Unboxed (Unbox, Vector) +import Data.Vector.Unboxed qualified as V + +-- A matrix represented as a 1D 'Vector'. +data Matrix a = Matrix + { elems :: Vector a, + nrows :: Int, + ncols :: Int + } + deriving (Eq) + +instance (Show a, Unbox a) => Show (Matrix a) where + show = + unlines . map show . toLists + +toList :: (Unbox a) => Matrix a -> [Vector a] +toList m = + map (\r -> V.slice (r * ncols m) (ncols m) (elems m)) [0 .. nrows m - 1] + +toLists :: (Unbox a) => Matrix a -> [[a]] +toLists m = + map (\r -> V.toList $ V.slice (r * ncols m) (ncols m) (elems m)) [0 .. nrows m - 1] + +fromRowVector :: (Unbox a) => Vector a -> Matrix a +fromRowVector v = + Matrix + { elems = v, + nrows = 1, + ncols = V.length v + } + +fromColVector :: (Unbox a) => Vector a -> Matrix a +fromColVector v = + Matrix + { elems = v, + nrows = V.length v, + ncols = 1 + } + +empty :: (Unbox a) => Matrix a +empty = Matrix mempty 0 0 + +fromVectors :: (Unbox a) => [Vector a] -> Matrix a +fromVectors [] = empty +fromVectors vs = + Matrix + { elems = V.concat $ vs, + nrows = length vs, + ncols = V.length $ head vs + } + +fromLists :: (Unbox a) => [[a]] -> Matrix a +fromLists xss = + Matrix + { elems = V.concat $ map V.fromList xss, + nrows = length xss, + ncols = length $ head xss + } + +class SelectCols a where + select :: Vector Int -> a -> a + (@) :: a -> Vector Int -> a + (@) = flip select + +infix 9 @ + +instance (Unbox a) => SelectCols (Vector a) where + select s v = V.map (v V.!) s + +instance (Unbox a) => SelectCols (Matrix a) where + select = sliceCols + +(!) :: (Unbox a) => Matrix a -> (Int, Int) -> a +m ! (r, c) = elems m V.! (ncols m * r + c) + +sliceCols :: (Unbox a) => Vector Int -> Matrix a -> Matrix a +sliceCols cols m = + Matrix + { elems = + V.generate (nrows m * V.length cols) $ \i -> + let col = cols V.! (i `rem` V.length cols) + row = i `div` V.length cols + in m ! (row, col), + nrows = nrows m, + ncols = V.length cols + } + +getColM :: (Unbox a) => Int -> Matrix a -> Matrix a +getColM col = sliceCols $ V.singleton col + +getCol :: (Unbox a) => Int -> Matrix a -> Vector a +getCol col = elems . getColM col + +setCol :: (Unbox a) => Int -> Vector a -> Matrix a -> Matrix a +setCol c col m = + m + { elems = + V.update_ (elems m) indices col + } + where + indices = V.generate (nrows m) $ + \r -> r * ncols m + c + +sliceRows :: (Unbox a) => Vector Int -> Matrix a -> Matrix a +sliceRows rows m = + Matrix + { elems = + V.generate (ncols m * V.length rows) $ \i -> + let row = rows V.! (i `rem` V.length rows) + col = i `div` V.length rows + in m ! (row, col), + nrows = V.length rows, + ncols = ncols m + } + +getRowM :: (Unbox a) => Int -> Matrix a -> Matrix a +getRowM row = sliceRows $ V.singleton row + +getRow :: (Unbox a) => Int -> Matrix a -> Vector a +getRow row = elems . getRowM row + +(<|>) :: (Unbox a) => Matrix a -> Matrix a -> Matrix a +m1 <|> m2 = + generate f (nrows m1) (ncols m1 + ncols m2) + where + f r c + | c < ncols m1 = m1 ! (r, c) + | otherwise = m2 ! (r, c - ncols m1) + +(<->) :: (Unbox a) => Matrix a -> Matrix a -> Matrix a +m1 <-> m2 = + generate f (nrows m1 + nrows m2) (ncols m1) + where + f r c + | r < nrows m1 = m1 ! (r, c) + | otherwise = m2 ! (r - nrows m1, c) + +addRow :: (Unbox a) => Matrix a -> Vector a -> Matrix a +addRow m v = + m + { elems = elems m V.++ v, + nrows = nrows m + 1 + } + +addRows :: (Unbox a) => Matrix a -> [Vector a] -> Matrix a +addRows = foldl addRow + +imap :: (Unbox a) => (Int -> Int -> a -> a) -> Matrix a -> Matrix a +imap f m = + m + { elems = V.imap g $ elems m + } + where + g i = + let r = i `div` ncols m + c = i `rem` nrows m + in f r c + +generate :: (Unbox a) => (Int -> Int -> a) -> Int -> Int -> Matrix a +generate f rows cols = + Matrix + { elems = + V.generate (rows * cols) $ \i -> + let r = i `div` cols + c = i `rem` cols + in f r c, + nrows = rows, + ncols = cols + } + +identity :: (Unbox a, Num a) => Int -> Matrix a +identity n = generate (\r c -> if r == c then 1 else 0) n n + +diagonal :: (Unbox a, Num a) => Vector a -> Matrix a +diagonal d = generate (\r c -> if r == c then d V.! r else 0) (V.length d) (V.length d) + +(<.>) :: (Unbox a, Num a) => Vector a -> Vector a -> a +v1 <.> v2 = V.sum $ V.zipWith (*) v1 v2 + +infixl 7 <.> + +(*.) :: (Unbox a, Num a) => Matrix a -> Vector a -> Vector a +m *. v = + V.generate (nrows m) $ \r -> + getRow r m <.> v + +infixl 7 *. + +(.*) :: (Unbox a, Num a) => Vector a -> Matrix a -> Vector a +v .* m = + V.generate (ncols m) $ \c -> + v <.> getCol c m + +infixl 7 .* + +(.-.) :: (Unbox a, Num a) => Vector a -> Vector a -> Vector a +(.-.) = V.zipWith (-) + +infixl 6 .-. + +(.+.) :: (Unbox a, Num a) => Vector a -> Vector a -> Vector a +(.+.) = V.zipWith (+) + +infixl 6 .+. + +swapRows :: (Unbox a) => Int -> Int -> Matrix a -> Matrix a +swapRows r1 r2 m = + m + { elems = + elems m `V.update` new + } + where + start1 = ncols m * r1 + start2 = ncols m * r2 + row1 = getRow r1 m + row2 = getRow r2 m + new = + V.imap (\i a -> (i + start1, a)) row2 + V.++ V.imap (\i a -> (i + start2, a)) row1 + +-- todo: fix +update :: (Unbox a) => Matrix a -> Vector ((Int, Int), a) -> Matrix a +update m upds = + generate + ( \i j -> + case (M.fromList $ V.toList upds) M.!? (i, j) of + Nothing -> m ! (i, j) + Just x -> x + ) + (nrows m) + (ncols m) + +update_ :: (Unbox a) => Matrix a -> Map (Int, Int) a -> Matrix a +update_ m upds = + generate + ( \i j -> + case upds M.!? (i, j) of + Nothing -> m ! (i, j) + Just x -> x + ) + (nrows m) + (ncols m) + +-- TODO: maintain integrality of entries in the matrix +-- rowEchelon :: (Num a, Fractional a, Unbox a, Ord a) => Matrix a -> Matrix a +-- rowEchelon = rowEchelon' 0 0 +-- where +-- rowEchelon' h k m@(Matrix _ nr nc) +-- | h < nr && k < nc = +-- if m ! (pivot_row, k) == 0 +-- then rowEchelon' h (k + 1) m +-- else rowEchelon' (h + 1) (k + 1) clear_rows_below +-- | otherwise = m +-- where +-- pivot_row = +-- fst $ +-- L.maximumBy (\(_, x) (_, y) -> x `compare` y) $ +-- [(r, abs (m ! (r, k))) | r <- [h .. nr - 1]] +-- m' = swapRows h pivot_row m +-- clear_rows_below = +-- update m' $ +-- V.fromList $ +-- [((i, k), 0) | i <- [h + 1 .. nr - 1]] +-- ++ [ ((i, j), m' ! (i, j) - (m' ! (h, j)) * f) +-- | i <- [h + 1 .. nr - 1], +-- let f = m' ! (i, k) / m' ! (h, k), +-- j <- [k + 1 .. nc - 1] +-- ] + +rowEchelon :: (Num a, Unbox a, Ord a) => Matrix a -> Matrix a +rowEchelon = rowEchelon' 0 0 + where + rowEchelon' h k m@(Matrix _ nr nc) + | h < nr && k < nc = + if m ! (pivot_row, k) == 0 + then rowEchelon' h (k + 1) m + else rowEchelon' (h + 1) (k + 1) clear_rows_below + | otherwise = m + where + pivot_row = + fst $ + L.maximumBy (\(_, x) (_, y) -> x `compare` y) $ + [(r, abs (m ! (r, k))) | r <- [h .. nr - 1]] + m' = swapRows h pivot_row m + clear_rows_below = + update m' $ + V.fromList $ + [((i, k), 0) | i <- [h + 1 .. nr - 1]] + ++ [ ((i, j), (m' ! (h, k)) * (m' ! (i, j)) - (m' ! (h, j)) * (m' ! (i, k))) + | i <- [h + 1 .. nr - 1], + j <- [k + 1 .. nc - 1] + ] + +filterRows :: (Unbox a) => (Vector a -> Bool) -> Matrix a -> Matrix a +filterRows p = fromVectors . filter p . toList + +deleteRow :: (Unbox a) => Int -> Matrix a -> Matrix a +deleteRow n m = sliceRows (V.generate (nrows m - 1) (\r -> if r < n then r else r + 1)) m + +deleteCol :: (Unbox a) => Int -> Matrix a -> Matrix a +deleteCol n m = sliceCols (V.generate (ncols m - 1) (\c -> if c < n then c else c + 1)) m diff --git a/src/Futhark/Solve/Simplex.hs b/src/Futhark/Solve/Simplex.hs new file mode 100644 index 0000000000..e01c7ce566 --- /dev/null +++ b/src/Futhark/Solve/Simplex.hs @@ -0,0 +1,238 @@ +module Futhark.Solve.Simplex + ( simplex, + simplexLP, + simplexProg, + findBasis, + ) +where + +import Data.List qualified as L +import Data.Map.Strict (Map) +import Data.Map.Strict qualified as M +import Data.Maybe +import Data.Vector.Unboxed (Unbox, Vector) +import Data.Vector.Unboxed qualified as V +import Debug.Trace +import Futhark.Solve.LP (LP (..), LPE (..), LinearProg (..), convert, linearProgToLPE, rowEchelonLPE) +import Futhark.Solve.Matrix + +-- | A tableau of an equational linear program @a * x = d@ is +-- +-- > x @ b = p + q * x @ n +-- > --------------------- +-- > z = z' + r^T * x @ n +-- +-- where @z = c^T * x@ and @b@ (@n@) is a vector containing the +-- indices of basic (nonbasic) variables. +-- +-- The basic feasible solution corresponding to the above tableau is +-- given by @x \@ b = p@, @x \@n = 0@ with the value of the objective +-- equal to @z'@. + +-- | Computes @r@ as given in the tableau above. +comp_r :: + (Num a, Unbox a) => + LPE a -> + Matrix a -> + Vector Int -> + Vector Int -> + Vector a +comp_r (LPE c a _) invA_B b n = + c @ n .-. c @ b .* invA_B .* a @ n + +-- | @comp_q_enter prob invA_B b n enter@ computes the @enter@th +-- column of @q@. +comp_q_enter :: + (Num a, Unbox a) => + LPE a -> + Matrix a -> + Int -> + Vector a +comp_q_enter (LPE _ a _) invA_B enter = + V.map negate $ invA_B *. getCol enter a + +-- | Computes the objective given an inversion of @a@ and a basis. +comp_z :: + (Num a, Unbox a) => + LPE a -> + Matrix a -> + Vector Int -> + a +comp_z (LPE c _ d) invA_B b = + c @ b .* invA_B <.> d + +-- | Constructs an auxiliary equational linear program to compute the +-- initial feasible basis; returns the program along with a feasible +-- basis. +mkAux :: (Ord a, Unbox a, Num a) => LPE a -> (LPE a, Vector Int, Vector Int) +mkAux (LPE _ a d) = (LPE c_aux a_aux d_aux, b_aux, n_aux) + where + c_aux = V.replicate (ncols a) 0 V.++ V.replicate (nrows a) (-1) + d_aux = V.map abs d + a_aux = + imap (\r _ e -> if (d V.! r) < 0 then negate e else e) a + <|> identity (nrows a) + b_aux = V.generate (nrows a) (+ ncols a) + n_aux = V.generate (ncols a) id + +-- | Finds an initial feasible basis for an equational linear program. +-- Returns 'Nothing' if the LP has no solution. Inverts some +-- equations by multiplying by -1 so it also returns a modified (but +-- equivalent) equational linear program. +findBasis :: + (Unbox a, Ord a, Fractional a, Show a) => + LPE a -> + Maybe (LPE a, Matrix a, Vector a, Vector Int, Vector Int) +findBasis prob = do + (invA_B, p, b, n) <- step p_aux (invA_B_aux, d_aux, b_aux, n_aux) + if comp_z p_aux invA_B b == 0 + then Just $ fixDegenerateBasis prob (ncols $ pA prob) p_aux (invA_B, p, b, n) + else Nothing + where + (p_aux@(LPE _ _ d_aux), b_aux, n_aux) = mkAux prob + invA_B_aux = identity $ V.length b_aux + + fixDegenerateBasis :: + (Unbox a, Ord a, Fractional a, Show a) => + LPE a -> + Int -> + LPE a -> + (Matrix a, Vector a, Vector Int, Vector Int) -> + (LPE a, Matrix a, Vector a, Vector Int, Vector Int) + fixDegenerateBasis og_prob col prob (invA_B, p, b, n) + | Just exit_idx <- mexit_idx, + V.null (elim_row exit_idx) = + let prob' = + prob + { pA = deleteRow exit_idx (pA prob), + pd = + V.map snd $ + V.filter ((/= exit_idx) . fst) $ + V.imap (curry id) $ + pd prob + } + invA_B' = deleteRow exit_idx $ deleteCol exit_idx invA_B + p' = + V.map snd $ + V.filter ((/= exit_idx) . fst) $ + V.imap (curry id) p + b' = + V.map snd $ + V.filter ((/= exit_idx) . fst) $ + V.imap (curry id) b + in fixDegenerateBasis og_prob col prob' (invA_B', p', b', n) + | Just exit_idx <- mexit_idx, + (enter, _) <- V.head (elim_row exit_idx) = + let enter_idx = fromJust $ V.findIndex (== enter) n + exit = b V.! exit_idx + in fixDegenerateBasis og_prob col prob $ + pivot prob (invA_B, p, b, n) (enter_idx, enter) (exit_idx, exit) + | otherwise = + let prob' = + prob + { pc = pc og_prob, + pA = sliceCols (V.generate col id) $ pA prob, + pd = V.map abs $ pd og_prob + } + in (prob', invA_B, p, V.filter (< col) b, V.filter (< col) n) + where + mexit_idx = + fst <$> ((V.filter ((>= col) . snd) (V.imap (curry id) b)) V.!? 0) + elim_row exit_idx = + V.filter ((/= 0) . snd) $ + V.map (\j -> (j, comp_q_enter prob invA_B j V.! exit_idx)) $ + V.generate col id + +-- | Solves an equational linear program. Returns 'Nothing' if the +-- program is infeasible or unbounded. Otherwise returns the optimal +-- value and the solution. +simplex :: + (Unbox a, Ord a, Fractional a, Show a) => + LPE a -> + Maybe (a, Vector a) +simplex lpe = do + let ech_lpe = rowEchelonLPE lpe + res@(lpe', invA_B, p, b, n) <- findBasis $ rowEchelonLPE lpe + (invA_B', p', b', n') <- step lpe' (invA_B, p, b, n) + let z = comp_z lpe' invA_B' b' + sol = + V.map snd $ + V.fromList $ + L.sortOn fst $ + V.toList $ + V.zip (b' V.++ n') (p' V.++ V.replicate (V.length n') 0) + pure (z, sol) + +-- | Solves a linear program. +simplexLP :: + (Unbox a, Ord a, Fractional a, Show a) => + LP a -> + Maybe (a, Vector a) +simplexLP lp = do + (opt, sol) <- simplex lpe + pure (opt, V.take (ncols $ lpA lp) sol) + where + lpe = convert lp + +simplexProg :: + (Unbox a, Ord a, Ord v, Fractional a, Show a) => + LinearProg v a -> + Maybe (a, Map v a) +simplexProg prog = do + (z, sol) <- simplex lpe + pure $ (z, M.fromList $ map (\(i, x) -> (idxMap M.! i, x)) $ zip [0 ..] $ V.toList sol) + where + (lpe, idxMap) = linearProgToLPE prog + +pivot :: + (Unbox a, Ord a, Fractional a, Show a) => + LPE a -> + (Matrix a, Vector a, Vector Int, Vector Int) -> + (Int, Int) -> + (Int, Int) -> + (Matrix a, Vector a, Vector Int, Vector Int) +pivot prob (invA_B, p, b, n) (enter_idx, enter) (exit_idx, exit) = + (invA_B', p', b', n') + where + q_enter = comp_q_enter prob invA_B enter + b' = b V.// [(exit_idx, enter)] + n' = n V.// [(enter_idx, exit)] + e_inv_vec = + V.map + (/ abs (q_enter V.! exit_idx)) + (q_enter V.// [(exit_idx, 1)]) + genF row col = + (if row == exit_idx then 0 else invA_B ! (row, col)) + + (e_inv_vec V.! row) * invA_B ! (exit_idx, col) + invA_B' = generate genF (nrows invA_B) (ncols invA_B) + p' = p V.// [(exit_idx, 0)] .+. V.map (* (p V.! exit_idx)) e_inv_vec + +-- | One step of the simplex algorithm. +step :: + (Unbox a, Ord a, Fractional a, Show a) => + LPE a -> + (Matrix a, Vector a, Vector Int, Vector Int) -> + Maybe (Matrix a, Vector a, Vector Int, Vector Int) +step prob (invA_B, p, b, n) + | Just enter_idx <- menter_idx = + let enter = n V.! enter_idx + q_enter = comp_q_enter prob invA_B enter + pq = + V.map (\(i, p_', q_) -> (i, -(p_' / q_))) $ + V.filter (\(_, _, q_) -> q_ < 0) $ + V.zip3 (V.generate (V.length q_enter) id) p q_enter + in if V.null pq + then Nothing + else + let exit_val = snd $ V.minimumOn snd pq + exit_cands = + V.map fst $ V.filter ((exit_val ==) . snd) pq + (exit_idx, exit) = + V.minimumOn snd $ + V.map (\i -> (i, b V.! i)) exit_cands + in step prob $ pivot prob (invA_B, p, b, n) (enter_idx, enter) (exit_idx, exit) + | otherwise = Just (invA_B, p, b, n) + where + r = comp_r prob invA_B b n + menter_idx = V.findIndex (> 0) r + b_zero = V.filter (\(v, i) -> v == 0 && (not $ V.null (V.filter (< i) n))) $ V.zip p b From c98bbe406ace88272d1dd0e472901dc71cf52f00 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Fri, 9 Feb 2024 17:37:37 -0800 Subject: [PATCH 014/258] Add AM AST annotation. --- src/Futhark/Internalise/Defunctionalise.hs | 14 +++++----- src/Futhark/Internalise/Exps.hs | 2 +- src/Futhark/Internalise/FullNormalise.hs | 6 ++-- src/Futhark/Internalise/LiftLambdas.hs | 2 +- src/Futhark/Internalise/Monomorphise.hs | 16 +++++------ src/Language/Futhark/Interpreter.hs | 2 +- src/Language/Futhark/Syntax.hs | 28 ++++++++++++++++--- .../Futhark/TypeChecker/Consumption.hs | 5 ++-- src/Language/Futhark/TypeChecker/Terms.hs | 4 +-- src/Language/Futhark/TypeChecker/Terms2.hs | 2 +- 10 files changed, 51 insertions(+), 30 deletions(-) diff --git a/src/Futhark/Internalise/Defunctionalise.hs b/src/Futhark/Internalise/Defunctionalise.hs index 8ad2e15948..d6b03a368b 100644 --- a/src/Futhark/Internalise/Defunctionalise.hs +++ b/src/Futhark/Internalise/Defunctionalise.hs @@ -752,7 +752,7 @@ etaExpand e_t e = do M.fromList . zip (retDims ret) $ map (ExpSubst . flip sizeFromName mempty . qualName) ext' ret' = applySubst (`M.lookup` extsubst) ret - e' = mkApply e (map (Nothing,) vars) $ AppRes (toStruct $ retType ret') ext' + e' = mkApply e (map (\v -> (Nothing, mempty, v)) vars) $ AppRes (toStruct $ retType ret') ext' pure (params, e', ret) where getType (RetType _ (Scalar (Arrow _ p d t1 t2))) = @@ -910,9 +910,9 @@ liftedName _ _ = "defunc" defuncApplyArg :: String -> (Exp, StaticVal) -> - ((Maybe VName, Exp), [ParamType]) -> + (((Maybe VName, AutoMap), Exp), [ParamType]) -> DefM (Exp, StaticVal) -defuncApplyArg fname_s (f', LambdaSV pat lam_e_t lam_e closure_env) ((argext, arg), _) = do +defuncApplyArg fname_s (f', LambdaSV pat lam_e_t lam_e closure_env) (((argext, _), arg), _) = do (arg', arg_sv) <- defuncExp arg let env' = alwaysMatchPatSV pat arg_sv dims = mempty @@ -963,18 +963,18 @@ defuncApplyArg fname_s (f', LambdaSV pat lam_e_t lam_e closure_env) ((argext, ar callret <- unRetType lifted_rettype pure - ( mkApply fname' [(Nothing, f'), (argext, arg')] callret, + ( mkApply fname' [(Nothing, mempty, f'), (argext, mempty, arg')] callret, sv ) -- If 'f' is a dynamic function, we just leave the application in -- place, but we update the types since it may be partially -- applied or return a higher-order value. -defuncApplyArg _ (f', DynamicFun _ sv) ((argext, arg), argtypes) = do +defuncApplyArg _ (f', DynamicFun _ sv) (((argext, _), arg), argtypes) = do (arg', _) <- defuncExp arg let (argtypes', rettype) = dynamicFunType sv argtypes restype = foldFunType argtypes' (RetType [] rettype) callret = AppRes restype [] - apply_e = mkApply f' [(argext, arg')] callret + apply_e = mkApply f' [(argext, mempty, arg')] callret pure (apply_e, sv) -- defuncApplyArg fname_s (_, sv) ((_, arg), _) = @@ -991,7 +991,7 @@ updateReturn (AppRes ret1 ext1) (AppExp apply (Info (AppRes ret2 ext2))) = AppExp apply $ Info $ AppRes (combineTypeShapes ret1 ret2) (ext1 <> ext2) updateReturn _ e = e -defuncApply :: Exp -> NE.NonEmpty (Maybe VName, Exp) -> AppRes -> SrcLoc -> DefM (Exp, StaticVal) +defuncApply :: Exp -> NE.NonEmpty ((Maybe VName, AutoMap), Exp) -> AppRes -> SrcLoc -> DefM (Exp, StaticVal) defuncApply f args appres loc = do (f', f_sv) <- defuncApplyFunction f (length args) case f_sv of diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index b2a8f37247..ec4adece0b 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -1484,7 +1484,7 @@ findFuncall (E.Apply f args _) | E.Hole (Info _) loc <- f = (FunctionHole loc, map onArg $ NE.toList args) where - onArg (Info argext, e) = (e, argext) + onArg (Info (argext, _), e) = (e, argext) findFuncall e = error $ "Invalid function expression in application:\n" ++ prettyString e diff --git a/src/Futhark/Internalise/FullNormalise.hs b/src/Futhark/Internalise/FullNormalise.hs index f797557776..f566f43e64 100644 --- a/src/Futhark/Internalise/FullNormalise.hs +++ b/src/Futhark/Internalise/FullNormalise.hs @@ -202,7 +202,7 @@ getOrdering final (OpSectionLeft op ty e (Info (xp, _, xext), Info (yp, yty)) (I let y = Var (qualName yn) (Info $ toStruct yty) mempty ret' = applySubst (pSubst x y) ret body = - mkApply (Var op ty mempty) [(xext, x), (Nothing, y)] $ + mkApply (Var op ty mempty) [(xext, mempty, x), (Nothing, mempty, y)] $ AppRes (toStruct ret') exts nameExp final $ Lambda [Id yn (Info yty) mempty] body Nothing (Info (RetType dims ret')) loc where @@ -215,7 +215,7 @@ getOrdering final (OpSectionRight op ty e (Info (xp, xty), Info (yp, _, yext)) ( y <- getOrdering False e let x = Var (qualName xn) (Info $ toStruct xty) mempty ret' = applySubst (pSubst x y) ret - body = mkApply (Var op ty mempty) [(Nothing, x), (yext, y)] $ AppRes (toStruct ret') [] + body = mkApply (Var op ty mempty) [(Nothing, mempty, x), (yext, mempty, y)] $ AppRes (toStruct ret') [] nameExp final $ Lambda [Id xn (Info xty) mempty] body Nothing (Info (RetType dims ret')) loc where pSubst x y vn @@ -304,7 +304,7 @@ getOrdering final (AppExp (BinOp (op, oloc) opT (el, Info elp) (er, Info erp) lo (False, False) -> do el' <- naming (prettyString op <> "_lhs") $ getOrdering False el er' <- naming (prettyString op <> "_rhs") $ getOrdering False er - pure $ mkApply (Var op opT oloc) [(elp, el'), (erp, er')] resT + pure $ mkApply (Var op opT oloc) [(elp, mempty, el'), (erp, mempty, er')] resT nameExp final expr' where isOr = baseName (qualLeaf op) == "||" diff --git a/src/Futhark/Internalise/LiftLambdas.hs b/src/Futhark/Internalise/LiftLambdas.hs index 0c9aead794..f515083170 100644 --- a/src/Futhark/Internalise/LiftLambdas.hs +++ b/src/Futhark/Internalise/LiftLambdas.hs @@ -138,7 +138,7 @@ liftFunction fname tparams params (RetType dims ret) funbody = do apply f [] = f apply f (p : rem_ps) = let inner_ret = AppRes (augType rem_ps) mempty - inner = mkApply f [(Nothing, freeVar p)] inner_ret + inner = mkApply f [(Nothing, mempty, freeVar p)] inner_ret in apply inner rem_ps transformSubExps :: ASTMapper LiftM diff --git a/src/Futhark/Internalise/Monomorphise.hs b/src/Futhark/Internalise/Monomorphise.hs index e26cc1cb28..935627bccd 100644 --- a/src/Futhark/Internalise/Monomorphise.hs +++ b/src/Futhark/Internalise/Monomorphise.hs @@ -121,10 +121,10 @@ entryAssert (x : xs) body = andop = Var (qualName (intrinsicVar "&&")) (Info opt) mempty eqop = Var (qualName (intrinsicVar "==")) (Info opt) mempty logAnd x' y = - mkApply andop [(Nothing, x'), (Nothing, y)] $ + mkApply andop [(Nothing, mempty, x'), (Nothing, mempty, y)] $ AppRes bool [] cmpExp (ReplacedExp x', y) = - mkApply eqop [(Nothing, x'), (Nothing, y')] $ + mkApply eqop [(Nothing, mempty, x'), (Nothing, mempty, y')] $ AppRes bool [] where y' = Var (qualName y) (Info i64) mempty @@ -415,7 +415,7 @@ transformFName loc fname t = do ( i - 1, mkApply f - [(Nothing, size_arg)] + [(Nothing, mempty, size_arg)] (AppRes (foldFunType (replicate i i64) (RetType [] t')) []) ) @@ -539,7 +539,7 @@ transformAppExp (Apply fe args _) res = <*> mapM onArg (NE.toList args) <*> transformAppRes res where - onArg (Info ext, e) = (ext,) <$> transformExp e + onArg (Info (ext, am), e) = (ext,am,) <$> transformExp e transformAppExp (Loop sparams pat e1 form body loc) res = do e1' <- transformExp e1 @@ -603,8 +603,8 @@ transformAppExp (BinOp (fname, _) (Info t) (e1, d1) (e2, d2) loc) res = do where applyOp ret ext fname' x y = mkApply - (mkApply fname' [(unInfo d1, x)] (AppRes ret mempty)) - [(unInfo d2, y)] + (mkApply fname' [(unInfo d1, mempty, x)] (AppRes ret mempty)) + [(unInfo d2, mempty, y)] (AppRes ret ext) makeVarParam arg = do @@ -790,7 +790,7 @@ desugarBinOpSection fname e_left e_right t (xp, xtype, xext) (yp, ytype, yext) ( let apply_left = mkApply op - [(xext, e1)] + [(xext, mempty, e1)] (AppRes (Scalar $ Arrow mempty yp (diet ytype) (toStruct ytype) (RetType [] $ toRes Nonunique t')) []) onDim (Var d typ _) | Named p <- xp, qualLeaf d == p = Var (qualName v1) typ loc @@ -799,7 +799,7 @@ desugarBinOpSection fname e_left e_right t (xp, xtype, xext) (yp, ytype, yext) ( rettype' = first onDim rettype body <- scoping (S.fromList [v1, v2]) $ - mkApply apply_left [(yext, e2)] + mkApply apply_left [(yext, mempty, e2)] <$> transformAppRes (AppRes (toStruct rettype') retext) rettype'' <- transformRetTypeSizes (S.fromList [v1, v2]) $ RetType dims rettype' pure . wrap_left . wrap_right $ diff --git a/src/Language/Futhark/Interpreter.hs b/src/Language/Futhark/Interpreter.hs index 5a9030aa65..ac76cf6645 100644 --- a/src/Language/Futhark/Interpreter.hs +++ b/src/Language/Futhark/Interpreter.hs @@ -855,7 +855,7 @@ evalAppExp env (Apply f args loc) = do f' <- eval env f foldM (apply loc env) f' args' where - evalArg' (Info ext, x) = evalArg env x ext + evalArg' (Info (ext, _), x) = evalArg env x ext evalAppExp env (Index e is loc) = do is' <- mapM (evalDimIndex env) is arr <- eval env e diff --git a/src/Language/Futhark/Syntax.hs b/src/Language/Futhark/Syntax.hs index e47d7d19f3..e85009ad4a 100644 --- a/src/Language/Futhark/Syntax.hs +++ b/src/Language/Futhark/Syntax.hs @@ -23,6 +23,7 @@ module Language.Futhark.Syntax Shape (..), shapeRank, stripDims, + AutoMap (..), TypeBase (..), TypeArg (..), SizeExp (..), @@ -230,7 +231,10 @@ sizeFromInteger x = IntLit x (Info <$> Scalar $ Prim $ Signed Int64) -- | The size of an array type is a list of its dimension sizes. If -- 'Nothing', that dimension is of a (statically) unknown size. -newtype Shape dim = Shape {shapeDims :: [dim]} +data Shape dim + = Shape {shapeDims :: [dim]} + | SVar VName + | SConcat (Shape dim) (Shape dim) deriving (Eq, Ord, Show) instance Foldable Shape where @@ -244,6 +248,9 @@ instance Functor Shape where instance Semigroup (Shape dim) where Shape l1 <> Shape l2 = Shape $ l1 ++ l2 + Shape [] <> s = s + s <> Shape [] = s + s1 <> s2 = s1 `SConcat` s2 instance Monoid (Shape dim) where mempty = Shape [] @@ -260,6 +267,19 @@ stripDims i (Shape l) | i < length l = Just $ Shape $ drop i l | otherwise = Nothing +data AutoMap = AutoMap + { autoRep :: Shape Size, + autoMap :: Shape Size, + autoFrame :: Shape Size + } + deriving (Eq, Show, Ord) + +instance Semigroup AutoMap where + (AutoMap r1 m1 f1) <> (AutoMap r2 m2 f2) = AutoMap (r1 <> r2) (m1 <> m2) (f1 <> f2) + +instance Monoid AutoMap where + mempty = AutoMap mempty mempty mempty + -- | The name (if any) of a function parameter. The 'Eq' and 'Ord' -- instances always compare values of this type equal. data PName = Named VName | Unnamed @@ -630,7 +650,7 @@ data AppExpBase f vn -- identical). Apply (ExpBase f vn) - (NE.NonEmpty (f (Maybe VName), ExpBase f vn)) + (NE.NonEmpty (f (Maybe VName, AutoMap), ExpBase f vn)) SrcLoc | Range (ExpBase f vn) @@ -1258,7 +1278,7 @@ deriving instance Show (ProgBase Info VName) deriving instance Show (ProgBase NoInfo Name) -- | Construct an 'Apply' node, with type information. -mkApply :: ExpBase Info vn -> [(Maybe VName, ExpBase Info vn)] -> AppRes -> ExpBase Info vn +mkApply :: ExpBase Info vn -> [(Maybe VName, AutoMap, ExpBase Info vn)] -> AppRes -> ExpBase Info vn mkApply f args (AppRes t ext) | Just args' <- NE.nonEmpty $ map onArg args = case f of @@ -1270,7 +1290,7 @@ mkApply f args (AppRes t ext) AppExp (Apply f args' (srcspan f $ snd $ NE.last args')) (Info (AppRes t ext)) | otherwise = f where - onArg (v, x) = (Info v, x) + onArg (v, am, x) = (Info (v, am), x) -- | Construct an 'Apply' node, without type information. mkApplyUT :: ExpBase NoInfo vn -> ExpBase NoInfo vn -> ExpBase NoInfo vn diff --git a/src/Language/Futhark/TypeChecker/Consumption.hs b/src/Language/Futhark/TypeChecker/Consumption.hs index d971f48def..5c1198537f 100644 --- a/src/Language/Futhark/TypeChecker/Consumption.hs +++ b/src/Language/Futhark/TypeChecker/Consumption.hs @@ -710,12 +710,13 @@ checkExp (AppExp (Apply f args loc) appres) = do res_als <- checkFuncall loc (fname f) f_als args_als pure (AppExp (Apply f' args' loc) appres, res_als) where + -- neUnzip3 xs = ((\(x, _, _) -> x) <$> xs, (\(_, y, _) -> y) <$> xs, (\(_, _, z) -> z) <$> xs) fname (Var v _ _) = Just v fname (AppExp (Apply e _ _) _) = fname e fname _ = Nothing - checkArg' prev d (Info p, e) = do + checkArg' prev d (Info (p, am), e) = do (e', e_als) <- checkArg prev (second (const d) (typeOf e)) e - pure ((Info p, e'), e_als) + pure ((Info (p, am), e'), e_als) checkArgs (Scalar (Arrow _ _ d _ (RetType _ rt))) (x NE.:| args') = do -- Note Futhark uses right-to-left evaluation of applications. diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index c27ec06f5f..59fc6d5ba3 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -547,7 +547,7 @@ checkExp (AppExp (Apply fe args loc) _) = do (_, rt, argext, exts) <- checkApply loc (fname, i) t arg' pure ( (i + 1, all_exts <> exts, rt), - (Info argext, arg') + (Info (argext, mempty), arg') ) checkExp (AppExp (LetPat sizes pat e body loc) _) = do e' <- checkExp e @@ -1099,7 +1099,7 @@ causalityCheck binding_body = do seqArgs known' [] = do void $ onExp known' f modify (S.fromList (appResExt res) <>) - seqArgs known' ((Info p, x) : xs) = do + seqArgs known' ((Info (p, _), x) : xs) = do new_known <- collectingNewKnown $ onExp known' x void $ seqArgs (new_known <> known') xs modify ((new_known <> S.fromList (maybeToList p)) <>) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index b822636581..85ade3c1f7 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -782,7 +782,7 @@ checkExp (AppExp (Apply fe args loc) NoInfo) = do rt <- checkApply loc (fname, i) (toType f_t) arg' pure ( (i + 1, rt), - (Info Nothing, arg') + (Info (Nothing, mempty), arg') ) -- checkExp (AppExp (BinOp (op, oploc) NoInfo (e1, _) (e2, _) loc) NoInfo) = do From db8db0cf02a80c8083422f26a020de14fc41d72b Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sat, 10 Feb 2024 11:34:34 +0100 Subject: [PATCH 015/258] More work on using the type information. --- .../Futhark/TypeChecker/Constraints.hs | 32 ++++++++++++--- src/Language/Futhark/TypeChecker/Terms.hs | 19 +++++---- .../Futhark/TypeChecker/Terms/Monad.hs | 32 +++++++-------- src/Language/Futhark/TypeChecker/Terms2.hs | 40 ++++++++++--------- 4 files changed, 74 insertions(+), 49 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 74a2dd2fd0..bd7c12f7bb 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -6,6 +6,7 @@ module Language.Futhark.TypeChecker.Constraints TyVarInfo (..), TyVar, TyVars, + Solution, solve, ) where @@ -94,12 +95,31 @@ substTyVars m (Scalar (Arrow u pname d t1 (RetType ext t2))) = substTyVars m (Array u shape elemt) = arrayOfWithAliases u shape $ substTyVars m $ Scalar elemt -solution :: SolverState -> M.Map TyVar Type -solution s = M.mapMaybe f $ solverTyVars s +-- | A solution maps types to the set of type variables that must be +-- substituted with this type. This slightly odd representation is +-- needed to encode when two type variables are actually the same +-- type. This matters when we start instanting the sizes of the type. +type Solution = M.Map Type [TyVar] + +solution :: SolverState -> Solution +solution s = + M.fromList $ + map adjust $ + M.toList $ + foldl addLinks (M.mapMaybe mkSubst $ solverTyVars s) $ + M.toList $ + solverTyVars s where - f (TyVarSol t) = Just $ substTyVars (solverTyVars s) t - f (TyVarLink v) = f =<< M.lookup v (solverTyVars s) - f (TyVarUnsol _) = Nothing + mkSubst (TyVarSol t) = Just (t, []) + mkSubst _ = Nothing + addLinks m (v1, TyVarLink v2) = + case M.lookup v2 $ solverTyVars s of + Just (TyVarLink v3) -> addLinks m (v1, TyVarLink v3) + _ -> case M.lookup v2 m of + Nothing -> m + Just (t, vs) -> M.insert v2 (t, v1 : vs) m + addLinks m _ = m + adjust (v, (t, vs)) = (t, v : vs) newtype SolveM a = SolveM {runSolveM :: StateT SolverState (Except T.Text) a} deriving (Functor, Applicative, Monad, MonadState SolverState, MonadError T.Text) @@ -172,7 +192,7 @@ solveCt ct = do Nothing -> bad Just eqs -> mapM_ solveCt' eqs -solve :: Constraints -> TyVars -> Either T.Text (M.Map TyVar Type) +solve :: Constraints -> TyVars -> Either T.Text Solution solve constraints tyvars = second solution . runExcept diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 59fc6d5ba3..b55dca1ad8 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -27,6 +27,7 @@ import Data.List.NonEmpty qualified as NE import Data.Map.Strict qualified as M import Data.Maybe import Data.Set qualified as S +import Debug.Trace import Futhark.Util (mapAccumLM, nubOrd, topologicalSort) import Futhark.Util.Pretty hiding (space) import Language.Futhark @@ -522,9 +523,10 @@ checkExp (QualParens (modname, modnameloc) e loc) = do ModFun {} -> typeError loc mempty . withIndexLink "module-is-parametric" $ "Module" <+> pretty modname <+> " is a parametric module." -checkExp (Var qn _ loc) = do - t <- lookupVar loc qn - pure $ Var qn (Info t) loc +checkExp (Var qn (Info t) loc) = do + t' <- lookupVar loc qn + unify (mkUsage loc "inferred rank type") t t' + pure $ Var qn (Info t') loc checkExp (Negate arg loc) = do arg' <- require "numeric negation" anyNumberType =<< checkExp arg pure $ Negate arg' loc @@ -1620,14 +1622,15 @@ arrayOfM loc t shape = do arrayElemType (mkUsage loc "use as array element") "type used in array" t pure $ arrayOf shape t -addInitialConstraints :: M.Map VName (TypeBase () NoUniqueness) -> TermTypeM () +addInitialConstraints :: M.Map (TypeBase () NoUniqueness) [VName] -> TermTypeM () addInitialConstraints = mapM_ f . M.toList where addConstraint v c = modifyConstraints $ M.insert v (0, c) - usage = mkUsage (mempty :: Loc) "trust me bro" - f (v, t) = do - (t', _) <- allDimsFreshInType usage Nonrigid "dv" t - addConstraint v $ Constraint (RetType [] t') usage + usage = mkUsage (mempty :: Loc) + f (t, vs) = do + (t', _) <- allDimsFreshInType (usage (prettyText t)) Nonrigid "dv" t + forM_ vs $ \v -> + addConstraint v $ Constraint (RetType [] t') $ usage $ prettyNameText v -- | Type-check a top-level (or module-level) function definition. -- Despite the name, this is also used for checking constant diff --git a/src/Language/Futhark/TypeChecker/Terms/Monad.hs b/src/Language/Futhark/TypeChecker/Terms/Monad.hs index 49c5226d29..e8525e0dfd 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Monad.hs @@ -347,22 +347,6 @@ instance MonadUnify TermTypeM where indent 2 (pretty t2) "do not match." --- | Instantiate a type scheme with fresh type variables for its type --- parameters. Returns the names of the fresh type variables, the --- instance list, and the instantiated type. -instantiateTypeScheme :: - QualName VName -> - SrcLoc -> - [TypeParam] -> - StructType -> - TermTypeM ([VName], StructType) -instantiateTypeScheme qn loc tparams t = do - let tnames = map typeParamName tparams - (tparam_names, tparam_substs) <- mapAndUnzipM (instantiateTypeParam qn loc) tparams - let substs = M.fromList $ zip tnames tparam_substs - t' = applySubst (`M.lookup` substs) t - pure (tparam_names, t') - -- | Create a new type name and insert it (unconstrained) in the -- substitution map. instantiateTypeParam :: @@ -385,6 +369,22 @@ instantiateTypeParam qn loc tparam = do "instantiated size parameter of " <> dquotes (pretty qn) pure (v, ExpSubst $ sizeFromName (qualName v) loc) +-- | Instantiate a type scheme with fresh type variables for its type +-- parameters. Returns the names of the fresh type variables, the +-- instance list, and the instantiated type. +instantiateTypeScheme :: + QualName VName -> + SrcLoc -> + [TypeParam] -> + StructType -> + TermTypeM ([VName], StructType) +instantiateTypeScheme qn loc tparams t = do + let tnames = map typeParamName tparams + (tparam_names, tparam_substs) <- mapAndUnzipM (instantiateTypeParam qn loc) tparams + let substs = M.fromList $ zip tnames tparam_substs + t' = applySubst (`M.lookup` substs) t + pure (tparam_names, t') + lookupQualNameEnv :: QualName VName -> TermTypeM TermScope lookupQualNameEnv (QualName [q] _) | baseTag q <= maxIntrinsicTag = asks termScope -- Magical intrinsic module. diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 85ade3c1f7..b8c17c57c9 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -366,23 +366,24 @@ lookupMod qn@(QualName _ name) = do lookupVar :: SrcLoc -> QualName VName -> TermM StructType lookupVar loc qn@(QualName qs name) = do scope <- lookupQualNameEnv qn - case M.lookup name $ scopeVtable scope of - Nothing -> - error $ "lookupVar: " <> show qn - Just (BoundV tparams t) -> do - if null tparams && null qs - then pure t - else do - (tnames, t') <- instTypeScheme qn loc tparams t - outer_env <- asks termOuterEnv - pure $ qualifyTypeVars outer_env tnames qs t' - Just EqualityF -> do - argtype <- newType loc "t" - pure $ foldFunType [argtype, argtype] $ RetType [] $ Scalar $ Prim Bool - Just (OverloadedF ts pts rt) -> do - argtype <- newTypeOverloaded loc "t" ts - let (pts', rt') = instOverloaded (argtype :: StructType) pts rt - pure $ foldFunType (map (toParam Observe) pts') $ RetType [] $ toRes Nonunique rt' + asStructType loc + =<< case M.lookup name $ scopeVtable scope of + Nothing -> + error $ "lookupVar: " <> show qn + Just (BoundV tparams t) -> do + if null tparams && null qs + then pure t + else do + (tnames, t') <- instTypeScheme qn loc tparams t + outer_env <- asks termOuterEnv + pure $ qualifyTypeVars outer_env tnames qs t' + Just EqualityF -> do + argtype <- newType loc "t" + pure $ foldFunType [argtype, argtype] $ RetType [] $ Scalar $ Prim Bool + Just (OverloadedF ts pts rt) -> do + argtype <- newTypeOverloaded loc "t" ts + let (pts', rt') = instOverloaded (argtype :: StructType) pts rt + pure $ foldFunType (map (toParam Observe) pts') $ RetType [] $ toRes Nonunique rt' where instOverloaded argtype pts rt = ( map (maybe argtype (Scalar . Prim)) pts, @@ -1018,7 +1019,7 @@ checkValDef :: SrcLoc ) -> TypeM - ( Either T.Text (M.Map TyVar Type), + ( Either T.Text Solution, [Pat ParamType], Maybe (TypeExp Exp VName), Exp @@ -1042,6 +1043,7 @@ checkValDef (fname, retdecl, tparams, params, body, _loc) = runTermM $ do "## tyvars:", unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars, "## solution:", - either T.unpack (unlines . map (prettyString . first prettyNameString) . M.toList) solution + let p (t, vs) = unwords (map prettyNameString vs) <> " => " <> prettyString t + in either T.unpack (unlines . map p . M.toList) solution ] pure (solution, params', retdecl', body') From 2bd5fe96e372b70d66cd2ace5bdc65cb4a4d27fd Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Sat, 10 Feb 2024 10:14:55 -0800 Subject: [PATCH 016/258] Start adding AUTOMAP machinery. --- src/Futhark/Internalise/Defunctionalise.hs | 4 +- src/Language/Futhark/Syntax.hs | 24 +++----- .../Futhark/TypeChecker/Constraints.hs | 36 ++++++++---- src/Language/Futhark/TypeChecker/Terms.hs | 2 +- src/Language/Futhark/TypeChecker/Terms2.hs | 58 ++++++++++++------- 5 files changed, 76 insertions(+), 48 deletions(-) diff --git a/src/Futhark/Internalise/Defunctionalise.hs b/src/Futhark/Internalise/Defunctionalise.hs index d6b03a368b..6cd66fbce6 100644 --- a/src/Futhark/Internalise/Defunctionalise.hs +++ b/src/Futhark/Internalise/Defunctionalise.hs @@ -910,7 +910,7 @@ liftedName _ _ = "defunc" defuncApplyArg :: String -> (Exp, StaticVal) -> - (((Maybe VName, AutoMap), Exp), [ParamType]) -> + (((Maybe VName, AutoMap Size), Exp), [ParamType]) -> DefM (Exp, StaticVal) defuncApplyArg fname_s (f', LambdaSV pat lam_e_t lam_e closure_env) (((argext, _), arg), _) = do (arg', arg_sv) <- defuncExp arg @@ -991,7 +991,7 @@ updateReturn (AppRes ret1 ext1) (AppExp apply (Info (AppRes ret2 ext2))) = AppExp apply $ Info $ AppRes (combineTypeShapes ret1 ret2) (ext1 <> ext2) updateReturn _ e = e -defuncApply :: Exp -> NE.NonEmpty ((Maybe VName, AutoMap), Exp) -> AppRes -> SrcLoc -> DefM (Exp, StaticVal) +defuncApply :: Exp -> NE.NonEmpty ((Maybe VName, AutoMap Size), Exp) -> AppRes -> SrcLoc -> DefM (Exp, StaticVal) defuncApply f args appres loc = do (f', f_sv) <- defuncApplyFunction f (length args) case f_sv of diff --git a/src/Language/Futhark/Syntax.hs b/src/Language/Futhark/Syntax.hs index b4b56aecf6..d1e98a165f 100644 --- a/src/Language/Futhark/Syntax.hs +++ b/src/Language/Futhark/Syntax.hs @@ -231,10 +231,7 @@ sizeFromInteger x = IntLit x (Info <$> Scalar $ Prim $ Signed Int64) -- | The size of an array type is a list of its dimension sizes. If -- 'Nothing', that dimension is of a (statically) unknown size. -data Shape dim - = Shape {shapeDims :: [dim]} - | SVar VName - | SConcat (Shape dim) (Shape dim) +newtype Shape dim = Shape {shapeDims :: [dim]} deriving (Eq, Ord, Show) instance Foldable Shape where @@ -248,9 +245,6 @@ instance Functor Shape where instance Semigroup (Shape dim) where Shape l1 <> Shape l2 = Shape $ l1 ++ l2 - Shape [] <> s = s - s <> Shape [] = s - s1 <> s2 = s1 `SConcat` s2 instance Monoid (Shape dim) where mempty = Shape [] @@ -267,17 +261,17 @@ stripDims i (Shape l) | i < length l = Just $ Shape $ drop i l | otherwise = Nothing -data AutoMap = AutoMap - { autoRep :: Shape Size, - autoMap :: Shape Size, - autoFrame :: Shape Size +data AutoMap u = AutoMap + { autoRep :: Shape u, + autoMap :: Shape u, + autoFrame :: Shape u } deriving (Eq, Show, Ord) -instance Semigroup AutoMap where +instance Semigroup (AutoMap u) where (AutoMap r1 m1 f1) <> (AutoMap r2 m2 f2) = AutoMap (r1 <> r2) (m1 <> m2) (f1 <> f2) -instance Monoid AutoMap where +instance Monoid (AutoMap u) where mempty = AutoMap mempty mempty mempty -- | The name (if any) of a function parameter. The 'Eq' and 'Ord' @@ -716,7 +710,7 @@ data AppExpBase f vn -- identical). Apply (ExpBase f vn) - (NE.NonEmpty (f (Maybe VName, AutoMap), ExpBase f vn)) + (NE.NonEmpty (f (Maybe VName, AutoMap Size), ExpBase f vn)) SrcLoc | Range (ExpBase f vn) @@ -1344,7 +1338,7 @@ deriving instance Show (ProgBase Info VName) deriving instance Show (ProgBase NoInfo Name) -- | Construct an 'Apply' node, with type information. -mkApply :: ExpBase Info vn -> [(Maybe VName, AutoMap, ExpBase Info vn)] -> AppRes -> ExpBase Info vn +mkApply :: ExpBase Info vn -> [(Maybe VName, AutoMap Size, ExpBase Info vn)] -> AppRes -> ExpBase Info vn mkApply f args (AppRes t ext) | Just args' <- NE.nonEmpty $ map onArg args = case f of diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index bd7c12f7bb..b6b5507283 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -1,5 +1,7 @@ module Language.Futhark.TypeChecker.Constraints - ( Type, + ( SVar, + SComp (..), + Type, toType, Ct (..), Constraints, @@ -16,29 +18,43 @@ import Control.Monad.State import Data.Bifunctor import Data.Map qualified as M import Data.Text qualified as T +import Futhark.IR.Pretty import Futhark.Util.Pretty import Language.Futhark --- | A shape component is currently just unit. The rank of an array is --- then just the number of shape components it contains in its shape --- list. When we add AUTOMAP, these components will also allow shape --- variables. The list of components should then be understood as --- concatenation of shapes (meaning you can't just take the length to --- determine the rank of the array). -type SComp = () +type SVar = VName + +-- | A shape component. `SDim` is a single dimension of unspecified +-- size, `SVar` is a shape variable. A list of shape components should +-- then be understood as concatenation of shapes (meaning you can't +-- just take the length to determine the rank of the array). +data SComp + = SDim + | SVar SVar + deriving (Eq, Ord, Show) + +instance Pretty SComp where + pretty (SDim) = "[]" + pretty (SVar x) = pretty x + +instance Pretty (Shape SComp) where + pretty = mconcat . map (brackets . pretty) . shapeDims -- | The type representation used by the constraint solver. Agnostic -- to sizes. type Type = TypeBase SComp NoUniqueness toType :: TypeBase d u -> Type -toType = bimap (const ()) (const NoUniqueness) +toType = bimap (const SDim) (const NoUniqueness) -data Ct = CtEq Type Type +data Ct + = CtEq Type Type + | CtAM SVar SVar deriving (Show) instance Pretty Ct where pretty (CtEq t1 t2) = pretty t1 <+> "~" <+> pretty t2 + pretty (CtAM r m) = pretty r <+> "=" <+> "•" <+> "∨" <+> pretty m <+> "=" <+> "•" type Constraints = [Ct] diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index b55dca1ad8..2a7bbbd6da 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -1656,7 +1656,7 @@ checkFunDef (fname, retdecl, tparams, params, body, loc) = do case maybe_tysubsts of Left err -> typeError loc mempty $ pretty err Right tysubsts -> runTermTypeM checkExp $ do - addInitialConstraints tysubsts + addInitialConstraints $ M.mapKeys (first $ const ()) tysubsts (tparams', params'', retdecl'', RetType dims rettype', body'') <- checkBinding (fname, retdecl', tparams, params', body', loc) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index b8c17c57c9..a037e0f2ec 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -235,6 +235,12 @@ newTypeOverloaded :: (Monoid u) => SrcLoc -> Name -> [PrimType] -> TermM (TypeBa newTypeOverloaded loc name pts = tyVarType <$> newTyVarWith loc name (TyVarPrim pts) +newSVar :: (Located loc) => loc -> Name -> TermM SVar +newSVar _loc desc = do + i <- incCounter + v <- newID $ mkTypeVarName desc i + pure v + asStructType :: (Monoid u) => SrcLoc -> TypeBase d u -> TermM (TypeBase Size u) asStructType _ (Scalar (Prim pt)) = pure $ Scalar $ Prim pt asStructType _ (Scalar (TypeVar u v [])) = pure $ Scalar $ TypeVar u v [] @@ -259,6 +265,9 @@ ctEq t1 t2 = t1' = toType t1 t2' = toType t2 +ctAM :: SVar -> SVar -> TermM () +ctAM r m = addCt $ CtAM r m + localScope :: (TermScope -> TermScope) -> TermM a -> TermM a localScope f = local $ \tenv -> tenv {termScope = f $ termScope tenv} @@ -328,7 +337,7 @@ instance MonadTypeChecker TermM where arrayOfRank :: Int -> Type -> Type arrayOfRank 0 t = t -arrayOfRank n t = arrayOf (Shape $ replicate n ()) t +arrayOfRank n t = arrayOf (Shape $ replicate n SDim) t require :: T.Text -> [PrimType] -> Exp -> TermM Exp require _why pts e = do @@ -346,13 +355,14 @@ instTypeScheme :: StructType -> TermM ([VName], StructType) instTypeScheme _qn loc tparams t = do - (names, substs) <- fmap (unzip . catMaybes) $ forM tparams $ \tparam -> do - case tparam of - TypeParamType _ v _ -> do - v' <- newTyVar loc $ nameFromString $ takeWhile isAscii $ baseString v - pure $ Just (v, (typeParamName tparam, Subst [] $ RetType [] $ tyVarType v')) - TypeParamDim {} -> - pure Nothing + (names, substs) <- fmap (unzip . catMaybes) $ + forM tparams $ \tparam -> do + case tparam of + TypeParamType _ v _ -> do + v' <- newTyVar loc $ nameFromString $ takeWhile isAscii $ baseString v + pure $ Just (v, (typeParamName tparam, Subst [] $ RetType [] $ tyVarType v')) + TypeParamDim {} -> + pure Nothing let t' = applySubst (`lookup` substs) t pure (names, t') @@ -575,16 +585,24 @@ bindParams tps orig_ps m = bindTypeParams tps $ do incLevel $ descend [] orig_ps -checkApply :: SrcLoc -> (Maybe (QualName VName), Int) -> Type -> Exp -> TermM Type -checkApply _ _ (Scalar (Arrow _ _ _ a (RetType _ b))) arg = do - ctEq a $ expType arg - pure $ toType b +checkApply :: SrcLoc -> (Maybe (QualName VName), Int) -> Type -> Exp -> TermM (Type, AutoMap SComp) checkApply loc _ ftype arg = do - a <- newType loc "arg" - b <- newTyVar loc "res" - ctEq ftype $ Scalar $ Arrow NoUniqueness Unnamed Observe a $ RetType [] (tyVarType b) - ctEq a (expType arg) - pure $ tyVarType b + (a, b) <- split ftype + r <- newSVar loc "R" + m <- newSVar loc "M" + let s_r = Shape $ pure $ SVar r + s_m = Shape $ pure $ SVar m + ctAM r m + ctEq (arrayOf s_r $ toType $ typeOf arg) (arrayOf s_m a) + pure (arrayOf s_m b, AutoMap {autoRep = s_r, autoMap = s_m, autoFrame = mempty}) + where + split (Scalar (Arrow _ _ _ a (RetType _ b))) = + pure (a, toType b) + split ftype' = do + a <- newType loc "arg" + b <- newTyVar loc "res" + ctEq ftype' $ Scalar $ Arrow NoUniqueness Unnamed Observe a $ RetType [] (tyVarType b) + pure (a, tyVarType b) checkSlice :: SliceBase NoInfo VName -> TermM [DimIndex] checkSlice = mapM checkDimIndex @@ -780,7 +798,7 @@ checkExp (AppExp (Apply fe args loc) NoInfo) = do onArg (i, f_t) (_, arg) = do arg' <- checkExp arg - rt <- checkApply loc (fname, i) (toType f_t) arg' + (rt, am) <- checkApply loc (fname, i) (toType f_t) arg' pure ( (i + 1, rt), (Info (Nothing, mempty), arg') @@ -791,8 +809,8 @@ checkExp (AppExp (BinOp (op, oploc) NoInfo (e1, _) (e2, _) loc) NoInfo) = do e1' <- checkExp e1 e2' <- checkExp e2 - rt1 <- checkApply loc (Just op, 0) (toType ftype) e1' - rt2 <- checkApply loc (Just op, 1) rt1 e2' + (rt1, am1) <- checkApply loc (Just op, 0) (toType ftype) e1' + (rt2, am2) <- checkApply loc (Just op, 1) rt1 e2' rt2' <- asStructType loc rt2 pure $ From 55cb35465d4f8504af712db62e4c4e7469db22ec Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Sat, 10 Feb 2024 11:00:06 -0800 Subject: [PATCH 017/258] Use `Shape Size` for AUTOMAP. --- src/Futhark/Internalise/Defunctionalise.hs | 4 ++-- src/Language/Futhark/Syntax.hs | 16 ++++++++-------- src/Language/Futhark/TypeChecker/Terms2.hs | 15 ++++++++++----- 3 files changed, 20 insertions(+), 15 deletions(-) diff --git a/src/Futhark/Internalise/Defunctionalise.hs b/src/Futhark/Internalise/Defunctionalise.hs index 6cd66fbce6..d6b03a368b 100644 --- a/src/Futhark/Internalise/Defunctionalise.hs +++ b/src/Futhark/Internalise/Defunctionalise.hs @@ -910,7 +910,7 @@ liftedName _ _ = "defunc" defuncApplyArg :: String -> (Exp, StaticVal) -> - (((Maybe VName, AutoMap Size), Exp), [ParamType]) -> + (((Maybe VName, AutoMap), Exp), [ParamType]) -> DefM (Exp, StaticVal) defuncApplyArg fname_s (f', LambdaSV pat lam_e_t lam_e closure_env) (((argext, _), arg), _) = do (arg', arg_sv) <- defuncExp arg @@ -991,7 +991,7 @@ updateReturn (AppRes ret1 ext1) (AppExp apply (Info (AppRes ret2 ext2))) = AppExp apply $ Info $ AppRes (combineTypeShapes ret1 ret2) (ext1 <> ext2) updateReturn _ e = e -defuncApply :: Exp -> NE.NonEmpty ((Maybe VName, AutoMap Size), Exp) -> AppRes -> SrcLoc -> DefM (Exp, StaticVal) +defuncApply :: Exp -> NE.NonEmpty ((Maybe VName, AutoMap), Exp) -> AppRes -> SrcLoc -> DefM (Exp, StaticVal) defuncApply f args appres loc = do (f', f_sv) <- defuncApplyFunction f (length args) case f_sv of diff --git a/src/Language/Futhark/Syntax.hs b/src/Language/Futhark/Syntax.hs index d1e98a165f..f0f2a586df 100644 --- a/src/Language/Futhark/Syntax.hs +++ b/src/Language/Futhark/Syntax.hs @@ -261,17 +261,17 @@ stripDims i (Shape l) | i < length l = Just $ Shape $ drop i l | otherwise = Nothing -data AutoMap u = AutoMap - { autoRep :: Shape u, - autoMap :: Shape u, - autoFrame :: Shape u +data AutoMap = AutoMap + { autoRep :: Shape Size, + autoMap :: Shape Size, + autoFrame :: Shape Size } deriving (Eq, Show, Ord) -instance Semigroup (AutoMap u) where +instance Semigroup AutoMap where (AutoMap r1 m1 f1) <> (AutoMap r2 m2 f2) = AutoMap (r1 <> r2) (m1 <> m2) (f1 <> f2) -instance Monoid (AutoMap u) where +instance Monoid AutoMap where mempty = AutoMap mempty mempty mempty -- | The name (if any) of a function parameter. The 'Eq' and 'Ord' @@ -710,7 +710,7 @@ data AppExpBase f vn -- identical). Apply (ExpBase f vn) - (NE.NonEmpty (f (Maybe VName, AutoMap Size), ExpBase f vn)) + (NE.NonEmpty (f (Maybe VName, AutoMap), ExpBase f vn)) SrcLoc | Range (ExpBase f vn) @@ -1338,7 +1338,7 @@ deriving instance Show (ProgBase Info VName) deriving instance Show (ProgBase NoInfo Name) -- | Construct an 'Apply' node, with type information. -mkApply :: ExpBase Info vn -> [(Maybe VName, AutoMap Size, ExpBase Info vn)] -> AppRes -> ExpBase Info vn +mkApply :: ExpBase Info vn -> [(Maybe VName, AutoMap, ExpBase Info vn)] -> AppRes -> ExpBase Info vn mkApply f args (AppRes t ext) | Just args' <- NE.nonEmpty $ map onArg args = case f of diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index a037e0f2ec..e6e9b26776 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -585,17 +585,22 @@ bindParams tps orig_ps m = bindTypeParams tps $ do incLevel $ descend [] orig_ps -checkApply :: SrcLoc -> (Maybe (QualName VName), Int) -> Type -> Exp -> TermM (Type, AutoMap SComp) +checkApply :: SrcLoc -> (Maybe (QualName VName), Int) -> Type -> Exp -> TermM (Type, AutoMap) checkApply loc _ ftype arg = do (a, b) <- split ftype r <- newSVar loc "R" m <- newSVar loc "M" - let s_r = Shape $ pure $ SVar r - s_m = Shape $ pure $ SVar m + let unit_info = Info $ Scalar $ Prim Bool + r_var = Var (QualName [] r) unit_info mempty + m_var = Var (QualName [] r) unit_info mempty ctAM r m - ctEq (arrayOf s_r $ toType $ typeOf arg) (arrayOf s_m a) - pure (arrayOf s_m b, AutoMap {autoRep = s_r, autoMap = s_m, autoFrame = mempty}) + ctEq (arrayOf (toShape $ SVar r) $ toType $ typeOf arg) (arrayOf (toShape $ SVar m) a) + pure + ( arrayOf (toShape $ SVar m) b, + AutoMap {autoRep = toShape r_var, autoMap = toShape m_var, autoFrame = mempty} + ) where + toShape = Shape . pure split (Scalar (Arrow _ _ _ a (RetType _ b))) = pure (a, toType b) split ftype' = do From a67a1158f7870565f3572c324aecb5b8f3d5502c Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Sat, 10 Feb 2024 13:20:39 -0800 Subject: [PATCH 018/258] Add rank analysis stuff. --- futhark.cabal | 1 + src/Futhark/Solve/LP.hs | 40 ++++---- src/Language/Futhark/TypeChecker/Rank.hs | 108 +++++++++++++++++++++ src/Language/Futhark/TypeChecker/Terms2.hs | 2 +- 4 files changed, 130 insertions(+), 21 deletions(-) create mode 100644 src/Language/Futhark/TypeChecker/Rank.hs diff --git a/futhark.cabal b/futhark.cabal index 1c0531a83b..80006d4cc9 100644 --- a/futhark.cabal +++ b/futhark.cabal @@ -418,6 +418,7 @@ library Language.Futhark.TypeChecker.Match Language.Futhark.TypeChecker.Modules Language.Futhark.TypeChecker.Monad + Language.Futhark.TypeChecker.Rank Language.Futhark.TypeChecker.Terms Language.Futhark.TypeChecker.Terms2 Language.Futhark.TypeChecker.Terms.Loop diff --git a/src/Futhark/Solve/LP.hs b/src/Futhark/Solve/LP.hs index 11e943a1b1..3b46af1965 100644 --- a/src/Futhark/Solve/LP.hs +++ b/src/Futhark/Solve/LP.hs @@ -20,9 +20,9 @@ module Futhark.Solve.LP LinearProg (..), OptType (..), Constraint (..), - (==), - (<=), - (>=), + (~==~), + (~<=~), + (~>=~), rowEchelonLPE, ) where @@ -143,47 +143,47 @@ instance (Show v, Show a) => Show (LinearProg v a) where bigM :: (Num a) => a bigM = 10 ^ 3 -oneIsZero :: (Eq a, Num a, Ord v) => v -> v -> v -> v -> [Constraint v a] -oneIsZero b1 b2 x1 x2 = +oneIsZero :: (Eq a, Num a, Ord v) => (v, v) -> (v, v) -> [Constraint v a] +oneIsZero (b1, x1) (b2, x2) = mkC b1 x1 <> mkC b2 x2 - <> [(var b1 ~+~ var b2) <= constant 1] + <> [(var b1 ~+~ var b2) ~<=~ constant 1] where mkC b x = - [ var x <= bigM ~*~ var b + [ var x ~<=~ bigM ~*~ var b ] or :: (Eq a, Num a, Ord v) => v -> v -> Constraint v a -> Constraint v a -> [Constraint v a] or b1 b2 c1 c2 = mkC b1 c1 <> mkC b2 c2 - <> [var b1 ~+~ var b2 <= constant 1] + <> [var b1 ~+~ var b2 ~<=~ constant 1] where mkC b (Constraint Equal l r) = - [ l <= r ~+~ bigM ~*~ (constant 1 ~-~ var b), - l >= r ~-~ bigM ~*~ (constant 1 ~-~ var b) + [ l ~<=~ r ~+~ bigM ~*~ (constant 1 ~-~ var b), + l ~>=~ r ~-~ bigM ~*~ (constant 1 ~-~ var b) ] mkC b (Constraint LessEq l r) = - [ l <= r ~+~ bigM ~*~ (constant 1 ~-~ var b) + [ l ~<=~ r ~+~ bigM ~*~ (constant 1 ~-~ var b) ] bin :: (Num a, Ord v) => v -> Constraint v a bin v = Constraint LessEq (var v) (constant 1) -(==) :: (Num a, Ord v) => LSum v a -> LSum v a -> Constraint v a -l == r = Constraint Equal l r +(~==~) :: (Num a, Ord v) => LSum v a -> LSum v a -> Constraint v a +l ~==~ r = Constraint Equal l r -infix 4 == +infix 4 ~==~ -(<=) :: (Num a, Ord v) => LSum v a -> LSum v a -> Constraint v a -l <= r = Constraint LessEq l r +(~<=~) :: (Num a, Ord v) => LSum v a -> LSum v a -> Constraint v a +l ~<=~ r = Constraint LessEq l r -infix 4 <= +infix 4 ~<=~ -(>=) :: (Num a, Ord v) => LSum v a -> LSum v a -> Constraint v a -l >= r = Constraint LessEq (neg l) (neg r) +(~>=~) :: (Num a, Ord v) => LSum v a -> LSum v a -> Constraint v a +l ~>=~ r = Constraint LessEq (neg l) (neg r) -infix 4 >= +infix 4 ~>=~ normalize :: (Eq a, Num a) => LSum v a -> LSum v a normalize = LSum . Map.filter (/= 0) . lsum diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs new file mode 100644 index 0000000000..e9bf7de859 --- /dev/null +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -0,0 +1,108 @@ +module Language.Futhark.TypeChecker.Rank (rankAnalysis) where + +import Control.Monad.State +import Data.Map (Map) +import Data.Map qualified as M +import Data.Vector.Unboxed qualified as V +import Futhark.Solve.BranchAndBound +import Futhark.Solve.LP hiding (Constraint, LSum, LinearProg) +import Futhark.Solve.LP qualified as LP +import Language.Futhark hiding (ScalarType) +import Language.Futhark.TypeChecker.Constraints + +type LSum = LP.LSum VName Double + +type Constraint = LP.Constraint VName Double + +type LinearProg = LP.LinearProg VName Double + +type ScalarType = ScalarTypeBase SComp NoUniqueness + +class Rank a where + rank :: a -> LSum + +instance Rank SComp where + rank SDim = constant 1 + rank (SVar v) = var v + +instance Rank (Shape SComp) where + rank = foldr (\d r -> rank d ~+~ r) (constant 0) . shapeDims + +instance Rank ScalarType where + rank Prim {} = constant 0 + rank (TypeVar _ (QualName [] v) []) = var v + rank (Arrow {}) = constant 0 + rank t = error $ prettyString t + +instance Rank Type where + rank (Scalar t) = rank t + rank (Array _ shape t) = rank shape ~+~ rank t + +data RankState = RankState + { rankBinVars :: Map VName VName, + rankCounter :: !Int, + rankConstraints :: [Constraint] + } + +newtype RankM a = RankM {runRankM :: State RankState a} + deriving (Functor, Applicative, Monad, MonadState RankState) + +incCounter :: RankM Int +incCounter = do + s <- get + put s {rankCounter = rankCounter s + 1} + pure $ rankCounter s + +binVar :: VName -> RankM (VName) +binVar sv = do + mbv <- (M.!? sv) <$> gets rankBinVars + case mbv of + Nothing -> do + bv <- VName ("b_" <> baseName sv) <$> incCounter + modify $ \s -> + s + { rankBinVars = M.insert sv bv $ rankBinVars s, + rankConstraints = rankConstraints s ++ [bin bv] + } + pure bv + Just bv -> pure bv + +addConstraints :: [Constraint] -> RankM () +addConstraints cs = + modify $ \s -> s {rankConstraints = rankConstraints s ++ cs} + +addConstraint :: Constraint -> RankM () +addConstraint = addConstraints . pure + +addCt :: Ct -> RankM () +addCt (CtEq t1 t2) = addConstraint $ rank t1 ~==~ rank t2 +addCt (CtAM r m) = do + b_r <- binVar r + b_m <- binVar m + addConstraints $ oneIsZero (b_r, r) (b_m, m) + +mkLinearProg :: Int -> [Ct] -> LinearProg +mkLinearProg counter cs = + LP.LinearProg + { optType = Minimize, + objective = + let shape_vars = M.keys $ rankBinVars finalState + in foldr (\sv s -> var sv ~+~ s) (constant 0) shape_vars, + constraints = rankConstraints finalState + } + where + initState = + RankState + { rankBinVars = mempty, + rankCounter = counter, + rankConstraints = mempty + } + finalState = flip execState initState $ runRankM $ mapM_ addCt cs + +rankAnalysis :: Int -> [Ct] -> Maybe (Map VName Int) +rankAnalysis counter cs = do + (_size, ranks) <- branchAndBound lp + pure $ (ranks V.!) <$> inv_var_map + where + (lp, var_map) = linearProgToLP $ mkLinearProg counter cs + inv_var_map = M.fromListWith (error "oh no!") [(v, k) | (k, v) <- M.toList var_map] diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index e6e9b26776..a9f6444980 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -806,7 +806,7 @@ checkExp (AppExp (Apply fe args loc) NoInfo) = do (rt, am) <- checkApply loc (fname, i) (toType f_t) arg' pure ( (i + 1, rt), - (Info (Nothing, mempty), arg') + (Info (Nothing, am), arg') ) -- checkExp (AppExp (BinOp (op, oploc) NoInfo (e1, _) (e2, _) loc) NoInfo) = do From 9318e211843774be85594f22097ca8b3ac75b0d5 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Sat, 10 Feb 2024 22:32:41 -0800 Subject: [PATCH 019/258] Starting to integrate the rank solver. --- src/Futhark/Solve/LP.hs | 55 ++++++++++++---------- src/Language/Futhark/TypeChecker/Rank.hs | 32 ++++++++++++- src/Language/Futhark/TypeChecker/Terms2.hs | 33 +++++++++++++ 3 files changed, 93 insertions(+), 27 deletions(-) diff --git a/src/Futhark/Solve/LP.hs b/src/Futhark/Solve/LP.hs index 3b46af1965..d5c0ee6c5e 100644 --- a/src/Futhark/Solve/LP.hs +++ b/src/Futhark/Solve/LP.hs @@ -36,7 +36,9 @@ import Data.Vector.Unboxed qualified as V import Debug.Trace import Futhark.Solve.Matrix (Matrix (..)) import Futhark.Solve.Matrix qualified as M -import Prelude hiding (or, (<=), (==), (>=)) +import Futhark.Util.Pretty +import Language.Futhark.Pretty +import Prelude hiding (or) import Prelude qualified -- | A linear program. 'LP c a d' represents the program @@ -86,17 +88,16 @@ convert (LP c a d) = LPE c' a' d -- | Linear sum of variables. newtype LSum v a = LSum {lsum :: (Map (Maybe v) a)} - deriving (Eq) + deriving (Show, Eq) -instance (Show v, Show a) => Show (LSum v a) where - show (LSum m) = - L.intercalate - " + " +instance (IsName v, Pretty a, Eq a, Num a) => Pretty (LSum v a) where + pretty (LSum m) = + concatWith (surround " + ") $ map ( \(k, a) -> case k of - Nothing -> show a - Just k' -> show a <> "*" <> show k' + Nothing -> pretty a + Just k' -> (if a == 1 then mempty else pretty a <> "·") <> prettyName k' ) $ Map.toList m @@ -105,40 +106,44 @@ instance Functor (LSum v) where -- | Type of constraint data CType = Equal | LessEq - deriving (Eq) + deriving (Show, Eq) -instance Show CType where - show (Equal) = "=" - show (LessEq) = "<=" +instance Pretty CType where + pretty Equal = "=" + pretty LessEq = "<=" -- | A constraint for a linear program. data Constraint v a = Constraint CType (LSum v a) (LSum v a) - deriving (Eq) + deriving (Show, Eq) -instance (Show a, Show v) => Show (Constraint v a) where - show (Constraint t l r) = - show l <> " " <> show t <> " " <> show r +instance (IsName v, Pretty a, Eq a, Num a) => Pretty (Constraint v a) where + pretty (Constraint t l r) = + pretty l <+> pretty t <+> pretty r data OptType = Maximize | Minimize deriving (Show, Eq) +instance Pretty OptType where + pretty Maximize = "maximize" + pretty Minimize = "minimize" + -- | A linear program. data LinearProg v a = LinearProg { optType :: OptType, objective :: LSum v a, constraints :: [Constraint v a] } - deriving (Eq) - -instance (Show v, Show a) => Show (LinearProg v a) where - show (LinearProg opt obj cs) = - unlines $ - [ show opt, - show obj, - "subject to:" + deriving (Show, Eq) + +instance (IsName v, Pretty a, Eq a, Num a) => Pretty (LinearProg v a) where + pretty (LinearProg opt obj cs) = + vcat $ + [ pretty opt, + indent 2 $ pretty obj, + "subject to", + indent 2 $ vcat $ map pretty cs ] - ++ map show cs bigM :: (Num a) => a bigM = 10 ^ 3 diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index e9bf7de859..5ef9f72594 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -3,10 +3,13 @@ module Language.Futhark.TypeChecker.Rank (rankAnalysis) where import Control.Monad.State import Data.Map (Map) import Data.Map qualified as M +import Data.Maybe import Data.Vector.Unboxed qualified as V +import Debug.Trace import Futhark.Solve.BranchAndBound import Futhark.Solve.LP hiding (Constraint, LSum, LinearProg) import Futhark.Solve.LP qualified as LP +import Futhark.Solve.Simplex import Language.Futhark hiding (ScalarType) import Language.Futhark.TypeChecker.Constraints @@ -38,6 +41,20 @@ instance Rank Type where rank (Scalar t) = rank t rank (Array _ shape t) = rank shape ~+~ rank t +class Distribute a where + distribute :: a -> a + +instance Distribute Type where + distribute = distributeOne + where + distributeOne (Array _ s (Arrow _ _ _ ta (RetType rd tr))) = + Scalar $ Arrow NoUniqueness Unnamed mempty (arrayOf s ta) (RetType rd $ arrayOfWithAliases Nonunique s $ tr) + distributeOne t = t + +instance Distribute Ct where + distribute (CtEq t1 t2) = distribute t1 `CtEq` distribute t2 + distribute c = c + data RankState = RankState { rankBinVars :: Map VName VName, rankCounter :: !Int, @@ -101,8 +118,19 @@ mkLinearProg counter cs = rankAnalysis :: Int -> [Ct] -> Maybe (Map VName Int) rankAnalysis counter cs = do + traceM $ unlines $ concat $ map (\c -> [prettyString c, show c]) cs' + traceM $ prettyString prog (_size, ranks) <- branchAndBound lp - pure $ (ranks V.!) <$> inv_var_map + pure $ (fromJust . (ranks V.!?)) <$> inv_var_map where - (lp, var_map) = linearProgToLP $ mkLinearProg counter cs + splitFuncs + ( CtEq + (Scalar (Arrow _ _ _ t1a (RetType _ t1r))) + (Scalar (Arrow _ _ _ t2a (RetType _ t2r))) + ) = + splitFuncs (CtEq t1a t2a) ++ splitFuncs (CtEq (toType t1r) (toType t2r)) + splitFuncs c = [c] + cs' = foldMap (splitFuncs . distribute) cs + prog = mkLinearProg counter cs' + (lp, var_map) = linearProgToLP prog inv_var_map = M.fromListWith (error "oh no!") [(v, k) | (k, v) <- M.toList var_map] diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index a9f6444980..3f233a05c6 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -59,6 +59,7 @@ import Language.Futhark import Language.Futhark.TypeChecker.Constraints import Language.Futhark.TypeChecker.Monad hiding (BoundV, lookupMod) import Language.Futhark.TypeChecker.Monad qualified as TypeM +import Language.Futhark.TypeChecker.Rank import Language.Futhark.TypeChecker.Types import Language.Futhark.TypeChecker.Unify (Level) import Prelude hiding (mod) @@ -1070,3 +1071,35 @@ checkValDef (fname, retdecl, tparams, params, body, _loc) = runTermM $ do in either T.unpack (unlines . map p . M.toList) solution ] pure (solution, params', retdecl', body') + +-- checkValDef (fname, retdecl, tparams, params, body, _loc) = runTermM $ do +-- bindParams tparams params $ \params' -> do +-- body' <- checkExp body +-- +-- retdecl' <- checkRetDecl body' retdecl +-- +-- cts <- gets termConstraints +-- +-- counter <- gets termCounter +-- +-- traceM $ unlines $ map prettyString cts +-- +-- case rankAnalysis counter cts of +-- Nothing -> error "" +-- Just rank_map -> do +-- tyvars <- gets termTyVars +-- +-- let solution = solve cts tyvars +-- +-- traceM $ +-- unlines +-- [ "# function " <> prettyNameString fname, +-- "## constraints:", +-- unlines $ map prettyString cts, +-- "## tyvars:", +-- unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars, +-- "## solution:", +-- let p (t, vs) = unwords (map prettyNameString vs) <> " => " <> prettyString t +-- in either T.unpack (unlines . map p . M.toList) solution +-- ] +-- pure (solution, params', retdecl', body') From e109d19e9cf03b40fb19dbe5f482b47040b1aac0 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Sat, 10 Feb 2024 22:41:27 -0800 Subject: [PATCH 020/258] Add LP/ILP unit tests. --- futhark.cabal | 3 + .../Futhark/Solve/BranchAndBoundTests.hs | 120 +++++++++++ unittests/Futhark/Solve/SimplexTests.hs | 189 ++++++++++++++++++ unittests/futhark_tests.hs | 6 +- 4 files changed, 317 insertions(+), 1 deletion(-) create mode 100644 unittests/Futhark/Solve/BranchAndBoundTests.hs create mode 100644 unittests/Futhark/Solve/SimplexTests.hs diff --git a/futhark.cabal b/futhark.cabal index 80006d4cc9..0b66baf4cc 100644 --- a/futhark.cabal +++ b/futhark.cabal @@ -523,6 +523,8 @@ test-suite unit Futhark.Optimise.MemoryBlockMerging.GreedyColoringTests Futhark.Pkg.SolveTests Futhark.ProfileTests + Futhark.Solve.BranchAndBoundTests + Futhark.Solve.SimplexTests Language.Futhark.CoreTests Language.Futhark.PrimitiveTests Language.Futhark.SyntaxTests @@ -540,3 +542,4 @@ test-suite unit , tasty-hunit , tasty-quickcheck , text + , vector >=0.12 diff --git a/unittests/Futhark/Solve/BranchAndBoundTests.hs b/unittests/Futhark/Solve/BranchAndBoundTests.hs new file mode 100644 index 0000000000..10867a1bee --- /dev/null +++ b/unittests/Futhark/Solve/BranchAndBoundTests.hs @@ -0,0 +1,120 @@ +module Futhark.Solve.BranchAndBoundTests + ( tests, + ) +where + +import Data.Vector.Unboxed qualified as V +import Futhark.Solve.BranchAndBound +import Futhark.Solve.LP +import Futhark.Solve.Matrix qualified as M +import Test.Tasty +import Test.Tasty.HUnit +import Prelude hiding (or) +import Prelude qualified + +tests :: TestTree +tests = + testGroup + "BranchAndBoundTests" + [ -- testCase "1" $ + -- let lpe = + -- LPE + -- { pc = V.fromList [1, 1, 0, 0, 0], + -- pA = + -- M.fromLists + -- [ [-1, 1, 1, 0, 0], + -- [1, 0, 0, 1, 0], + -- [0, 1, 0, 0, 1] + -- ], + -- pd = V.fromList [1, 3, 2] + -- } + -- in simplex lpe @?= Just (5 :: Double, V.fromList [3, 2, 2, 0, 0]), + testCase "2" $ + let lp = + LP + { lpc = V.fromList [40, 30], + lpA = + M.fromLists + [ [1, 1], + [2, 1] + ], + lpd = V.fromList [12, 16] + } + in branchAndBound lp @?= Just (400 :: Double, V.fromList [4, 8]), + testCase "3" $ + let lp = + LP + { lpc = V.fromList [1, 2, 3], + lpA = + M.fromLists + [ [1, 1, 1], + [2, 1, 3] + ], + lpd = V.fromList [12, 18] + } + in branchAndBound lp @?= Just (27 :: Double, V.fromList [0, 9, 3]), + testCase "4" $ + let lp = + LP + { lpc = V.fromList [5.5, 2.1], + lpA = + M.fromLists + [ [-1, 1], + [8, 2] + ], + lpd = V.fromList [2, 17] + } + in assertBool (show $ branchAndBound lp) $ + case branchAndBound lp of + Nothing -> False + Just (z, sol) -> + and + [ z `approxEq` (11.8 :: Double), + and $ zipWith (==) (V.toList sol) [1, 3] + ], + testCase "5" $ + let prog = + LinearProg + { optType = Maximize, + objective = var "x1" ~+~ var "x2", + constraints = + [ var "x1" ~<=~ constant 10, + var "x2" ~<=~ constant 5 + ] + <> oneIsZero ("b1", "x1") ("b2", "x2") + } + (lp, idxmap) = linearProgToLP prog + lpe = convert lp + in assertBool + (unlines [show $ branchAndBound lp]) + $ case branchAndBound lp of + Nothing -> False + Just (z, sol) -> + and + [ z `approxEq` (10 :: Double) + ] + -- testCase "6" $ + -- let prog = + -- LinearProg + -- { optType = Maximize, + -- objective = var "x1" ~+~ var "x2", + -- constraints = + -- [ var "x1" ~<=~ constant 10, + -- var "x2" ~<=~ constant 5 + -- ] + -- <> or "b1" "b2" (var "x1" ~==~ constant 0) (var "x2" ~==~ constant 0) + -- } + -- (lp, idxmap) = linearProgToLP prog + -- lpe = convert lp + -- in assertBool + -- (unlines [show $ branchAndBound lp]) + -- $ case branchAndBound lp of + -- Nothing -> False + -- Just (z, sol) -> + -- and + -- [ z `approxEq` (10 :: Double) + -- ] + ] + +approxEq :: (Fractional a, Ord a) => a -> a -> Bool +approxEq x1 x2 = (abs $ x1 - x2) < 10 ^^ (-10 :: Int) diff --git a/unittests/Futhark/Solve/SimplexTests.hs b/unittests/Futhark/Solve/SimplexTests.hs new file mode 100644 index 0000000000..80eee3237e --- /dev/null +++ b/unittests/Futhark/Solve/SimplexTests.hs @@ -0,0 +1,189 @@ +module Futhark.Solve.SimplexTests + ( tests, + ) +where + +import Data.Vector.Unboxed qualified as V +import Debug.Trace +import Futhark.Solve.LP +import Futhark.Solve.Matrix qualified as M +import Futhark.Solve.Simplex +import Test.Tasty +import Test.Tasty.HUnit +import Prelude hiding (or) +import Prelude qualified + +tests :: TestTree +tests = + testGroup + "SimplexTests" + [ testCase "1" $ + let lpe = + LPE + { pc = V.fromList [1, 1, 0, 0, 0], + pA = + M.fromLists + [ [-1, 1, 1, 0, 0], + [1, 0, 0, 1, 0], + [0, 1, 0, 0, 1] + ], + pd = V.fromList [1, 3, 2] + } + in simplex lpe @?= Just (5 :: Double, V.fromList [3, 2, 2, 0, 0]), + testCase "2" $ + let lp = + LP + { lpc = V.fromList [40, 30], + lpA = + M.fromLists + [ [1, 1], + [2, 1] + ], + lpd = V.fromList [12, 16] + } + in simplexLP lp @?= Just (400 :: Double, V.fromList [4, 8]), + testCase "3" $ + let lp = + LP + { lpc = V.fromList [1, 2, 3], + lpA = + M.fromLists + [ [1, 1, 1], + [2, 1, 3] + ], + lpd = V.fromList [12, 18] + } + in simplexLP lp @?= Just (27 :: Double, V.fromList [0, 9, 3]), + testCase "4" $ + let lp = + LP + { lpc = V.fromList [5.5, 2.1], + lpA = + M.fromLists + [ [-1, 1], + [8, 2] + ], + lpd = V.fromList [2, 17] + } + in assertBool (show $ simplexLP lp) $ + case simplexLP lp of + Nothing -> False + Just (z, sol) -> + and + [ z `approxEq` (14.08 :: Double), + and $ zipWith approxEq (V.toList sol) [1.3, 3.3] + ], + testCase "5" $ + let lp = + LP + { lpc = V.fromList [0], + lpA = + M.fromLists + [ [1], + [-1] + ], + lpd = V.fromList [0, 0] + } + in assertBool (show $ simplexLP lp) $ + case simplexLP lp of + Nothing -> False + Just (z, sol) -> + and + [ z `approxEq` (0 :: Double), + and $ zipWith approxEq (V.toList sol) [0] + ], + testCase "6" $ + let lp = + LP + { lpc = V.fromList [1], + lpA = + M.fromLists + [ [1], + [-1] + ], + lpd = V.fromList [5, 5] + } + in assertBool (show $ simplexLP lp) $ + case simplexLP lp of + Nothing -> False + Just (z, sol) -> + and + [ z `approxEq` (5 :: Double), + and $ zipWith approxEq (V.toList sol) [5] + ], + testCase "7" $ + let prog = + LinearProg + { optType = Maximize, + objective = var "x1", + constraints = + [ var "x1" ~<=~ 10 ~*~ var "b1", + var "b1" ~+~ var "b2" ~<=~ constant 1 + ] + } + (lp, idxmap) = linearProgToLP prog + lpe = convert lp + in trace + (unlines [show prog, show lp, show idxmap, show lpe]) + ( assertBool + (unlines [show $ simplexLP lp]) + $ case simplexLP lp of + Nothing -> False + Just (z, sol) -> + and + [ z `approxEq` (10 :: Double), + and $ zipWith (==) (V.toList sol) [1, 0, 10] + ] + ), + testCase "8" $ + let prog = + LinearProg + { optType = Maximize, + objective = var "x1" ~+~ var "x2", + constraints = + [ var "x1" ~<=~ constant 10, + var "x2" ~<=~ constant 5 + ] + <> oneIsZero ("b1", "x1") ("b2", "x2") + } + (lp, idxmap) = linearProgToLP prog + lpe = convert lp + in trace + (unlines [show prog, show lp, show idxmap, show lpe]) + ( assertBool + (unlines [show $ simplexLP lp]) + $ case simplexLP lp of + Nothing -> False + Just (z, sol) -> + and + [ z `approxEq` (15 :: Double) + ] + ), + testCase "9" $ + let prog = + LinearProg + { optType = Maximize, + objective = var "x1" ~+~ var "x2", + constraints = + [ var "x1" ~<=~ constant 10, + var "x2" ~<=~ constant 5 + ] + <> or "b1" "b2" (var "x1" ~==~ constant 0) (var "x2" ~==~ constant 0) + } + (lp, idxmap) = linearProgToLP prog + lpe = convert lp + in trace + (unlines [show prog, show lp, show idxmap, show lpe]) + ( assertBool + (unlines [show $ simplexLP lp]) + $ case simplexLP lp of + Nothing -> False + Just (z, sol) -> + and + [ z `approxEq` (15 :: Double) + ] + ) + ] + +approxEq :: (Fractional a, Ord a) => a -> a -> Bool +approxEq x1 x2 = (abs $ x1 - x2) < 10 ^^ (-10 :: Int) diff --git a/unittests/futhark_tests.hs b/unittests/futhark_tests.hs index 32e22272cf..79986794b7 100644 --- a/unittests/futhark_tests.hs +++ b/unittests/futhark_tests.hs @@ -10,6 +10,8 @@ import Futhark.IR.Syntax.CoreTests qualified import Futhark.Internalise.TypesValuesTests qualified import Futhark.Optimise.MemoryBlockMerging.GreedyColoringTests qualified import Futhark.Pkg.SolveTests qualified +import Futhark.Solve.BranchAndBoundTests qualified +import Futhark.Solve.SimplexTests qualified import Language.Futhark.PrimitiveTests qualified import Language.Futhark.SyntaxTests qualified import Language.Futhark.TypeCheckerTests qualified @@ -31,7 +33,9 @@ allTests = Language.Futhark.PrimitiveTests.tests, Futhark.Optimise.MemoryBlockMerging.GreedyColoringTests.tests, Futhark.Analysis.AlgSimplifyTests.tests, - Language.Futhark.TypeCheckerTests.tests + Language.Futhark.TypeCheckerTests.tests, + Futhark.Solve.SimplexTests.tests, + Futhark.Solve.BranchAndBoundTests.tests ] main :: IO () From fd2cc94f34f4998263bd72dd93e95da5e39c9a9d Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Sun, 11 Feb 2024 11:28:02 -0800 Subject: [PATCH 021/258] Add conversion to PuLP for easier debugging. --- src/Futhark/Solve/LP.hs | 38 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/src/Futhark/Solve/LP.hs b/src/Futhark/Solve/LP.hs index d5c0ee6c5e..af9265f458 100644 --- a/src/Futhark/Solve/LP.hs +++ b/src/Futhark/Solve/LP.hs @@ -24,9 +24,11 @@ module Futhark.Solve.LP (~<=~), (~>=~), rowEchelonLPE, + linearProgToPulp, ) where +import Data.Char (isAscii) import Data.List qualified as L import Data.Map (Map) import Data.Map qualified as Map @@ -97,7 +99,7 @@ instance (IsName v, Pretty a, Eq a, Num a) => Pretty (LSum v a) where ( \(k, a) -> case k of Nothing -> pretty a - Just k' -> (if a == 1 then mempty else pretty a <> "·") <> prettyName k' + Just k' -> (if a == 1 then mempty else pretty a <> "*") <> prettyName k' ) $ Map.toList m @@ -109,7 +111,7 @@ data CType = Equal | LessEq deriving (Show, Eq) instance Pretty CType where - pretty Equal = "=" + pretty Equal = "==" pretty LessEq = "<=" -- | A constraint for a linear program. @@ -145,6 +147,38 @@ instance (IsName v, Pretty a, Eq a, Num a) => Pretty (LinearProg v a) where indent 2 $ vcat $ map pretty cs ] +-- For debugging +linearProgToPulp :: (Unbox a, IsName v, Ord v, Pretty a, Eq a, Num a) => LinearProg v a -> String +linearProgToPulp prog = + map rm_subscript $ + unlines + [ "from pulp import *", + "prob = LpProblem('', " <> lptype <> ")", + unlines vars, + unlines $ map (("prob += " <>) . prettyString) $ constraints prog, + "status = prob.solve()", + "print(f'status: {status}')", + unlines res + ] + where + lptype = + case optType prog of + Maximize -> "LpMaximize" + Minimize -> "LpMinimize" + prog_vars = Map.elems $ snd $ linearProgToLP prog + vars = + map + ( \v -> + show (prettyName v) + <> " = " + <> "LpVariable(" + <> show (show (prettyName v)) + <> ", lowBound = 0, cat = 'Integer')" + ) + prog_vars + res = map (\v -> "print(f'" <> show (prettyName v) <> ": {value(" <> show (prettyName v) <> ")}')") prog_vars + rm_subscript x = fromMaybe x $ lookup x $ zip "₀₁₂₃₄₅₆₇₈₉" "0123456789" + bigM :: (Num a) => a bigM = 10 ^ 3 From ffcf337129cf71292d0a0889bcb8ad434c93269f Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 11 Feb 2024 21:30:48 +0100 Subject: [PATCH 022/258] Dummy handler for CtAM. --- src/Language/Futhark/TypeChecker/Constraints.hs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index b6b5507283..28567a9a5b 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -171,9 +171,10 @@ unify t1 t2 unify _ _ = Nothing solveCt :: Ct -> SolveM () -solveCt ct = do - let CtEq t1 t2 = ct - solveCt' (t1, t2) +solveCt ct = + case ct of + CtEq t1 t2 -> solveCt' (t1, t2) + CtAM _ _ -> pure () -- Good vibes only. where bad = throwError $ "Unsolvable: " <> prettyText ct solveCt' (t1, t2) = do From 93b11e8acb3fc7fe8b14029f8f1361025b609ecf Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 11 Feb 2024 21:30:55 +0100 Subject: [PATCH 023/258] Style fixes. --- src/Language/Futhark/TypeChecker/Constraints.hs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 28567a9a5b..7f43aaf323 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -34,7 +34,7 @@ data SComp deriving (Eq, Ord, Show) instance Pretty SComp where - pretty (SDim) = "[]" + pretty SDim = "[]" pretty (SVar x) = pretty x instance Pretty (Shape SComp) where @@ -158,7 +158,7 @@ unify (Scalar (Record fs1)) (Scalar (Record fs2)) | M.keys fs1 == M.keys fs2 = Just $ M.elems $ M.intersectionWith (,) fs1 fs2 unify (Scalar (Sum cs1)) (Scalar (Sum cs2)) - | M.keys cs1 == M.keys cs2 = do + | M.keys cs1 == M.keys cs2 = fmap concat . forM (M.elems $ M.intersectionWith (,) cs1 cs2) $ \(ts1, ts2) -> do From de224f7b9e9155fee2b93b34dbec0f15ac04c222 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 11 Feb 2024 21:33:22 +0100 Subject: [PATCH 024/258] Better prettyprinting. --- src/Language/Futhark/TypeChecker/Constraints.hs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 7f43aaf323..8ef97e2287 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -35,10 +35,10 @@ data SComp instance Pretty SComp where pretty SDim = "[]" - pretty (SVar x) = pretty x + pretty (SVar x) = brackets $ pretty x instance Pretty (Shape SComp) where - pretty = mconcat . map (brackets . pretty) . shapeDims + pretty = mconcat . map pretty . shapeDims -- | The type representation used by the constraint solver. Agnostic -- to sizes. From cbc4356ce144fd6e4cabc5ffc3c3698b6d41c120 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 11 Feb 2024 22:21:58 +0100 Subject: [PATCH 025/258] Better prettyprinting of arrays of functions. --- src/Language/Futhark/Pretty.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Language/Futhark/Pretty.hs b/src/Language/Futhark/Pretty.hs index d81553f5b8..81ca4a152f 100644 --- a/src/Language/Futhark/Pretty.hs +++ b/src/Language/Futhark/Pretty.hs @@ -158,7 +158,7 @@ instance (Pretty (Shape dim), Pretty u) => Pretty (ScalarTypeBase dim u) where prettyType :: (Pretty (Shape dim), Pretty u) => Int -> TypeBase dim u -> Doc a prettyType _ (Array u shape at) = - pretty u <> pretty shape <> align (prettyScalarType 1 at) + pretty u <> pretty shape <> align (prettyScalarType 2 at) prettyType p (Scalar t) = prettyScalarType p t From 2a65a7660e4baa500ee04489cc71d66e78894ba3 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 11 Feb 2024 22:59:36 +0100 Subject: [PATCH 026/258] Make toType require Size sizes. This makes it harder to accidentally throw away shape variables. --- .../Futhark/TypeChecker/Constraints.hs | 11 +- src/Language/Futhark/TypeChecker/Rank.hs | 5 +- src/Language/Futhark/TypeChecker/Terms2.hs | 148 +++++++++--------- 3 files changed, 89 insertions(+), 75 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 8ef97e2287..a81812def5 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -44,8 +44,10 @@ instance Pretty (Shape SComp) where -- to sizes. type Type = TypeBase SComp NoUniqueness -toType :: TypeBase d u -> Type -toType = bimap (const SDim) (const NoUniqueness) +-- | Careful when using this on something that already has an SComp +-- size: it will throw away information by converting them to SDim. +toType :: TypeBase Size u -> TypeBase SComp u +toType = first (const SDim) data Ct = CtEq Type Type @@ -153,7 +155,10 @@ unify :: Type -> Type -> Maybe [(Type, Type)] unify (Scalar (Prim pt1)) (Scalar (Prim pt2)) | pt1 == pt2 = Just [] unify (Scalar (Arrow _ _ _ t1a (RetType _ t1r))) (Scalar (Arrow _ _ _ t2a (RetType _ t2r))) = - Just [(t1a, t2a), (toType t1r, toType t2r)] + Just [(t1a, t2a), (t1r', t2r')] + where + t1r' = t1r `setUniqueness` NoUniqueness + t2r' = t2r `setUniqueness` NoUniqueness unify (Scalar (Record fs1)) (Scalar (Record fs2)) | M.keys fs1 == M.keys fs2 = Just $ M.elems $ M.intersectionWith (,) fs1 fs2 diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 5ef9f72594..ee9eaafdcc 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -128,7 +128,10 @@ rankAnalysis counter cs = do (Scalar (Arrow _ _ _ t1a (RetType _ t1r))) (Scalar (Arrow _ _ _ t2a (RetType _ t2r))) ) = - splitFuncs (CtEq t1a t2a) ++ splitFuncs (CtEq (toType t1r) (toType t2r)) + splitFuncs (CtEq t1a t2a) ++ splitFuncs (CtEq t1r' t2r') + where + t1r' = t1r `setUniqueness` NoUniqueness + t2r' = t2r `setUniqueness` NoUniqueness splitFuncs c = [c] cs' = foldMap (splitFuncs . distribute) cs prog = mkLinearProg counter cs' diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 3f233a05c6..c79c67f5e2 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -224,25 +224,26 @@ newTyVar loc desc = newTyVarWith loc desc TyVarFree newType :: (Located loc, Monoid u) => loc -> Name -> TermM (TypeBase dim u) newType loc desc = tyVarType <$> newTyVar loc desc -newTypeWithField :: (Monoid u) => SrcLoc -> Name -> Name -> TypeBase dim u -> TermM (TypeBase dim u) +newTypeWithField :: (Monoid u) => SrcLoc -> Name -> Name -> Type -> TermM (TypeBase d u) newTypeWithField loc desc k t = - tyVarType <$> newTyVarWith loc desc (TyVarRecord $ M.singleton k $ toType t) + tyVarType <$> newTyVarWith loc desc (TyVarRecord $ M.singleton k t) -newTypeWithConstr :: (Monoid u) => SrcLoc -> Name -> Name -> [Type] -> TermM (TypeBase dim u) +newTypeWithConstr :: (Monoid u) => SrcLoc -> Name -> Name -> [TypeBase SComp u] -> TermM (TypeBase d u) newTypeWithConstr loc desc k ts = - tyVarType <$> newTyVarWith loc desc (TyVarSum $ M.singleton k ts) + tyVarType <$> newTyVarWith loc desc (TyVarSum $ M.singleton k ts') + where + ts' = map (`setUniqueness` NoUniqueness) ts -newTypeOverloaded :: (Monoid u) => SrcLoc -> Name -> [PrimType] -> TermM (TypeBase dim u) +newTypeOverloaded :: (Monoid u) => SrcLoc -> Name -> [PrimType] -> TermM (TypeBase d u) newTypeOverloaded loc name pts = tyVarType <$> newTyVarWith loc name (TyVarPrim pts) newSVar :: (Located loc) => loc -> Name -> TermM SVar newSVar _loc desc = do i <- incCounter - v <- newID $ mkTypeVarName desc i - pure v + newID $ mkTypeVarName desc i -asStructType :: (Monoid u) => SrcLoc -> TypeBase d u -> TermM (TypeBase Size u) +asStructType :: (Monoid u) => SrcLoc -> TypeBase SComp u -> TermM (TypeBase Size u) asStructType _ (Scalar (Prim pt)) = pure $ Scalar $ Prim pt asStructType _ (Scalar (TypeVar u v [])) = pure $ Scalar $ TypeVar u v [] asStructType loc (Scalar (Arrow u pname d t1 (RetType ext t2))) = do @@ -251,20 +252,20 @@ asStructType loc (Scalar (Arrow u pname d t1 (RetType ext t2))) = do pure $ Scalar $ Arrow u pname d t1' $ RetType ext t2' asStructType loc t = do t' <- newType loc "artificial" - ctEq t' t + ctEq (toType t' `setUniqueness` NoUniqueness) (t `setUniqueness` NoUniqueness) pure t' addCt :: Ct -> TermM () addCt ct = modify $ \s -> s {termConstraints = ct : termConstraints s} -ctEq :: TypeBase d1 u1 -> TypeBase d2 u2 -> TermM () +ctEq :: TypeBase SComp u1 -> TypeBase SComp u2 -> TermM () ctEq t1 t2 = -- As a minor optimisation, do not add constraint if the types are -- equal. unless (t1' == t2') $ addCt $ CtEq t1' t2' where - t1' = toType t1 - t2' = toType t2 + t1' = t1 `setUniqueness` NoUniqueness + t2' = t2 `setUniqueness` NoUniqueness ctAM :: SVar -> SVar -> TermM () ctAM r m = addCt $ CtAM r m @@ -377,24 +378,23 @@ lookupMod qn@(QualName _ name) = do lookupVar :: SrcLoc -> QualName VName -> TermM StructType lookupVar loc qn@(QualName qs name) = do scope <- lookupQualNameEnv qn - asStructType loc - =<< case M.lookup name $ scopeVtable scope of - Nothing -> - error $ "lookupVar: " <> show qn - Just (BoundV tparams t) -> do - if null tparams && null qs - then pure t - else do - (tnames, t') <- instTypeScheme qn loc tparams t - outer_env <- asks termOuterEnv - pure $ qualifyTypeVars outer_env tnames qs t' - Just EqualityF -> do - argtype <- newType loc "t" - pure $ foldFunType [argtype, argtype] $ RetType [] $ Scalar $ Prim Bool - Just (OverloadedF ts pts rt) -> do - argtype <- newTypeOverloaded loc "t" ts - let (pts', rt') = instOverloaded (argtype :: StructType) pts rt - pure $ foldFunType (map (toParam Observe) pts') $ RetType [] $ toRes Nonunique rt' + case M.lookup name $ scopeVtable scope of + Nothing -> + error $ "lookupVar: " <> show qn + Just (BoundV tparams t) -> do + if null tparams && null qs + then pure t + else do + (tnames, t') <- instTypeScheme qn loc tparams t + outer_env <- asks termOuterEnv + pure $ qualifyTypeVars outer_env tnames qs t' + Just EqualityF -> do + argtype <- newType loc "t" + pure $ foldFunType [argtype, argtype] $ RetType [] $ Scalar $ Prim Bool + Just (OverloadedF ts pts rt) -> do + argtype <- newTypeOverloaded loc "t" ts + let (pts', rt') = instOverloaded argtype pts rt + pure $ foldFunType pts' $ RetType [] $ toRes Nonunique rt' where instOverloaded argtype pts rt = ( map (maybe argtype (Scalar . Prim)) pts, @@ -452,9 +452,10 @@ checkPat' p@(TuplePat ps loc) (Ascribed t) <$> zipWithM checkPat' ps (map Ascribed ts) <*> pure loc | otherwise = do - ps_t :: [ParamType] <- replicateM (length ps) (newType loc "t") - ctEq (Scalar (tupleRecord ps_t)) t - checkPat' p $ Ascribed $ toParam Observe $ Scalar $ tupleRecord ps_t + ps_t :: [Type] <- replicateM (length ps) (newType loc "t") + ctEq (Scalar (tupleRecord ps_t)) (toType t) + st <- asStructType loc $ Scalar $ tupleRecord ps_t + checkPat' p $ Ascribed $ toParam Observe st checkPat' (TuplePat ps loc) NoneInferred = TuplePat <$> mapM (`checkPat'` NoneInferred) ps <*> pure loc checkPat' p@(RecordPat p_fs loc) (Ascribed t) @@ -462,9 +463,10 @@ checkPat' p@(RecordPat p_fs loc) (Ascribed t) L.sort (map fst p_fs) == L.sort (M.keys t_fs) = RecordPat . M.toList <$> check t_fs <*> pure loc | otherwise = do - p_fs' <- traverse (const $ newType loc "t") $ M.fromList p_fs - ctEq (Scalar (Record p_fs') :: ParamType) t - checkPat' p $ Ascribed $ toParam Observe $ Scalar (Record p_fs') + p_fs' :: M.Map Name Type <- traverse (const $ newType loc "t") $ M.fromList p_fs + ctEq (Scalar (Record p_fs')) $ toType t + st <- asStructType loc $ Scalar (Record p_fs') + checkPat' p $ Ascribed $ toParam Observe st where check t_fs = traverse (uncurry checkPat') $ @@ -478,7 +480,7 @@ checkPat' (PatAscription p t loc) maybe_outer_t = do case maybe_outer_t of Ascribed outer_t -> do - ctEq st outer_t + ctEq (toType st) (toType outer_t) PatAscription <$> checkPat' p (Ascribed (resToParam st)) <*> pure t' @@ -490,7 +492,7 @@ checkPat' (PatAscription p t loc) maybe_outer_t = do <*> pure loc checkPat' (PatLit l NoInfo loc) (Ascribed t) = do t' <- patLitMkType l loc - addCt $ CtEq (toType t') (toType t) + ctEq (toType t') (toType t) pure $ PatLit l (Info t') loc checkPat' (PatLit l NoInfo loc) NoneInferred = do t' <- patLitMkType l loc @@ -513,12 +515,14 @@ checkPat' (PatConstr n NoInfo ps loc) (Ascribed t) = do p_t <- newType (srclocOf p) "t" checkPat' p $ Ascribed p_t t' <- newTypeWithConstr loc "t" n $ map (toType . patternType) ps' - ctEq t' t - pure $ PatConstr n (Info t') ps' loc + ctEq t' (toType t) + t'' <- asStructType loc t' + pure $ PatConstr n (Info $ toParam Observe t'') ps' loc checkPat' (PatConstr n NoInfo ps loc) NoneInferred = do ps' <- mapM (`checkPat'` NoneInferred) ps t <- newTypeWithConstr loc "t" n $ map (toType . patternType) ps' - pure $ PatConstr n (Info t) ps' loc + t' <- asStructType loc t + pure $ PatConstr n (Info $ toParam Observe t') ps' loc checkPat :: PatBase NoInfo VName (TypeBase Size u) -> @@ -603,11 +607,11 @@ checkApply loc _ ftype arg = do where toShape = Shape . pure split (Scalar (Arrow _ _ _ a (RetType _ b))) = - pure (a, toType b) + pure (a, b `setUniqueness` NoUniqueness) split ftype' = do a <- newType loc "arg" b <- newTyVar loc "res" - ctEq ftype' $ Scalar $ Arrow NoUniqueness Unnamed Observe a $ RetType [] (tyVarType b) + ctEq ftype' $ Scalar $ Arrow NoUniqueness Unnamed Observe a $ RetType [] $ tyVarType b pure (a, tyVarType b) checkSlice :: SliceBase NoInfo VName -> TermM [DimIndex] @@ -635,7 +639,7 @@ mustHaveFields loc t [f] ve_t = do ctEq t rt mustHaveFields loc t (f : fs) ve_t = do ft :: Type <- newType loc "ft" - rt <- newTypeWithField loc "rt" f ft + rt :: Type <- newTypeWithField loc "rt" f ft mustHaveFields loc ft fs ve_t ctEq t rt @@ -660,7 +664,7 @@ checkCases mt rest_cs = (c, Just cs) -> do (c', c_t) <- checkCase mt c (cs', cs_t) <- checkCases mt cs - ctEq c_t cs_t + ctEq (toType c_t) (toType cs_t) pure (NE.cons c' cs', c_t) -- | An unmatched pattern. Used in in the generation of @@ -698,7 +702,7 @@ checkRetDecl :: checkRetDecl _ Nothing = pure Nothing checkRetDecl body (Just te) = do (te', _, RetType _ st, _) <- checkTypeExp checkSizeExp te - ctEq (typeOf body) st + ctEq (expType body) (toType st) pure $ Just te' checkExp :: ExpBase NoInfo VName -> TermM (ExpBase Info VName) @@ -749,7 +753,7 @@ checkExp (ArrayLit es _ loc) = do et <- newType loc "et" es' <- forM es $ \e -> do e' <- checkExp e - ctEq (typeOf e') et + ctEq (expType e') (toType et) pure e' let arr_t = arrayOf (Shape [sizeFromInteger (L.genericLength es) loc]) et pure $ ArrayLit es' (Info arr_t) loc @@ -804,7 +808,7 @@ checkExp (AppExp (Apply fe args loc) NoInfo) = do onArg (i, f_t) (_, arg) = do arg' <- checkExp arg - (rt, am) <- checkApply loc (fname, i) (toType f_t) arg' + (rt, am) <- checkApply loc (fname, i) f_t arg' pure ( (i + 1, rt), (Info (Nothing, am), arg') @@ -831,7 +835,7 @@ checkExp (OpSectionLeft op _ e _ _ loc) = do let t1 = typeOf e' t2 <- newType loc "t2" rt <- newType loc "rt" - ctEq optype $ foldFunType [toParam Observe t1, t2] $ RetType [] rt + ctEq (toType optype) $ toType $ foldFunType [toParam Observe t1, t2] $ RetType [] $ rt `setUniqueness` Nonunique pure $ OpSectionLeft op @@ -849,7 +853,7 @@ checkExp (OpSectionRight op _ e _ NoInfo loc) = do t1 <- newType loc "t" let t2 = typeOf e' rt <- newType loc "rt" - ctEq optype $ foldFunType [t1, toParam Observe t2] $ RetType [] rt + ctEq (toType optype) $ toType $ foldFunType [t1, toParam Observe t2] $ RetType [] $ rt `setUniqueness` Nonunique pure $ OpSectionRight op @@ -866,7 +870,7 @@ checkExp (ProjectSection fields NoInfo loc) = do a <- newType loc "a" b <- newType loc "b" mustHaveFields loc (toType a) fields (toType b) - let ft = Scalar $ Arrow mempty Unnamed Observe a $ RetType [] b + let ft = Scalar $ Arrow mempty Unnamed Observe a $ RetType [] $ b `setUniqueness` Nonunique pure $ ProjectSection fields (Info ft) loc -- checkExp (Lambda params body retdecl NoInfo loc) = do @@ -916,7 +920,7 @@ checkExp (AppExp (Range start maybe_step end loc) _) = do maybe_step' <- traverse checkExp' maybe_step end' <- traverse checkExp' end range_t <- newType loc "range" - ctEq range_t $ arrayOfRank 1 (toType (typeOf start')) + ctEq (toType range_t) (arrayOfRank 1 (expType start')) pure $ AppExp (Range start' maybe_step' end' loc) $ Info $ AppRes range_t [] where checkExp' = require "use in range expression" anyIntType <=< checkExp @@ -924,9 +928,10 @@ checkExp (AppExp (Range start maybe_step end loc) _) = do checkExp (Project k e NoInfo loc) = do e' <- checkExp e kt <- newType loc "kt" - t <- newTypeWithField loc "t" k kt - ctEq (typeOf e') t - pure $ Project k e' (Info kt) loc + t :: Type <- newTypeWithField loc "t" k kt + ctEq (expType e') t + kt' <- asStructType loc kt + pure $ Project k e' (Info kt') loc -- checkExp (RecordUpdate src fields ve NoInfo loc) = do src' <- checkExp src @@ -938,11 +943,12 @@ checkExp (IndexSection slice NoInfo loc) = do slice' <- checkSlice slice index_arg_t <- newType loc "index" index_elem_t <- newType loc "index_elem" - index_res_t <- newType loc "index_res" + index_res_t :: Type <- newType loc "index_res" let num_slices = length $ filter isSlice slice - ctEq index_arg_t $ arrayOfRank num_slices index_elem_t + ctEq (toType index_arg_t) $ arrayOfRank num_slices index_elem_t ctEq index_res_t $ arrayOfRank (length slice) index_elem_t - let ft = Scalar $ Arrow mempty Unnamed Observe index_arg_t $ RetType [] index_res_t + index_res_t' <- asStructType loc index_res_t + let ft = Scalar $ Arrow mempty Unnamed Observe index_arg_t $ toResRet Nonunique $ RetType [] index_res_t' pure $ IndexSection slice' (Info ft) loc -- checkExp (AppExp (Index e slice loc) _) = do @@ -951,8 +957,8 @@ checkExp (AppExp (Index e slice loc) _) = do index_t <- newType loc "index" index_elem_t <- newType loc "index_elem" let num_slices = length $ filter isSlice slice - ctEq index_t $ arrayOfRank num_slices index_elem_t - ctEq (typeOf e') $ arrayOfRank (length slice) index_elem_t + ctEq (toType index_t) $ arrayOfRank num_slices index_elem_t + ctEq (expType e') $ arrayOfRank (length slice) index_elem_t pure $ AppExp (Index e' slice' loc) (Info $ AppRes index_t []) -- checkExp (Update src slice ve loc) = do @@ -961,8 +967,8 @@ checkExp (Update src slice ve loc) = do ve' <- checkExp ve let num_slices = length $ filter isSlice slice update_elem_t <- newType loc "update_elem" - ctEq (typeOf src') $ arrayOfRank (length slice) update_elem_t - ctEq (typeOf ve') $ arrayOfRank num_slices update_elem_t + ctEq (expType src') $ arrayOfRank (length slice) update_elem_t + ctEq (expType ve') $ arrayOfRank num_slices update_elem_t pure $ Update src' slice' ve' loc -- checkExp (AppExp (LetWith dest src slice ve body loc) _) = do @@ -973,8 +979,8 @@ checkExp (AppExp (LetWith dest src slice ve body loc) _) = do ve' <- checkExp ve let num_slices = length $ filter isSlice slice update_elem_t <- newType loc "update_elem" - ctEq src_t $ arrayOfRank (length slice) update_elem_t - ctEq (typeOf ve') $ arrayOfRank num_slices update_elem_t + ctEq (toType src_t) $ arrayOfRank (length slice) update_elem_t + ctEq (expType ve') $ arrayOfRank num_slices update_elem_t bind [dest'] $ do body' <- checkExp body pure $ AppExp (LetWith dest' src' slice' ve' body' loc) (Info $ AppRes (typeOf body') []) @@ -984,8 +990,8 @@ checkExp (AppExp (If e1 e2 e3 loc) _) = do e2' <- checkExp e2 e3' <- checkExp e3 - ctEq (typeOf e1') (Scalar (Prim Bool) :: Type) - ctEq (typeOf e2') (typeOf e3') + ctEq (expType e1') (Scalar (Prim Bool)) + ctEq (expType e2') (expType e3') pure $ AppExp (If e1' e2' e3' loc) (Info $ AppRes (typeOf e2') []) -- @@ -1004,17 +1010,17 @@ checkExp (AppExp (Loop _ pat arg form body loc) _) = do let i' = Ident i (Info (typeOf bound')) iloc bind [i'] $ do body' <- checkExp body - ctEq (typeOf arg') (typeOf body') + ctEq (expType arg') (expType body') pure (For i' bound', body') While cond -> do cond' <- checkExp cond body' <- checkExp body - ctEq (typeOf arg') (typeOf body') + ctEq (expType arg') (expType body') pure (While cond', body') ForIn elemp arr -> do arr' <- checkExp arr elem_t <- newType elemp "elem" - ctEq (typeOf arr') $ arrayOfRank 1 (toType elem_t) + ctEq (expType arr') $ arrayOfRank 1 (toType elem_t) bindLetPat elemp elem_t $ \elemp' -> do body' <- checkExp body pure (ForIn (toStruct <$> elemp') arr', body') @@ -1026,12 +1032,12 @@ checkExp (AppExp (Loop _ pat arg form body loc) _) = do checkExp (Ascript e te loc) = do e' <- checkExp e (te', _, RetType _ st, _) <- checkTypeExp checkSizeExp te - ctEq (typeOf e') st + ctEq (expType e') (toType st) pure $ Ascript e' te' loc checkExp (Coerce e te NoInfo loc) = do e' <- checkExp e (te', _, RetType _ st, _) <- checkTypeExp checkSizeExp te - ctEq (typeOf e') st + ctEq (expType e') (toType st) pure $ Coerce e' te' (Info (toStruct st)) loc checkValDef :: From d2e28288c2f23df68cf515125c3983b68cbe3794 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Sun, 11 Feb 2024 19:35:25 -0800 Subject: [PATCH 027/258] Fall back to non-integral row echelon transformation for now. --- src/Futhark/Solve/Matrix.hs | 57 +++++++++++++++++++------------------ 1 file changed, 29 insertions(+), 28 deletions(-) diff --git a/src/Futhark/Solve/Matrix.hs b/src/Futhark/Solve/Matrix.hs index 90e1a3e126..ae3bdf6b7c 100644 --- a/src/Futhark/Solve/Matrix.hs +++ b/src/Futhark/Solve/Matrix.hs @@ -281,33 +281,8 @@ update_ m upds = (nrows m) (ncols m) --- TODO: maintain integrality of entries in the matrix --- rowEchelon :: (Num a, Fractional a, Unbox a, Ord a) => Matrix a -> Matrix a --- rowEchelon = rowEchelon' 0 0 --- where --- rowEchelon' h k m@(Matrix _ nr nc) --- | h < nr && k < nc = --- if m ! (pivot_row, k) == 0 --- then rowEchelon' h (k + 1) m --- else rowEchelon' (h + 1) (k + 1) clear_rows_below --- | otherwise = m --- where --- pivot_row = --- fst $ --- L.maximumBy (\(_, x) (_, y) -> x `compare` y) $ --- [(r, abs (m ! (r, k))) | r <- [h .. nr - 1]] --- m' = swapRows h pivot_row m --- clear_rows_below = --- update m' $ --- V.fromList $ --- [((i, k), 0) | i <- [h + 1 .. nr - 1]] --- ++ [ ((i, j), m' ! (i, j) - (m' ! (h, j)) * f) --- | i <- [h + 1 .. nr - 1], --- let f = m' ! (i, k) / m' ! (h, k), --- j <- [k + 1 .. nc - 1] --- ] - -rowEchelon :: (Num a, Unbox a, Ord a) => Matrix a -> Matrix a +-- This version doesn't maintain integrality of the entries. +rowEchelon :: (Num a, Fractional a, Unbox a, Ord a) => Matrix a -> Matrix a rowEchelon = rowEchelon' 0 0 where rowEchelon' h k m@(Matrix _ nr nc) @@ -326,11 +301,37 @@ rowEchelon = rowEchelon' 0 0 update m' $ V.fromList $ [((i, k), 0) | i <- [h + 1 .. nr - 1]] - ++ [ ((i, j), (m' ! (h, k)) * (m' ! (i, j)) - (m' ! (h, j)) * (m' ! (i, k))) + ++ [ ((i, j), m' ! (i, j) - (m' ! (h, j)) * f) | i <- [h + 1 .. nr - 1], + let f = m' ! (i, k) / m' ! (h, k), j <- [k + 1 .. nc - 1] ] +-- TODO: fix. Something's wrong here, causes huge blow-up. +-- rowEchelon :: (Num a, Unbox a, Ord a) => Matrix a -> Matrix a +-- rowEchelon = rowEchelon' 0 0 +-- where +-- rowEchelon' h k m@(Matrix _ nr nc) +-- | h < nr && k < nc = +-- if m ! (pivot_row, k) == 0 +-- then rowEchelon' h (k + 1) m +-- else rowEchelon' (h + 1) (k + 1) clear_rows_below +-- | otherwise = m +-- where +-- pivot_row = +-- fst $ +-- L.maximumBy (\(_, x) (_, y) -> x `compare` y) $ +-- [(r, abs (m ! (r, k))) | r <- [h .. nr - 1]] +-- m' = swapRows h pivot_row m +-- clear_rows_below = +-- update m' $ +-- V.fromList $ +-- [((i, k), 0) | i <- [h + 1 .. nr - 1]] +-- ++ [ ((i, j), (m' ! (h, k)) * (m' ! (i, j)) - (m' ! (h, j)) * (m' ! (i, k))) +-- | i <- [h + 1 .. nr - 1], +-- j <- [k + 1 .. nc - 1] +-- ] + filterRows :: (Unbox a) => (Vector a -> Bool) -> Matrix a -> Matrix a filterRows p = fromVectors . filter p . toList From 18665d86a68e28e8b8d24cdd1bf3286dbefcd1af Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Sun, 11 Feb 2024 19:42:50 -0800 Subject: [PATCH 028/258] Some new tests. --- .../Futhark/Solve/BranchAndBoundTests.hs | 68 +++++++--- unittests/Futhark/Solve/SimplexTests.hs | 124 ++++++++++++------ 2 files changed, 131 insertions(+), 61 deletions(-) diff --git a/unittests/Futhark/Solve/BranchAndBoundTests.hs b/unittests/Futhark/Solve/BranchAndBoundTests.hs index 10867a1bee..ed7e04c715 100644 --- a/unittests/Futhark/Solve/BranchAndBoundTests.hs +++ b/unittests/Futhark/Solve/BranchAndBoundTests.hs @@ -92,28 +92,54 @@ tests = Just (z, sol) -> and [ z `approxEq` (10 :: Double) + ], + -- testCase "6" $ + -- let prog = + -- LinearProg + -- { optType = Maximize, + -- objective = var "x1" ~+~ var "x2", + -- constraints = + -- [ var "x1" ~<=~ constant 10, + -- var "x2" ~<=~ constant 5 + -- ] + -- <> or "b1" "b2" (var "x1" ~==~ constant 0) (var "x2" ~==~ constant 0) + -- } + -- (lp, idxmap) = linearProgToLP prog + -- lpe = convert lp + -- in assertBool + -- (unlines [show $ branchAndBound lp]) + -- $ case branchAndBound lp of + -- Nothing -> False + -- Just (z, sol) -> + -- and + -- [ z `approxEq` (10 :: Double) + -- ] + + testCase "10" $ + let prog = + LinearProg + { optType = Minimize, + objective = var "R2" ~+~ var "M3", + constraints = + [ var "artifical4" ~==~ constant 1 ~+~ var "t0", + constant 1 ~+~ var "num1" ~==~ constant 1 ~+~ var "t0", + var "b_R2" ~<=~ constant 1, + var "b_M3" ~<=~ constant 1, + var "R2" ~<=~ 1000 ~*~ var "b_R2", + var "M3" ~<=~ 1000 ~*~ var "b_M3", + var "b_R2" ~+~ var "b_M3" ~<=~ constant 1 + ] + } + (lp, idxmap) = linearProgToLP prog + lpe = convert lp + in assertBool + (unlines [show $ branchAndBound lp]) + $ case branchAndBound lp of + Nothing -> False + Just (z, sol) -> + and + [ z `approxEq` (0 :: Double) ] - -- testCase "6" $ - -- let prog = - -- LinearProg - -- { optType = Maximize, - -- objective = var "x1" ~+~ var "x2", - -- constraints = - -- [ var "x1" ~<=~ constant 10, - -- var "x2" ~<=~ constant 5 - -- ] - -- <> or "b1" "b2" (var "x1" ~==~ constant 0) (var "x2" ~==~ constant 0) - -- } - -- (lp, idxmap) = linearProgToLP prog - -- lpe = convert lp - -- in assertBool - -- (unlines [show $ branchAndBound lp]) - -- $ case branchAndBound lp of - -- Nothing -> False - -- Just (z, sol) -> - -- and - -- [ z `approxEq` (10 :: Double) - -- ] ] approxEq :: (Fractional a, Ord a) => a -> a -> Bool diff --git a/unittests/Futhark/Solve/SimplexTests.hs b/unittests/Futhark/Solve/SimplexTests.hs index 80eee3237e..1a52203d12 100644 --- a/unittests/Futhark/Solve/SimplexTests.hs +++ b/unittests/Futhark/Solve/SimplexTests.hs @@ -123,18 +123,15 @@ tests = } (lp, idxmap) = linearProgToLP prog lpe = convert lp - in trace - (unlines [show prog, show lp, show idxmap, show lpe]) - ( assertBool - (unlines [show $ simplexLP lp]) - $ case simplexLP lp of - Nothing -> False - Just (z, sol) -> - and - [ z `approxEq` (10 :: Double), - and $ zipWith (==) (V.toList sol) [1, 0, 10] - ] - ), + in assertBool + (unlines [show $ simplexLP lp]) + $ case simplexLP lp of + Nothing -> False + Just (z, sol) -> + and + [ z `approxEq` (10 :: Double), + and $ zipWith (==) (V.toList sol) [1, 0, 10] + ], testCase "8" $ let prog = LinearProg @@ -148,41 +145,88 @@ tests = } (lp, idxmap) = linearProgToLP prog lpe = convert lp - in trace - (unlines [show prog, show lp, show idxmap, show lpe]) - ( assertBool - (unlines [show $ simplexLP lp]) - $ case simplexLP lp of - Nothing -> False - Just (z, sol) -> - and - [ z `approxEq` (15 :: Double) - ] - ), - testCase "9" $ + in assertBool + (unlines [show $ simplexLP lp]) + $ case simplexLP lp of + Nothing -> False + Just (z, sol) -> + and + [ z `approxEq` (15 :: Double) + ], + -- testCase "9" $ + -- let prog = + -- LinearProg + -- { optType = Maximize, + -- objective = var "x1" ~+~ var "x2", + -- constraints = + -- [ var "x1" ~<=~ constant 10, + -- var "x2" ~<=~ constant 5 + -- ] + -- <> or "b1" "b2" (var "x1" ~==~ constant 0) (var "x2" ~==~ constant 0) + -- } + -- (lp, idxmap) = linearProgToLP prog + -- lpe = convert lp + -- in trace + -- (unlines [show prog, show lp, show idxmap, show lpe]) + -- ( assertBool + -- (unlines [show $ simplexLP lp]) + -- $ case simplexLP lp of + -- Nothing -> False + -- Just (z, sol) -> + -- and + -- [ z `approxEq` (15 :: Double) + -- ] + -- ), + testCase "10" $ let prog = LinearProg - { optType = Maximize, - objective = var "x1" ~+~ var "x2", + { optType = Minimize, + objective = var "R2" ~+~ var "M3", constraints = - [ var "x1" ~<=~ constant 10, - var "x2" ~<=~ constant 5 + [ var "artifical4" ~==~ constant 1 ~+~ var "t0", + constant 1 ~+~ var "num1" ~==~ constant 1 ~+~ var "t0", + var "b_R2" ~<=~ constant 1, + var "b_M3" ~<=~ constant 1, + var "R2" ~<=~ 1000 ~*~ var "b_R2", + var "M3" ~<=~ 1000 ~*~ var "b_M3", + var "b_R2" ~+~ var "b_M3" ~<=~ constant 1 ] - <> or "b1" "b2" (var "x1" ~==~ constant 0) (var "x2" ~==~ constant 0) } (lp, idxmap) = linearProgToLP prog lpe = convert lp - in trace - (unlines [show prog, show lp, show idxmap, show lpe]) - ( assertBool - (unlines [show $ simplexLP lp]) - $ case simplexLP lp of - Nothing -> False - Just (z, sol) -> - and - [ z `approxEq` (15 :: Double) - ] - ) + in assertBool + (unlines [show $ simplexLP lp]) + $ case simplexLP lp of + Nothing -> False + Just (z, sol) -> + and + [ z `approxEq` (0 :: Double) + ], + testCase "11" $ + let prog = + LinearProg + { optType = Minimize, + objective = var "4R" ~+~ var "5M", + constraints = + [ var "6artifical" ~==~ constant 1 ~+~ var "2t", + constant 1 ~+~ var "3num" ~==~ constant 1 ~+~ var "2t", + var "0b_R" ~<=~ constant 1, + var "1b_M" ~<=~ constant 1, + var "4R" ~<=~ 1000 ~*~ var "0b_R", + var "5M" ~<=~ 1000 ~*~ var "1b_M", + var "0b_R" ~+~ var "1b_M" ~<=~ constant 1 + ] + } + (lp, idxmap) = linearProgToLP prog + lpe = convert lp + in assertBool + (unlines [show $ simplexLP lp]) + $ case simplexLP lp of + Nothing -> False + Just (z, sol) -> + and + [ z `approxEq` (0 :: Double) + ] ] approxEq :: (Fractional a, Ord a) => a -> a -> Bool From 6380dd1c24db872b2fb50c06274446d01cfb0c7d Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Sun, 11 Feb 2024 20:14:05 -0800 Subject: [PATCH 029/258] Use frame-based AUTOMAP; removes need to distribute over function tys. --- src/Language/Futhark/Prop.hs | 6 +++++ src/Language/Futhark/TypeChecker/Rank.hs | 5 ++-- src/Language/Futhark/TypeChecker/Terms2.hs | 31 ++++++++++++---------- 3 files changed, 25 insertions(+), 17 deletions(-) diff --git a/src/Language/Futhark/Prop.hs b/src/Language/Futhark/Prop.hs index 181f5d4135..65602c4689 100644 --- a/src/Language/Futhark/Prop.hs +++ b/src/Language/Futhark/Prop.hs @@ -32,6 +32,7 @@ module Language.Futhark.Prop funType, stripExp, similarExps, + frameOf, -- * Queries on patterns and params patIdents, @@ -1435,6 +1436,11 @@ similarExps (IndexSection slice1 _ _) (IndexSection slice2 _ _) = similarSlices slice1 slice2 similarExps _ _ = Nothing +frameOf :: Exp -> Shape Size +frameOf (AppExp (Apply _ args _) _) = + ((\(_, am) -> autoFrame am) . unInfo . fst) $ NE.last args +frameOf _ = mempty + -- | An identifier with type- and aliasing information. type Ident = IdentBase Info VName diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index ee9eaafdcc..8031982f21 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -118,8 +118,7 @@ mkLinearProg counter cs = rankAnalysis :: Int -> [Ct] -> Maybe (Map VName Int) rankAnalysis counter cs = do - traceM $ unlines $ concat $ map (\c -> [prettyString c, show c]) cs' - traceM $ prettyString prog + traceM $ unlines ["rankAnalysis prog:", prettyString prog] (_size, ranks) <- branchAndBound lp pure $ (fromJust . (ranks V.!?)) <$> inv_var_map where @@ -133,7 +132,7 @@ rankAnalysis counter cs = do t1r' = t1r `setUniqueness` NoUniqueness t2r' = t2r `setUniqueness` NoUniqueness splitFuncs c = [c] - cs' = foldMap (splitFuncs . distribute) cs + cs' = foldMap splitFuncs cs prog = mkLinearProg counter cs' (lp, var_map) = linearProgToLP prog inv_var_map = M.fromListWith (error "oh no!") [(v, k) | (k, v) <- M.toList var_map] diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index c79c67f5e2..fda11bd604 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -590,21 +590,25 @@ bindParams tps orig_ps m = bindTypeParams tps $ do incLevel $ descend [] orig_ps -checkApply :: SrcLoc -> (Maybe (QualName VName), Int) -> Type -> Exp -> TermM (Type, AutoMap) -checkApply loc _ ftype arg = do +checkApply :: SrcLoc -> (Maybe (QualName VName), Int) -> Type -> Shape Size -> Exp -> TermM (Type, AutoMap) +checkApply loc _ ftype fframe arg = do (a, b) <- split ftype r <- newSVar loc "R" m <- newSVar loc "M" let unit_info = Info $ Scalar $ Prim Bool r_var = Var (QualName [] r) unit_info mempty m_var = Var (QualName [] r) unit_info mempty + lhs = arrayOf (toShape (SVar r) <> (toSComp <$> frameOf arg)) $ toType $ typeOf arg + rhs = arrayOf (toShape (SVar m) <> (toSComp <$> fframe)) a ctAM r m - ctEq (arrayOf (toShape $ SVar r) $ toType $ typeOf arg) (arrayOf (toShape $ SVar m) a) + ctEq lhs rhs pure - ( arrayOf (toShape $ SVar m) b, - AutoMap {autoRep = toShape r_var, autoMap = toShape m_var, autoFrame = mempty} + ( b, + AutoMap {autoRep = toShape r_var, autoMap = toShape m_var, autoFrame = toShape m_var <> fframe} ) where + toSComp (Var (QualName [] x) _ _) = SVar x + toSComp _ = error "" toShape = Shape . pure split (Scalar (Arrow _ _ _ a (RetType _ b))) = pure (a, b `setUniqueness` NoUniqueness) @@ -797,7 +801,7 @@ checkExp (Constr name es NoInfo loc) = do -- checkExp (AppExp (Apply fe args loc) NoInfo) = do fe' <- checkExp fe - ((_, rt), args') <- mapAccumLM onArg (0, expType fe') args + ((_, rt, _), args') <- mapAccumLM onArg (0, expType fe', frameOf fe') args rt' <- asStructType loc rt pure $ AppExp (Apply fe' args' loc) $ Info $ AppRes rt' [] where @@ -806,11 +810,11 @@ checkExp (AppExp (Apply fe args loc) NoInfo) = do Var v _ _ -> Just v _ -> Nothing - onArg (i, f_t) (_, arg) = do + onArg (i, f_t, f_f) (_, arg) = do arg' <- checkExp arg - (rt, am) <- checkApply loc (fname, i) f_t arg' + (rt, am) <- checkApply loc (fname, i) f_t f_f arg' pure - ( (i + 1, rt), + ( (i + 1, rt, autoFrame am), (Info (Nothing, am), arg') ) -- @@ -819,8 +823,8 @@ checkExp (AppExp (BinOp (op, oploc) NoInfo (e1, _) (e2, _) loc) NoInfo) = do e1' <- checkExp e1 e2' <- checkExp e2 - (rt1, am1) <- checkApply loc (Just op, 0) (toType ftype) e1' - (rt2, am2) <- checkApply loc (Just op, 1) rt1 e2' + (rt1, am1) <- checkApply loc (Just op, 0) (toType ftype) mempty e1' + (rt2, am2) <- checkApply loc (Just op, 1) rt1 mempty e2' rt2' <- asStructType loc rt2 pure $ @@ -831,7 +835,7 @@ checkExp (AppExp (BinOp (op, oploc) NoInfo (e1, _) (e2, _) loc) NoInfo) = do checkExp (OpSectionLeft op _ e _ _ loc) = do optype <- lookupVar loc op e' <- checkExp e - void $ checkApply loc (Just op, 0) (toType optype) e' + void $ checkApply loc (Just op, 0) (toType optype) mempty e' let t1 = typeOf e' t2 <- newType loc "t2" rt <- newType loc "rt" @@ -1088,11 +1092,10 @@ checkValDef (fname, retdecl, tparams, params, body, _loc) = runTermM $ do -- -- counter <- gets termCounter -- --- traceM $ unlines $ map prettyString cts --- -- case rankAnalysis counter cts of -- Nothing -> error "" -- Just rank_map -> do +-- traceM $ prettyString $ M.toList rank_map -- tyvars <- gets termTyVars -- -- let solution = solve cts tyvars From 37236d11dfcf624f9af0adcb4865de36e80cf5b1 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Sun, 11 Feb 2024 20:27:26 -0800 Subject: [PATCH 030/258] Support `TyVarInfo` info in rank analysis && use rank analysis in the checker. --- src/Language/Futhark/TypeChecker/Rank.hs | 24 ++++++-- src/Language/Futhark/TypeChecker/Terms2.hs | 67 +++++++--------------- 2 files changed, 40 insertions(+), 51 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 8031982f21..adc0d32f72 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -24,6 +24,9 @@ type ScalarType = ScalarTypeBase SComp NoUniqueness class Rank a where rank :: a -> LSum +instance Rank VName where + rank = var + instance Rank SComp where rank SDim = constant 1 rank (SVar v) = var v @@ -98,8 +101,14 @@ addCt (CtAM r m) = do b_m <- binVar m addConstraints $ oneIsZero (b_r, r) (b_m, m) -mkLinearProg :: Int -> [Ct] -> LinearProg -mkLinearProg counter cs = +addTyVarInfo :: TyVar -> TyVarInfo -> RankM () +addTyVarInfo tv (TyVarFree) = pure () +addTyVarInfo tv (TyVarPrim _) = + addConstraint $ rank tv ~==~ constant 0 +addTyVarInfo _ _ = error "Unhandled" + +mkLinearProg :: Int -> [Ct] -> TyVars -> LinearProg +mkLinearProg counter cs tyVars = LP.LinearProg { optType = Minimize, objective = @@ -114,10 +123,13 @@ mkLinearProg counter cs = rankCounter = counter, rankConstraints = mempty } - finalState = flip execState initState $ runRankM $ mapM_ addCt cs + buildLP = do + mapM_ addCt cs + mapM_ (uncurry addTyVarInfo) $ M.toList tyVars + finalState = flip execState initState $ runRankM buildLP -rankAnalysis :: Int -> [Ct] -> Maybe (Map VName Int) -rankAnalysis counter cs = do +rankAnalysis :: Int -> [Ct] -> TyVars -> Maybe (Map VName Int) +rankAnalysis counter cs tyVars = do traceM $ unlines ["rankAnalysis prog:", prettyString prog] (_size, ranks) <- branchAndBound lp pure $ (fromJust . (ranks V.!?)) <$> inv_var_map @@ -133,6 +145,6 @@ rankAnalysis counter cs = do t2r' = t2r `setUniqueness` NoUniqueness splitFuncs c = [c] cs' = foldMap splitFuncs cs - prog = mkLinearProg counter cs' + prog = mkLinearProg counter cs' tyVars (lp, var_map) = linearProgToLP prog inv_var_map = M.fromListWith (error "oh no!") [(v, k) | (k, v) <- M.toList var_map] diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index fda11bd604..d7484d7407 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -1065,50 +1065,27 @@ checkValDef (fname, retdecl, tparams, params, body, _loc) = runTermM $ do retdecl' <- checkRetDecl body' retdecl cts <- gets termConstraints + + counter <- gets termCounter + tyvars <- gets termTyVars - let solution = solve cts tyvars - - traceM $ - unlines - [ "# function " <> prettyNameString fname, - "## constraints:", - unlines $ map prettyString cts, - "## tyvars:", - unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars, - "## solution:", - let p (t, vs) = unwords (map prettyNameString vs) <> " => " <> prettyString t - in either T.unpack (unlines . map p . M.toList) solution - ] - pure (solution, params', retdecl', body') - --- checkValDef (fname, retdecl, tparams, params, body, _loc) = runTermM $ do --- bindParams tparams params $ \params' -> do --- body' <- checkExp body --- --- retdecl' <- checkRetDecl body' retdecl --- --- cts <- gets termConstraints --- --- counter <- gets termCounter --- --- case rankAnalysis counter cts of --- Nothing -> error "" --- Just rank_map -> do --- traceM $ prettyString $ M.toList rank_map --- tyvars <- gets termTyVars --- --- let solution = solve cts tyvars --- --- traceM $ --- unlines --- [ "# function " <> prettyNameString fname, --- "## constraints:", --- unlines $ map prettyString cts, --- "## tyvars:", --- unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars, --- "## solution:", --- let p (t, vs) = unwords (map prettyNameString vs) <> " => " <> prettyString t --- in either T.unpack (unlines . map p . M.toList) solution --- ] --- pure (solution, params', retdecl', body') + case rankAnalysis counter cts tyvars of + Nothing -> error "" + Just rank_map -> do + traceM $ prettyString $ M.toList rank_map + + let solution = solve cts tyvars + + traceM $ + unlines + [ "# function " <> prettyNameString fname, + "## constraints:", + unlines $ map prettyString cts, + "## tyvars:", + unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars, + "## solution:", + let p (t, vs) = unwords (map prettyNameString vs) <> " => " <> prettyString t + in either T.unpack (unlines . map p . M.toList) solution + ] + pure (solution, params', retdecl', body') From 8e8db497ddf13548d6f5ed6112564a6abb1506a4 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Sun, 11 Feb 2024 22:43:45 -0800 Subject: [PATCH 031/258] Basic support for substituting in rank info. --- .../Futhark/TypeChecker/Constraints.hs | 4 ++ src/Language/Futhark/TypeChecker/Rank.hs | 64 ++++++++++++++++++- src/Language/Futhark/TypeChecker/Terms2.hs | 10 ++- 3 files changed, 70 insertions(+), 8 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index a81812def5..d5b13aa049 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -8,6 +8,7 @@ module Language.Futhark.TypeChecker.Constraints TyVarInfo (..), TyVar, TyVars, + TyVarSol (..), Solution, solve, ) @@ -70,6 +71,8 @@ data TyVarInfo TyVarRecord (M.Map Name Type) | -- | Must be a sum type with these fields. TyVarSum (M.Map Name [Type]) + | -- | Must have at least this rank. + TyVarRank Int deriving (Show) instance Pretty TyVarInfo where @@ -77,6 +80,7 @@ instance Pretty TyVarInfo where pretty (TyVarPrim pts) = "∈" <+> pretty pts pretty (TyVarRecord fs) = pretty $ Scalar $ Record fs pretty (TyVarSum cs) = pretty $ Scalar $ Sum cs + pretty (TyVarRank x) = "rank ≥" <+> pretty x type TyVar = VName diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index adc0d32f72..9a4eca4010 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -6,6 +6,8 @@ import Data.Map qualified as M import Data.Maybe import Data.Vector.Unboxed qualified as V import Debug.Trace +-- import Futhark.FreshNames qualified as FreshNames +-- import Futhark.MonadFreshNames hiding (newName) import Futhark.Solve.BranchAndBound import Futhark.Solve.LP hiding (Constraint, LSum, LinearProg) import Futhark.Solve.LP qualified as LP @@ -13,6 +15,8 @@ import Futhark.Solve.Simplex import Language.Futhark hiding (ScalarType) import Language.Futhark.TypeChecker.Constraints +-- import Language.Futhark.TypeChecker.Monad (mkTypeVarName) + type LSum = LP.LSum VName Double type Constraint = LP.Constraint VName Double @@ -128,12 +132,20 @@ mkLinearProg counter cs tyVars = mapM_ (uncurry addTyVarInfo) $ M.toList tyVars finalState = flip execState initState $ runRankM buildLP -rankAnalysis :: Int -> [Ct] -> TyVars -> Maybe (Map VName Int) +rankAnalysis :: Int -> [Ct] -> TyVars -> Maybe ([Ct], TyVars) rankAnalysis counter cs tyVars = do traceM $ unlines ["rankAnalysis prog:", prettyString prog] (_size, ranks) <- branchAndBound lp - pure $ (fromJust . (ranks V.!?)) <$> inv_var_map + let rank_map = (fromJust . (ranks V.!?)) <$> inv_var_map + let (cs', SubstState tyVars') = + flip runState (SubstState mempty) $ + runSubstM $ + substRanks rank_map $ + filter (not . isCtAM) cs + pure (cs', tyVars <> tyVars') where + isCtAM (CtAM {}) = True + isCtAM _ = False splitFuncs ( CtEq (Scalar (Arrow _ _ _ t1a (RetType _ t1r))) @@ -148,3 +160,51 @@ rankAnalysis counter cs tyVars = do prog = mkLinearProg counter cs' tyVars (lp, var_map) = linearProgToLP prog inv_var_map = M.fromListWith (error "oh no!") [(v, k) | (k, v) <- M.toList var_map] + +newtype SubstM a = SubstM {runSubstM :: State SubstState a} + deriving (Functor, Applicative, Monad, MonadState SubstState) + +data SubstState = SubstState + { substTyVars :: TyVars + } + +rankToShape :: Map VName Int -> VName -> Shape SComp +rankToShape rs x = Shape $ replicate (rs M.! x) SDim + +addRankInfo :: Map VName Int -> TyVar -> SubstM () +addRankInfo rs t = + modify $ \s -> s {substTyVars = M.insert t (TyVarRank $ rs M.! t) $ substTyVars s} + +class SubstRanks a where + substRanks :: Map VName Int -> a -> SubstM a + +instance (SubstRanks a) => SubstRanks [a] where + substRanks rs = mapM (substRanks rs) + +instance SubstRanks (Shape SComp) where + substRanks rs = pure . foldMap instDim + where + instDim (SDim) = Shape $ pure SDim + instDim (SVar x) = rankToShape rs x + +instance SubstRanks (TypeBase SComp u) where + substRanks rs t@(Scalar (TypeVar u (QualName [] x) [])) + | rs M.! x > 0 = do + addRankInfo rs x + pure t + substRanks rs (Scalar (Arrow u p d ta (RetType retdims tr))) = do + ta' <- substRanks rs ta + tr' <- substRanks rs tr + pure $ Scalar (Arrow u p d ta' (RetType retdims tr')) + substRanks rs (Array u shape t) = do + shape' <- substRanks rs shape + t' <- substRanks rs (Scalar t) + pure $ Array u (shape' <> arrayShape t') (scalarType t') + where + scalarType (Array _ _ t) = t + scalarType (Scalar t) = t + substRanks _ t = pure t + +instance SubstRanks Ct where + substRanks rs (CtEq t1 t2) = CtEq <$> substRanks rs t1 <*> substRanks rs t2 + substRanks _ _ = error "" diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index d7484d7407..019e0b0d8b 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -1072,18 +1072,16 @@ checkValDef (fname, retdecl, tparams, params, body, _loc) = runTermM $ do case rankAnalysis counter cts tyvars of Nothing -> error "" - Just rank_map -> do - traceM $ prettyString $ M.toList rank_map - - let solution = solve cts tyvars + Just (cts', tyvars') -> do + let solution = solve cts' tyvars' traceM $ unlines [ "# function " <> prettyNameString fname, "## constraints:", - unlines $ map prettyString cts, + unlines $ map prettyString cts', "## tyvars:", - unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars, + unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars', "## solution:", let p (t, vs) = unwords (map prettyNameString vs) <> " => " <> prettyString t in either T.unpack (unlines . map p . M.toList) solution From d4d19138c3db9cc06b9560018aa1425f5f9f42b7 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 12 Feb 2024 11:54:38 +0100 Subject: [PATCH 032/258] Correct construction of array types. --- src/Language/Futhark/Prop.hs | 4 +++- src/Language/Futhark/TypeChecker/Rank.hs | 5 +---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/Language/Futhark/Prop.hs b/src/Language/Futhark/Prop.hs index 65602c4689..cc5c40268a 100644 --- a/src/Language/Futhark/Prop.hs +++ b/src/Language/Futhark/Prop.hs @@ -316,7 +316,9 @@ arrayOfWithAliases :: arrayOfWithAliases u shape2 (Array _ shape1 et) = Array u (shape2 <> shape1) et arrayOfWithAliases u shape (Scalar t) = - Array u shape (second (const mempty) t) + if shapeRank shape == 0 + then Scalar t `setUniqueness` u + else Array u shape (second (const mempty) t) -- | @stripArray n t@ removes the @n@ outermost layers of the array. -- Essentially, it is the type of indexing an array of type @t@ with diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 9a4eca4010..b282d704c3 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -199,10 +199,7 @@ instance SubstRanks (TypeBase SComp u) where substRanks rs (Array u shape t) = do shape' <- substRanks rs shape t' <- substRanks rs (Scalar t) - pure $ Array u (shape' <> arrayShape t') (scalarType t') - where - scalarType (Array _ _ t) = t - scalarType (Scalar t) = t + pure $ arrayOfWithAliases u shape' t' substRanks _ t = pure t instance SubstRanks Ct where From 320d95444677ff4d6e4b0c5aa06123ee672338b3 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 12 Feb 2024 12:09:36 +0100 Subject: [PATCH 033/258] Forget size variables here. --- src/Language/Futhark/TypeChecker/Terms2.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 019e0b0d8b..bcb97efcbf 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -387,7 +387,7 @@ lookupVar loc qn@(QualName qs name) = do else do (tnames, t') <- instTypeScheme qn loc tparams t outer_env <- asks termOuterEnv - pure $ qualifyTypeVars outer_env tnames qs t' + asStructType loc $ first (const SDim) $ qualifyTypeVars outer_env tnames qs t' Just EqualityF -> do argtype <- newType loc "t" pure $ foldFunType [argtype, argtype] $ RetType [] $ Scalar $ Prim Bool From 46ba3f99e7216f9f7c07d5e5e270f46d88032a35 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 12 Feb 2024 12:46:27 +0100 Subject: [PATCH 034/258] Basic tracking of levels. --- .../Futhark/TypeChecker/Constraints.hs | 69 +++++++++++-------- src/Language/Futhark/TypeChecker/Rank.hs | 10 +-- src/Language/Futhark/TypeChecker/Terms.hs | 8 +-- src/Language/Futhark/TypeChecker/Terms2.hs | 8 ++- 4 files changed, 56 insertions(+), 39 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index d5b13aa049..d3d145c57e 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -18,6 +18,7 @@ import Control.Monad.Except import Control.Monad.State import Data.Bifunctor import Data.Map qualified as M +import Data.Maybe import Data.Text qualified as T import Futhark.IR.Pretty import Futhark.Util.Pretty @@ -84,30 +85,31 @@ instance Pretty TyVarInfo where type TyVar = VName --- | If a VName is not in this map, it is assumed to be rigid. -type TyVars = M.Map TyVar TyVarInfo +-- | If a VName is not in this map, it is assumed to be rigid. The +-- integer is the level. +type TyVars = M.Map TyVar (Int, TyVarInfo) data TyVarSol = -- | Has been substituted with this. - TyVarSol Type + TyVarSol Int Type | -- | Replaced by this other type variable. TyVarLink VName | -- | Not substituted yet; has this constraint. - TyVarUnsol TyVarInfo + TyVarUnsol Int TyVarInfo deriving (Show) newtype SolverState = SolverState {solverTyVars :: M.Map TyVar TyVarSol} initialState :: TyVars -> SolverState -initialState tyvars = SolverState $ M.map TyVarUnsol tyvars +initialState tyvars = SolverState $ M.map (uncurry TyVarUnsol) tyvars substTyVars :: (Monoid u) => M.Map TyVar TyVarSol -> TypeBase SComp u -> TypeBase SComp u substTyVars m t@(Scalar (TypeVar u (QualName qs v) args)) = case M.lookup v m of Just (TyVarLink v') -> substTyVars m $ Scalar $ TypeVar u (QualName qs v') args - Just (TyVarSol t') -> second (const mempty) t' - Just (TyVarUnsol _) -> t + Just (TyVarSol _ t') -> second (const mempty) t' + Just (TyVarUnsol {}) -> t Nothing -> t substTyVars _ (Scalar (Prim pt)) = Scalar $ Prim pt substTyVars m (Scalar (Record fs)) = Scalar $ Record $ M.map (substTyVars m) fs @@ -118,10 +120,11 @@ substTyVars m (Array u shape elemt) = arrayOfWithAliases u shape $ substTyVars m $ Scalar elemt -- | A solution maps types to the set of type variables that must be --- substituted with this type. This slightly odd representation is --- needed to encode when two type variables are actually the same --- type. This matters when we start instanting the sizes of the type. -type Solution = M.Map Type [TyVar] +-- substituted with this type, as well as its binding level. This +-- slightly odd representation is needed to encode when two type +-- variables are actually the same type. This matters when we start +-- instanting the sizes of the type. +type Solution = M.Map Type (Int, [TyVar]) solution :: SolverState -> Solution solution s = @@ -132,23 +135,23 @@ solution s = M.toList $ solverTyVars s where - mkSubst (TyVarSol t) = Just (t, []) + mkSubst (TyVarSol lvl t) = Just (lvl, (t, [])) mkSubst _ = Nothing addLinks m (v1, TyVarLink v2) = case M.lookup v2 $ solverTyVars s of Just (TyVarLink v3) -> addLinks m (v1, TyVarLink v3) _ -> case M.lookup v2 m of Nothing -> m - Just (t, vs) -> M.insert v2 (t, v1 : vs) m + Just (t, (lvl, vs)) -> M.insert v2 (t, (lvl, v1 : vs)) m addLinks m _ = m - adjust (v, (t, vs)) = (t, v : vs) + adjust (v, (lvl, (t, vs))) = (t, (lvl, v : vs)) newtype SolveM a = SolveM {runSolveM :: StateT SolverState (Except T.Text) a} deriving (Functor, Applicative, Monad, MonadState SolverState, MonadError T.Text) -subTyVar :: VName -> Type -> SolveM () -subTyVar v t = - modify $ \s -> s {solverTyVars = M.insert v (TyVarSol t) $ solverTyVars s} +subTyVar :: VName -> Int -> Type -> SolveM () +subTyVar v lvl t = + modify $ \s -> s {solverTyVars = M.insert v (TyVarSol lvl t) $ solverTyVars s} linkTyVar :: VName -> VName -> SolveM () linkTyVar v t = @@ -190,13 +193,13 @@ solveCt ct = tyvars <- gets solverTyVars let flexible v = case M.lookup v tyvars of Just (TyVarLink v') -> flexible v' - Just (TyVarUnsol _) -> True - Just (TyVarSol _) -> False - Nothing -> False + Just (TyVarUnsol lvl _) -> Just lvl + Just (TyVarSol _ _) -> Nothing + Nothing -> Nothing sub t@(Scalar (TypeVar u (QualName [] v) [])) = case M.lookup v tyvars of Just (TyVarLink v') -> sub $ Scalar (TypeVar u (QualName [] v') []) - Just (TyVarSol t') -> sub t' + Just (TyVarSol _ t') -> sub t' _ -> t sub t = t case (sub t1, sub t2) of @@ -206,14 +209,22 @@ solveCt ct = | v1 == v2 -> pure () | otherwise -> case (flexible v1, flexible v2) of - (False, False) -> bad - (True, False) -> subTyVar v1 t2' - (False, True) -> subTyVar v2 t1' - (True, True) -> linkTyVar v1 v2 - (Scalar (TypeVar _ (QualName [] v1) []), t2') -> - if flexible v1 then subTyVar v1 t2' else bad - (t1', Scalar (TypeVar _ (QualName [] v2) [])) -> - if flexible v2 then subTyVar v2 t1' else bad + (Nothing, Nothing) -> bad + (Just lvl, Nothing) -> subTyVar v1 lvl t2' + (Nothing, Just lvl) -> subTyVar v2 lvl t1' + (Just lvl1, Just lvl2) + | lvl1 <= lvl2 -> linkTyVar v1 v2 + | otherwise -> linkTyVar v2 v1 + (Scalar (TypeVar _ (QualName [] v1) []), t2') + | Just lvl <- flexible v1 -> + subTyVar v1 lvl t2' + | otherwise -> + bad + (t1', Scalar (TypeVar _ (QualName [] v2) [])) + | Just lvl <- flexible v2 -> + subTyVar v2 lvl t1' + | otherwise -> + bad (t1', t2') -> case unify t1' t2' of Nothing -> bad Just eqs -> mapM_ solveCt' eqs diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index b282d704c3..bcb67d2a13 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -105,9 +105,9 @@ addCt (CtAM r m) = do b_m <- binVar m addConstraints $ oneIsZero (b_r, r) (b_m, m) -addTyVarInfo :: TyVar -> TyVarInfo -> RankM () -addTyVarInfo tv (TyVarFree) = pure () -addTyVarInfo tv (TyVarPrim _) = +addTyVarInfo :: TyVar -> (Int, TyVarInfo) -> RankM () +addTyVarInfo tv (_, TyVarFree) = pure () +addTyVarInfo tv (_, TyVarPrim _) = addConstraint $ rank tv ~==~ constant 0 addTyVarInfo _ _ = error "Unhandled" @@ -173,7 +173,9 @@ rankToShape rs x = Shape $ replicate (rs M.! x) SDim addRankInfo :: Map VName Int -> TyVar -> SubstM () addRankInfo rs t = - modify $ \s -> s {substTyVars = M.insert t (TyVarRank $ rs M.! t) $ substTyVars s} + modify $ \s -> s {substTyVars = M.insert t (lvl, TyVarRank $ rs M.! t) $ substTyVars s} + where + lvl = 0 -- FIXME class SubstRanks a where substRanks :: Map VName Int -> a -> SubstM a diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 2a7bbbd6da..22331fe113 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -1622,15 +1622,15 @@ arrayOfM loc t shape = do arrayElemType (mkUsage loc "use as array element") "type used in array" t pure $ arrayOf shape t -addInitialConstraints :: M.Map (TypeBase () NoUniqueness) [VName] -> TermTypeM () +addInitialConstraints :: M.Map (TypeBase () NoUniqueness) (Int, [VName]) -> TermTypeM () addInitialConstraints = mapM_ f . M.toList where - addConstraint v c = modifyConstraints $ M.insert v (0, c) + addConstraint v lvl c = modifyConstraints $ M.insert v (lvl, c) usage = mkUsage (mempty :: Loc) - f (t, vs) = do + f (t, (lvl, vs)) = do (t', _) <- allDimsFreshInType (usage (prettyText t)) Nonrigid "dv" t forM_ vs $ \v -> - addConstraint v $ Constraint (RetType [] t') $ usage $ prettyNameText v + addConstraint v lvl $ Constraint (RetType [] t') $ usage $ prettyNameText v -- | Type-check a top-level (or module-level) function definition. -- Despite the name, this is also used for checking constant diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index bcb97efcbf..6c2194511a 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -202,6 +202,9 @@ runTermM (TermM m) = do incLevel :: TermM a -> TermM a incLevel = local $ \env -> env {termLevel = termLevel env + 1} +curLevel :: TermM Int +curLevel = asks termLevel + incCounter :: TermM Int incCounter = do s <- get @@ -215,7 +218,8 @@ newTyVarWith :: (Located loc) => loc -> Name -> TyVarInfo -> TermM TyVar newTyVarWith _loc desc info = do i <- incCounter v <- newID $ mkTypeVarName desc i - modify $ \s -> s {termTyVars = M.insert v info $ termTyVars s} + lvl <- curLevel + modify $ \s -> s {termTyVars = M.insert v (lvl, info) $ termTyVars s} pure v newTyVar :: (Located loc) => loc -> Name -> TermM TyVar @@ -1083,7 +1087,7 @@ checkValDef (fname, retdecl, tparams, params, body, _loc) = runTermM $ do "## tyvars:", unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars', "## solution:", - let p (t, vs) = unwords (map prettyNameString vs) <> " => " <> prettyString t + let p (t, (lvl, vs)) = unwords (show [lvl] : map prettyNameString vs) <> " => " <> prettyString t in either T.unpack (unlines . map p . M.toList) solution ] pure (solution, params', retdecl', body') From 957b46232b513f754078ae84fc46bc91e636303f Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 12 Feb 2024 12:51:22 +0100 Subject: [PATCH 035/258] Create sizes at right level. --- src/Language/Futhark/TypeChecker/Terms.hs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 22331fe113..96eb651bb3 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -1622,13 +1622,18 @@ arrayOfM loc t shape = do arrayElemType (mkUsage loc "use as array element") "type used in array" t pure $ arrayOf shape t +-- A hack we need to create size variables for types at the right +-- level. +atLevel :: Int -> TermTypeM a -> TermTypeM a +atLevel lvl = local $ \env -> env {termLevel = lvl} + addInitialConstraints :: M.Map (TypeBase () NoUniqueness) (Int, [VName]) -> TermTypeM () addInitialConstraints = mapM_ f . M.toList where addConstraint v lvl c = modifyConstraints $ M.insert v (lvl, c) usage = mkUsage (mempty :: Loc) f (t, (lvl, vs)) = do - (t', _) <- allDimsFreshInType (usage (prettyText t)) Nonrigid "dv" t + (t', _) <- atLevel lvl $ allDimsFreshInType (usage (prettyText t)) Nonrigid "dv" t forM_ vs $ \v -> addConstraint v lvl $ Constraint (RetType [] t') $ usage $ prettyNameText v From ceaa72fa3b570baa7d0dae6fff4e8b5578b57ab2 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 12 Feb 2024 12:55:46 +0100 Subject: [PATCH 036/258] Handle more cases. --- src/Language/Futhark/TypeChecker/Rank.hs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index bcb67d2a13..d107f6465b 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -40,9 +40,11 @@ instance Rank (Shape SComp) where instance Rank ScalarType where rank Prim {} = constant 0 - rank (TypeVar _ (QualName [] v) []) = var v + rank (TypeVar _ (QualName [] v) []) = var v -- FIXME - might not be a type variable. + rank (TypeVar {}) = constant 0 rank (Arrow {}) = constant 0 - rank t = error $ prettyString t + rank (Record {}) = constant 0 + rank (Sum {}) = constant 0 instance Rank Type where rank (Scalar t) = rank t From 3d8a1a14dc4348986e531d69e85dfb7ac22da80a Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 12 Feb 2024 12:57:34 +0100 Subject: [PATCH 037/258] Consistent printing. --- src/Language/Futhark/TypeChecker/Rank.hs | 2 +- src/Language/Futhark/TypeChecker/Terms2.hs | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index d107f6465b..547c2f2cab 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -136,7 +136,7 @@ mkLinearProg counter cs tyVars = rankAnalysis :: Int -> [Ct] -> TyVars -> Maybe ([Ct], TyVars) rankAnalysis counter cs tyVars = do - traceM $ unlines ["rankAnalysis prog:", prettyString prog] + traceM $ unlines ["## rankAnalysis prog", prettyString prog] (_size, ranks) <- branchAndBound lp let rank_map = (fromJust . (ranks V.!?)) <$> inv_var_map let (cs', SubstState tyVars') = diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 6c2194511a..227642e1b5 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -1074,6 +1074,8 @@ checkValDef (fname, retdecl, tparams, params, body, _loc) = runTermM $ do tyvars <- gets termTyVars + traceM $ "# function " <> prettyNameString fname + case rankAnalysis counter cts tyvars of Nothing -> error "" Just (cts', tyvars') -> do @@ -1081,8 +1083,7 @@ checkValDef (fname, retdecl, tparams, params, body, _loc) = runTermM $ do traceM $ unlines - [ "# function " <> prettyNameString fname, - "## constraints:", + [ "## constraints:", unlines $ map prettyString cts', "## tyvars:", unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars', From 64ae342ea8946e19ff94d4720c42daff7407a35e Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 12 Feb 2024 13:19:03 +0100 Subject: [PATCH 038/258] Fix construction of solution. --- src/Language/Futhark/TypeChecker/Constraints.hs | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index d3d145c57e..af55a6d243 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -17,6 +17,7 @@ where import Control.Monad.Except import Control.Monad.State import Data.Bifunctor +import Data.List qualified as L import Data.Map qualified as M import Data.Maybe import Data.Text qualified as T @@ -128,12 +129,11 @@ type Solution = M.Map Type (Int, [TyVar]) solution :: SolverState -> Solution solution s = - M.fromList $ - map adjust $ - M.toList $ - foldl addLinks (M.mapMaybe mkSubst $ solverTyVars s) $ - M.toList $ - solverTyVars s + L.foldl' byType mempty $ + M.toList $ + L.foldl' addLinks (M.mapMaybe mkSubst $ solverTyVars s) $ + M.toList $ + solverTyVars s where mkSubst (TyVarSol lvl t) = Just (lvl, (t, [])) mkSubst _ = Nothing @@ -144,7 +144,9 @@ solution s = Nothing -> m Just (t, (lvl, vs)) -> M.insert v2 (t, (lvl, v1 : vs)) m addLinks m _ = m - adjust (v, (lvl, (t, vs))) = (t, (lvl, v : vs)) + byType m (v, (lvl, (t, vs))) = M.insertWith comb t (lvl, v : vs) m + where + comb (lvl1, ts1) (lvl2, ts2) = (min lvl1 lvl2, ts1 <> ts2) newtype SolveM a = SolveM {runSolveM :: StateT SolverState (Except T.Text) a} deriving (Functor, Applicative, Monad, MonadState SolverState, MonadError T.Text) From ca5c562b92a0162ec2111b17188f4344fbabb924 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Mon, 12 Feb 2024 04:44:21 -0800 Subject: [PATCH 039/258] Spaghetti code to better add in rank info. --- .../Futhark/TypeChecker/Constraints.hs | 4 - src/Language/Futhark/TypeChecker/Rank.hs | 197 ++++++++++++++---- src/Language/Futhark/TypeChecker/Terms2.hs | 7 +- 3 files changed, 163 insertions(+), 45 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index af55a6d243..67131e67cc 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -8,7 +8,6 @@ module Language.Futhark.TypeChecker.Constraints TyVarInfo (..), TyVar, TyVars, - TyVarSol (..), Solution, solve, ) @@ -73,8 +72,6 @@ data TyVarInfo TyVarRecord (M.Map Name Type) | -- | Must be a sum type with these fields. TyVarSum (M.Map Name [Type]) - | -- | Must have at least this rank. - TyVarRank Int deriving (Show) instance Pretty TyVarInfo where @@ -82,7 +79,6 @@ instance Pretty TyVarInfo where pretty (TyVarPrim pts) = "∈" <+> pretty pts pretty (TyVarRecord fs) = pretty $ Scalar $ Record fs pretty (TyVarSum cs) = pretty $ Scalar $ Sum cs - pretty (TyVarRank x) = "rank ≥" <+> pretty x type TyVar = VName diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 547c2f2cab..c87e10a402 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -1,21 +1,21 @@ module Language.Futhark.TypeChecker.Rank (rankAnalysis) where +import Control.Monad.Reader import Control.Monad.State import Data.Map (Map) import Data.Map qualified as M import Data.Maybe import Data.Vector.Unboxed qualified as V import Debug.Trace --- import Futhark.FreshNames qualified as FreshNames --- import Futhark.MonadFreshNames hiding (newName) +import Futhark.FreshNames qualified as FreshNames +import Futhark.MonadFreshNames hiding (newName) import Futhark.Solve.BranchAndBound import Futhark.Solve.LP hiding (Constraint, LSum, LinearProg) import Futhark.Solve.LP qualified as LP import Futhark.Solve.Simplex import Language.Futhark hiding (ScalarType) import Language.Futhark.TypeChecker.Constraints - --- import Language.Futhark.TypeChecker.Monad (mkTypeVarName) +import Language.Futhark.TypeChecker.Monad (mkTypeVarName) type LSum = LP.LSum VName Double @@ -134,17 +134,29 @@ mkLinearProg counter cs tyVars = mapM_ (uncurry addTyVarInfo) $ M.toList tyVars finalState = flip execState initState $ runRankM buildLP -rankAnalysis :: Int -> [Ct] -> TyVars -> Maybe ([Ct], TyVars) -rankAnalysis counter cs tyVars = do +rankAnalysis :: VNameSource -> Int -> [Ct] -> TyVars -> Maybe ([Ct], TyVars, VNameSource, Int) +rankAnalysis vns counter cs tyVars = do traceM $ unlines ["## rankAnalysis prog", prettyString prog] (_size, ranks) <- branchAndBound lp let rank_map = (fromJust . (ranks V.!?)) <$> inv_var_map - let (cs', SubstState tyVars') = - flip runState (SubstState mempty) $ - runSubstM $ - substRanks rank_map $ - filter (not . isCtAM) cs - pure (cs', tyVars <> tyVars') + initEnv = + SubstEnv + { envTyVars = tyVars, + envRanks = rank_map + } + + initState = + SubstState + { substTyVars = mempty, + substNewVars = mempty, + substNameSource = vns, + substCounter = counter + } + (cs', state') = + runSubstM initEnv initState $ + substRanks $ + filter (not . isCtAM) cs + pure (cs', substTyVars state' <> tyVars, substNameSource state', substCounter state') where isCtAM (CtAM {}) = True isCtAM _ = False @@ -163,49 +175,156 @@ rankAnalysis counter cs tyVars = do (lp, var_map) = linearProgToLP prog inv_var_map = M.fromListWith (error "oh no!") [(v, k) | (k, v) <- M.toList var_map] -newtype SubstM a = SubstM {runSubstM :: State SubstState a} - deriving (Functor, Applicative, Monad, MonadState SubstState) +newtype SubstM a = SubstM (StateT SubstState (Reader SubstEnv) a) + deriving (Functor, Applicative, Monad, MonadState SubstState, MonadReader SubstEnv) -data SubstState = SubstState - { substTyVars :: TyVars +runSubstM :: SubstEnv -> SubstState -> SubstM a -> (a, SubstState) +runSubstM initEnv initState (SubstM m) = + runReader (runStateT m initState) initEnv + +data SubstEnv = SubstEnv + { envTyVars :: TyVars, + envRanks :: Map VName Int } -rankToShape :: Map VName Int -> VName -> Shape SComp -rankToShape rs x = Shape $ replicate (rs M.! x) SDim +data SubstState = SubstState + { substTyVars :: TyVars, + substNewVars :: Map TyVar TyVar, + substNameSource :: VNameSource, + substCounter :: !Int + } -addRankInfo :: Map VName Int -> TyVar -> SubstM () -addRankInfo rs t = - modify $ \s -> s {substTyVars = M.insert t (lvl, TyVarRank $ rs M.! t) $ substTyVars s} +substIncCounter :: SubstM Int +substIncCounter = do + s <- get + put s {substCounter = substCounter s + 1} + pure $ substCounter s + +newTyVar :: TyVar -> SubstM TyVar +newTyVar t = do + i <- substIncCounter + t' <- newID $ mkTypeVarName (baseName t) i + modify $ \s -> s {substNewVars = M.insert t t' $ substNewVars s} + pure t' + where + newID x = do + s <- get + let (v', src') = FreshNames.newName (substNameSource s) $ VName x 0 + put $ s {substNameSource = src'} + pure v' + +rankToShape :: VName -> SubstM (Shape SComp) +rankToShape x = do + rs <- asks envRanks + pure $ Shape $ replicate (rs M.! x) SDim + +addRankInfo :: TyVar -> SubstM TyVar +addRankInfo t = do + rs <- asks envRanks + if rs M.! t == 0 + then pure t + else do + new_vars <- gets substNewVars + maybe new_var pure $ new_vars M.!? t where lvl = 0 -- FIXME + new_var = do + t' <- newTyVar t + old_tyvars <- asks envTyVars + case old_tyvars M.!? t of + Nothing -> pure t' + Just info -> do + modify $ \s -> s {substTyVars = M.insert t' info $ substTyVars s} + modify $ \s -> s {substTyVars = M.insert t (fst info, TyVarFree) $ substTyVars s} + pure t' class SubstRanks a where - substRanks :: Map VName Int -> a -> SubstM a + substRanks :: a -> SubstM a instance (SubstRanks a) => SubstRanks [a] where - substRanks rs = mapM (substRanks rs) + substRanks = mapM substRanks instance SubstRanks (Shape SComp) where - substRanks rs = pure . foldMap instDim + substRanks = foldM (\s d -> (s <>) <$> instDim d) mempty where - instDim (SDim) = Shape $ pure SDim - instDim (SVar x) = rankToShape rs x + instDim (SDim) = pure $ Shape $ pure SDim + instDim (SVar x) = rankToShape x instance SubstRanks (TypeBase SComp u) where - substRanks rs t@(Scalar (TypeVar u (QualName [] x) [])) - | rs M.! x > 0 = do - addRankInfo rs x - pure t - substRanks rs (Scalar (Arrow u p d ta (RetType retdims tr))) = do - ta' <- substRanks rs ta - tr' <- substRanks rs tr + substRanks t@(Scalar (TypeVar u (QualName [] x) [])) = do + x' <- addRankInfo x + pure $ (Scalar (TypeVar u (QualName [] x') [])) + substRanks (Scalar (Arrow u p d ta (RetType retdims tr))) = do + ta' <- substRanks ta + tr' <- substRanks tr pure $ Scalar (Arrow u p d ta' (RetType retdims tr')) - substRanks rs (Array u shape t) = do - shape' <- substRanks rs shape - t' <- substRanks rs (Scalar t) + substRanks (Array u shape t) = do + shape' <- substRanks shape + t' <- substRanks $ Scalar t pure $ arrayOfWithAliases u shape' t' - substRanks _ t = pure t + substRanks t = pure t instance SubstRanks Ct where - substRanks rs (CtEq t1 t2) = CtEq <$> substRanks rs t1 <*> substRanks rs t2 - substRanks _ _ = error "" + substRanks (CtEq t1 t2) = CtEq <$> substRanks t1 <*> substRanks t2 + substRanks _ = error "" + +-- data SubstState = SubstState +-- { substTyVars :: Map TyVar TyVarSol, +-- substNameSource :: VNameSource, +-- substCounter :: !Int +-- } +-- +-- newtype SubstM a = SubstM {runSubstM :: State SubstState a} +-- deriving (Functor, Applicative, Monad, MonadState SubstState) +-- +-- substIncCounter :: SubstM Int +-- substIncCounter = do +-- s <- get +-- put s {substCounter = substCounter s + 1} +-- pure $ substCounter s +-- +-- newTyVar :: Name -> SubstM TyVar +-- newTyVar desc = do +-- i <- substIncCounter +-- newID $ mkTypeVarName desc i +-- where +-- newID x = do +-- s <- get +-- let (v', src') = FreshNames.newName (substNameSource s) $ VName x 0 +-- put $ s {substNameSource = src'} +-- pure v' +-- +-- addTyVarSol :: TyVar -> Shape SComp -> SubstM TyVar +-- addTyVarSol t shape = do +-- m <- subsTyVars gets +-- case m M.!? t of +-- Nothing -> do +-- t' <- newTyVar $ baseName t +-- modify $ \s -> s {substTyVars = M.insert t () $ substTyVars s} +-- Just t' -> pure t' +-- +-- rankToShape :: Map VName Int -> VName -> Shape SComp +-- rankToShape rs x = Shape $ replicate (rs M.! x) SDim +-- +-- class SubstRanks a where +-- substRanks :: Map VName Int -> a -> SubstM a +-- +-- instance SubstRanks (Shape SComp) where +-- substRanks rs = pure . foldMap instDim +-- where +-- instDim (SDim) = Shape $ pure SDim +-- instDim (SVar x) = rankToShape rs x +-- +-- instance SubstRanks (TypeBase SComp u) where +-- substRanks rs t@(Scalar (TypeVar u (QualName [] x) [])) +-- | rs M.! x > 0 = do +-- t' <- newTyVar $ baseName t +-- t' <- addTyVarSol +-- arrayOfWithAliases u (rankToShape rs x) t +-- substRanks rs (Scalar (Arrow u p d ta (RetType retdims tr))) = +-- Scalar (Arrow u p d (substRanks rs ta) (RetType retdims (substRanks rs tr))) +-- substRanks _ t = t +-- +-- instance SubstRanks Ct where +-- substRanks rs (CtEq t1 t2) = CtEq <$> substRanks rs t1 <*> substRanks rs t2 +-- substRanks _ _ = error "" diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 227642e1b5..4fa481140c 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -1076,9 +1076,12 @@ checkValDef (fname, retdecl, tparams, params, body, _loc) = runTermM $ do traceM $ "# function " <> prettyNameString fname - case rankAnalysis counter cts tyvars of + vns <- gets termNameSource + + case rankAnalysis vns counter cts tyvars of Nothing -> error "" - Just (cts', tyvars') -> do + Just (cts', tyvars', vns', counter') -> do + modify $ \s -> s {termCounter = counter', termNameSource = vns'} let solution = solve cts' tyvars' traceM $ From e18fe2e7678df35b821fe44f9b1eaab4a90c9232 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Mon, 12 Feb 2024 04:50:08 -0800 Subject: [PATCH 040/258] Print out the rank map too. --- src/Language/Futhark/TypeChecker/Rank.hs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index c87e10a402..4293097d96 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -139,7 +139,8 @@ rankAnalysis vns counter cs tyVars = do traceM $ unlines ["## rankAnalysis prog", prettyString prog] (_size, ranks) <- branchAndBound lp let rank_map = (fromJust . (ranks V.!?)) <$> inv_var_map - initEnv = + traceM $ unlines $ "## rank map" : map prettyString (M.toList rank_map) + let initEnv = SubstEnv { envTyVars = tyVars, envRanks = rank_map From 344b747b9c8bad5acdbf89e1376559d611726c9c Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Mon, 12 Feb 2024 05:05:55 -0800 Subject: [PATCH 041/258] Forgot to actually generate the new type variable constraints. --- src/Language/Futhark/TypeChecker/Rank.hs | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 4293097d96..9412553f82 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -151,13 +151,14 @@ rankAnalysis vns counter cs tyVars = do { substTyVars = mempty, substNewVars = mempty, substNameSource = vns, - substCounter = counter + substCounter = counter, + substNewCts = mempty } (cs', state') = runSubstM initEnv initState $ substRanks $ filter (not . isCtAM) cs - pure (cs', substTyVars state' <> tyVars, substNameSource state', substCounter state') + pure (cs' <> substNewCts state', substTyVars state' <> tyVars, substNameSource state', substCounter state') where isCtAM (CtAM {}) = True isCtAM _ = False @@ -192,7 +193,8 @@ data SubstState = SubstState { substTyVars :: TyVars, substNewVars :: Map TyVar TyVar, substNameSource :: VNameSource, - substCounter :: !Int + substCounter :: !Int, + substNewCts :: [Ct] } substIncCounter :: SubstM Int @@ -205,7 +207,17 @@ newTyVar :: TyVar -> SubstM TyVar newTyVar t = do i <- substIncCounter t' <- newID $ mkTypeVarName (baseName t) i - modify $ \s -> s {substNewVars = M.insert t t' $ substNewVars s} + shape <- rankToShape t + modify $ \s -> + s + { substNewVars = M.insert t t' $ substNewVars s, + substNewCts = + substNewCts s + ++ [ CtEq + (Scalar (TypeVar mempty (QualName [] t) [])) + (arrayOf shape (Scalar (TypeVar mempty (QualName [] t') []))) + ] + } pure t' where newID x = do From b2427fa241f608f29c8da5f52efe34a24ebd712f Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Mon, 12 Feb 2024 05:20:32 -0800 Subject: [PATCH 042/258] Bug fixes. --- src/Language/Futhark/TypeChecker/Rank.hs | 85 +++--------------------- 1 file changed, 9 insertions(+), 76 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 9412553f82..cf774947ff 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -231,25 +231,20 @@ rankToShape x = do rs <- asks envRanks pure $ Shape $ replicate (rs M.! x) SDim -addRankInfo :: TyVar -> SubstM TyVar +addRankInfo :: TyVar -> SubstM () addRankInfo t = do rs <- asks envRanks - if rs M.! t == 0 - then pure t - else do - new_vars <- gets substNewVars - maybe new_var pure $ new_vars M.!? t + unless (rs M.! t == 0) $ do + new_vars <- gets substNewVars + maybe new_var (const $ pure ()) $ new_vars M.!? t where lvl = 0 -- FIXME new_var = do t' <- newTyVar t old_tyvars <- asks envTyVars - case old_tyvars M.!? t of - Nothing -> pure t' - Just info -> do - modify $ \s -> s {substTyVars = M.insert t' info $ substTyVars s} - modify $ \s -> s {substTyVars = M.insert t (fst info, TyVarFree) $ substTyVars s} - pure t' + let info = old_tyvars M.! t + modify $ \s -> s {substTyVars = M.insert t' info $ substTyVars s} + modify $ \s -> s {substTyVars = M.insert t (fst info, TyVarFree) $ substTyVars s} class SubstRanks a where substRanks :: a -> SubstM a @@ -264,9 +259,8 @@ instance SubstRanks (Shape SComp) where instDim (SVar x) = rankToShape x instance SubstRanks (TypeBase SComp u) where - substRanks t@(Scalar (TypeVar u (QualName [] x) [])) = do - x' <- addRankInfo x - pure $ (Scalar (TypeVar u (QualName [] x') [])) + substRanks t@(Scalar (TypeVar u (QualName [] x) [])) = + addRankInfo x >> pure t substRanks (Scalar (Arrow u p d ta (RetType retdims tr))) = do ta' <- substRanks ta tr' <- substRanks tr @@ -280,64 +274,3 @@ instance SubstRanks (TypeBase SComp u) where instance SubstRanks Ct where substRanks (CtEq t1 t2) = CtEq <$> substRanks t1 <*> substRanks t2 substRanks _ = error "" - --- data SubstState = SubstState --- { substTyVars :: Map TyVar TyVarSol, --- substNameSource :: VNameSource, --- substCounter :: !Int --- } --- --- newtype SubstM a = SubstM {runSubstM :: State SubstState a} --- deriving (Functor, Applicative, Monad, MonadState SubstState) --- --- substIncCounter :: SubstM Int --- substIncCounter = do --- s <- get --- put s {substCounter = substCounter s + 1} --- pure $ substCounter s --- --- newTyVar :: Name -> SubstM TyVar --- newTyVar desc = do --- i <- substIncCounter --- newID $ mkTypeVarName desc i --- where --- newID x = do --- s <- get --- let (v', src') = FreshNames.newName (substNameSource s) $ VName x 0 --- put $ s {substNameSource = src'} --- pure v' --- --- addTyVarSol :: TyVar -> Shape SComp -> SubstM TyVar --- addTyVarSol t shape = do --- m <- subsTyVars gets --- case m M.!? t of --- Nothing -> do --- t' <- newTyVar $ baseName t --- modify $ \s -> s {substTyVars = M.insert t () $ substTyVars s} --- Just t' -> pure t' --- --- rankToShape :: Map VName Int -> VName -> Shape SComp --- rankToShape rs x = Shape $ replicate (rs M.! x) SDim --- --- class SubstRanks a where --- substRanks :: Map VName Int -> a -> SubstM a --- --- instance SubstRanks (Shape SComp) where --- substRanks rs = pure . foldMap instDim --- where --- instDim (SDim) = Shape $ pure SDim --- instDim (SVar x) = rankToShape rs x --- --- instance SubstRanks (TypeBase SComp u) where --- substRanks rs t@(Scalar (TypeVar u (QualName [] x) [])) --- | rs M.! x > 0 = do --- t' <- newTyVar $ baseName t --- t' <- addTyVarSol --- arrayOfWithAliases u (rankToShape rs x) t --- substRanks rs (Scalar (Arrow u p d ta (RetType retdims tr))) = --- Scalar (Arrow u p d (substRanks rs ta) (RetType retdims (substRanks rs tr))) --- substRanks _ t = t --- --- instance SubstRanks Ct where --- substRanks rs (CtEq t1 t2) = CtEq <$> substRanks rs t1 <*> substRanks rs t2 --- substRanks _ _ = error "" From b00301c840e2be18abe3995c2022650a5cf11dbb Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 12 Feb 2024 14:28:03 +0100 Subject: [PATCH 043/258] Unused now. --- src/Language/Futhark/TypeChecker/Rank.hs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index cf774947ff..7582b27157 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -238,7 +238,6 @@ addRankInfo t = do new_vars <- gets substNewVars maybe new_var (const $ pure ()) $ new_vars M.!? t where - lvl = 0 -- FIXME new_var = do t' <- newTyVar t old_tyvars <- asks envTyVars From 3b125122a42b42446a6da976358aadd2cb30f029 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 12 Feb 2024 14:52:56 +0100 Subject: [PATCH 044/258] Has to be this way around for dumb reasons. --- src/Language/Futhark/TypeChecker/Terms.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 96eb651bb3..bdde3a119d 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -525,7 +525,7 @@ checkExp (QualParens (modname, modnameloc) e loc) = do "Module" <+> pretty modname <+> " is a parametric module." checkExp (Var qn (Info t) loc) = do t' <- lookupVar loc qn - unify (mkUsage loc "inferred rank type") t t' + unify (mkUsage loc "inferred rank type") t' t pure $ Var qn (Info t') loc checkExp (Negate arg loc) = do arg' <- require "numeric negation" anyNumberType =<< checkExp arg From 36ae5c9e9c0a13242fea27e2d673f84a7e229a3c Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 12 Feb 2024 14:55:45 +0100 Subject: [PATCH 045/258] Always connect rank type. --- src/Language/Futhark/TypeChecker/Terms.hs | 19 +++--- .../Futhark/TypeChecker/Terms/Monad.hs | 61 ++++++++++--------- 2 files changed, 41 insertions(+), 39 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index bdde3a119d..816f334069 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -469,8 +469,8 @@ checkExp (Coerce e te _ loc) = do t <- expTypeFully e' t' <- matchDims (const . const pure) t te_t pure $ Coerce e' te' (Info t') loc -checkExp (AppExp (BinOp (op, oploc) _ (e1, _) (e2, _) loc) _) = do - ftype <- lookupVar oploc op +checkExp (AppExp (BinOp (op, oploc) (Info op_t) (e1, _) (e2, _) loc) _) = do + ftype <- lookupVar oploc op op_t e1' <- checkExp e1 e2' <- checkExp e2 @@ -524,8 +524,7 @@ checkExp (QualParens (modname, modnameloc) e loc) = do typeError loc mempty . withIndexLink "module-is-parametric" $ "Module" <+> pretty modname <+> " is a parametric module." checkExp (Var qn (Info t) loc) = do - t' <- lookupVar loc qn - unify (mkUsage loc "inferred rank type") t' t + t' <- lookupVar loc qn t pure $ Var qn (Info t') loc checkExp (Negate arg loc) = do arg' <- require "numeric negation" anyNumberType =<< checkExp arg @@ -715,11 +714,11 @@ checkExp (Lambda params body rettype_te _ loc) = do onDim _ = mempty pure $ RetType (S.toList $ foldMap onDim $ fvVars $ freeInType ret) ret -checkExp (OpSection op _ loc) = do - ftype <- lookupVar loc op +checkExp (OpSection op (Info op_t) loc) = do + ftype <- lookupVar loc op op_t pure $ OpSection op (Info ftype) loc -checkExp (OpSectionLeft op _ e _ _ loc) = do - ftype <- lookupVar loc op +checkExp (OpSectionLeft op (Info op_t) e _ _ loc) = do + ftype <- lookupVar loc op op_t e' <- checkExp e (t1, rt, argext, retext) <- checkApply loc (Just op, 0) ftype e' case (ftype, rt) of @@ -735,8 +734,8 @@ checkExp (OpSectionLeft op _ e _ _ loc) = do _ -> typeError loc mempty $ "Operator section with invalid operator of type" <+> pretty ftype -checkExp (OpSectionRight op _ e _ _ loc) = do - ftype <- lookupVar loc op +checkExp (OpSectionRight op (Info op_t) e _ _ loc) = do + ftype <- lookupVar loc op op_t e' <- checkExp e case ftype of Scalar (Arrow _ m1 d1 t1 (RetType [] (Scalar (Arrow _ m2 d2 t2 (RetType dims2 ret))))) -> do diff --git a/src/Language/Futhark/TypeChecker/Terms/Monad.hs b/src/Language/Futhark/TypeChecker/Terms/Monad.hs index e8525e0dfd..f5ca46c85c 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Monad.hs @@ -448,38 +448,41 @@ instance MonadTypeChecker TermTypeM where Nothing -> throwError $ TypeError (locOf loc) notes s -lookupVar :: SrcLoc -> QualName VName -> TermTypeM StructType -lookupVar loc qn@(QualName qs name) = do +lookupVar :: SrcLoc -> QualName VName -> StructType -> TermTypeM StructType +lookupVar loc qn@(QualName qs name) t = do scope <- lookupQualNameEnv qn let usage = mkUsage loc $ docText $ "use of " <> dquotes (pretty qn) - case M.lookup name $ scopeVtable scope of - Nothing -> - error $ "lookupVar: " <> show qn - Just (BoundV tparams t) -> do - when (null qs) . modify $ \s -> - s {stateUsed = S.insert name $ stateUsed s} - if null tparams && null qs - then pure t - else do - (tnames, t') <- instantiateTypeScheme qn loc tparams t - outer_env <- asks termOuterEnv - pure $ qualifyTypeVars outer_env tnames qs t' - Just EqualityF -> do - argtype <- newTypeVar loc "t" - equalityType usage argtype - pure $ - Scalar . Arrow mempty Unnamed Observe argtype . RetType [] $ - Scalar $ - Arrow mempty Unnamed Observe argtype $ - RetType [] $ - Scalar $ - Prim Bool - Just (OverloadedF ts pts rt) -> do - argtype <- newTypeVar loc "t" - mustBeOneOf ts usage argtype - let (pts', rt') = instOverloaded argtype pts rt - pure $ foldFunType (map (toParam Observe) pts') $ RetType [] $ toRes Nonunique rt' + t' <- + case M.lookup name $ scopeVtable scope of + Nothing -> + error $ "lookupVar: " <> show qn + Just (BoundV tparams bound_t) -> do + when (null qs) . modify $ \s -> + s {stateUsed = S.insert name $ stateUsed s} + if null tparams && null qs + then pure t + else do + (tnames, t') <- instantiateTypeScheme qn loc tparams bound_t + outer_env <- asks termOuterEnv + pure $ qualifyTypeVars outer_env tnames qs t' + Just EqualityF -> do + argtype <- newTypeVar loc "t" + equalityType usage argtype + pure $ + Scalar . Arrow mempty Unnamed Observe argtype . RetType [] $ + Scalar $ + Arrow mempty Unnamed Observe argtype $ + RetType [] $ + Scalar $ + Prim Bool + Just (OverloadedF ts pts rt) -> do + argtype <- newTypeVar loc "t" + mustBeOneOf ts usage argtype + let (pts', rt') = instOverloaded argtype pts rt + pure $ foldFunType (map (toParam Observe) pts') $ RetType [] $ toRes Nonunique rt' + unify (mkUsage loc "inferred rank type") t' t + pure t' where instOverloaded argtype pts rt = ( map (maybe (toStruct argtype) (Scalar . Prim)) pts, From 7d4bad6a7de2f05d3c6843ac1e1d78e5981b4b8d Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 12 Feb 2024 16:16:33 +0100 Subject: [PATCH 046/258] Various fixes. --- .../Futhark/TypeChecker/Constraints.hs | 1 + src/Language/Futhark/TypeChecker/Terms2.hs | 32 +++++++++---------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 67131e67cc..1886be827c 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -20,6 +20,7 @@ import Data.List qualified as L import Data.Map qualified as M import Data.Maybe import Data.Text qualified as T +import Debug.Trace import Futhark.IR.Pretty import Futhark.Util.Pretty import Language.Futhark diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 4fa481140c..0d7ec4e01d 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -342,8 +342,7 @@ instance MonadTypeChecker TermM where --- All the general machinery goes above. arrayOfRank :: Int -> Type -> Type -arrayOfRank 0 t = t -arrayOfRank n t = arrayOf (Shape $ replicate n SDim) t +arrayOfRank n = arrayOf $ Shape $ replicate n SDim require :: T.Text -> [PrimType] -> Exp -> TermM Exp require _why pts e = do @@ -449,17 +448,16 @@ checkPat' (Wildcard _ loc) (Ascribed t) = checkPat' (Wildcard NoInfo loc) NoneInferred = do t <- newType loc "t" pure $ Wildcard (Info t) loc -checkPat' p@(TuplePat ps loc) (Ascribed t) +checkPat' (TuplePat ps loc) (Ascribed t) | Just ts <- isTupleRecord t, length ts == length ps = TuplePat <$> zipWithM checkPat' ps (map Ascribed ts) <*> pure loc | otherwise = do - ps_t :: [Type] <- replicateM (length ps) (newType loc "t") - ctEq (Scalar (tupleRecord ps_t)) (toType t) - st <- asStructType loc $ Scalar $ tupleRecord ps_t - checkPat' p $ Ascribed $ toParam Observe st + ps_t :: [ParamType] <- replicateM (length ps) (newType loc "t") + ctEq (toType (Scalar (tupleRecord ps_t))) (toType t) + TuplePat <$> zipWithM checkPat' ps (map Ascribed ps_t) <*> pure loc checkPat' (TuplePat ps loc) NoneInferred = TuplePat <$> mapM (`checkPat'` NoneInferred) ps <*> pure loc checkPat' p@(RecordPat p_fs loc) (Ascribed t) @@ -924,14 +922,16 @@ checkExp (AppExp (LetFun name (tparams, params, retdecl, NoInfo, e) body loc) _) (Info $ AppRes (typeOf body') []) -- checkExp (AppExp (Range start maybe_step end loc) _) = do - start' <- checkExp' start - maybe_step' <- traverse checkExp' maybe_step - end' <- traverse checkExp' end - range_t <- newType loc "range" - ctEq (toType range_t) (arrayOfRank 1 (expType start')) - pure $ AppExp (Range start' maybe_step' end' loc) $ Info $ AppRes range_t [] - where - checkExp' = require "use in range expression" anyIntType <=< checkExp + start' <- require "use in range expression" anyIntType =<< checkExp start + let check e = do + e' <- checkExp e + ctEq (expType start') (expType e') + pure e' + maybe_step' <- traverse check maybe_step + end' <- traverse check end + range_t <- newTyVar loc "range" + ctEq (tyVarType range_t :: Type) (arrayOfRank 1 (expType start')) + pure $ AppExp (Range start' maybe_step' end' loc) $ Info $ AppRes (tyVarType range_t) [] -- checkExp (Project k e NoInfo loc) = do e' <- checkExp e @@ -1074,7 +1074,7 @@ checkValDef (fname, retdecl, tparams, params, body, _loc) = runTermM $ do tyvars <- gets termTyVars - traceM $ "# function " <> prettyNameString fname + traceM $ "\n# function " <> prettyNameString fname <> "\n" vns <- gets termNameSource From 6a958f0a558d49f4ce40ce0b0a84b98271b16c64 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Mon, 12 Feb 2024 11:05:19 -0800 Subject: [PATCH 047/258] Add options to set all ranks to zero for debugging. --- src/Language/Futhark/TypeChecker/Rank.hs | 12 ++++++++---- src/Language/Futhark/TypeChecker/Terms2.hs | 4 +++- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 7582b27157..ed3c0e24d1 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -134,11 +134,15 @@ mkLinearProg counter cs tyVars = mapM_ (uncurry addTyVarInfo) $ M.toList tyVars finalState = flip execState initState $ runRankM buildLP -rankAnalysis :: VNameSource -> Int -> [Ct] -> TyVars -> Maybe ([Ct], TyVars, VNameSource, Int) -rankAnalysis vns counter cs tyVars = do +rankAnalysis :: Bool -> VNameSource -> Int -> [Ct] -> TyVars -> Maybe ([Ct], TyVars, VNameSource, Int) +rankAnalysis debug_zero_ranks vns counter cs tyVars = do traceM $ unlines ["## rankAnalysis prog", prettyString prog] - (_size, ranks) <- branchAndBound lp - let rank_map = (fromJust . (ranks V.!?)) <$> inv_var_map + rank_map <- + if debug_zero_ranks + then pure $ fmap (const 0) inv_var_map + else do + (_size, ranks) <- branchAndBound lp + pure $ (fromJust . (ranks V.!?)) <$> inv_var_map traceM $ unlines $ "## rank map" : map prettyString (M.toList rank_map) let initEnv = SubstEnv diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 0d7ec4e01d..7021afd3d3 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -1078,7 +1078,9 @@ checkValDef (fname, retdecl, tparams, params, body, _loc) = runTermM $ do vns <- gets termNameSource - case rankAnalysis vns counter cts tyvars of + let debug_zero_ranks = True + + case rankAnalysis debug_zero_ranks vns counter cts tyvars of Nothing -> error "" Just (cts', tyvars', vns', counter') -> do modify $ \s -> s {termCounter = counter', termNameSource = vns'} From 989e1e30558f1bb038f3747638d3d71ae45c1b9a Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Mon, 12 Feb 2024 12:06:20 -0800 Subject: [PATCH 048/258] Use PuLP instead of setting ranks to zero. --- futhark.cabal | 2 + src/Futhark/Solve/LP.hs | 4 +- src/Language/Futhark/TypeChecker/Rank.hs | 44 +++++++++++++++++++--- src/Language/Futhark/TypeChecker/Terms2.hs | 10 ++++- 4 files changed, 51 insertions(+), 9 deletions(-) diff --git a/futhark.cabal b/futhark.cabal index ded1234442..c73b997f71 100644 --- a/futhark.cabal +++ b/futhark.cabal @@ -493,6 +493,8 @@ library , mwc-random , prettyprinter >= 1.7 , prettyprinter-ansi-terminal >= 1.1 + -- remove me later + , process executable futhark import: common diff --git a/src/Futhark/Solve/LP.hs b/src/Futhark/Solve/LP.hs index af9265f458..a2b625a5e0 100644 --- a/src/Futhark/Solve/LP.hs +++ b/src/Futhark/Solve/LP.hs @@ -172,7 +172,9 @@ linearProgToPulp prog = show (prettyName v) <> " = " <> "LpVariable(" - <> show (show (prettyName v)) + <> "'" + <> show (prettyName v) + <> "_'" <> ", lowBound = 0, cat = 'Integer')" ) prog_vars diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index ed3c0e24d1..83d7d5dce8 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -2,6 +2,7 @@ module Language.Futhark.TypeChecker.Rank (rankAnalysis) where import Control.Monad.Reader import Control.Monad.State +import Data.List qualified as L import Data.Map (Map) import Data.Map qualified as M import Data.Maybe @@ -16,6 +17,8 @@ import Futhark.Solve.Simplex import Language.Futhark hiding (ScalarType) import Language.Futhark.TypeChecker.Constraints import Language.Futhark.TypeChecker.Monad (mkTypeVarName) +import System.IO.Unsafe +import System.Process type LSum = LP.LSum VName Double @@ -135,11 +138,16 @@ mkLinearProg counter cs tyVars = finalState = flip execState initState $ runRankM buildLP rankAnalysis :: Bool -> VNameSource -> Int -> [Ct] -> TyVars -> Maybe ([Ct], TyVars, VNameSource, Int) -rankAnalysis debug_zero_ranks vns counter cs tyVars = do +rankAnalysis use_python vns counter cs tyVars = do traceM $ unlines ["## rankAnalysis prog", prettyString prog] rank_map <- - if debug_zero_ranks - then pure $ fmap (const 0) inv_var_map + if use_python + then do + -- traceM $ linearProgToPulp prog + parseRes $ + unsafePerformIO $ + readProcess "python" [] $ + linearProgToPulp prog else do (_size, ranks) <- branchAndBound lp pure $ (fromJust . (ranks V.!?)) <$> inv_var_map @@ -181,6 +189,29 @@ rankAnalysis debug_zero_ranks vns counter cs tyVars = do (lp, var_map) = linearProgToLP prog inv_var_map = M.fromListWith (error "oh no!") [(v, k) | (k, v) <- M.toList var_map] + rm_subscript x = fromMaybe x $ lookup x $ zip "₀₁₂₃₄₅₆₇₈₉" "0123456789" + vname_to_pulp_var = M.mapWithKey (\k _ -> map rm_subscript $ show $ prettyName k) inv_var_map + pulp_var_to_vname = + M.fromListWith (error "oh no!") [(v, k) | (k, v) <- M.toList vname_to_pulp_var] + + parseRes :: String -> Maybe (Map VName Int) + parseRes s = do + (status : vars) <- trimToStart $ lines s + if not (success status) + then Nothing + else do + pure $ M.fromList $ catMaybes $ map readVar vars + where + trimToStart [] = Nothing + trimToStart (l : ls) + | "status" `L.isPrefixOf` l = Just (l : ls) + | otherwise = trimToStart ls + success l = + (read $ drop (length ("status: " :: [Char])) l) == (1 :: Int) + readVar xs = + let (v, _ : value) = L.span (/= ':') xs + in Just (fromJust $ pulp_var_to_vname M.!? v, read value) + newtype SubstM a = SubstM (StateT SubstState (Reader SubstEnv) a) deriving (Functor, Applicative, Monad, MonadState SubstState, MonadReader SubstEnv) @@ -233,19 +264,20 @@ newTyVar t = do rankToShape :: VName -> SubstM (Shape SComp) rankToShape x = do rs <- asks envRanks - pure $ Shape $ replicate (rs M.! x) SDim + pure $ Shape $ replicate (fromJust $ rs M.!? x) SDim addRankInfo :: TyVar -> SubstM () addRankInfo t = do rs <- asks envRanks - unless (rs M.! t == 0) $ do + -- unless (fromMaybe (error $ prettyString t) (rs M.!? t) == 0) $ do + unless (fromMaybe 0 (rs M.!? t) == 0) $ do new_vars <- gets substNewVars maybe new_var (const $ pure ()) $ new_vars M.!? t where new_var = do t' <- newTyVar t old_tyvars <- asks envTyVars - let info = old_tyvars M.! t + let info = fromJust $ old_tyvars M.!? t modify $ \s -> s {substTyVars = M.insert t' info $ substTyVars s} modify $ \s -> s {substTyVars = M.insert t (fst info, TyVarFree) $ substTyVars s} diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 7021afd3d3..8fa0bf765a 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -1078,9 +1078,15 @@ checkValDef (fname, retdecl, tparams, params, body, _loc) = runTermM $ do vns <- gets termNameSource - let debug_zero_ranks = True + let use_python = True - case rankAnalysis debug_zero_ranks vns counter cts tyvars of + traceM $ + unlines + [ "## cts:", + unlines $ map prettyString cts + ] + + case rankAnalysis use_python vns counter cts tyvars of Nothing -> error "" Just (cts', tyvars', vns', counter') -> do modify $ \s -> s {termCounter = counter', termNameSource = vns'} From 3c2e319b1928f6acfb6162ef7f13d380c806a0bc Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Mon, 12 Feb 2024 12:16:47 -0800 Subject: [PATCH 049/258] Add PuLP stuff. --- shell.nix | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/shell.nix b/shell.nix index cf8424a2be..d5199b0c02 100644 --- a/shell.nix +++ b/shell.nix @@ -2,6 +2,20 @@ let sources = import ./nix/sources.nix; pkgs = import sources.nixpkgs {}; + pps = ps: with ps; [ + ( + buildPythonPackage rec { + pname = "PuLP"; + version = "2.7.0"; + src = fetchPypi { + inherit pname version; + sha256 = "sha256-5z7msy1jnJuM9LSt7TNLoVi+X4MTVE4Fb3lqzgoQrmM="; + }; + doCheck = false; + } + ) + ]; + python = pkgs.python3.withPackages pps; in pkgs.stdenv.mkDerivation { name = "futhark"; @@ -38,6 +52,10 @@ pkgs.stdenv.mkDerivation { python3Packages.sphinx python3Packages.sphinxcontrib-bibtex imagemagick # needed for literate tests + # remove (needed for PuLP) + python + cbc + glpk ] ++ lib.optionals (stdenv.isLinux) [ opencl-headers From 5ef513d89d8f879fae436ec37dde09b215560318 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 12 Feb 2024 21:45:32 +0100 Subject: [PATCH 050/258] Improve Solution. --- .../Futhark/TypeChecker/Constraints.hs | 21 +++++++------------ src/Language/Futhark/TypeChecker/Terms.hs | 13 +++++++----- src/Language/Futhark/TypeChecker/Terms2.hs | 3 ++- 3 files changed, 18 insertions(+), 19 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 1886be827c..b504d3d663 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -117,33 +117,28 @@ substTyVars m (Scalar (Arrow u pname d t1 (RetType ext t2))) = substTyVars m (Array u shape elemt) = arrayOfWithAliases u shape $ substTyVars m $ Scalar elemt --- | A solution maps types to the set of type variables that must be --- substituted with this type, as well as its binding level. This --- slightly odd representation is needed to encode when two type +-- | A solution maps a type variable to its substitution, binding +-- level, and additional type variables that are linked to this type. +-- This slightly odd representation is needed to encode when two type -- variables are actually the same type. This matters when we start -- instanting the sizes of the type. -type Solution = M.Map Type (Int, [TyVar]) +type Solution = M.Map TyVar (Type, Int, [TyVar]) solution :: SolverState -> Solution solution s = - L.foldl' byType mempty $ + L.foldl' addLinks (M.mapMaybe mkSubst $ solverTyVars s) $ M.toList $ - L.foldl' addLinks (M.mapMaybe mkSubst $ solverTyVars s) $ - M.toList $ - solverTyVars s + solverTyVars s where - mkSubst (TyVarSol lvl t) = Just (lvl, (t, [])) + mkSubst (TyVarSol lvl t) = Just (t, lvl, []) mkSubst _ = Nothing addLinks m (v1, TyVarLink v2) = case M.lookup v2 $ solverTyVars s of Just (TyVarLink v3) -> addLinks m (v1, TyVarLink v3) _ -> case M.lookup v2 m of Nothing -> m - Just (t, (lvl, vs)) -> M.insert v2 (t, (lvl, v1 : vs)) m + Just (t, lvl, vs) -> M.insert v2 (t, lvl, v1 : vs) m addLinks m _ = m - byType m (v, (lvl, (t, vs))) = M.insertWith comb t (lvl, v : vs) m - where - comb (lvl1, ts1) (lvl2, ts2) = (min lvl1 lvl2, ts1 <> ts2) newtype SolveM a = SolveM {runSolveM :: StateT SolverState (Except T.Text) a} deriving (Functor, Applicative, Monad, MonadState SolverState, MonadError T.Text) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 816f334069..65c82ae0fd 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -1626,15 +1626,15 @@ arrayOfM loc t shape = do atLevel :: Int -> TermTypeM a -> TermTypeM a atLevel lvl = local $ \env -> env {termLevel = lvl} -addInitialConstraints :: M.Map (TypeBase () NoUniqueness) (Int, [VName]) -> TermTypeM () +addInitialConstraints :: Terms2.Solution -> TermTypeM () addInitialConstraints = mapM_ f . M.toList where addConstraint v lvl c = modifyConstraints $ M.insert v (lvl, c) usage = mkUsage (mempty :: Loc) - f (t, (lvl, vs)) = do + f (v, (t, lvl, vs)) = do (t', _) <- atLevel lvl $ allDimsFreshInType (usage (prettyText t)) Nonrigid "dv" t - forM_ vs $ \v -> - addConstraint v lvl $ Constraint (RetType [] t') $ usage $ prettyNameText v + forM_ (v : vs) $ \v' -> + addConstraint v' lvl $ Constraint (RetType [] t') $ usage $ prettyNameText v' -- | Type-check a top-level (or module-level) function definition. -- Despite the name, this is also used for checking constant @@ -1660,7 +1660,10 @@ checkFunDef (fname, retdecl, tparams, params, body, loc) = do case maybe_tysubsts of Left err -> typeError loc mempty $ pretty err Right tysubsts -> runTermTypeM checkExp $ do - addInitialConstraints $ M.mapKeys (first $ const ()) tysubsts + addInitialConstraints tysubsts + + traceM $ unlines $ map prettyString params + traceM $ prettyString body' (tparams', params'', retdecl'', RetType dims rettype', body'') <- checkBinding (fname, retdecl', tparams, params', body', loc) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 8fa0bf765a..d6e95785a7 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -37,6 +37,7 @@ -- inference, perhaps we can do it in a post-inference check. module Language.Futhark.TypeChecker.Terms2 ( checkValDef, + Solution, ) where @@ -1099,7 +1100,7 @@ checkValDef (fname, retdecl, tparams, params, body, _loc) = runTermM $ do "## tyvars:", unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars', "## solution:", - let p (t, (lvl, vs)) = unwords (show [lvl] : map prettyNameString vs) <> " => " <> prettyString t + let p (v, (t, lvl, vs)) = unwords (show [lvl] : map prettyNameString (v : vs)) <> " => " <> prettyString t in either T.unpack (unlines . map p . M.toList) solution ] pure (solution, params', retdecl', body') From 678633bf18b56008b684695e2ac989d2930f4689 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 12 Feb 2024 22:00:41 +0100 Subject: [PATCH 051/258] Preserve types better. --- src/Language/Futhark/TypeChecker/Terms.hs | 1 - src/Language/Futhark/TypeChecker/Terms/Pat.hs | 19 ++++++++----------- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 65c82ae0fd..a5073ce194 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -1662,7 +1662,6 @@ checkFunDef (fname, retdecl, tparams, params, body, loc) = do Right tysubsts -> runTermTypeM checkExp $ do addInitialConstraints tysubsts - traceM $ unlines $ map prettyString params traceM $ prettyString body' (tparams', params'', retdecl'', RetType dims rettype', body'') <- diff --git a/src/Language/Futhark/TypeChecker/Terms/Pat.hs b/src/Language/Futhark/TypeChecker/Terms/Pat.hs index 16ad00f710..3c80a3b5ca 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Pat.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Pat.hs @@ -115,9 +115,9 @@ bindingIdent ident = binding [ident] checkPat' :: [(SizeBinder VName, QualName VName)] -> - Pat ParamType -> + Pat (TypeBase Size u) -> Inferred ParamType -> - TermTypeM (Pat ParamType) + TermTypeM (Pat (TypeBase Size u)) checkPat' sizes (PatParens p loc) t = PatParens <$> checkPat' sizes p t <*> pure loc checkPat' sizes (PatAttr attr p loc) t = @@ -151,11 +151,6 @@ checkPat' sizes p@(RecordPat p_fs loc) (Ascribed t) RecordPat . M.toList <$> check t_fs <*> pure loc | otherwise = do p_fs' <- traverse (const $ newTypeVar loc "t") $ M.fromList p_fs - - when (sort (M.keys p_fs') /= sort (map fst p_fs)) $ - typeError loc mempty $ - "Duplicate fields in record pattern" <+> pretty p <> "." - unify (mkUsage loc "matching a record pattern") (Scalar (Record p_fs')) (toStruct t) checkPat' sizes p $ Ascribed $ toParam Observe $ Scalar (Record p_fs') where @@ -193,12 +188,12 @@ checkPat :: [(SizeBinder VName, QualName VName)] -> Pat (TypeBase Size u) -> Inferred StructType -> - (Pat ParamType -> TermTypeM a) -> + (Pat (TypeBase Size u) -> TermTypeM a) -> TermTypeM a checkPat sizes p t m = do p' <- onFailure (CheckingPat (fmap toStruct p) t) $ - checkPat' sizes (fmap (toParam Observe) p) (fmap (toParam Observe) t) + checkPat' sizes p (fmap (toParam Observe) t) let explicit = mustBeExplicitInType $ patternStructType p' @@ -216,7 +211,7 @@ bindingPat :: [SizeBinder VName] -> Pat (TypeBase Size u) -> StructType -> - (Pat ParamType -> TermTypeM a) -> + (Pat (TypeBase Size u) -> TermTypeM a) -> TermTypeM a bindingPat sizes p t m = do substs <- mapM mkSizeSubst sizes @@ -240,7 +235,9 @@ bindingParams :: bindingParams tps orig_ps m = bindingTypeParams tps $ do let descend ps' (p : ps) = checkPat [] p NoneInferred $ \p' -> - binding (patIdents $ fmap toStruct p') $ incLevel $ descend (p' : ps') ps + binding (patIdents $ fmap toStruct p') $ + incLevel $ + descend (p' : ps') ps descend ps' [] = m $ reverse ps' incLevel $ descend [] orig_ps From c73e681ed03be44ee566f87750b2e4b45ae351db Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 14 Feb 2024 17:33:58 +0100 Subject: [PATCH 052/258] WIP on type checker integration. --- src/Language/Futhark/TypeChecker/Terms.hs | 26 +-- .../Futhark/TypeChecker/Terms/Monad.hs | 150 ++++++++++-------- src/Language/Futhark/TypeChecker/Terms2.hs | 2 +- 3 files changed, 91 insertions(+), 87 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index a5073ce194..abf797470c 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -1009,7 +1009,7 @@ checkApply loc (fname, prev_applied) ftype argexp = do -- turn out to be polymorphic, in which case the list of type -- parameters will be non-empty. checkOneExp :: ExpBase NoInfo VName -> TypeM ([TypeParam], Exp) -checkOneExp e = runTermTypeM checkExp $ do +checkOneExp e = runTermTypeM checkExp mempty $ do e' <- checkExp $ undefined e let t = typeOf e' (tparams, _, _) <- @@ -1023,7 +1023,7 @@ checkOneExp e = runTermTypeM checkExp $ do -- | Type-check a single size expression in isolation. This expression may -- turn out to be polymorphic, in which case it is unified with i64. checkSizeExp :: ExpBase NoInfo VName -> TypeM Exp -checkSizeExp e = runTermTypeM checkExp $ do +checkSizeExp e = runTermTypeM checkExp mempty $ do e' <- checkExp $ undefined e let t = typeOf e' when (hasBinding e') $ @@ -1621,21 +1621,6 @@ arrayOfM loc t shape = do arrayElemType (mkUsage loc "use as array element") "type used in array" t pure $ arrayOf shape t --- A hack we need to create size variables for types at the right --- level. -atLevel :: Int -> TermTypeM a -> TermTypeM a -atLevel lvl = local $ \env -> env {termLevel = lvl} - -addInitialConstraints :: Terms2.Solution -> TermTypeM () -addInitialConstraints = mapM_ f . M.toList - where - addConstraint v lvl c = modifyConstraints $ M.insert v (lvl, c) - usage = mkUsage (mempty :: Loc) - f (v, (t, lvl, vs)) = do - (t', _) <- atLevel lvl $ allDimsFreshInType (usage (prettyText t)) Nonrigid "dv" t - forM_ (v : vs) $ \v' -> - addConstraint v' lvl $ Constraint (RetType [] t') $ usage $ prettyNameText v' - -- | Type-check a top-level (or module-level) function definition. -- Despite the name, this is also used for checking constant -- definitions, by treating them as 0-ary functions. @@ -1657,11 +1642,12 @@ checkFunDef :: checkFunDef (fname, retdecl, tparams, params, body, loc) = do (maybe_tysubsts, params', retdecl', body') <- Terms2.checkValDef (fname, retdecl, tparams, params, body, loc) + let adjust = M.fromList . concatMap f . M.toList + where + f (v, (t, _, vs)) = map (,first (const ()) t) (v : vs) case maybe_tysubsts of Left err -> typeError loc mempty $ pretty err - Right tysubsts -> runTermTypeM checkExp $ do - addInitialConstraints tysubsts - + Right tysubsts -> runTermTypeM checkExp (adjust tysubsts) $ do traceM $ prettyString body' (tparams', params'', retdecl'', RetType dims rettype', body'') <- diff --git a/src/Language/Futhark/TypeChecker/Terms/Monad.hs b/src/Language/Futhark/TypeChecker/Terms/Monad.hs index f5ca46c85c..b40e377530 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Monad.hs @@ -50,17 +50,20 @@ import Control.Monad import Control.Monad.Except import Control.Monad.Reader import Control.Monad.State.Strict +import Data.Bifunctor import Data.Bitraversable -import Data.Char (isAscii) +import Data.Foldable import Data.Map.Strict qualified as M import Data.Maybe import Data.Set qualified as S import Data.Text qualified as T +import Debug.Trace import Futhark.FreshNames hiding (newName) import Futhark.FreshNames qualified import Futhark.Util.Pretty hiding (space) import Language.Futhark import Language.Futhark.Traversals +import Language.Futhark.TypeChecker.Constraints (TyVar) import Language.Futhark.TypeChecker.Monad hiding (BoundV, lookupMod, stateNameSource) import Language.Futhark.TypeChecker.Monad qualified as TypeM import Language.Futhark.TypeChecker.Types @@ -198,6 +201,7 @@ data TermEnv = TermEnv termLevel :: Level, termCheckExp :: ExpBase Info VName -> TermTypeM Exp, termOuterEnv :: Env, + termTyVars :: M.Map TyVar (TypeBase () NoUniqueness), termImportName :: ImportName } @@ -347,43 +351,59 @@ instance MonadUnify TermTypeM where indent 2 (pretty t2) "do not match." --- | Create a new type name and insert it (unconstrained) in the --- substitution map. -instantiateTypeParam :: - (Monoid as) => - QualName VName -> - SrcLoc -> - TypeParam -> - TermTypeM (VName, Subst (RetTypeBase dim as)) -instantiateTypeParam qn loc tparam = do - i <- incCounter - let name = nameFromString (takeWhile isAscii (baseString (typeParamName tparam))) - v <- newID $ mkTypeVarName name i - case tparam of - TypeParamType x _ _ -> do - constrain v . NoConstraint x . mkUsage loc . docText $ - "instantiated type parameter of " <> dquotes (pretty qn) - pure (v, Subst [] $ RetType [] $ Scalar $ TypeVar mempty (qualName v) []) - TypeParamDim {} -> do - constrain v . Size Nothing . mkUsage loc . docText $ - "instantiated size parameter of " <> dquotes (pretty qn) - pure (v, ExpSubst $ sizeFromName (qualName v) loc) - --- | Instantiate a type scheme with fresh type variables for its type --- parameters. Returns the names of the fresh type variables, the --- instance list, and the instantiated type. -instantiateTypeScheme :: +replaceTyVars :: SrcLoc -> TypeBase () NoUniqueness -> StructType -> TermTypeM StructType +replaceTyVars loc orig_t1 orig_t2 = do + tyvars <- asks termTyVars + let f :: (Monoid u) => TypeBase () u' -> TypeBase Size u -> TermTypeM (TypeBase Size u) + f + (Scalar (TypeVar _ (QualName [] v1) [])) + t2 + | Just t <- M.lookup v1 tyvars = + f t t2 + | otherwise = + pure $ Scalar (TypeVar (fold t2) (QualName [] v1) []) + f (Scalar (Record fs1)) (Scalar (Record fs2)) = + Scalar . Record <$> sequence (M.intersectionWith f fs1 fs2) + f (Scalar (Sum fs1)) (Scalar (Sum fs2)) = + Scalar . Sum <$> sequence (M.intersectionWith (zipWithM f) fs1 fs2) + f + (Scalar (Arrow _ _ _ t1a (RetType _ t1r))) + (Scalar (Arrow u pname d t2a (RetType ext t2r))) = do + ta <- f t1a t2a + tr <- f t1r t2r + pure $ Scalar $ Arrow u pname d ta $ RetType ext tr + f + (Array _ (Shape (() : ds1)) t1) + (Array u (Shape (d : ds2)) t2) = + arrayOfWithAliases u (Shape [d]) + <$> f (arrayOf (Shape ds1) (Scalar t1)) (arrayOf (Shape ds2) (Scalar t2)) + f _ t2 = pure t2 + f orig_t1 orig_t2 + +-- | Instantiate a type scheme with fresh size variables for its size +-- parameters. Replaces type parameters with their known +-- instantiations. Returns the names of the fresh size variables and +-- the instantiated type. +instTypeScheme :: QualName VName -> SrcLoc -> [TypeParam] -> StructType -> + TypeBase () NoUniqueness -> TermTypeM ([VName], StructType) -instantiateTypeScheme qn loc tparams t = do - let tnames = map typeParamName tparams - (tparam_names, tparam_substs) <- mapAndUnzipM (instantiateTypeParam qn loc) tparams - let substs = M.fromList $ zip tnames tparam_substs - t' = applySubst (`M.lookup` substs) t - pure (tparam_names, t') +instTypeScheme qn loc tparams scheme_t inferred = do + (names, substs) <- fmap (unzip . catMaybes) $ + forM tparams $ \tparam -> do + case tparam of + TypeParamType {} -> pure Nothing + TypeParamDim v _ -> do + constrain v . Size Nothing . mkUsage loc . docText $ + "instantiated size parameter of " <> dquotes (pretty qn) + pure $ Just (v, (v, ExpSubst $ sizeFromName (qualName v) loc)) + + t' <- replaceTyVars loc inferred $ applySubst (`lookup` substs) scheme_t + + pure (names, t') lookupQualNameEnv :: QualName VName -> TermTypeM TermScope lookupQualNameEnv (QualName [q] _) @@ -453,36 +473,33 @@ lookupVar loc qn@(QualName qs name) t = do scope <- lookupQualNameEnv qn let usage = mkUsage loc $ docText $ "use of " <> dquotes (pretty qn) - t' <- - case M.lookup name $ scopeVtable scope of - Nothing -> - error $ "lookupVar: " <> show qn - Just (BoundV tparams bound_t) -> do - when (null qs) . modify $ \s -> - s {stateUsed = S.insert name $ stateUsed s} - if null tparams && null qs - then pure t - else do - (tnames, t') <- instantiateTypeScheme qn loc tparams bound_t - outer_env <- asks termOuterEnv - pure $ qualifyTypeVars outer_env tnames qs t' - Just EqualityF -> do - argtype <- newTypeVar loc "t" - equalityType usage argtype - pure $ - Scalar . Arrow mempty Unnamed Observe argtype . RetType [] $ - Scalar $ - Arrow mempty Unnamed Observe argtype $ - RetType [] $ - Scalar $ - Prim Bool - Just (OverloadedF ts pts rt) -> do - argtype <- newTypeVar loc "t" - mustBeOneOf ts usage argtype - let (pts', rt') = instOverloaded argtype pts rt - pure $ foldFunType (map (toParam Observe) pts') $ RetType [] $ toRes Nonunique rt' - unify (mkUsage loc "inferred rank type") t' t - pure t' + case M.lookup name $ scopeVtable scope of + Nothing -> + error $ "lookupVar: " <> show qn + Just (BoundV tparams bound_t) -> do + when (null qs) . modify $ \s -> + s {stateUsed = S.insert name $ stateUsed s} + if null tparams && null qs + then pure t + else do + (tnames, t') <- instTypeScheme qn loc tparams bound_t $ first (const ()) t + outer_env <- asks termOuterEnv + pure $ qualifyTypeVars outer_env tnames qs t' + Just EqualityF -> do + argtype <- newTypeVar loc "t" + equalityType usage argtype + pure $ + Scalar . Arrow mempty Unnamed Observe argtype . RetType [] $ + Scalar $ + Arrow mempty Unnamed Observe argtype $ + RetType [] $ + Scalar $ + Prim Bool + Just (OverloadedF ts pts rt) -> do + argtype <- newTypeVar loc "t" + mustBeOneOf ts usage argtype + let (pts', rt') = instOverloaded argtype pts rt + pure $ foldFunType (map (toParam Observe) pts') $ RetType [] $ toRes Nonunique rt' where instOverloaded argtype pts rt = ( map (maybe (toStruct argtype) (Scalar . Prim)) pts, @@ -639,8 +656,8 @@ initialTermScope = Just (name, EqualityF) addIntrinsicF _ = Nothing -runTermTypeM :: (ExpBase Info VName -> TermTypeM Exp) -> TermTypeM a -> TypeM a -runTermTypeM checker (TermTypeM m) = do +runTermTypeM :: (ExpBase Info VName -> TermTypeM Exp) -> M.Map TyVar (TypeBase () NoUniqueness) -> TermTypeM a -> TypeM a +runTermTypeM checker tyvars (TermTypeM m) = do initial_scope <- (initialTermScope <>) . envToTermScope <$> askEnv name <- askImportName outer_env <- askEnv @@ -652,7 +669,8 @@ runTermTypeM checker (TermTypeM m) = do termLevel = 0, termCheckExp = checker, termImportName = name, - termOuterEnv = outer_env + termOuterEnv = outer_env, + termTyVars = tyvars } initial_state = TermTypeState diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index d6e95785a7..804fcf9c48 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -362,7 +362,7 @@ instTypeScheme :: TermM ([VName], StructType) instTypeScheme _qn loc tparams t = do (names, substs) <- fmap (unzip . catMaybes) $ - forM tparams $ \tparam -> do + forM tparams $ \tparam -> case tparam of TypeParamType _ v _ -> do v' <- newTyVar loc $ nameFromString $ takeWhile isAscii $ baseString v From fb42c3245cc122851ca279f739e57584138dfbb3 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 15 Feb 2024 11:24:16 +0100 Subject: [PATCH 053/258] Now we get quite far in type checking! --- .../Futhark/TypeChecker/Constraints.hs | 28 +++----- src/Language/Futhark/TypeChecker/Rank.hs | 1 + src/Language/Futhark/TypeChecker/Terms.hs | 72 +++++++++---------- .../Futhark/TypeChecker/Terms/Monad.hs | 35 +++++---- src/Language/Futhark/TypeChecker/Terms/Pat.hs | 12 +++- src/Language/Futhark/TypeChecker/Terms2.hs | 15 +++- 6 files changed, 89 insertions(+), 74 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index b504d3d663..b6920c0f5b 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -106,7 +106,7 @@ substTyVars m t@(Scalar (TypeVar u (QualName qs v) args)) = case M.lookup v m of Just (TyVarLink v') -> substTyVars m $ Scalar $ TypeVar u (QualName qs v') args - Just (TyVarSol _ t') -> second (const mempty) t' + Just (TyVarSol _ t') -> second (const mempty) $ substTyVars m t' Just (TyVarUnsol {}) -> t Nothing -> t substTyVars _ (Scalar (Prim pt)) = Scalar $ Prim pt @@ -117,28 +117,20 @@ substTyVars m (Scalar (Arrow u pname d t1 (RetType ext t2))) = substTyVars m (Array u shape elemt) = arrayOfWithAliases u shape $ substTyVars m $ Scalar elemt --- | A solution maps a type variable to its substitution, binding --- level, and additional type variables that are linked to this type. --- This slightly odd representation is needed to encode when two type --- variables are actually the same type. This matters when we start --- instanting the sizes of the type. -type Solution = M.Map TyVar (Type, Int, [TyVar]) +-- | A solution maps a type variable to its substitution. This substitution is complete, in the sense there are no right-hand sides that contain a type variable. +type Solution = M.Map TyVar (TypeBase () NoUniqueness) solution :: SolverState -> Solution solution s = - L.foldl' addLinks (M.mapMaybe mkSubst $ solverTyVars s) $ - M.toList $ - solverTyVars s + M.mapMaybe mkSubst $ + solverTyVars s where - mkSubst (TyVarSol lvl t) = Just (t, lvl, []) + mkSubst (TyVarSol _lvl t) = Just $ first (const ()) $ substTyVars (solverTyVars s) t + mkSubst (TyVarLink v') = mkSubst =<< M.lookup v' (solverTyVars s) + mkSubst (TyVarUnsol _ (TyVarPrim pts)) + | Signed Int32 `elem` pts = + Just (Scalar (Prim (Signed Int32))) -- XXX - we need warnings and things! mkSubst _ = Nothing - addLinks m (v1, TyVarLink v2) = - case M.lookup v2 $ solverTyVars s of - Just (TyVarLink v3) -> addLinks m (v1, TyVarLink v3) - _ -> case M.lookup v2 m of - Nothing -> m - Just (t, lvl, vs) -> M.insert v2 (t, lvl, v1 : vs) m - addLinks m _ = m newtype SolveM a = SolveM {runSolveM :: StateT SolverState (Except T.Text) a} deriving (Functor, Applicative, Monad, MonadState SolverState, MonadError T.Text) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 83d7d5dce8..56d1be4e56 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -138,6 +138,7 @@ mkLinearProg counter cs tyVars = finalState = flip execState initState $ runRankM buildLP rankAnalysis :: Bool -> VNameSource -> Int -> [Ct] -> TyVars -> Maybe ([Ct], TyVars, VNameSource, Int) +rankAnalysis _ vns counter [] tyVars = Just ([], tyVars, vns, counter) rankAnalysis use_python vns counter cs tyVars = do traceM $ unlines ["## rankAnalysis prog", prettyString prog] rank_map <- diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index abf797470c..01b8cfdc3d 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -352,14 +352,17 @@ unscopeType tloc unscoped = checkExp :: Exp -> TermTypeM Exp checkExp (Literal val loc) = pure $ Literal val loc -checkExp (Hole info loc) = - pure $ Hole info loc +checkExp (Hole (Info t) loc) = do + t' <- replaceTyVars loc t t + pure $ Hole (Info t') loc checkExp (StringLit vs loc) = pure $ StringLit vs loc -checkExp (IntLit val info loc) = - pure $ IntLit val info loc -checkExp (FloatLit val info loc) = - pure $ FloatLit val info loc +checkExp (IntLit val (Info t) loc) = do + t' <- replaceTyVars loc t t + pure $ IntLit val (Info t') loc +checkExp (FloatLit val (Info t) loc) = do + t' <- replaceTyVars loc t t + pure $ FloatLit val (Info t') loc checkExp (TupLit es loc) = TupLit <$> mapM checkExp es <*> pure loc checkExp (RecordLit fs loc) = @@ -662,14 +665,14 @@ checkExp (Assert e1 e2 _ loc) = do e1' <- require "being asserted" [Bool] =<< checkExp e1 e2' <- checkExp e2 pure $ Assert e1' e2' (Info (prettyText e1)) loc -checkExp (Lambda params body rettype_te _ loc) = do +checkExp (Lambda params body rettype_te (Info (RetType _ rt)) loc) = do (params', body', rettype', RetType dims ty) <- incLevel . bindingParams [] params $ \params' -> do rettype_checked <- traverse checkTypeExpNonrigid rettype_te - let declared_rettype = - case rettype_checked of - Just (_, st, _) -> Just st - Nothing -> Nothing + declared_rettype <- + case rettype_checked of + Just (_, st, _) -> Just <$> replaceTyVars loc rt st + Nothing -> pure Nothing body' <- checkFunBody params' body declared_rettype loc body_t <- expTypeFully body' @@ -677,11 +680,13 @@ checkExp (Lambda params body rettype_te _ loc) = do (rettype', rettype_st) <- case rettype_checked of - Just (te, st, ext) -> - pure (Just te, RetType ext st) + Just (te, st, ext) -> do + st' <- replaceTyVars loc rt st + pure (Just te, RetType ext st') Nothing -> do - ret <- inferReturnSizes params'' $ toRes Nonunique body_t - pure (Nothing, ret) + RetType ext ret <- inferReturnSizes params'' $ toRes Nonunique body_t + ret' <- replaceTyVars loc rt ret + pure (Nothing, RetType ext ret') pure (params'', body', rettype', rettype_st) @@ -851,14 +856,10 @@ instance Pretty (Unmatched (Pat StructType)) where checkSlice :: SliceBase Info VName -> TermTypeM [DimIndex] checkSlice = mapM checkDimIndex where - checkDimIndex (DimFix i) = do - DimFix <$> (require "use as index" anySignedType =<< checkExp i) + checkDimIndex (DimFix i) = + DimFix <$> checkExp i checkDimIndex (DimSlice i j s) = - DimSlice <$> check i <*> check j <*> check s - - check = - maybe (pure Nothing) $ - fmap Just . unifies "use as index" (Scalar $ Prim $ Signed Int64) <=< checkExp + DimSlice <$> traverse checkExp i <*> traverse checkExp j <*> traverse checkExp s -- The number of dimensions affected by this slice (so the minimum -- rank of the array we are slicing). @@ -1023,14 +1024,18 @@ checkOneExp e = runTermTypeM checkExp mempty $ do -- | Type-check a single size expression in isolation. This expression may -- turn out to be polymorphic, in which case it is unified with i64. checkSizeExp :: ExpBase NoInfo VName -> TypeM Exp -checkSizeExp e = runTermTypeM checkExp mempty $ do - e' <- checkExp $ undefined e - let t = typeOf e' - when (hasBinding e') $ - typeError (srclocOf e') mempty . withIndexLink "size-expression-bind" $ - "Size expression with binding is forbidden." - unify (mkUsage e' "Size expression") t (Scalar (Prim (Signed Int64))) - updateTypes e' +checkSizeExp e = do + (maybe_tysubsts, e') <- Terms2.checkSingleExp e + case maybe_tysubsts of + Left err -> typeError e' mempty $ pretty err + Right tysubsts -> runTermTypeM checkExp tysubsts $ do + e'' <- checkExp e' + let t = typeOf e'' + when (hasBinding e'') $ + typeError (srclocOf e'') mempty . withIndexLink "size-expression-bind" $ + "Size expression with binding is forbidden." + unify (mkUsage e'' "Size expression") t (Scalar (Prim (Signed Int64))) + updateTypes e'' -- Verify that all sum type constructors and empty array literals have -- a size that is known (rigid or a type parameter). This is to @@ -1642,14 +1647,9 @@ checkFunDef :: checkFunDef (fname, retdecl, tparams, params, body, loc) = do (maybe_tysubsts, params', retdecl', body') <- Terms2.checkValDef (fname, retdecl, tparams, params, body, loc) - let adjust = M.fromList . concatMap f . M.toList - where - f (v, (t, _, vs)) = map (,first (const ()) t) (v : vs) case maybe_tysubsts of Left err -> typeError loc mempty $ pretty err - Right tysubsts -> runTermTypeM checkExp (adjust tysubsts) $ do - traceM $ prettyString body' - + Right tysubsts -> runTermTypeM checkExp tysubsts $ do (tparams', params'', retdecl'', RetType dims rettype', body'') <- checkBinding (fname, retdecl', tparams, params', body', loc) diff --git a/src/Language/Futhark/TypeChecker/Terms/Monad.hs b/src/Language/Futhark/TypeChecker/Terms/Monad.hs index b40e377530..2b7e2cf9fd 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Monad.hs @@ -25,6 +25,7 @@ module Language.Futhark.TypeChecker.Terms.Monad constrain, newArrayType, allDimsFreshInType, + replaceTyVars, updateTypes, Names, @@ -351,33 +352,37 @@ instance MonadUnify TermTypeM where indent 2 (pretty t2) "do not match." -replaceTyVars :: SrcLoc -> TypeBase () NoUniqueness -> StructType -> TermTypeM StructType +replaceTyVars :: + SrcLoc -> + TypeBase d u1 -> + TypeBase Size u2 -> + TermTypeM (TypeBase Size u1) replaceTyVars loc orig_t1 orig_t2 = do tyvars <- asks termTyVars - let f :: (Monoid u) => TypeBase () u' -> TypeBase Size u -> TermTypeM (TypeBase Size u) + let f :: TypeBase d u1 -> TypeBase Size u2 -> TermTypeM (TypeBase Size u1) f - (Scalar (TypeVar _ (QualName [] v1) [])) + (Scalar (TypeVar u (QualName [] v1) [])) t2 | Just t <- M.lookup v1 tyvars = - f t t2 - | otherwise = - pure $ Scalar (TypeVar (fold t2) (QualName [] v1) []) + f (second (const u) t) t2 f (Scalar (Record fs1)) (Scalar (Record fs2)) = Scalar . Record <$> sequence (M.intersectionWith f fs1 fs2) f (Scalar (Sum fs1)) (Scalar (Sum fs2)) = Scalar . Sum <$> sequence (M.intersectionWith (zipWithM f) fs1 fs2) f - (Scalar (Arrow _ _ _ t1a (RetType _ t1r))) - (Scalar (Arrow u pname d t2a (RetType ext t2r))) = do + (Scalar (Arrow u _ _ t1a (RetType _ t1r))) + (Scalar (Arrow _ pname d t2a (RetType ext t2r))) = do ta <- f t1a t2a tr <- f t1r t2r pure $ Scalar $ Arrow u pname d ta $ RetType ext tr f - (Array _ (Shape (() : ds1)) t1) - (Array u (Shape (d : ds2)) t2) = + (Array u (Shape (_ : ds1)) t1) + (Array _ (Shape (d : ds2)) t2) = arrayOfWithAliases u (Shape [d]) <$> f (arrayOf (Shape ds1) (Scalar t1)) (arrayOf (Shape ds2) (Scalar t2)) - f _ t2 = pure t2 + f t1 _ = + fst <$> allDimsFreshInType (mkUsage loc "instantiation") Nonrigid "dv" t1 + f orig_t1 orig_t2 -- | Instantiate a type scheme with fresh size variables for its size @@ -469,7 +474,7 @@ instance MonadTypeChecker TermTypeM where throwError $ TypeError (locOf loc) notes s lookupVar :: SrcLoc -> QualName VName -> StructType -> TermTypeM StructType -lookupVar loc qn@(QualName qs name) t = do +lookupVar loc qn@(QualName qs name) inst_t = do scope <- lookupQualNameEnv qn let usage = mkUsage loc $ docText $ "use of " <> dquotes (pretty qn) @@ -480,11 +485,11 @@ lookupVar loc qn@(QualName qs name) t = do when (null qs) . modify $ \s -> s {stateUsed = S.insert name $ stateUsed s} if null tparams && null qs - then pure t + then pure bound_t else do - (tnames, t') <- instTypeScheme qn loc tparams bound_t $ first (const ()) t + (tnames, t) <- instTypeScheme qn loc tparams bound_t $ first (const ()) inst_t outer_env <- asks termOuterEnv - pure $ qualifyTypeVars outer_env tnames qs t' + pure $ qualifyTypeVars outer_env tnames qs t Just EqualityF -> do argtype <- newTypeVar loc "t" equalityType usage argtype diff --git a/src/Language/Futhark/TypeChecker/Terms/Pat.hs b/src/Language/Futhark/TypeChecker/Terms/Pat.hs index 3c80a3b5ca..f8bd0a42f6 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Pat.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Pat.hs @@ -122,9 +122,17 @@ checkPat' sizes (PatParens p loc) t = PatParens <$> checkPat' sizes p t <*> pure loc checkPat' sizes (PatAttr attr p loc) t = PatAttr <$> checkAttr attr <*> checkPat' sizes p t <*> pure loc -checkPat' _ (Id name (Info t) loc) _ = +checkPat' _ (Id name (Info t) loc) NoneInferred = do + t' <- replaceTyVars loc (first (const ()) t) t + pure $ Id name (Info t') loc +checkPat' _ (Id name (Info t1) loc) (Ascribed t2) = do + t <- replaceTyVars loc (first (const ()) t1) t2 pure $ Id name (Info t) loc -checkPat' _ (Wildcard (Info t) loc) _ = +checkPat' _ (Wildcard (Info t) loc) NoneInferred = do + t' <- replaceTyVars loc (first (const ()) t) t + pure $ Wildcard (Info t') loc +checkPat' _ (Wildcard (Info t1) loc) (Ascribed t2) = do + t <- replaceTyVars loc (first (const ()) t1) t2 pure $ Wildcard (Info t) loc checkPat' sizes p@(TuplePat ps loc) (Ascribed t) | Just ts <- isTupleRecord t, diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 804fcf9c48..e4ed9daef9 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -37,6 +37,7 @@ -- inference, perhaps we can do it in a post-inference check. module Language.Futhark.TypeChecker.Terms2 ( checkValDef, + checkSingleExp, Solution, ) where @@ -1063,7 +1064,7 @@ checkValDef :: Maybe (TypeExp Exp VName), Exp ) -checkValDef (fname, retdecl, tparams, params, body, _loc) = runTermM $ do +checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do bindParams tparams params $ \params' -> do body' <- checkExp body @@ -1075,7 +1076,7 @@ checkValDef (fname, retdecl, tparams, params, body, _loc) = runTermM $ do tyvars <- gets termTyVars - traceM $ "\n# function " <> prettyNameString fname <> "\n" + traceM $ "\n# function " <> prettyNameString fname <> "\n# " <> locStr loc <> "\n" vns <- gets termNameSource @@ -1100,7 +1101,15 @@ checkValDef (fname, retdecl, tparams, params, body, _loc) = runTermM $ do "## tyvars:", unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars', "## solution:", - let p (v, (t, lvl, vs)) = unwords (show [lvl] : map prettyNameString (v : vs)) <> " => " <> prettyString t + let p (v, t) = prettyNameString v <> " => " <> prettyString t in either T.unpack (unlines . map p . M.toList) solution ] pure (solution, params', retdecl', body') + +checkSingleExp :: ExpBase NoInfo VName -> TypeM (Either T.Text Solution, Exp) +checkSingleExp e = runTermM $ do + e' <- checkExp e + cts <- gets termConstraints + tyvars <- gets termTyVars + let solution = solve cts tyvars + pure (solution, e') From 333ca8de3e2bbf7d8fc0ac39c33e19457ddcbf53 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 15 Feb 2024 15:00:17 +0100 Subject: [PATCH 054/258] Instantiate sizes properly. --- src/Language/Futhark/TypeChecker/Terms/Monad.hs | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms/Monad.hs b/src/Language/Futhark/TypeChecker/Terms/Monad.hs index 2b7e2cf9fd..0b1efd86b7 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Monad.hs @@ -397,14 +397,15 @@ instTypeScheme :: TypeBase () NoUniqueness -> TermTypeM ([VName], StructType) instTypeScheme qn loc tparams scheme_t inferred = do - (names, substs) <- fmap (unzip . catMaybes) $ - forM tparams $ \tparam -> do - case tparam of - TypeParamType {} -> pure Nothing - TypeParamDim v _ -> do - constrain v . Size Nothing . mkUsage loc . docText $ - "instantiated size parameter of " <> dquotes (pretty qn) - pure $ Just (v, (v, ExpSubst $ sizeFromName (qualName v) loc)) + (names, substs) <- fmap (unzip . catMaybes) . forM tparams $ \tparam -> do + case tparam of + TypeParamType {} -> pure Nothing + TypeParamDim v _ -> do + i <- incCounter + v' <- newID $ mkTypeVarName (baseName v) i + constrain v' . Size Nothing . mkUsage loc . docText $ + "instantiated size parameter of " <> dquotes (pretty qn) + pure $ Just (v', (v, ExpSubst $ sizeFromName (qualName v') loc)) t' <- replaceTyVars loc inferred $ applySubst (`lookup` substs) scheme_t From b4983c63cfccb7ab59ff57be39f6b8c922d58bb1 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Thu, 15 Feb 2024 06:08:04 -0800 Subject: [PATCH 055/258] Don't normalize/forget variables. --- src/Futhark/Solve/LP.hs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/Futhark/Solve/LP.hs b/src/Futhark/Solve/LP.hs index a2b625a5e0..c3e321dfae 100644 --- a/src/Futhark/Solve/LP.hs +++ b/src/Futhark/Solve/LP.hs @@ -239,7 +239,8 @@ cval :: (Num a, Ord v) => LSum v a -> a cval = (! Nothing) (~+~) :: (Eq a, Num a, Ord v) => LSum v a -> LSum v a -> LSum v a -(LSum x) ~+~ (LSum y) = normalize $ LSum $ Map.unionWith (+) x y +-- (LSum x) ~+~ (LSum y) = normalize $ LSum $ Map.unionWith (+) x y +(LSum x) ~+~ (LSum y) = LSum $ Map.unionWith (+) x y infixl 6 ~+~ @@ -249,7 +250,8 @@ x ~-~ y = x ~+~ (neg y) infixl 6 ~-~ (~*~) :: (Eq a, Num a, Ord v) => a -> LSum v a -> LSum v a -a ~*~ s = normalize $ fmap (a *) s +-- a ~*~ s = normalize $ fmap (a *) s +a ~*~ s = fmap (a *) s infixl 7 ~*~ From b122b274b0ea47ff376ee064413373ea0e5fee2a Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 15 Feb 2024 15:11:24 +0100 Subject: [PATCH 056/258] Preserve uniqueness. --- src/Language/Futhark/TypeChecker/Constraints.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index b6920c0f5b..665e67af3f 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -113,7 +113,7 @@ substTyVars _ (Scalar (Prim pt)) = Scalar $ Prim pt substTyVars m (Scalar (Record fs)) = Scalar $ Record $ M.map (substTyVars m) fs substTyVars m (Scalar (Sum cs)) = Scalar $ Sum $ M.map (map $ substTyVars m) cs substTyVars m (Scalar (Arrow u pname d t1 (RetType ext t2))) = - Scalar $ Arrow u pname d (substTyVars m t1) $ RetType ext $ substTyVars m t2 + Scalar $ Arrow u pname d (substTyVars m t1) $ RetType ext $ substTyVars m t2 `setUniqueness` uniqueness t2 substTyVars m (Array u shape elemt) = arrayOfWithAliases u shape $ substTyVars m $ Scalar elemt From fc083add39863f5a3631e9a348f71432dfedb86f Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Thu, 15 Feb 2024 07:15:35 -0800 Subject: [PATCH 057/258] Jank frame fix. --- src/Language/Futhark/TypeChecker/Terms2.hs | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index e4ed9daef9..8005ffbcc8 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -596,7 +596,7 @@ bindParams tps orig_ps m = bindTypeParams tps $ do checkApply :: SrcLoc -> (Maybe (QualName VName), Int) -> Type -> Shape Size -> Exp -> TermM (Type, AutoMap) checkApply loc _ ftype fframe arg = do - (a, b) <- split ftype + (a, b) <- split $ stripFrame fframe ftype r <- newSVar loc "R" m <- newSVar loc "M" let unit_info = Info $ Scalar $ Prim Bool @@ -607,10 +607,17 @@ checkApply loc _ ftype fframe arg = do ctAM r m ctEq lhs rhs pure - ( b, + ( arrayOf (toShape (SVar m) <> (toSComp <$> fframe)) b, AutoMap {autoRep = toShape r_var, autoMap = toShape m_var, autoFrame = toShape m_var <> fframe} ) where + stripFrame :: Shape Size -> Type -> Type + stripFrame frame (Array u ds t) = + let mnew_shape = Shape <$> L.stripPrefix (toSComp <$> shapeDims frame) (shapeDims ds) + in case mnew_shape of + Nothing -> Scalar t + Just new_shape -> arrayOfWithAliases u new_shape $ Scalar t + stripFrame _ t = t toSComp (Var (QualName [] x) _ _) = SVar x toSComp _ = error "" toShape = Shape . pure From 1ed744b04674eb5a70f06c52b7fb03a9774badb8 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Thu, 15 Feb 2024 09:50:04 -0800 Subject: [PATCH 058/258] SPEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEED --- futhark.cabal | 2 ++ src/Futhark/Solve/LP.hs | 19 ++++++++++++++ src/Language/Futhark/TypeChecker/Rank.hs | 30 +++------------------- src/Language/Futhark/TypeChecker/Terms2.hs | 4 +-- 4 files changed, 27 insertions(+), 28 deletions(-) diff --git a/futhark.cabal b/futhark.cabal index c73b997f71..8cf87483f3 100644 --- a/futhark.cabal +++ b/futhark.cabal @@ -371,6 +371,7 @@ library Futhark.Pkg.Types Futhark.Profile Futhark.Script + Futhark.Solve.GLPK Futhark.Solve.LP Futhark.Solve.Matrix Futhark.Solve.Simplex @@ -495,6 +496,7 @@ library , prettyprinter-ansi-terminal >= 1.1 -- remove me later , process + , glpk-hs executable futhark import: common diff --git a/src/Futhark/Solve/LP.hs b/src/Futhark/Solve/LP.hs index c3e321dfae..7623033e7c 100644 --- a/src/Futhark/Solve/LP.hs +++ b/src/Futhark/Solve/LP.hs @@ -20,6 +20,8 @@ module Futhark.Solve.LP LinearProg (..), OptType (..), Constraint (..), + Vars (..), + CType (..), (~==~), (~<=~), (~>=~), @@ -28,11 +30,14 @@ module Futhark.Solve.LP ) where +import Control.Monad.LPMonad import Data.Char (isAscii) import Data.List qualified as L import Data.Map (Map) import Data.Map qualified as Map import Data.Maybe +import Data.Set (Set) +import Data.Set qualified as S import Data.Vector.Unboxed (Unbox, Vector) import Data.Vector.Unboxed qualified as V import Debug.Trace @@ -106,6 +111,12 @@ instance (IsName v, Pretty a, Eq a, Num a) => Pretty (LSum v a) where instance Functor (LSum v) where fmap f (LSum m) = LSum $ fmap f m +class Vars a v where + vars :: a -> Set v + +instance (Ord v) => Vars (LSum v a) v where + vars = S.fromList . catMaybes . Map.keys . lsum + -- | Type of constraint data CType = Equal | LessEq deriving (Show, Eq) @@ -123,6 +134,9 @@ instance (IsName v, Pretty a, Eq a, Num a) => Pretty (Constraint v a) where pretty (Constraint t l r) = pretty l <+> pretty t <+> pretty r +instance (Ord v) => Vars (Constraint v a) v where + vars (Constraint _ l r) = vars l <> vars r + data OptType = Maximize | Minimize deriving (Show, Eq) @@ -147,6 +161,11 @@ instance (IsName v, Pretty a, Eq a, Num a) => Pretty (LinearProg v a) where indent 2 $ vcat $ map pretty cs ] +instance (Ord v) => Vars (LinearProg v a) v where + vars lp = + vars (objective lp) + <> foldMap vars (constraints lp) + -- For debugging linearProgToPulp :: (Unbox a, IsName v, Ord v, Pretty a, Eq a, Num a) => LinearProg v a -> String linearProgToPulp prog = diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 56d1be4e56..d255f3106f 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -11,6 +11,7 @@ import Debug.Trace import Futhark.FreshNames qualified as FreshNames import Futhark.MonadFreshNames hiding (newName) import Futhark.Solve.BranchAndBound +import Futhark.Solve.GLPK import Futhark.Solve.LP hiding (Constraint, LSum, LinearProg) import Futhark.Solve.LP qualified as LP import Futhark.Solve.Simplex @@ -139,16 +140,11 @@ mkLinearProg counter cs tyVars = rankAnalysis :: Bool -> VNameSource -> Int -> [Ct] -> TyVars -> Maybe ([Ct], TyVars, VNameSource, Int) rankAnalysis _ vns counter [] tyVars = Just ([], tyVars, vns, counter) -rankAnalysis use_python vns counter cs tyVars = do +rankAnalysis use_glpk vns counter cs tyVars = do traceM $ unlines ["## rankAnalysis prog", prettyString prog] rank_map <- - if use_python - then do - -- traceM $ linearProgToPulp prog - parseRes $ - unsafePerformIO $ - readProcess "python" [] $ - linearProgToPulp prog + if use_glpk + then snd <$> (unsafePerformIO $ glpk prog) else do (_size, ranks) <- branchAndBound lp pure $ (fromJust . (ranks V.!?)) <$> inv_var_map @@ -195,24 +191,6 @@ rankAnalysis use_python vns counter cs tyVars = do pulp_var_to_vname = M.fromListWith (error "oh no!") [(v, k) | (k, v) <- M.toList vname_to_pulp_var] - parseRes :: String -> Maybe (Map VName Int) - parseRes s = do - (status : vars) <- trimToStart $ lines s - if not (success status) - then Nothing - else do - pure $ M.fromList $ catMaybes $ map readVar vars - where - trimToStart [] = Nothing - trimToStart (l : ls) - | "status" `L.isPrefixOf` l = Just (l : ls) - | otherwise = trimToStart ls - success l = - (read $ drop (length ("status: " :: [Char])) l) == (1 :: Int) - readVar xs = - let (v, _ : value) = L.span (/= ':') xs - in Just (fromJust $ pulp_var_to_vname M.!? v, read value) - newtype SubstM a = SubstM (StateT SubstState (Reader SubstEnv) a) deriving (Functor, Applicative, Monad, MonadState SubstState, MonadReader SubstEnv) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 8005ffbcc8..6fefb80673 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -1087,7 +1087,7 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do vns <- gets termNameSource - let use_python = True + let use_glpk = True traceM $ unlines @@ -1095,7 +1095,7 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do unlines $ map prettyString cts ] - case rankAnalysis use_python vns counter cts tyvars of + case rankAnalysis use_glpk vns counter cts tyvars of Nothing -> error "" Just (cts', tyvars', vns', counter') -> do modify $ \s -> s {termCounter = counter', termNameSource = vns'} From 6eb745ef7ab1cd998b9f810b43b6d649ac6be086 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Thu, 15 Feb 2024 11:47:50 -0800 Subject: [PATCH 059/258] Add `GLPK.hs`. --- src/Futhark/Solve/GLPK.hs | 47 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 src/Futhark/Solve/GLPK.hs diff --git a/src/Futhark/Solve/GLPK.hs b/src/Futhark/Solve/GLPK.hs new file mode 100644 index 0000000000..7b27408a27 --- /dev/null +++ b/src/Futhark/Solve/GLPK.hs @@ -0,0 +1,47 @@ +module Futhark.Solve.GLPK (glpk) where + +import Data.LinearProgram +import Data.Map qualified as M +import Data.Maybe +import Data.Set qualified as S +import Futhark.Solve.LP qualified as F + +linearProgToGLPK :: (Show v, Ord v, Eq a, Num a, Group a) => F.LinearProg v a -> (LP v a) +linearProgToGLPK prog = + LP + { direction = cOptType $ F.optType prog, + objective = cObj $ F.objective prog, + constraints = map cConstraint $ F.constraints prog, + varBounds = bounds, + varTypes = kinds + } + where + cOptType F.Maximize = Max + cOptType F.Minimize = Min + cObj = fst . cLSum + + cLSum (F.LSum m) = + ( M.mapKeys fromJust $ M.filterWithKey (\k _ -> isJust k) m, + fromMaybe 0 (m M.!? Nothing) + ) + + cConstraint (F.Constraint ctype l r) = + let (linfunc, c) = cLSum $ l F.~-~ r + bound = + case ctype of + F.Equal -> Equ (-c) + F.LessEq -> UBound (-c) + in Constr Nothing linfunc bound + + bounds = M.fromList $ (,LBound 0) <$> varList + kinds = M.fromList $ (,IntVar) <$> varList + + varList = S.toList $ F.vars prog + +glpk :: + (Show v, Ord v, Show a, Eq a, Real a, Group a) => + F.LinearProg v a -> + IO (Maybe (Int, M.Map v Int)) +glpk lp = do + (_, mres) <- glpSolveVars mipDefaults $ linearProgToGLPK lp + pure $ (\(opt, vs) -> (truncate opt, fmap truncate vs)) <$> mres From 1d772818bf52914d46ab70006410f1c43a5ab957 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 16 Feb 2024 15:21:10 +0100 Subject: [PATCH 060/258] Basic things work now. --- src/Language/Futhark/TypeChecker/Rank.hs | 5 +- src/Language/Futhark/TypeChecker/Terms.hs | 19 ++--- .../Futhark/TypeChecker/Terms/Monad.hs | 70 +++++++++++++++---- src/Language/Futhark/TypeChecker/Terms/Pat.hs | 10 +-- 4 files changed, 78 insertions(+), 26 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index d255f3106f..23d295f8ee 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -115,7 +115,10 @@ addTyVarInfo :: TyVar -> (Int, TyVarInfo) -> RankM () addTyVarInfo tv (_, TyVarFree) = pure () addTyVarInfo tv (_, TyVarPrim _) = addConstraint $ rank tv ~==~ constant 0 -addTyVarInfo _ _ = error "Unhandled" +addTyVarInfo tv (_, TyVarRecord _) = + addConstraint $ rank tv ~==~ constant 0 +addTyVarInfo tv (_, TyVarSum _) = + addConstraint $ rank tv ~==~ constant 0 mkLinearProg :: Int -> [Ct] -> TyVars -> LinearProg mkLinearProg counter cs tyVars = diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 01b8cfdc3d..919ebda406 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -353,15 +353,15 @@ checkExp :: Exp -> TermTypeM Exp checkExp (Literal val loc) = pure $ Literal val loc checkExp (Hole (Info t) loc) = do - t' <- replaceTyVars loc t t + t' <- replaceTyVars loc t pure $ Hole (Info t') loc checkExp (StringLit vs loc) = pure $ StringLit vs loc checkExp (IntLit val (Info t) loc) = do - t' <- replaceTyVars loc t t + t' <- replaceTyVars loc t pure $ IntLit val (Info t') loc checkExp (FloatLit val (Info t) loc) = do - t' <- replaceTyVars loc t t + t' <- replaceTyVars loc t pure $ FloatLit val (Info t') loc checkExp (TupLit es loc) = TupLit <$> mapM checkExp es <*> pure loc @@ -668,10 +668,13 @@ checkExp (Assert e1 e2 _ loc) = do checkExp (Lambda params body rettype_te (Info (RetType _ rt)) loc) = do (params', body', rettype', RetType dims ty) <- incLevel . bindingParams [] params $ \params' -> do + rt' <- replaceTyVars loc rt rettype_checked <- traverse checkTypeExpNonrigid rettype_te declared_rettype <- case rettype_checked of - Just (_, st, _) -> Just <$> replaceTyVars loc rt st + Just (_, st, _) -> do + unify (mkUsage body "lambda return type ascription") (toStruct rt') (toStruct st) + pure $ Just st Nothing -> pure Nothing body' <- checkFunBody params' body declared_rettype loc body_t <- expTypeFully body' @@ -680,13 +683,11 @@ checkExp (Lambda params body rettype_te (Info (RetType _ rt)) loc) = do (rettype', rettype_st) <- case rettype_checked of - Just (te, st, ext) -> do - st' <- replaceTyVars loc rt st - pure (Just te, RetType ext st') + Just (te, _, ext) -> + pure (Just te, RetType ext rt') Nothing -> do RetType ext ret <- inferReturnSizes params'' $ toRes Nonunique body_t - ret' <- replaceTyVars loc rt ret - pure (Nothing, RetType ext ret') + pure (Nothing, RetType ext ret) pure (params'', body', rettype', rettype_st) diff --git a/src/Language/Futhark/TypeChecker/Terms/Monad.hs b/src/Language/Futhark/TypeChecker/Terms/Monad.hs index 0b1efd86b7..87697bc359 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Monad.hs @@ -58,7 +58,6 @@ import Data.Map.Strict qualified as M import Data.Maybe import Data.Set qualified as S import Data.Text qualified as T -import Debug.Trace import Futhark.FreshNames hiding (newName) import Futhark.FreshNames qualified import Futhark.Util.Pretty hiding (space) @@ -352,14 +351,47 @@ instance MonadUnify TermTypeM where indent 2 (pretty t2) "do not match." -replaceTyVars :: +replaceTyVars :: SrcLoc -> TypeBase Size u -> TermTypeM (TypeBase Size u) +replaceTyVars loc orig_t = do + tyvars <- asks termTyVars + let f :: TypeBase Size u -> TermTypeM (TypeBase Size u) + f (Scalar (Prim t)) = pure $ Scalar $ Prim t + f + (Scalar (TypeVar u (QualName [] v) [])) + | Just t <- M.lookup v tyvars = + fst <$> allDimsFreshInType (mkUsage loc "instantiation") Nonrigid "dv" (second (const u) t) + | otherwise = + pure $ Scalar (TypeVar u (QualName [] v) []) + f (Scalar (TypeVar u qn targs)) = + Scalar . TypeVar u qn <$> mapM onTyArg targs + where + onTyArg (TypeArgDim e) = pure $ TypeArgDim e + onTyArg (TypeArgType t) = TypeArgType <$> f t + f (Scalar (Record fs)) = + Scalar . Record <$> traverse f fs + f (Scalar (Sum fs)) = + Scalar . Sum <$> traverse (mapM f) fs + f (Scalar (Arrow u pname d ta (RetType ext tr))) = do + ta' <- f ta + tr' <- f tr + pure $ Scalar $ Arrow u pname d ta' $ RetType ext tr' + f (Array u shape t) = + arrayOfWithAliases u shape <$> f (Scalar t) + + f orig_t + +instTyVars :: SrcLoc -> - TypeBase d u1 -> - TypeBase Size u2 -> - TermTypeM (TypeBase Size u1) -replaceTyVars loc orig_t1 orig_t2 = do + [VName] -> + TypeBase () u -> + TypeBase Size u -> + TermTypeM (TypeBase Size u) +instTyVars loc names orig_t1 orig_t2 = do tyvars <- asks termTyVars - let f :: TypeBase d u1 -> TypeBase Size u2 -> TermTypeM (TypeBase Size u1) + let f :: + TypeBase d u -> + TypeBase Size u -> + StateT (M.Map VName (TypeBase Size NoUniqueness)) TermTypeM (TypeBase Size u) f (Scalar (TypeVar u (QualName [] v1) [])) t2 @@ -380,10 +412,23 @@ replaceTyVars loc orig_t1 orig_t2 = do (Array _ (Shape (d : ds2)) t2) = arrayOfWithAliases u (Shape [d]) <$> f (arrayOf (Shape ds1) (Scalar t1)) (arrayOf (Shape ds2) (Scalar t2)) - f t1 _ = - fst <$> allDimsFreshInType (mkUsage loc "instantiation") Nonrigid "dv" t1 - - f orig_t1 orig_t2 + f t1 t2 = do + let mkNew = + fst <$> lift (allDimsFreshInType (mkUsage loc "instantiation") Nonrigid "dv" t1) + case t2 of + Scalar (TypeVar u (QualName [] v2) []) + | v2 `elem` names -> do + seen <- get + case M.lookup v2 seen of + Nothing -> do + t <- mkNew + modify $ M.insert v2 $ second (const NoUniqueness) t + pure t + Just t -> + pure $ second (const u) t + _ -> mkNew + + evalStateT (f orig_t1 orig_t2) mempty -- | Instantiate a type scheme with fresh size variables for its size -- parameters. Replaces type parameters with their known @@ -407,7 +452,8 @@ instTypeScheme qn loc tparams scheme_t inferred = do "instantiated size parameter of " <> dquotes (pretty qn) pure $ Just (v', (v, ExpSubst $ sizeFromName (qualName v') loc)) - t' <- replaceTyVars loc inferred $ applySubst (`lookup` substs) scheme_t + let tp_names = map typeParamName $ filter isTypeParam tparams + t' <- instTyVars loc tp_names inferred $ applySubst (`lookup` substs) scheme_t pure (names, t') diff --git a/src/Language/Futhark/TypeChecker/Terms/Pat.hs b/src/Language/Futhark/TypeChecker/Terms/Pat.hs index f8bd0a42f6..b1a2f59a8d 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Pat.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Pat.hs @@ -123,16 +123,18 @@ checkPat' sizes (PatParens p loc) t = checkPat' sizes (PatAttr attr p loc) t = PatAttr <$> checkAttr attr <*> checkPat' sizes p t <*> pure loc checkPat' _ (Id name (Info t) loc) NoneInferred = do - t' <- replaceTyVars loc (first (const ()) t) t + t' <- replaceTyVars loc t pure $ Id name (Info t') loc checkPat' _ (Id name (Info t1) loc) (Ascribed t2) = do - t <- replaceTyVars loc (first (const ()) t1) t2 + t <- replaceTyVars loc t1 + unify (mkUsage loc "wildcard") (toStruct t) (toStruct t2) pure $ Id name (Info t) loc checkPat' _ (Wildcard (Info t) loc) NoneInferred = do - t' <- replaceTyVars loc (first (const ()) t) t + t' <- replaceTyVars loc t pure $ Wildcard (Info t') loc checkPat' _ (Wildcard (Info t1) loc) (Ascribed t2) = do - t <- replaceTyVars loc (first (const ()) t1) t2 + t <- replaceTyVars loc t1 + unify (mkUsage loc "wildcard") (toStruct t) (toStruct t2) pure $ Wildcard (Info t) loc checkPat' sizes p@(TuplePat ps loc) (Ascribed t) | Just ts <- isTupleRecord t, From cb21531f32edbe2226e760f3c9eca1ff1b97baf3 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 16 Feb 2024 16:46:18 +0100 Subject: [PATCH 061/258] Fix typo. --- src/Language/Futhark/TypeChecker/Terms2.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 6fefb80673..f2f7f8f327 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -601,7 +601,7 @@ checkApply loc _ ftype fframe arg = do m <- newSVar loc "M" let unit_info = Info $ Scalar $ Prim Bool r_var = Var (QualName [] r) unit_info mempty - m_var = Var (QualName [] r) unit_info mempty + m_var = Var (QualName [] m) unit_info mempty lhs = arrayOf (toShape (SVar r) <> (toSComp <$> frameOf arg)) $ toType $ typeOf arg rhs = arrayOf (toShape (SVar m) <> (toSComp <$> fframe)) a ctAM r m From 4d673d3eba2777cb3d8a2976b7819224e9470b8a Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 16 Feb 2024 17:51:26 +0100 Subject: [PATCH 062/258] Working AUTOMAP (not really). --- prelude/zip.fut | 2 +- .../Futhark/TypeChecker/Consumption.hs | 34 +++++++++-------- src/Language/Futhark/TypeChecker/Terms.hs | 38 ++++++++++++++----- 3 files changed, 48 insertions(+), 26 deletions(-) diff --git a/prelude/zip.fut b/prelude/zip.fut index 1171820307..18361e545f 100644 --- a/prelude/zip.fut +++ b/prelude/zip.fut @@ -11,7 +11,7 @@ -- depended upon by soacs.fut. So we just define a quick-and-dirty -- internal one here that uses the intrinsic version. local def internal_map 'a [n] 'x (f: a -> x) (as: [n]a): *[n]x = - intrinsics.map f as + f as -- | Construct an array of pairs from two arrays. def zip [n] 'a 'b (as: [n]a) (bs: [n]b): *[n](a,b) = diff --git a/src/Language/Futhark/TypeChecker/Consumption.hs b/src/Language/Futhark/TypeChecker/Consumption.hs index 5c1198537f..5ebe2996df 100644 --- a/src/Language/Futhark/TypeChecker/Consumption.hs +++ b/src/Language/Futhark/TypeChecker/Consumption.hs @@ -486,9 +486,10 @@ consumeAsNeeded loc (Scalar (Record fs1)) (Scalar (Record fs2)) = consumeAsNeeded loc pt t = when (diet pt == Consume) $ consumeAliases loc $ aliases t -checkArg :: [(Exp, TypeAliases)] -> ParamType -> Exp -> CheckM (Exp, TypeAliases) -checkArg prev p_t e = do - ((e', e_als), e_cons) <- contain $ checkExp e +checkArg :: [(Exp, TypeAliases)] -> ParamType -> AutoMap -> Exp -> CheckM (Exp, TypeAliases) +checkArg prev p_t am e = do + ((e', e_als), e_cons) <- + contain $ if autoRep am == mempty then noAliases e else checkExp e consumed e_cons let e_t = typeOf e' when (e_cons /= mempty && not (orderZero e_t)) $ @@ -542,9 +543,11 @@ returnType appres (Scalar (Arrow _ v pd t1 (RetType dims t2))) Observe arg = returnType appres (Scalar (Sum cs)) d arg = Scalar $ Sum $ (fmap . fmap) (\et -> returnType appres et d arg) cs -applyArg :: TypeAliases -> TypeAliases -> TypeAliases -applyArg (Scalar (Arrow closure_als _ d _ (RetType _ rettype))) arg_als = - returnType closure_als rettype d arg_als +applyArg :: TypeAliases -> (AutoMap, TypeAliases) -> TypeAliases +applyArg (Scalar (Arrow closure_als _ d _ (RetType _ rettype))) (am, arg_als) = + if autoMap am /= mempty + then second (const mempty) rettype + else returnType closure_als rettype d arg_als applyArg t _ = error $ "applyArg: " <> show t boundFreeInExp :: Exp -> CheckM (M.Map VName TypeAliases) @@ -664,7 +667,7 @@ checkLoop loop_loc (param, arg, form, body) = do param' <- convergeLoopParam loop_loc param (M.keysSet body_cons) body_als let param_t = patternType param' - ((arg', arg_als), arg_cons) <- contain $ checkArg [] param_t arg + ((arg', arg_als), arg_cons) <- contain $ checkArg [] param_t mempty arg consumed arg_cons free_bound <- boundFreeInExp body @@ -685,7 +688,7 @@ checkLoop loop_loc (param, arg, form, body) = do `setAliases` S.singleton (AliasFree v) pure ( (param', arg', form', body'), - applyArg loopt arg_als `combineAliases` body_als + applyArg loopt (mempty, arg_als) `combineAliases` body_als ) checkFuncall :: @@ -693,7 +696,7 @@ checkFuncall :: SrcLoc -> Maybe (QualName VName) -> TypeAliases -> - f TypeAliases -> + f (AutoMap, TypeAliases) -> CheckM TypeAliases checkFuncall loc fname f_als arg_als = do v <- VName "internal_app_result" <$> incCounter @@ -707,15 +710,16 @@ checkExp :: Exp -> CheckM (Exp, TypeAliases) checkExp (AppExp (Apply f args loc) appres) = do (f', f_als) <- checkExp f (args', args_als) <- NE.unzip <$> checkArgs (toRes Nonunique f_als) args - res_als <- checkFuncall loc (fname f) f_als args_als + res_als <- + checkFuncall loc (fname f) f_als $ + NE.zip (fmap (snd . unInfo . fst) args') args_als pure (AppExp (Apply f' args' loc) appres, res_als) where - -- neUnzip3 xs = ((\(x, _, _) -> x) <$> xs, (\(_, y, _) -> y) <$> xs, (\(_, _, z) -> z) <$> xs) fname (Var v _ _) = Just v fname (AppExp (Apply e _ _) _) = fname e fname _ = Nothing checkArg' prev d (Info (p, am), e) = do - (e', e_als) <- checkArg prev (second (const d) (typeOf e)) e + (e', e_als) <- checkArg prev (second (const d) (typeOf e)) am e pure ((Info (p, am), e'), e_als) checkArgs (Scalar (Arrow _ _ d _ (RetType _ rt))) (x NE.:| args') = do @@ -807,9 +811,9 @@ checkExp (AppExp (LetFun fname (typarams, params, te, Info (RetType ext ret), fu checkExp (AppExp (BinOp (op, oploc) opt (x, xp) (y, yp) loc) appres) = do op_als <- observeVar (locOf oploc) (qualLeaf op) (unInfo opt) let at1 : at2 : _ = fst $ unfoldFunType op_als - (x', x_als) <- checkArg [] at1 x - (y', y_als) <- checkArg [(x', x_als)] at2 y - res_als <- checkFuncall loc (Just op) op_als [x_als, y_als] + (x', x_als) <- checkArg [] at1 mempty x + (y', y_als) <- checkArg [(x', x_als)] at2 mempty y + res_als <- checkFuncall loc (Just op) op_als [(mempty, x_als), (mempty, y_als)] pure ( AppExp (BinOp (op, oploc) opt (x', xp) (y', yp) loc) appres, res_als diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 919ebda406..6510037fb6 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -479,8 +479,8 @@ checkExp (AppExp (BinOp (op, oploc) (Info op_t) (e1, _) (e2, _) loc) _) = do -- Note that the application to the first operand cannot fix any -- existential sizes, because it must by necessity be a function. - (_, rt, p1_ext, _) <- checkApply loc (Just op, 0) ftype e1' - (_, rt', p2_ext, retext) <- checkApply loc (Just op, 1) rt e2' + (_, rt, p1_ext, _, _) <- checkApply loc (Just op, 0) ftype e1' + (_, rt', p2_ext, retext, _) <- checkApply loc (Just op, 1) rt e2' pure $ AppExp @@ -548,10 +548,10 @@ checkExp (AppExp (Apply fe args loc) _) = do pure $ AppExp (Apply fe' args'' loc) $ Info $ AppRes rt exts where onArg fname (i, all_exts, t) arg' = do - (_, rt, argext, exts) <- checkApply loc (fname, i) t arg' + (_, rt, argext, exts, am) <- checkApply loc (fname, i) t arg' pure - ( (i + 1, all_exts <> exts, rt), - (Info (argext, mempty), arg') + ( (i + 1, all_exts <> exts, arrayOf (autoFrame am) rt), + (Info (argext, am), arg') ) checkExp (AppExp (LetPat sizes pat e body loc) _) = do e' <- checkExp e @@ -726,7 +726,7 @@ checkExp (OpSection op (Info op_t) loc) = do checkExp (OpSectionLeft op (Info op_t) e _ _ loc) = do ftype <- lookupVar loc op op_t e' <- checkExp e - (t1, rt, argext, retext) <- checkApply loc (Just op, 0) ftype e' + (t1, rt, argext, retext, _) <- checkApply loc (Just op, 0) ftype e' case (ftype, rt) of (Scalar (Arrow _ m1 d1 _ _), Scalar (Arrow _ m2 d2 t2 rettype)) -> pure $ @@ -745,7 +745,7 @@ checkExp (OpSectionRight op (Info op_t) e _ _ loc) = do e' <- checkExp e case ftype of Scalar (Arrow _ m1 d1 t1 (RetType [] (Scalar (Arrow _ m2 d2 t2 (RetType dims2 ret))))) -> do - (t2', arrow', argext, _) <- + (t2', arrow', argext, _, _) <- checkApply loc (Just op, 1) @@ -923,16 +923,27 @@ dimUses = flip execState mempty . traverseDims f where fv = freeInExp e `freeWithout` bound +-- | Try to find out how many dimensions of the argument we are +-- mapping. Returns the shape mapped and the remaining type. +stripToMatch :: StructType -> StructType -> (Shape Size, StructType) +stripToMatch paramt argt | toStructural paramt == toStructural argt = (mempty, argt) +stripToMatch paramt (Array _ (Shape (d : ds)) argt) = + first (Shape [d] <>) $ stripToMatch paramt $ arrayOf (Shape ds) (Scalar argt) +stripToMatch _ argt = (mempty, argt) + checkApply :: SrcLoc -> ApplyOp -> StructType -> Exp -> - TermTypeM (StructType, StructType, Maybe VName, [VName]) + TermTypeM (StructType, StructType, Maybe VName, [VName], AutoMap) checkApply loc (fname, _) (Scalar (Arrow _ pname _ tp1 tp2)) argexp = do let argtype = typeOf argexp onFailure (CheckingApply fname argexp tp1 argtype) $ do - unify (mkUsage argexp "use as function argument") tp1 argtype + (am_map_shape, argtype_automap) <- + stripToMatch <$> normTypeFully tp1 <*> normTypeFully argtype + + unify (mkUsage argexp "use as function argument") tp1 argtype_automap -- Perform substitutions of instantiated variables in the types. (tp2', ext) <- instantiateDimsInReturnType loc fname =<< normTypeFully tp2 @@ -972,7 +983,14 @@ checkApply loc (fname, _) (Scalar (Arrow _ pname _ tp1 tp2)) argexp = do in pure (Nothing, applySubst parsubst $ toStruct tp2') _ -> pure (Nothing, toStruct tp2') - pure (tp1, tp2'', argext, ext) + let am = + AutoMap + { autoMap = am_map_shape, + autoRep = mempty, + autoFrame = am_map_shape + } + + pure (tp1, tp2'', argext, ext, am) checkApply loc fname tfun@(Scalar TypeVar {}) arg = do tv <- newTypeVar loc "b" unify (mkUsage loc "use as function") tfun $ From 8fd5e188ca49f8a1ecad90fd3178d992a23eacff Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Fri, 16 Feb 2024 14:40:52 -0800 Subject: [PATCH 063/258] Support AUTOMAP on `BinOp`s. --- src/Futhark/Internalise/FullNormalise.hs | 4 ++-- src/Futhark/Internalise/Monomorphise.hs | 6 +++--- src/Language/Futhark/Interpreter.hs | 2 +- src/Language/Futhark/Prop.hs | 4 ++-- src/Language/Futhark/Syntax.hs | 4 ++-- src/Language/Futhark/TypeChecker/Terms.hs | 18 +++++++++--------- src/Language/Futhark/TypeChecker/Terms2.hs | 4 ++-- 7 files changed, 21 insertions(+), 21 deletions(-) diff --git a/src/Futhark/Internalise/FullNormalise.hs b/src/Futhark/Internalise/FullNormalise.hs index fea8000abd..91e16a9a53 100644 --- a/src/Futhark/Internalise/FullNormalise.hs +++ b/src/Futhark/Internalise/FullNormalise.hs @@ -298,7 +298,7 @@ getOrdering final (AppExp (Loop sizes pat einit form body loc) resT) = do While e -> While <$> transformBody e body' <- transformBody body nameExp final $ AppExp (Loop sizes pat einit' form' body' loc) resT -getOrdering final (AppExp (BinOp (op, oloc) opT (el, Info elp) (er, Info erp) loc) (Info resT)) = do +getOrdering final (AppExp (BinOp (op, oloc) opT (el, Info (elp, elam)) (er, Info (erp, eram)) loc) (Info resT)) = do expr' <- case (isOr, isAnd) of (True, _) -> do el' <- naming "or_lhs" $ getOrdering True el @@ -311,7 +311,7 @@ getOrdering final (AppExp (BinOp (op, oloc) opT (el, Info elp) (er, Info erp) lo (False, False) -> do el' <- naming (prettyString op <> "_lhs") $ getOrdering False el er' <- naming (prettyString op <> "_rhs") $ getOrdering False er - pure $ mkApply (Var op opT oloc) [(elp, mempty, el'), (erp, mempty, er')] resT + pure $ mkApply (Var op opT oloc) [(elp, elam, el'), (erp, eram, er')] resT nameExp final expr' where isOr = baseName (qualLeaf op) == "||" diff --git a/src/Futhark/Internalise/Monomorphise.hs b/src/Futhark/Internalise/Monomorphise.hs index e0d9834cfa..ee352f67c6 100644 --- a/src/Futhark/Internalise/Monomorphise.hs +++ b/src/Futhark/Internalise/Monomorphise.hs @@ -535,7 +535,7 @@ transformAppExp (Loop sparams pat e1 form body loc) res = do (pat_sizes, pat'') <- sizesForPat pat' res' <- transformAppRes res pure $ AppExp (Loop (sparams' ++ pat_sizes) pat'' e1' form' body' loc) (Info res') -transformAppExp (BinOp (fname, _) (Info t) (e1, d1) (e2, d2) loc) res = do +transformAppExp (BinOp (fname, _) (Info t) (e1, Info (d1, am1)) (e2, Info (d2, am2)) loc) res = do (AppRes ret ext) <- transformAppRes res fname' <- transformFName loc fname (toStruct t) e1' <- transformExp e1 @@ -570,8 +570,8 @@ transformAppExp (BinOp (fname, _) (Info t) (e1, d1) (e2, d2) loc) res = do where applyOp ret ext fname' x y = mkApply - (mkApply fname' [(unInfo d1, mempty, x)] (AppRes ret mempty)) - [(unInfo d2, mempty, y)] + (mkApply fname' [(d1, am1, x)] (AppRes ret mempty)) + [(d2, am2, y)] (AppRes ret ext) makeVarParam arg = do diff --git a/src/Language/Futhark/Interpreter.hs b/src/Language/Futhark/Interpreter.hs index ac76cf6645..01e4bddb21 100644 --- a/src/Language/Futhark/Interpreter.hs +++ b/src/Language/Futhark/Interpreter.hs @@ -828,7 +828,7 @@ evalAppExp env (LetPat sizes p e body _) = do evalAppExp env (LetFun f (tparams, ps, _, Info ret, fbody) body _) = do binding <- evalFunctionBinding env tparams ps ret fbody eval (env {envTerm = M.insert f binding $ envTerm env}) body -evalAppExp env (BinOp (op, _) op_t (x, Info xext) (y, Info yext) loc) +evalAppExp env (BinOp (op, _) op_t (x, Info (xext, xam)) (y, Info (yext, yam)) loc) | baseString (qualLeaf op) == "&&" = do x' <- asBool <$> eval env x if x' diff --git a/src/Language/Futhark/Prop.hs b/src/Language/Futhark/Prop.hs index cc5c40268a..419e50ba6b 100644 --- a/src/Language/Futhark/Prop.hs +++ b/src/Language/Futhark/Prop.hs @@ -682,8 +682,8 @@ mkBinOp op t x y = ( BinOp (qualName (intrinsicVar op), mempty) (Info t) - (x, Info Nothing) - (y, Info Nothing) + (x, Info (Nothing, mempty)) + (y, Info (Nothing, mempty)) mempty ) (Info $ AppRes t []) diff --git a/src/Language/Futhark/Syntax.hs b/src/Language/Futhark/Syntax.hs index f0f2a586df..bd2133f017 100644 --- a/src/Language/Futhark/Syntax.hs +++ b/src/Language/Futhark/Syntax.hs @@ -744,8 +744,8 @@ data AppExpBase f vn | BinOp (QualName vn, SrcLoc) (f StructType) - (ExpBase f vn, f (Maybe VName)) - (ExpBase f vn, f (Maybe VName)) + (ExpBase f vn, f (Maybe VName, AutoMap)) + (ExpBase f vn, f (Maybe VName, AutoMap)) SrcLoc | LetWith (IdentBase f vn StructType) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 6510037fb6..967919c52d 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -170,8 +170,8 @@ sliceShape r slice t@(Array u (Shape orig_dims) et) = ( BinOp (qualName (intrinsicVar "-"), mempty) sizeBinOpInfo - (j, Info Nothing) - (i, Info Nothing) + (j, Info (Nothing, mempty)) + (i, Info (Nothing, mempty)) mempty ) $ Info @@ -454,8 +454,8 @@ checkExp (AppExp (Range start maybe_step end loc) _) = do ( BinOp (qualName (intrinsicVar op), mempty) sizeBinOpInfo - (x, Info Nothing) - (y, Info Nothing) + (x, Info (Nothing, mempty)) + (y, Info (Nothing, mempty)) mempty ) (Info $ AppRes t []) @@ -479,16 +479,16 @@ checkExp (AppExp (BinOp (op, oploc) (Info op_t) (e1, _) (e2, _) loc) _) = do -- Note that the application to the first operand cannot fix any -- existential sizes, because it must by necessity be a function. - (_, rt, p1_ext, _, _) <- checkApply loc (Just op, 0) ftype e1' - (_, rt', p2_ext, retext, _) <- checkApply loc (Just op, 1) rt e2' + (_, rt, p1_ext, _, am1) <- checkApply loc (Just op, 0) ftype e1' + (_, rt', p2_ext, retext, am2) <- checkApply loc (Just op, 1) rt e2' pure $ AppExp ( BinOp (op, oploc) (Info ftype) - (e1', Info p1_ext) - (e2', Info p2_ext) + (e1', Info (p1_ext, am1)) + (e2', Info (p2_ext, am2)) loc ) (Info (AppRes rt' retext)) @@ -1143,7 +1143,7 @@ causalityCheck binding_body = do modify (new_known <>) onExp known - e@(AppExp (BinOp (f, floc) ft (x, Info xp) (y, Info yp) _) (Info res)) = do + e@(AppExp (BinOp (f, floc) ft (x, Info (xp, _)) (y, Info (yp, _)) _) (Info res)) = do args_known <- collectingNewKnown $ sequencePoint known x y $ catMaybes [xp, yp] void $ onExp (args_known <> known) (Var f ft floc) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index f2f7f8f327..71318292ef 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -835,12 +835,12 @@ checkExp (AppExp (BinOp (op, oploc) NoInfo (e1, _) (e2, _) loc) NoInfo) = do e2' <- checkExp e2 (rt1, am1) <- checkApply loc (Just op, 0) (toType ftype) mempty e1' - (rt2, am2) <- checkApply loc (Just op, 1) rt1 mempty e2' + (rt2, am2) <- checkApply loc (Just op, 1) rt1 (autoFrame am1) e2' rt2' <- asStructType loc rt2 pure $ AppExp - (BinOp (op, oploc) (Info ftype) (e1', Info Nothing) (e2', Info Nothing) loc) + (BinOp (op, oploc) (Info ftype) (e1', Info (Nothing, am1)) (e2', Info (Nothing, am2)) loc) (Info (AppRes rt2' [])) -- checkExp (OpSectionLeft op _ e _ _ loc) = do From 32c9a73f7d62b277e8f21a4b88f1f6447051f92e Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sat, 17 Feb 2024 14:07:33 +0100 Subject: [PATCH 064/258] Some work on overloaded type variables. --- .../Futhark/TypeChecker/Constraints.hs | 8 ++-- src/Language/Futhark/TypeChecker/Terms.hs | 1 + src/Language/Futhark/TypeChecker/Terms2.hs | 45 ++++++++++++++++--- 3 files changed, 43 insertions(+), 11 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 665e67af3f..2548c74008 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -118,18 +118,16 @@ substTyVars m (Array u shape elemt) = arrayOfWithAliases u shape $ substTyVars m $ Scalar elemt -- | A solution maps a type variable to its substitution. This substitution is complete, in the sense there are no right-hand sides that contain a type variable. -type Solution = M.Map TyVar (TypeBase () NoUniqueness) +type Solution = M.Map TyVar (Either [PrimType] (TypeBase () NoUniqueness)) solution :: SolverState -> Solution solution s = M.mapMaybe mkSubst $ solverTyVars s where - mkSubst (TyVarSol _lvl t) = Just $ first (const ()) $ substTyVars (solverTyVars s) t + mkSubst (TyVarSol _lvl t) = Just $ Right $ first (const ()) $ substTyVars (solverTyVars s) t mkSubst (TyVarLink v') = mkSubst =<< M.lookup v' (solverTyVars s) - mkSubst (TyVarUnsol _ (TyVarPrim pts)) - | Signed Int32 `elem` pts = - Just (Scalar (Prim (Signed Int32))) -- XXX - we need warnings and things! + mkSubst (TyVarUnsol _ (TyVarPrim pts)) = Just $ Left pts mkSubst _ = Nothing newtype SolveM a = SolveM {runSolveM :: StateT SolverState (Except T.Text) a} diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 967919c52d..4e0ea57060 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -1666,6 +1666,7 @@ checkFunDef :: checkFunDef (fname, retdecl, tparams, params, body, loc) = do (maybe_tysubsts, params', retdecl', body') <- Terms2.checkValDef (fname, retdecl, tparams, params, body, loc) + case maybe_tysubsts of Left err -> typeError loc mempty $ pretty err Right tysubsts -> runTermTypeM checkExp tysubsts $ do diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 71318292ef..52933dd1a3 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -49,8 +49,10 @@ import Data.Bifunctor import Data.Char (isAscii) import Data.List qualified as L import Data.List.NonEmpty qualified as NE +import Data.Loc (Loc (NoLoc)) import Data.Map qualified as M import Data.Maybe +import Data.Set qualified as S import Data.Text qualified as T import Debug.Trace import Futhark.FreshNames qualified as FreshNames @@ -63,7 +65,7 @@ import Language.Futhark.TypeChecker.Monad hiding (BoundV, lookupMod) import Language.Futhark.TypeChecker.Monad qualified as TypeM import Language.Futhark.TypeChecker.Rank import Language.Futhark.TypeChecker.Types -import Language.Futhark.TypeChecker.Unify (Level) +import Language.Futhark.TypeChecker.Unify (Level, mkUsage) import Prelude hiding (mod) data Inferred t @@ -430,7 +432,10 @@ patLitMkType (PatLitPrim v) _ = pure $ Scalar $ Prim $ primValueType v checkSizeExp :: ExpBase NoInfo VName -> TermM Exp -checkSizeExp = require "use as size" [Signed Int64] <=< checkExp +checkSizeExp e = do + e' <- checkExp e + ctEq (expType e') (Scalar (Prim (Signed Int64))) + pure e' checkPat' :: PatBase NoInfo VName ParamType -> @@ -1057,6 +1062,30 @@ checkExp (Coerce e te NoInfo loc) = do ctEq (expType e') (toType st) pure $ Coerce e' te' (Info (toStruct st)) loc +doDefaults :: + S.Set VName -> + VName -> + Either [PrimType] (TypeBase () NoUniqueness) -> + TermM (TypeBase () NoUniqueness) +doDefaults tyvars_at_toplevel v (Left pts) + | Signed Int32 `elem` pts = do + when (v `S.member` tyvars_at_toplevel) $ + warn usage "Defaulting ambiguous type to i32." + pure $ Scalar $ Prim $ Signed Int32 + | FloatType Float64 `elem` pts = do + when (v `S.member` tyvars_at_toplevel) $ + warn usage "Defaulting ambiguous type to f64." + pure $ Scalar $ Prim $ FloatType Float64 + | otherwise = + typeError usage mempty . withIndexLink "ambiguous-type" $ + "Type is ambiguous (could be one of" + <+> commasep (map pretty pts) + <> ")." + "Add a type annotation to disambiguate the type." + where + usage = mkUsage NoLoc "overload" +doDefaults _ _ (Right t) = pure t + checkValDef :: ( VName, Maybe (TypeExp (ExpBase NoInfo VName) VName), @@ -1066,7 +1095,7 @@ checkValDef :: SrcLoc ) -> TypeM - ( Either T.Text Solution, + ( Either T.Text (M.Map TyVar (TypeBase () NoUniqueness)), [Pat ParamType], Maybe (TypeExp Exp VName), Exp @@ -1099,7 +1128,8 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do Nothing -> error "" Just (cts', tyvars', vns', counter') -> do modify $ \s -> s {termCounter = counter', termNameSource = vns'} - let solution = solve cts' tyvars' + + solution <- traverse (M.traverseWithKey (doDefaults mempty)) $ solve cts' tyvars' traceM $ unlines @@ -1111,12 +1141,15 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do let p (v, t) = prettyNameString v <> " => " <> prettyString t in either T.unpack (unlines . map p . M.toList) solution ] + pure (solution, params', retdecl', body') -checkSingleExp :: ExpBase NoInfo VName -> TypeM (Either T.Text Solution, Exp) +checkSingleExp :: + ExpBase NoInfo VName -> + TypeM (Either T.Text (M.Map TyVar (TypeBase () NoUniqueness)), Exp) checkSingleExp e = runTermM $ do e' <- checkExp e cts <- gets termConstraints tyvars <- gets termTyVars - let solution = solve cts tyvars + solution <- traverse (M.traverseWithKey (doDefaults mempty)) $ solve cts tyvars pure (solution, e') From 5f45f198ff8819d425388260652c7f5eceff8548 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sat, 17 Feb 2024 14:14:40 +0100 Subject: [PATCH 065/258] Tear out some organs we probably do not need anymore. --- src/Language/Futhark/TypeChecker/Terms.hs | 25 +----- .../Futhark/TypeChecker/Terms/Loop.hs | 4 +- .../Futhark/TypeChecker/Terms/Monad.hs | 27 +----- src/Language/Futhark/TypeChecker/Unify.hs | 84 ------------------- 4 files changed, 9 insertions(+), 131 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 4e0ea57060..58e1b009c3 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -390,7 +390,7 @@ checkExp (ArrayLit all_es _ loc) = t <- arrayOfM loc et (Shape [sizeFromInteger (genericLength all_es) mempty]) pure $ ArrayLit (e' : es') (Info t) loc checkExp (AppExp (Range start maybe_step end loc) _) = do - start' <- require "use in range expression" anySignedType =<< checkExp start + start' <- checkExp start start_t <- expType start' maybe_step' <- case maybe_step of Nothing -> pure Nothing @@ -530,10 +530,10 @@ checkExp (Var qn (Info t) loc) = do t' <- lookupVar loc qn t pure $ Var qn (Info t') loc checkExp (Negate arg loc) = do - arg' <- require "numeric negation" anyNumberType =<< checkExp arg + arg' <- checkExp arg pure $ Negate arg' loc checkExp (Not arg loc) = do - arg' <- require "logical negation" (Bool : anyIntType) =<< checkExp arg + arg' <- checkExp arg pure $ Not arg' loc checkExp (AppExp (Apply fe args loc) _) = do fe' <- checkExp fe @@ -662,7 +662,7 @@ checkExp (AppExp (Index e slice loc) _) = do pure $ AppExp (Index e' slice' loc) (Info $ AppRes t' retext) checkExp (Assert e1 e2 _ loc) = do - e1' <- require "being asserted" [Bool] =<< checkExp e1 + e1' <- checkExp e1 e2' <- checkExp e2 pure $ Assert e1' e2' (Info (prettyText e1)) loc checkExp (Lambda params body rettype_te (Info (RetType _ rt)) loc) = do @@ -1286,23 +1286,6 @@ fixOverloadedTypes :: Names -> TermTypeM () fixOverloadedTypes tyvars_at_toplevel = getConstraints >>= mapM_ fixOverloaded . M.toList . M.map snd where - fixOverloaded (v, Overloaded ots usage) - | Signed Int32 `elem` ots = do - unify usage (Scalar (TypeVar mempty (qualName v) [])) $ - Scalar (Prim $ Signed Int32) - when (v `S.member` tyvars_at_toplevel) $ - warn usage "Defaulting ambiguous type to i32." - | FloatType Float64 `elem` ots = do - unify usage (Scalar (TypeVar mempty (qualName v) [])) $ - Scalar (Prim $ FloatType Float64) - when (v `S.member` tyvars_at_toplevel) $ - warn usage "Defaulting ambiguous type to f64." - | otherwise = - typeError usage mempty . withIndexLink "ambiguous-type" $ - "Type is ambiguous (could be one of" - <+> commasep (map pretty ots) - <> ")." - "Add a type annotation to disambiguate the type." fixOverloaded (v, NoConstraint _ usage) = do -- See #1552. unify usage (Scalar (TypeVar mempty (qualName v) [])) $ diff --git a/src/Language/Futhark/TypeChecker/Terms/Loop.hs b/src/Language/Futhark/TypeChecker/Terms/Loop.hs index 349b105823..7cba8af7e8 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Loop.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Loop.hs @@ -220,9 +220,7 @@ checkLoop checkExp (mergepat, mergeexp, form, loopbody) loc = do (sparams, mergepat', form', loopbody') <- case form of For i uboundexp -> do - uboundexp' <- - require "being the bound in a 'for' loop" anySignedType - =<< checkExp uboundexp + uboundexp' <- checkExp uboundexp bindingIdent i . bindingPat [] mergepat merge_t $ \mergepat' -> incLevel $ do loopbody' <- checkExp loopbody diff --git a/src/Language/Futhark/TypeChecker/Terms/Monad.hs b/src/Language/Futhark/TypeChecker/Terms/Monad.hs index 87697bc359..8bdbb81daf 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Monad.hs @@ -31,7 +31,6 @@ module Language.Futhark.TypeChecker.Terms.Monad -- * Primitive checking unifies, - require, checkTypeExpNonrigid, lookupVar, lookupMod, @@ -537,21 +536,10 @@ lookupVar loc qn@(QualName qs name) inst_t = do (tnames, t) <- instTypeScheme qn loc tparams bound_t $ first (const ()) inst_t outer_env <- asks termOuterEnv pure $ qualifyTypeVars outer_env tnames qs t - Just EqualityF -> do - argtype <- newTypeVar loc "t" - equalityType usage argtype - pure $ - Scalar . Arrow mempty Unnamed Observe argtype . RetType [] $ - Scalar $ - Arrow mempty Unnamed Observe argtype $ - RetType [] $ - Scalar $ - Prim Bool - Just (OverloadedF ts pts rt) -> do - argtype <- newTypeVar loc "t" - mustBeOneOf ts usage argtype - let (pts', rt') = instOverloaded argtype pts rt - pure $ foldFunType (map (toParam Observe) pts') $ RetType [] $ toRes Nonunique rt' + Just EqualityF -> + replaceTyVars loc inst_t + Just OverloadedF {} -> + replaceTyVars loc inst_t where instOverloaded argtype pts rt = ( map (maybe (toStruct argtype) (Scalar . Prim)) pts, @@ -639,13 +627,6 @@ unifies why t e = do unify (mkUsage (srclocOf e) why) t . toStruct =<< expType e pure e --- | @require ts e@ causes a 'TypeError' if @expType e@ is not one of --- the types in @ts@. Otherwise, simply returns @e@. -require :: T.Text -> [PrimType] -> Exp -> TermTypeM Exp -require why ts e = do - mustBeOneOf ts (mkUsage (srclocOf e) why) . toStruct =<< expType e - pure e - checkExpForSize :: ExpBase Info VName -> TermTypeM Exp checkExpForSize e = do checker <- asks termCheckExp diff --git a/src/Language/Futhark/TypeChecker/Unify.hs b/src/Language/Futhark/TypeChecker/Unify.hs index 1d26af8354..8e1c414e6c 100644 --- a/src/Language/Futhark/TypeChecker/Unify.hs +++ b/src/Language/Futhark/TypeChecker/Unify.hs @@ -18,7 +18,6 @@ module Language.Futhark.TypeChecker.Unify arrayElemType, mustHaveConstr, mustHaveField, - mustBeOneOf, equalityType, normType, normTypeFully, @@ -119,7 +118,6 @@ data Constraint = NoConstraint Liftedness Usage | ParamType Liftedness Loc | Constraint StructRetType Usage - | Overloaded [PrimType] Usage | HasFields Liftedness (M.Map Name StructType) Usage | Equality Usage | HasConstrs Liftedness (M.Map Name [StructType]) Usage @@ -138,7 +136,6 @@ instance Located Constraint where locOf (NoConstraint _ usage) = locOf usage locOf (ParamType _ usage) = locOf usage locOf (Constraint _ usage) = locOf usage - locOf (Overloaded _ usage) = locOf usage locOf (HasFields _ _ usage) = locOf usage locOf (Equality usage) = locOf usage locOf (HasConstrs _ _ usage) = locOf usage @@ -282,8 +279,6 @@ typeVarNotes v = maybe mempty (note . snd) . M.lookup v <$> getConstraints <+> "=" <+> hsep (map ppConstr (M.toList cs)) <+> "..." - note (Overloaded ts _) = - aNote $ prettyName v <+> "must be one of" <+> mconcat (punctuate ", " (map pretty ts)) note (HasFields _ fs _) = aNote $ prettyName v @@ -685,26 +680,6 @@ linkVarToType onDims usage bound bcs vn lvl tp_unnorm = do Just (Equality _) -> do link equalityType usage tp - Just (Overloaded ts old_usage) - | tp `notElem` map (Scalar . Prim) ts -> do - link - case tp of - Scalar (TypeVar _ (QualName [] v) []) - | not $ isRigid v constraints -> - linkVarToTypes usage v ts - _ -> - unifyError usage mempty bcs $ - "Cannot instantiate" - <+> dquotes (prettyName vn) - <+> "with type" - indent 2 (pretty tp) - "as" - <+> dquotes (prettyName vn) - <+> "must be one of" - <+> commasep (map pretty ts) - "due to" - <+> pretty old_usage - <> "." Just (HasFields l required_fields old_usage) -> do when (l == Unlifted) $ arrayElemTypeWith usage (unliftedBcs old_usage) tp case tp of @@ -846,63 +821,6 @@ linkVarToDim usage bcs vn lvl e = do _ -> modifyConstraints $ M.insert dim' (lvl, c) checkVar _ _ = pure () --- | Assert that this type must be one of the given primitive types. -mustBeOneOf :: (MonadUnify m) => [PrimType] -> Usage -> StructType -> m () -mustBeOneOf [req_t] usage t = unify usage (Scalar (Prim req_t)) t -mustBeOneOf ts usage t = do - t' <- normType t - constraints <- getConstraints - let isRigid' v = isRigid v constraints - - case t' of - Scalar (TypeVar _ (QualName [] v) []) - | not $ isRigid' v -> linkVarToTypes usage v ts - Scalar (Prim pt) | pt `elem` ts -> pure () - _ -> failure - where - failure = - unifyError usage mempty noBreadCrumbs $ - "Cannot unify type" - <+> dquotes (pretty t) - <+> "with any of " - <> commasep (map pretty ts) - <> "." - -linkVarToTypes :: (MonadUnify m) => Usage -> VName -> [PrimType] -> m () -linkVarToTypes usage vn ts = do - vn_constraint <- M.lookup vn <$> getConstraints - case vn_constraint of - Just (lvl, Overloaded vn_ts vn_usage) -> - case ts `intersect` vn_ts of - [] -> - unifyError usage mempty noBreadCrumbs $ - "Type constrained to one of" - <+> commasep (map pretty ts) - <+> "but also one of" - <+> commasep (map pretty vn_ts) - <+> "due to" - <+> pretty vn_usage - <> "." - ts' -> modifyConstraints $ M.insert vn (lvl, Overloaded ts' usage) - Just (_, HasConstrs _ _ vn_usage) -> - unifyError usage mempty noBreadCrumbs $ - "Type constrained to one of" - <+> commasep (map pretty ts) - <> ", but also inferred to be sum type due to" - <+> pretty vn_usage - <> "." - Just (_, HasFields _ _ vn_usage) -> - unifyError usage mempty noBreadCrumbs $ - "Type constrained to one of" - <+> commasep (map pretty ts) - <> ", but also inferred to be record due to" - <+> pretty vn_usage - <> "." - Just (lvl, _) -> modifyConstraints $ M.insert vn (lvl, Overloaded ts usage) - Nothing -> - unifyError usage mempty noBreadCrumbs $ - "Cannot constrain type to one of" <+> commasep (map pretty ts) - -- | Assert that this type must support equality. equalityType :: (MonadUnify m, Pretty (Shape dim), Pretty u) => @@ -932,8 +850,6 @@ equalityType usage t = do | otherwise -> pure () Just (lvl, NoConstraint _ _) -> modifyConstraints $ M.insert vn (lvl, Equality usage) - Just (_, Overloaded _ _) -> - pure () -- All primtypes support equality. Just (_, Equality {}) -> pure () _ -> From ae1f529e544854f30d090360bb4647ed1b930061 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sat, 17 Feb 2024 14:43:06 +0100 Subject: [PATCH 066/258] Remove more guts. --- src/Language/Futhark/TypeChecker/Terms.hs | 49 +--- src/Language/Futhark/TypeChecker/Unify.hs | 310 +--------------------- 2 files changed, 19 insertions(+), 340 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 58e1b009c3..08c7585de6 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -56,12 +56,6 @@ hasBinding e = isNothing $ astMap m e m = identityMapper {mapOnExp = \e' -> if hasBinding e' then Nothing else Just e'} -overloadedTypeVars :: Constraints -> Names -overloadedTypeVars = mconcat . map f . M.elems - where - f (_, HasFields _ fs _) = mconcat $ map typeVars $ M.elems fs - f _ = mempty - --- Basic checking -- | Determine if the two types are identical, ignoring uniqueness. @@ -495,8 +489,11 @@ checkExp (AppExp (BinOp (op, oploc) (Info op_t) (e1, _) (e2, _) loc) _) = do checkExp (Project k e _ loc) = do e' <- checkExp e t <- expType e' - kt <- mustHaveField (mkUsage loc $ docText $ "projection of field " <> dquotes (pretty k)) k t - pure $ Project k e' (Info kt) loc + case t of + Scalar (Record fs) + | Just kt <- M.lookup k fs -> + pure $ Project k e' (Info kt) loc + _ -> error $ "checkExp Project: " <> show t checkExp (AppExp (If e1 e2 e3 loc) _) = do e1' <- checkExp e1 e2' <- checkExp e2 @@ -765,12 +762,9 @@ checkExp (OpSectionRight op (Info op_t) e _ _ loc) = do _ -> typeError loc mempty $ "Operator section with invalid operator of type" <+> pretty ftype -checkExp (ProjectSection fields _ loc) = do - a <- newTypeVar loc "a" - let usage = mkUsage loc "projection at" - b <- foldM (flip $ mustHaveField usage) a fields - let ft = Scalar $ Arrow mempty Unnamed Observe a $ RetType [] $ toRes Nonunique b - pure $ ProjectSection fields (Info ft) loc +checkExp (ProjectSection fields (Info t) loc) = do + t' <- replaceTyVars loc t + pure $ ProjectSection fields (Info t') loc checkExp (IndexSection slice _ loc) = do slice' <- checkSlice slice (t, _) <- newArrayType (mkUsage' loc) "e" $ sliceDims slice' @@ -1292,23 +1286,6 @@ fixOverloadedTypes tyvars_at_toplevel = Scalar (tupleRecord []) when (v `S.member` tyvars_at_toplevel) $ warn usage "Defaulting ambiguous type to ()." - fixOverloaded (_, Equality usage) = - typeError usage mempty . withIndexLink "ambiguous-type" $ - "Type is ambiguous (must be equality type)." - "Add a type annotation to disambiguate the type." - fixOverloaded (_, HasFields _ fs usage) = - typeError usage mempty . withIndexLink "ambiguous-type" $ - "Type is ambiguous. Must be record with fields:" - indent 2 (stack $ map field $ M.toList fs) - "Add a type annotation to disambiguate the type." - where - field (l, t) = pretty l <> colon <+> align (pretty t) - fixOverloaded (_, HasConstrs _ cs usage) = - typeError usage mempty . withIndexLink "ambiguous-type" $ - "Type is ambiguous (must be a sum type with constructors:" - <+> pretty (Sum cs) - <> ")." - "Add a type annotation to disambiguate the type." fixOverloaded (v, Size Nothing (Usage Nothing loc)) = typeError loc mempty . withIndexLink "ambiguous-size" $ "Ambiguous size" <+> dquotes (prettyName v) <> "." @@ -1552,18 +1529,12 @@ letGeneralise defname defloc tparams params restype = -- -- (2) are not used in the (new) definition of any type variables -- known before we checked this function. - -- - -- (3) are not referenced from an overloaded type (for example, - -- are the element types of an incompletely resolved record type). - -- This is a bit more restrictive than I'd like, and SML for - -- example does not have this restriction. - -- + -- Criteria (1) and (2) is implemented by looking at the binding -- level of the type variables. - let keep_type_vars = overloadedTypeVars now_substs cur_lvl <- curLevel - let candidate k (lvl, _) = (k `S.notMember` keep_type_vars) && lvl >= (cur_lvl - length params) + let candidate k (lvl, _) = lvl >= (cur_lvl - length params) new_substs = M.filterWithKey candidate now_substs (tparams', RetType ret_dims restype') <- diff --git a/src/Language/Futhark/TypeChecker/Unify.hs b/src/Language/Futhark/TypeChecker/Unify.hs index 8e1c414e6c..4493b02b2d 100644 --- a/src/Language/Futhark/TypeChecker/Unify.hs +++ b/src/Language/Futhark/TypeChecker/Unify.hs @@ -16,9 +16,6 @@ module Language.Futhark.TypeChecker.Unify dimNotes, zeroOrderType, arrayElemType, - mustHaveConstr, - mustHaveField, - equalityType, normType, normTypeFully, unify, @@ -30,7 +27,7 @@ where import Control.Monad import Control.Monad.Except import Control.Monad.State -import Data.List (foldl', intersect) +import Data.List (foldl') import Data.Map.Strict qualified as M import Data.Maybe import Data.Set qualified as S @@ -118,9 +115,6 @@ data Constraint = NoConstraint Liftedness Usage | ParamType Liftedness Loc | Constraint StructRetType Usage - | HasFields Liftedness (M.Map Name StructType) Usage - | Equality Usage - | HasConstrs Liftedness (M.Map Name [StructType]) Usage | ParamSize Loc | -- | Is not actually a type, but a term-level size, -- possibly already set to something specific. @@ -136,9 +130,6 @@ instance Located Constraint where locOf (NoConstraint _ usage) = locOf usage locOf (ParamType _ usage) = locOf usage locOf (Constraint _ usage) = locOf usage - locOf (HasFields _ _ usage) = locOf usage - locOf (Equality usage) = locOf usage - locOf (HasConstrs _ _ usage) = locOf usage locOf (ParamSize loc) = locOf loc locOf (Size _ usage) = locOf usage locOf (UnknownSize loc _) = locOf loc @@ -270,25 +261,6 @@ typeNotes ctx = . fvVars . freeInType -typeVarNotes :: (MonadUnify m) => VName -> m Notes -typeVarNotes v = maybe mempty (note . snd) . M.lookup v <$> getConstraints - where - note (HasConstrs _ cs _) = - aNote $ - prettyName v - <+> "=" - <+> hsep (map ppConstr (M.toList cs)) - <+> "..." - note (HasFields _ fs _) = - aNote $ - prettyName v - <+> "=" - <+> braces (mconcat (punctuate ", " (map ppField (M.toList fs)))) - note _ = mempty - - ppConstr (c, _) = "#" <> pretty c <+> "..." <+> "|" - ppField (f, _) = prettyName f <> ":" <+> "..." - -- | Monads that which to perform unification must implement this type -- class. class (Monad m) => MonadUnify m where @@ -354,12 +326,6 @@ unsharedConstructorsMsg cs1 cs2 = filter (`notElem` M.keys cs1) (M.keys cs2) ++ filter (`notElem` M.keys cs2) (M.keys cs1) --- | Is the given type variable the name of an abstract type or type --- parameter, which we cannot substitute? -isRigid :: VName -> Constraints -> Bool -isRigid v constraints = - maybe True (rigidConstraint . snd) $ M.lookup v constraints - -- | If the given type variable is nonrigid, what is its level? isNonRigid :: VName -> Constraints -> Maybe Level isNonRigid v constraints = do @@ -370,10 +336,6 @@ isNonRigid v constraints = do type UnifySizes m = BreadCrumbs -> [VName] -> (VName -> Maybe Int) -> Exp -> Exp -> m () -flipUnifySizes :: UnifySizes m -> UnifySizes m -flipUnifySizes onDims bcs bound nonrigid t1 t2 = - onDims bcs bound nonrigid t2 t1 - unifyWith :: (MonadUnify m) => UnifySizes m -> @@ -398,14 +360,7 @@ unifyWith onDims usage = subunify False failure = matchError (srclocOf usage) mempty bcs t1' t2' - link ord' = - linkVarToType linkDims usage bound bcs - where - -- We may have to flip the order of future calls to - -- onDims inside linkVarToType. - linkDims - | ord' = flipUnifySizes onDims - | otherwise = onDims + link = linkVarToType usage bound bcs unifyTypeArg bcs' (TypeArgDim d1) (TypeArgDim d2) = onDims' bcs' (swap ord d1 d2) @@ -452,17 +407,17 @@ unifyWith onDims usage = subunify False ) -> case (nonrigid v1, nonrigid v2) of (Nothing, Nothing) -> failure - (Just lvl1, Nothing) -> link ord v1 lvl1 t2' - (Nothing, Just lvl2) -> link (not ord) v2 lvl2 t1' + (Just lvl1, Nothing) -> link v1 lvl1 t2' + (Nothing, Just lvl2) -> link v2 lvl2 t1' (Just lvl1, Just lvl2) - | lvl1 <= lvl2 -> link ord v1 lvl1 t2' - | otherwise -> link (not ord) v2 lvl2 t1' + | lvl1 <= lvl2 -> link v1 lvl1 t2' + | otherwise -> link v2 lvl2 t1' (Scalar (TypeVar _ (QualName [] v1) []), _) | Just lvl <- nonrigid v1 -> - link ord v1 lvl t2' + link v1 lvl t2' (_, Scalar (TypeVar _ (QualName [] v2) [])) | Just lvl <- nonrigid v2 -> - link (not ord) v2 lvl t1' + link v2 lvl t1' ( Scalar (Arrow _ p1 d1 a1 (RetType b1_dims b1)), Scalar (Arrow _ p2 d2 a2 (RetType b2_dims b2)) ) @@ -625,7 +580,6 @@ scopeCheck usage bcs vn max_lvl tp = do linkVarToType :: (MonadUnify m) => - UnifySizes m -> Usage -> [VName] -> BreadCrumbs -> @@ -633,7 +587,7 @@ linkVarToType :: Level -> StructType -> m () -linkVarToType onDims usage bound bcs vn lvl tp_unnorm = do +linkVarToType usage bound bcs vn lvl tp_unnorm = do -- We have to expand anyway for the occurs check, so we might as -- well link the fully expanded type. tp <- normTypeFully tp_unnorm @@ -677,105 +631,7 @@ linkVarToType onDims usage bound bcs vn lvl tp_unnorm = do <+> "cannot be instantiated with type containing anonymous sizes:" indent 2 (pretty tp) textwrap "This is usually because the size of an array returned by a higher-order function argument cannot be determined statically. This can also be due to the return size being a value parameter. Add type annotation to clarify." - Just (Equality _) -> do - link - equalityType usage tp - Just (HasFields l required_fields old_usage) -> do - when (l == Unlifted) $ arrayElemTypeWith usage (unliftedBcs old_usage) tp - case tp of - Scalar (Record tp_fields) - | all (`M.member` tp_fields) $ M.keys required_fields -> do - required_fields' <- mapM normTypeFully required_fields - let tp' = Scalar $ Record $ required_fields <> tp_fields -- Crucially left-biased. - ext = filter (`S.member` fvVars (freeInType tp')) bound - modifyConstraints $ - M.insert vn (lvl, Constraint (RetType ext tp') usage) - unifySharedFields onDims usage bound bcs required_fields' tp_fields - Scalar (TypeVar _ (QualName [] v) []) -> do - case M.lookup v constraints of - Just (_, HasFields _ tp_fields _) -> - unifySharedFields onDims usage bound bcs required_fields tp_fields - Just (_, NoConstraint {}) -> pure () - Just (_, Equality {}) -> pure () - _ -> do - notes <- (<>) <$> typeVarNotes vn <*> typeVarNotes v - noRecordType notes - link - modifyConstraints $ - M.insertWith - combineFields - v - (lvl, HasFields l required_fields old_usage) - where - combineFields (_, HasFields l1 fs1 usage1) (_, HasFields l2 fs2 _) = - (lvl, HasFields (l1 `min` l2) (M.union fs1 fs2) usage1) - combineFields hasfs _ = hasfs - _ -> - unifyError usage mempty bcs $ - "Cannot instantiate" - <+> dquotes (prettyName vn) - <+> "with type" - indent 2 (pretty tp) - "as" - <+> dquotes (prettyName vn) - <+> "must be a record with fields" - indent 2 (pretty (Record required_fields)) - "due to" - <+> pretty old_usage - <> "." - -- See Note [Linking variables to sum types] - Just (HasConstrs l required_cs old_usage) -> do - when (l == Unlifted) $ arrayElemTypeWith usage (unliftedBcs old_usage) tp - case tp of - Scalar (Sum ts) - | all (`M.member` ts) $ M.keys required_cs -> do - let tp' = Scalar $ Sum $ required_cs <> ts -- Crucially left-biased. - ext = filter (`S.member` fvVars (freeInType tp')) bound - modifyConstraints $ - M.insert vn (lvl, Constraint (RetType ext tp') usage) - unifySharedConstructors onDims usage bound bcs required_cs ts - | otherwise -> - unsharedConstructors required_cs ts =<< typeVarNotes vn - Scalar (TypeVar _ (QualName [] v) []) -> do - case M.lookup v constraints of - Just (_, HasConstrs _ v_cs _) -> - unifySharedConstructors onDims usage bound bcs required_cs v_cs - Just (_, NoConstraint {}) -> pure () - Just (_, Equality {}) -> pure () - _ -> do - notes <- (<>) <$> typeVarNotes vn <*> typeVarNotes v - noSumType notes - link - modifyConstraints $ - M.insertWith - combineConstrs - v - (lvl, HasConstrs l required_cs old_usage) - where - combineConstrs (_, HasConstrs l1 cs1 usage1) (_, HasConstrs l2 cs2 _) = - (lvl, HasConstrs (l1 `min` l2) (M.union cs1 cs2) usage1) - combineConstrs hasCs _ = hasCs - _ -> noSumType =<< typeVarNotes vn _ -> link - where - unsharedConstructors cs1 cs2 notes = - unifyError - usage - notes - bcs - (unsharedConstructorsMsg cs1 cs2) - noSumType notes = - unifyError - usage - notes - bcs - "Cannot unify a sum type with a non-sum type." - noRecordType notes = - unifyError - usage - notes - bcs - "Cannot unify a record type with a non-record type." linkVarToDim :: (MonadUnify m) => @@ -821,41 +677,6 @@ linkVarToDim usage bcs vn lvl e = do _ -> modifyConstraints $ M.insert dim' (lvl, c) checkVar _ _ = pure () --- | Assert that this type must support equality. -equalityType :: - (MonadUnify m, Pretty (Shape dim), Pretty u) => - Usage -> - TypeBase dim u -> - m () -equalityType usage t = do - unless (orderZero t) $ - unifyError usage mempty noBreadCrumbs $ - "Type " <+> dquotes (pretty t) <+> "does not support equality (may contain function)." - mapM_ mustBeEquality $ typeVars t - where - mustBeEquality vn = do - constraints <- getConstraints - case M.lookup vn constraints of - Just (_, Constraint (RetType [] (Scalar (TypeVar _ (QualName [] vn') []))) _) -> - mustBeEquality vn' - Just (_, Constraint (RetType _ vn_t) cusage) - | not $ orderZero vn_t -> - unifyError usage mempty noBreadCrumbs $ - "Type" - <+> dquotes (pretty t) - <+> "does not support equality." - "Constrained to be higher-order due to" - <+> pretty cusage - <+> "." - | otherwise -> pure () - Just (lvl, NoConstraint _ _) -> - modifyConstraints $ M.insert vn (lvl, Equality usage) - Just (_, Equality {}) -> - pure () - _ -> - unifyError usage mempty noBreadCrumbs $ - "Type" <+> prettyName vn <+> "does not support equality." - zeroOrderTypeWith :: (MonadUnify m) => Usage -> @@ -873,10 +694,6 @@ zeroOrderTypeWith usage bcs t = do case M.lookup vn constraints of Just (lvl, NoConstraint _ _) -> modifyConstraints $ M.insert vn (lvl, NoConstraint Unlifted usage) - Just (lvl, HasFields _ fs _) -> - modifyConstraints $ M.insert vn (lvl, HasFields Unlifted fs usage) - Just (lvl, HasConstrs _ cs _) -> - modifyConstraints $ M.insert vn (lvl, HasConstrs Unlifted cs usage) Just (_, ParamType Lifted ploc) -> unifyError usage mempty bcs $ "Type parameter" @@ -967,96 +784,6 @@ unifySharedConstructors onDims usage bound bcs cs1 cs2 = unifyError usage mempty bcs $ "Cannot unify constructor" <+> dquotes (prettyName c) <> "." --- | In @mustHaveConstr usage c t fs@, the type @t@ must have a --- constructor named @c@ that takes arguments of types @ts@. -mustHaveConstr :: - (MonadUnify m) => - Usage -> - Name -> - StructType -> - [StructType] -> - m () -mustHaveConstr usage c t fs = do - constraints <- getConstraints - case t of - Scalar (TypeVar _ (QualName _ tn) []) - | Just (lvl, NoConstraint l _) <- M.lookup tn constraints -> do - mapM_ (scopeCheck usage noBreadCrumbs tn lvl) fs - modifyConstraints $ M.insert tn (lvl, HasConstrs l (M.singleton c fs) usage) - | Just (lvl, HasConstrs l cs _) <- M.lookup tn constraints -> - case M.lookup c cs of - Nothing -> - modifyConstraints $ - M.insert tn (lvl, HasConstrs l (M.insert c fs cs) usage) - Just fs' - | length fs == length fs' -> zipWithM_ (unify usage) fs fs' - | otherwise -> - unifyError usage mempty noBreadCrumbs $ - "Different arity for constructor" <+> dquotes (pretty c) <> "." - Scalar (Sum cs) -> - case M.lookup c cs of - Nothing -> - unifyError usage mempty noBreadCrumbs $ - "Constuctor" <+> dquotes (pretty c) <+> "not present in type." - Just fs' - | length fs == length fs' -> zipWithM_ (unify usage) fs fs' - | otherwise -> - unifyError usage mempty noBreadCrumbs $ - "Different arity for constructor" <+> dquotes (pretty c) <+> "." - _ -> - unify usage t $ Scalar $ Sum $ M.singleton c fs - -mustHaveFieldWith :: - (MonadUnify m) => - UnifySizes m -> - Usage -> - [VName] -> - BreadCrumbs -> - Name -> - StructType -> - m StructType -mustHaveFieldWith onDims usage bound bcs l t = do - constraints <- getConstraints - l_type <- newTypeVar (locOf usage) "t" - case t of - Scalar (TypeVar _ (QualName _ tn) []) - | Just (lvl, NoConstraint {}) <- M.lookup tn constraints -> do - scopeCheck usage bcs tn lvl l_type - modifyConstraints $ M.insert tn (lvl, HasFields Lifted (M.singleton l l_type) usage) - pure l_type - | Just (lvl, HasFields lifted fields _) <- M.lookup tn constraints -> do - case M.lookup l fields of - Just t' -> unifyWith onDims usage bound bcs l_type t' - Nothing -> - modifyConstraints $ - M.insert - tn - (lvl, HasFields lifted (M.insert l l_type fields) usage) - pure l_type - Scalar (Record fields) - | Just t' <- M.lookup l fields -> do - unify usage l_type t' - pure t' - | otherwise -> - unifyError usage mempty bcs $ - "Attempt to access field" - <+> dquotes (pretty l) - <+> " of value of type" - <+> pretty (toStructural t) - <> "." - _ -> do - unify usage t $ Scalar $ Record $ M.singleton l l_type - pure l_type - --- | Assert that some type must have a field with this name and type. -mustHaveField :: - (MonadUnify m) => - Usage -> - Name -> - StructType -> - m StructType -mustHaveField usage = mustHaveFieldWith (unifySizes usage) usage mempty noBreadCrumbs - newDimOnMismatch :: (MonadUnify m) => Loc -> @@ -1180,22 +907,3 @@ doUnification loc rigid_tparams nonrigid_tparams t1 t2 = runUnifyM rigid_tparams nonrigid_tparams $ do unify (Usage Nothing (locOf loc)) t1 t2 normTypeFully t2 - --- Note [Linking variables to sum types] --- --- Consider the case when unifying a result type --- --- i32 -> ?[n].(#foo [n]bool) --- --- with --- --- i32 -> ?[k].a --- --- where 'a' has a HasConstrs constraint saying that it must have at --- least a constructor of type '#foo [0]bool'. --- --- This unification should succeed, but we must not merely link 'a' to --- '#foo [n]bool', as 'n' is not free. Instead we should instantiate --- 'a' to be a concrete sum type (because now we know exactly which --- constructor labels it must have), and unify each of its constructor --- payloads with the corresponding expected payload. From 0f6156a2dc914c94c53215f8517e8928ccf3c619 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sat, 17 Feb 2024 18:15:38 +0100 Subject: [PATCH 067/258] Must also unify here. --- src/Language/Futhark/TypeChecker/Terms.hs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 08c7585de6..28e2e9f2e5 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -676,6 +676,8 @@ checkExp (Lambda params body rettype_te (Info (RetType _ rt)) loc) = do body' <- checkFunBody params' body declared_rettype loc body_t <- expTypeFully body' + unify (mkUsage body "inferred return type") (toStruct rt') body_t + params'' <- mapM updateTypes params' (rettype', rettype_st) <- @@ -1534,8 +1536,8 @@ letGeneralise defname defloc tparams params restype = -- level of the type variables. cur_lvl <- curLevel - let candidate k (lvl, _) = lvl >= (cur_lvl - length params) - new_substs = M.filterWithKey candidate now_substs + let candidate (lvl, _) = lvl >= (cur_lvl - length params) + new_substs = M.filter candidate now_substs (tparams', RetType ret_dims restype') <- closeOverTypes From 9bf72c86dbfb95bfb9106d2b9e508b9702895de6 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sat, 17 Feb 2024 18:50:48 +0100 Subject: [PATCH 068/258] AUTOMAP does not work yet. --- prelude/zip.fut | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prelude/zip.fut b/prelude/zip.fut index 18361e545f..1171820307 100644 --- a/prelude/zip.fut +++ b/prelude/zip.fut @@ -11,7 +11,7 @@ -- depended upon by soacs.fut. So we just define a quick-and-dirty -- internal one here that uses the intrinsic version. local def internal_map 'a [n] 'x (f: a -> x) (as: [n]a): *[n]x = - f as + intrinsics.map f as -- | Construct an array of pairs from two arrays. def zip [n] 'a 'b (as: [n]a) (bs: [n]b): *[n](a,b) = From 7d3bcc7e46942e08f1343f2590576021179baf4d Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Sat, 17 Feb 2024 21:11:31 -0800 Subject: [PATCH 069/258] Defunctionalization and internalization AUTOMAP progress. --- src/Futhark/IR/Syntax/Core.hs | 6 ++ src/Futhark/Internalise/Defunctionalise.hs | 64 +++++++++-- src/Futhark/Internalise/Exps.hs | 119 +++++++++++++++++++-- src/Language/Futhark/Prop.hs | 14 +++ src/Language/Futhark/TypeChecker/Rank.hs | 4 +- src/Language/Futhark/TypeChecker/Terms2.hs | 32 +++--- 6 files changed, 204 insertions(+), 35 deletions(-) diff --git a/src/Futhark/IR/Syntax/Core.hs b/src/Futhark/IR/Syntax/Core.hs index 227c25b23b..982fadcdec 100644 --- a/src/Futhark/IR/Syntax/Core.hs +++ b/src/Futhark/IR/Syntax/Core.hs @@ -15,6 +15,7 @@ module Futhark.IR.Syntax.Core ShapeBase (..), Shape, stripDims, + takeDims, Ext (..), ExtSize, ExtShape, @@ -128,6 +129,11 @@ instance Monoid (ShapeBase d) where stripDims :: Int -> ShapeBase d -> ShapeBase d stripDims n (Shape dims) = Shape $ drop n dims +-- | @takeDims n shape@ takes the outer @n@ dimensions from +-- @shape@. If @shape@ has m <= n dimensions, it returns $shape$. +takeDims :: Int -> ShapeBase d -> ShapeBase d +takeDims n (Shape dims) = Shape $ take n dims + -- | The size of an array as a list of subexpressions. If a variable, -- that variable must be in scope where this array is used. type Shape = ShapeBase SubExp diff --git a/src/Futhark/Internalise/Defunctionalise.hs b/src/Futhark/Internalise/Defunctionalise.hs index 82cc845d69..d5894bec66 100644 --- a/src/Futhark/Internalise/Defunctionalise.hs +++ b/src/Futhark/Internalise/Defunctionalise.hs @@ -14,6 +14,7 @@ import Data.List.NonEmpty qualified as NE import Data.Map.Strict qualified as M import Data.Maybe import Data.Set qualified as S +import Debug.Trace import Futhark.IR.Pretty () import Futhark.MonadFreshNames import Futhark.Util (mapAccumLM, nubOrd) @@ -905,7 +906,7 @@ defuncApplyArg :: (Exp, StaticVal) -> (((Maybe VName, AutoMap), Exp), [ParamType]) -> DefM (Exp, StaticVal) -defuncApplyArg fname_s (f', LambdaSV pat lam_e_t lam_e closure_env) (((argext, _), arg), _) = do +defuncApplyArg fname_s (f', LambdaSV pat lam_e_t lam_e closure_env) (((argext, am), arg), _) = do (arg', arg_sv) <- defuncExp arg let env' = alwaysMatchPatSV pat arg_sv dims = mempty @@ -955,20 +956,29 @@ defuncApplyArg fname_s (f', LambdaSV pat lam_e_t lam_e closure_env) (((argext, _ fname' = Var (qualName fname) (Info fname_t) (srclocOf arg) callret <- unRetType lifted_rettype + traceM $ + unlines + [ "sv", + show sv, + "ret sv", + show $ autoMapSV (autoMap am) sv + ] + pure - ( mkApply fname' [(Nothing, mempty, f'), (argext, mempty, arg')] callret, - sv + ( mkApply fname' [(Nothing, mempty, f'), (argext, am, arg')] callret, + autoMapSV (autoMap am) sv ) -- If 'f' is a dynamic function, we just leave the application in -- place, but we update the types since it may be partially -- applied or return a higher-order value. -defuncApplyArg _ (f', DynamicFun _ sv) (((argext, _), arg), argtypes) = do +defuncApplyArg _ (f', DynamicFun _ sv) (((argext, am), arg), argtypes) = do (arg', _) <- defuncExp arg let (argtypes', rettype) = dynamicFunType sv argtypes restype = foldFunType argtypes' (RetType [] rettype) callret = AppRes restype [] - apply_e = mkApply f' [(argext, mempty, arg')] callret - pure (apply_e, sv) + apply_e = mkApply f' [(argext, am, arg')] callret + -- pure (apply_e, autoMapSV (autoRep am) sv) + pure (apply_e, autoMapSV (autoMap am) sv) -- defuncApplyArg fname_s (_, sv) ((_, arg), _) = error $ @@ -984,6 +994,11 @@ updateReturn (AppRes ret1 ext1) (AppExp apply (Info (AppRes ret2 ext2))) = AppExp apply $ Info $ AppRes (combineTypeShapes ret1 ret2) (ext1 <> ext2) updateReturn _ e = e +autoMapSV :: Shape Size -> StaticVal -> StaticVal +autoMapSV shape (Dynamic t) = + Dynamic $ arrayOfWithAliases (diet t) shape t +autoMapSV _ sv = sv + defuncApply :: Exp -> NE.NonEmpty ((Maybe VName, AutoMap), Exp) -> AppRes -> SrcLoc -> DefM (Exp, StaticVal) defuncApply f args appres loc = do (f', f_sv) <- defuncApplyFunction f (length args) @@ -999,10 +1014,39 @@ defuncApply f args appres loc = do _ -> do let fname = liftedName 0 f (argtypes, _) = unfoldFunType $ typeOf f - fmap (first $ updateReturn appres) $ - foldM (defuncApplyArg fname) (f', f_sv) $ - NE.zip args $ - NE.tails argtypes + (app, app_sv) <- + fmap (first $ updateReturn appres) $ + foldM (defuncApplyArg fname) (f', f_sv) $ + NE.zip args $ + NE.tails argtypes + + let (p_ts, _) = unfoldFunType $ typeOf f + arg_ts = typeOf . snd <$> args + -- am_dims = zipWith typeShapePrefix (NE.toList arg_ts) p_ts + -- ret_am = maximumBy (\x y -> shapeRank x `compare` shapeRank y) am_dims + ams = NE.toList $ autoMap . snd . fst <$> args + ret_am = maximumBy (\x y -> shapeRank x `compare` shapeRank y) ams + traceM $ + unlines + [ "## defuncApply", + "## f", + prettyString f, + "## args", + prettyString $ snd <$> args, + "## appres", + show appres, + "## app", + prettyString app, + "## app_sv", + show app_sv, + "## f type", + prettyString $ typeOf f, + "## arg types", + prettyString $ (typeOf . snd) <$> args, + "## ret_am", + prettyString ret_am + ] + pure (app, autoMapSV ret_am $ app_sv) where intrinsicOrHole e' = do -- If the intrinsic is fully applied, then we are done. diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index ec4adece0b..1b96819dc8 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -8,13 +8,15 @@ module Futhark.Internalise.Exps (transformProg) where import Control.Monad import Control.Monad.Reader import Data.Bifunctor +import Data.Either import Data.Foldable (toList) -import Data.List (elemIndex, find, intercalate, intersperse, transpose) +import Data.List (elemIndex, find, intercalate, intersperse, maximumBy, transpose, zip4) import Data.List.NonEmpty (NonEmpty (..)) import Data.List.NonEmpty qualified as NE import Data.Map.Strict qualified as M import Data.Set qualified as S import Data.Text qualified as T +import Debug.Trace import Futhark.IR.SOACS as I hiding (stmPat) import Futhark.Internalise.AccurateSizes import Futhark.Internalise.Bindings @@ -346,12 +348,15 @@ internaliseAppExp desc (E.AppRes et ext) e@E.Apply {} = let subst = map (,E.ExpSubst (E.sizeFromInteger 0 mempty)) ext et' = E.applySubst (`lookup` subst) et internaliseExp desc (E.Hole (Info et') loc) - (FunctionName qfname, args) -> do + (FunctionName qfname, argsam) -> do -- Argument evaluation is outermost-in so that any existential sizes -- created by function applications can be brought into scope. let fname = nameFromString $ prettyString $ baseName $ qualLeaf qfname loc = srclocOf e arg_desc = nameToString fname ++ "_arg" + args = map (\(a, b, _) -> (a, b)) argsam + ams = map (\(_, _, c) -> c) argsam + res_t = et -- Some functions are magical (overloaded) and we handle that here. case () of @@ -388,8 +393,16 @@ internaliseAppExp desc (E.AppRes et ext) e@E.Apply {} = let args'' = concatMap tag args' letValExp' desc $ I.Apply fname args'' [(I.Prim rettype, mempty)] (Safe, loc, []) | otherwise -> do - args' <- concat . reverse <$> mapM (internaliseArg arg_desc) (reverse args) - funcall desc qfname args' loc + traceM $ + unlines + [ "## qfname", + prettyString qfname + ] + -- args' <- concat . reverse <$> mapM (internaliseArg arg_desc) (reverse args) + -- funcall desc qfname args' loc + + withAutoMap_ ams arg_desc res_t args $ \args' -> + funcall desc qfname (concat args') loc internaliseAppExp desc _ (E.LetPat sizes pat e body _) = internalisePat desc sizes pat e $ internaliseExp desc body internaliseAppExp _ _ (E.LetFun ofname _ _ _) = @@ -890,6 +903,98 @@ internalisePatLit (E.PatLitFloat x) (E.Scalar (E.Prim (E.FloatType ft))) = internalisePatLit l t = error $ "Nonsensical pattern and type: " ++ show (l, t) +withAutoMap_ :: [AutoMap] -> String -> StructType -> [(E.Exp, Maybe VName)] -> ([[SubExp]] -> InternaliseM [SubExp]) -> InternaliseM [SubExp] +withAutoMap_ ams arg_desc res_t args_e innerM = + withAutoMap ams arg_desc res_t args_e $ \args_stms -> do + let (args, stms) = unzip args_stms + mapM_ addStms $ reverse stms + innerM args + +withAutoMap :: [AutoMap] -> String -> StructType -> [(E.Exp, Maybe VName)] -> ([([SubExp], Stms SOACS)] -> InternaliseM [SubExp]) -> InternaliseM [SubExp] +withAutoMap ams arg_desc res_t args_e innerM = do + (args, stms) <- + foldM + ( \(args, stms) arg -> do + (arg', stms') <- inScopeOf (reverse stms) $ collectStms $ internaliseArg arg_desc arg + pure (arg' : args, stms' : stms) + ) + (mempty, mempty) + (reverse args_e) + argts <- inScopeOf (reverse stms) $ (mapM . mapM) subExpType args + expand args stms argts ams (maximum ds) + where + stripAutoMapDims i am = + am {autoMap = E.Shape $ drop i $ E.shapeDims $ autoMap am} + autoMapRank = E.shapeRank . autoMap + max_am = maximumBy (\x y -> E.shapeRank x `compare` E.shapeRank y) $ fmap autoMap ams + inner_t = E.stripArray (E.shapeRank max_am) res_t + ds = map autoMapRank ams + mkLambdaParams level (ses, ts, stm, d) + | d == level = + Left + <$> zipWithM + ( \se t -> do + let t' = I.stripArray 1 t + p <- newParam "x" t' + addStms stm + pure ((se, p), t') + ) + ses + ts + | otherwise = pure $ Right $ zip ses ts + + expand args stms argts ams' level + | level <= 0 = innerM $ zip args stms + | otherwise = do + let ds' = map autoMapRank ams' + arg_params <- mapM (mkLambdaParams level) $ zip4 args argts stms ds' + let argts' = map (either (map snd) (map snd)) arg_params + (ams'', stms') = + unzip $ + zipWith + ( \am stm -> + if autoMapRank am == level + then (stripAutoMapDims 1 am, mempty) + else (am, stm) + ) + ams' + stms + args' = map (either (map (I.Var . I.paramName . snd . fst)) (map fst)) arg_params + (map_ses, params) = unzip $ (concatMap . map) fst $ lefts arg_params + + ((ses, ses_ts), lam_stms) <- collectStms $ localScope (scopeOfLParams params) $ do + ses <- expand args' stms' argts' ams'' (level - 1) + ses_ts <- internaliseLambdaReturnType (E.toRes Nonunique inner_t) =<< mapM subExpType ses + pure (ses, ses_ts) + + case map_ses of + [] -> pure mempty + (map_se : _) -> do + outer_shape <- I.takeDims 1 . I.arrayShape <$> subExpType map_se + let I.Shape [outer_shape_se] = outer_shape + map_args <- forM map_ses $ \se -> do + se_t <- subExpType se + se_name <- letExp "map_arg" =<< toExp se + letExp "reshaped" $ + I.BasicOp $ + I.Reshape + I.ReshapeCoerce + (reshapeOuter outer_shape 1 $ I.arrayShape se_t) + se_name + + letValExp' "automap" + . Op + . Screma outer_shape_se map_args + . mapSOAC + =<< mkLambda + params + ( ensureResultShape + (ErrorMsg [ErrorString "AutoMap: unexpected lambda result size"]) + mempty + ses_ts + =<< (addStms lam_stms >> pure (subExpsRes ses)) + ) + generateCond :: E.Pat StructType -> [I.SubExp] -> @@ -1477,14 +1582,14 @@ data Function | FunctionHole SrcLoc deriving (Show) -findFuncall :: E.AppExp -> (Function, [(E.Exp, Maybe VName)]) +findFuncall :: E.AppExp -> (Function, [(E.Exp, Maybe VName, AutoMap)]) findFuncall (E.Apply f args _) | E.Var fname _ _ <- f = (FunctionName fname, map onArg $ NE.toList args) - | E.Hole (Info _) loc <- f = + | E.Hole (Info t) loc <- f = (FunctionHole loc, map onArg $ NE.toList args) where - onArg (Info (argext, _), e) = (e, argext) + onArg (Info (argext, am), e) = (e, argext, am) findFuncall e = error $ "Invalid function expression in application:\n" ++ prettyString e diff --git a/src/Language/Futhark/Prop.hs b/src/Language/Futhark/Prop.hs index 419e50ba6b..6c69e1ef3d 100644 --- a/src/Language/Futhark/Prop.hs +++ b/src/Language/Futhark/Prop.hs @@ -33,6 +33,8 @@ module Language.Futhark.Prop stripExp, similarExps, frameOf, + shapePrefix, + typeShapePrefix, -- * Queries on patterns and params patIdents, @@ -1443,6 +1445,18 @@ frameOf (AppExp (Apply _ args _) _) = ((\(_, am) -> autoFrame am) . unInfo . fst) $ NE.last args frameOf _ = mempty +-- | @s1 `shapePrefix` s2@ assumes @s1 = prefix <> s2@ and +-- returns @prefix@. +shapePrefix :: Shape dim -> Shape dim -> Shape dim +shapePrefix (Shape ss1) (Shape ss2) = + Shape $ take (length ss1 - length ss2) ss1 + +typeShapePrefix :: TypeBase dim as1 -> TypeBase dim as2 -> Shape dim +typeShapePrefix (Array _ s _) Scalar {} = s +typeShapePrefix (Array _ s1 _) (Array _ s2 _) = + s1 `shapePrefix` s2 +typeShapePrefix _ _ = mempty + -- | An identifier with type- and aliasing information. type Ident = IdentBase Info VName diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 23d295f8ee..14215a7338 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -144,14 +144,14 @@ mkLinearProg counter cs tyVars = rankAnalysis :: Bool -> VNameSource -> Int -> [Ct] -> TyVars -> Maybe ([Ct], TyVars, VNameSource, Int) rankAnalysis _ vns counter [] tyVars = Just ([], tyVars, vns, counter) rankAnalysis use_glpk vns counter cs tyVars = do - traceM $ unlines ["## rankAnalysis prog", prettyString prog] + -- traceM $ unlines ["## rankAnalysis prog", prettyString prog] rank_map <- if use_glpk then snd <$> (unsafePerformIO $ glpk prog) else do (_size, ranks) <- branchAndBound lp pure $ (fromJust . (ranks V.!?)) <$> inv_var_map - traceM $ unlines $ "## rank map" : map prettyString (M.toList rank_map) + -- traceM $ unlines $ "## rank map" : map prettyString (M.toList rank_map) let initEnv = SubstEnv { envTyVars = tyVars, diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 52933dd1a3..d33a2ae17f 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -1112,17 +1112,17 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do tyvars <- gets termTyVars - traceM $ "\n# function " <> prettyNameString fname <> "\n# " <> locStr loc <> "\n" + -- traceM $ "\n# function " <> prettyNameString fname <> "\n# " <> locStr loc <> "\n" vns <- gets termNameSource let use_glpk = True - traceM $ - unlines - [ "## cts:", - unlines $ map prettyString cts - ] + -- traceM $ + -- unlines + -- [ "## cts:", + -- unlines $ map prettyString cts + -- ] case rankAnalysis use_glpk vns counter cts tyvars of Nothing -> error "" @@ -1131,16 +1131,16 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do solution <- traverse (M.traverseWithKey (doDefaults mempty)) $ solve cts' tyvars' - traceM $ - unlines - [ "## constraints:", - unlines $ map prettyString cts', - "## tyvars:", - unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars', - "## solution:", - let p (v, t) = prettyNameString v <> " => " <> prettyString t - in either T.unpack (unlines . map p . M.toList) solution - ] + -- traceM $ + -- unlines + -- [ "## constraints:", + -- unlines $ map prettyString cts', + -- "## tyvars:", + -- unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars', + -- "## solution:", + -- let p (v, t) = prettyNameString v <> " => " <> prettyString t + -- in either T.unpack (unlines . map p . M.toList) solution + -- ] pure (solution, params', retdecl', body') From 3e091b70b8588d8e389d3ac1b67ca3bdf35fb62e Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Sun, 18 Feb 2024 07:31:47 -0800 Subject: [PATCH 070/258] Oops. --- src/Language/Futhark/TypeChecker/Rank.hs | 4 +-- src/Language/Futhark/TypeChecker/Terms2.hs | 32 +++++++++++----------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 14215a7338..23d295f8ee 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -144,14 +144,14 @@ mkLinearProg counter cs tyVars = rankAnalysis :: Bool -> VNameSource -> Int -> [Ct] -> TyVars -> Maybe ([Ct], TyVars, VNameSource, Int) rankAnalysis _ vns counter [] tyVars = Just ([], tyVars, vns, counter) rankAnalysis use_glpk vns counter cs tyVars = do - -- traceM $ unlines ["## rankAnalysis prog", prettyString prog] + traceM $ unlines ["## rankAnalysis prog", prettyString prog] rank_map <- if use_glpk then snd <$> (unsafePerformIO $ glpk prog) else do (_size, ranks) <- branchAndBound lp pure $ (fromJust . (ranks V.!?)) <$> inv_var_map - -- traceM $ unlines $ "## rank map" : map prettyString (M.toList rank_map) + traceM $ unlines $ "## rank map" : map prettyString (M.toList rank_map) let initEnv = SubstEnv { envTyVars = tyVars, diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index d33a2ae17f..52933dd1a3 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -1112,17 +1112,17 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do tyvars <- gets termTyVars - -- traceM $ "\n# function " <> prettyNameString fname <> "\n# " <> locStr loc <> "\n" + traceM $ "\n# function " <> prettyNameString fname <> "\n# " <> locStr loc <> "\n" vns <- gets termNameSource let use_glpk = True - -- traceM $ - -- unlines - -- [ "## cts:", - -- unlines $ map prettyString cts - -- ] + traceM $ + unlines + [ "## cts:", + unlines $ map prettyString cts + ] case rankAnalysis use_glpk vns counter cts tyvars of Nothing -> error "" @@ -1131,16 +1131,16 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do solution <- traverse (M.traverseWithKey (doDefaults mempty)) $ solve cts' tyvars' - -- traceM $ - -- unlines - -- [ "## constraints:", - -- unlines $ map prettyString cts', - -- "## tyvars:", - -- unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars', - -- "## solution:", - -- let p (v, t) = prettyNameString v <> " => " <> prettyString t - -- in either T.unpack (unlines . map p . M.toList) solution - -- ] + traceM $ + unlines + [ "## constraints:", + unlines $ map prettyString cts', + "## tyvars:", + unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars', + "## solution:", + let p (v, t) = prettyNameString v <> " => " <> prettyString t + in either T.unpack (unlines . map p . M.toList) solution + ] pure (solution, params', retdecl', body') From 89939231df90c2b3e6208b64586432c7071636fe Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Sun, 18 Feb 2024 07:59:06 -0800 Subject: [PATCH 071/258] Basic map-only AUTOMAP seems to work now. --- src/Futhark/Internalise/Defunctionalise.hs | 28 ++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/src/Futhark/Internalise/Defunctionalise.hs b/src/Futhark/Internalise/Defunctionalise.hs index d5894bec66..249616ef67 100644 --- a/src/Futhark/Internalise/Defunctionalise.hs +++ b/src/Futhark/Internalise/Defunctionalise.hs @@ -958,15 +958,23 @@ defuncApplyArg fname_s (f', LambdaSV pat lam_e_t lam_e closure_env) (((argext, a traceM $ unlines - [ "sv", + [ "##defuncApplyArg LambdaSV", + "## fname", + fname_s, + "## f'", + prettyString f', + "## arg", + prettyString arg, + "## sv", show sv, - "ret sv", + "## ret sv", show $ autoMapSV (autoMap am) sv ] pure ( mkApply fname' [(Nothing, mempty, f'), (argext, am, arg')] callret, autoMapSV (autoMap am) sv + -- sv ) -- If 'f' is a dynamic function, we just leave the application in -- place, but we update the types since it may be partially @@ -977,8 +985,20 @@ defuncApplyArg _ (f', DynamicFun _ sv) (((argext, am), arg), argtypes) = do restype = foldFunType argtypes' (RetType [] rettype) callret = AppRes restype [] apply_e = mkApply f' [(argext, am, arg')] callret - -- pure (apply_e, autoMapSV (autoRep am) sv) + traceM $ + unlines + [ "##defuncApplyArg DynamicFun", + "## f'", + prettyString f', + "## arg", + prettyString arg, + "## sv", + show sv, + "## ret sv", + show $ autoMapSV (autoMap am) sv + ] pure (apply_e, autoMapSV (autoMap am) sv) +-- pure (apply_e, sv) -- defuncApplyArg fname_s (_, sv) ((_, arg), _) = error $ @@ -1046,7 +1066,7 @@ defuncApply f args appres loc = do "## ret_am", prettyString ret_am ] - pure (app, autoMapSV ret_am $ app_sv) + pure (app, app_sv) where intrinsicOrHole e' = do -- If the intrinsic is fully applied, then we are done. From 4c8a1248a282a53a229b5f6a10e8d8b0247bbebb Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 18 Feb 2024 17:28:57 +0100 Subject: [PATCH 072/258] Please shut up. --- src/Futhark/Solve/GLPK.hs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/Futhark/Solve/GLPK.hs b/src/Futhark/Solve/GLPK.hs index 7b27408a27..fe7ac5d129 100644 --- a/src/Futhark/Solve/GLPK.hs +++ b/src/Futhark/Solve/GLPK.hs @@ -43,5 +43,7 @@ glpk :: F.LinearProg v a -> IO (Maybe (Int, M.Map v Int)) glpk lp = do - (_, mres) <- glpSolveVars mipDefaults $ linearProgToGLPK lp + (_, mres) <- glpSolveVars opts $ linearProgToGLPK lp pure $ (\(opt, vs) -> (truncate opt, fmap truncate vs)) <$> mres + where + opts = mipDefaults {msgLev = MsgOff} From 6407ff3ee9e8604364b0c939ed779be26b1695f1 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 18 Feb 2024 17:40:49 +0100 Subject: [PATCH 073/258] Handle automapped operand. --- src/Language/Futhark/TypeChecker/Terms.hs | 38 +++-------------------- 1 file changed, 5 insertions(+), 33 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 28e2e9f2e5..1f12e844d4 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -987,39 +987,11 @@ checkApply loc (fname, _) (Scalar (Arrow _ pname _ tp1 tp2)) argexp = do } pure (tp1, tp2'', argext, ext, am) -checkApply loc fname tfun@(Scalar TypeVar {}) arg = do - tv <- newTypeVar loc "b" - unify (mkUsage loc "use as function") tfun $ - Scalar (Arrow mempty Unnamed Observe (typeOf arg) $ RetType [] $ paramToRes tv) - tfun' <- normType tfun - checkApply loc fname tfun' arg -checkApply loc (fname, prev_applied) ftype argexp = do - let fname' = maybe "expression" (dquotes . pretty) fname - - typeError loc mempty $ - if prev_applied == 0 - then - "Cannot apply" - <+> fname' - <+> "as function, as it has type:" - indent 2 (pretty ftype) - else - "Cannot apply" - <+> fname' - <+> "to argument #" - <> pretty (prev_applied + 1) - <+> dquotes (shorten $ group $ pretty argexp) - <> "," - "as" - <+> fname' - <+> "only takes" - <+> pretty prev_applied - <+> arguments - <> "." - where - arguments - | prev_applied == 1 = "argument" - | otherwise = "arguments" +checkApply loc fname (Array _ _ t) arg = + -- This implies the function is the result of an automap. + checkApply loc fname (Scalar t) arg +checkApply _ _ _ _ = + error "checkApply: impossible case" -- | Type-check a single expression in isolation. This expression may -- turn out to be polymorphic, in which case the list of type From 9f1093241a154bb2d6d1cf1d243bbce41b1ef503 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 18 Feb 2024 17:45:10 +0100 Subject: [PATCH 074/258] Put these adjacent. --- src/Language/Futhark/TypeChecker/Terms.hs | 36 +++++++++++------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 1f12e844d4..d8766a8b2b 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -466,6 +466,24 @@ checkExp (Coerce e te _ loc) = do t <- expTypeFully e' t' <- matchDims (const . const pure) t te_t pure $ Coerce e' te' (Info t') loc +checkExp (AppExp (Apply fe args loc) _) = do + fe' <- checkExp fe + args' <- mapM (checkExp . snd) args + t <- expType fe' + let fname = + case fe' of + Var v _ _ -> Just v + _ -> Nothing + ((_, exts, rt), args'') <- mapAccumLM (onArg fname) (0, [], t) args' + + pure $ AppExp (Apply fe' args'' loc) $ Info $ AppRes rt exts + where + onArg fname (i, all_exts, t) arg' = do + (_, rt, argext, exts, am) <- checkApply loc (fname, i) t arg' + pure + ( (i + 1, all_exts <> exts, arrayOf (autoFrame am) rt), + (Info (argext, am), arg') + ) checkExp (AppExp (BinOp (op, oploc) (Info op_t) (e1, _) (e2, _) loc) _) = do ftype <- lookupVar oploc op op_t e1' <- checkExp e1 @@ -532,24 +550,6 @@ checkExp (Negate arg loc) = do checkExp (Not arg loc) = do arg' <- checkExp arg pure $ Not arg' loc -checkExp (AppExp (Apply fe args loc) _) = do - fe' <- checkExp fe - args' <- mapM (checkExp . snd) args - t <- expType fe' - let fname = - case fe' of - Var v _ _ -> Just v - _ -> Nothing - ((_, exts, rt), args'') <- mapAccumLM (onArg fname) (0, [], t) args' - - pure $ AppExp (Apply fe' args'' loc) $ Info $ AppRes rt exts - where - onArg fname (i, all_exts, t) arg' = do - (_, rt, argext, exts, am) <- checkApply loc (fname, i) t arg' - pure - ( (i + 1, all_exts <> exts, arrayOf (autoFrame am) rt), - (Info (argext, am), arg') - ) checkExp (AppExp (LetPat sizes pat e body loc) _) = do e' <- checkExp e From 2ea1c6af433b2a53510aa9a4cab2100bf9720b44 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 18 Feb 2024 17:48:23 +0100 Subject: [PATCH 075/258] Add frame to binop result. --- src/Language/Futhark/TypeChecker/Terms.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index d8766a8b2b..7c97574b15 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -503,7 +503,7 @@ checkExp (AppExp (BinOp (op, oploc) (Info op_t) (e1, _) (e2, _) loc) _) = do (e2', Info (p2_ext, am2)) loc ) - (Info (AppRes rt' retext)) + (Info (AppRes (arrayOf (autoFrame am2) rt') retext)) checkExp (Project k e _ loc) = do e' <- checkExp e t <- expType e' From b644e3f6e7aba99a247d6b6c4004266596b7a608 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Sun, 18 Feb 2024 11:50:51 -0800 Subject: [PATCH 076/258] AUTOMAP `OpSection` support. --- src/Futhark/Internalise/Exps.hs | 27 ++++-- src/Futhark/Internalise/FullNormalise.hs | 8 +- src/Futhark/Internalise/Monomorphise.hs | 22 ++--- src/Language/Futhark/Interpreter.hs | 4 +- src/Language/Futhark/Syntax.hs | 4 +- src/Language/Futhark/Traversals.hs | 8 +- src/Language/Futhark/TypeChecker/Terms.hs | 14 ++-- src/Language/Futhark/TypeChecker/Terms2.hs | 96 +++++++++++++++++----- 8 files changed, 124 insertions(+), 59 deletions(-) diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index 1b96819dc8..2e3cabd4f5 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -381,17 +381,28 @@ internaliseAppExp desc (E.AppRes et ext) e@E.Apply {} = -- existential), so we can safely ignore the existential -- dimensions. | Just internalise <- isOverloadedFunction qfname desc loc -> do - let prepareArg (arg, _) = - (E.toStruct (E.typeOf arg),) <$> internaliseExp "arg" arg - internalise =<< mapM prepareArg args + -- let prepareArg (arg, _) = + -- (E.toStruct (E.typeOf arg),) <$> internaliseExp "arg" arg + -- internalise =<< mapM prepareArg args + -- + withAutoMap_ ams arg_desc res_t args $ \args' -> do + let prepareArg (arg, _, am) arg' = + (E.toStruct $ E.stripArray (E.shapeRank $ autoMap am) (E.typeOf arg), arg') + internalise $ zipWith prepareArg argsam args' | Just internalise <- isIntrinsicFunction qfname (map fst args) loc -> internalise desc | baseTag (qualLeaf qfname) <= maxIntrinsicTag, - Just (rettype, _) <- M.lookup fname I.builtInFunctions -> do - let tag ses = [(se, I.Observe) | se <- ses] - args' <- reverse <$> mapM (internaliseArg arg_desc) (reverse args) - let args'' = concatMap tag args' - letValExp' desc $ I.Apply fname args'' [(I.Prim rettype, mempty)] (Safe, loc, []) + -- Just (rettype, _) <- M.lookup fname I.builtInFunctions -> do + -- let tag ses = [(se, I.Observe) | se <- ses] + -- args' <- reverse <$> mapM (internaliseArg arg_desc) (reverse args) + -- let args'' = concatMap tag args' + -- letValExp' desc $ I.Apply fname args'' [(I.Prim rettype, mempty)] (Safe, loc, []) + -- + Just (rettype, _) <- M.lookup fname I.builtInFunctions -> + withAutoMap_ ams arg_desc res_t args $ \args' -> do + let tag ses = [(se, I.Observe) | se <- ses] + let args'' = concatMap tag args' + letValExp' desc $ I.Apply fname args'' [(I.Prim rettype, mempty)] (Safe, loc, []) | otherwise -> do traceM $ unlines diff --git a/src/Futhark/Internalise/FullNormalise.hs b/src/Futhark/Internalise/FullNormalise.hs index 91e16a9a53..6ee354ea4f 100644 --- a/src/Futhark/Internalise/FullNormalise.hs +++ b/src/Futhark/Internalise/FullNormalise.hs @@ -203,13 +203,13 @@ getOrdering final (Lambda params body mte ret loc) = do nameExp final $ Lambda params body' mte ret loc getOrdering _ (OpSection qn ty loc) = pure $ Var qn ty loc -getOrdering final (OpSectionLeft op ty e (Info (xp, _, xext), Info (yp, yty)) (Info (RetType dims ret), Info exts) loc) = do +getOrdering final (OpSectionLeft op ty e (Info (xp, _, xext, xam), Info (yp, yty)) (Info (RetType dims ret), Info exts) loc) = do x <- getOrdering False e yn <- newNameFromString "y" let y = Var (qualName yn) (Info $ toStruct yty) mempty ret' = applySubst (pSubst x y) ret body = - mkApply (Var op ty mempty) [(xext, mempty, x), (Nothing, mempty, y)] $ + mkApply (Var op ty mempty) [(xext, xam, x), (Nothing, mempty, y)] $ AppRes (toStruct ret') exts nameExp final $ Lambda [Id yn (Info yty) mempty] body Nothing (Info (RetType dims ret')) loc where @@ -217,12 +217,12 @@ getOrdering final (OpSectionLeft op ty e (Info (xp, _, xext), Info (yp, yty)) (I | Named p <- xp, p == vn = Just $ ExpSubst x | Named p <- yp, p == vn = Just $ ExpSubst y | otherwise = Nothing -getOrdering final (OpSectionRight op ty e (Info (xp, xty), Info (yp, _, yext)) (Info (RetType dims ret)) loc) = do +getOrdering final (OpSectionRight op ty e (Info (xp, xty), Info (yp, _, yext, yam)) (Info (RetType dims ret)) loc) = do xn <- newNameFromString "x" y <- getOrdering False e let x = Var (qualName xn) (Info $ toStruct xty) mempty ret' = applySubst (pSubst x y) ret - body = mkApply (Var op ty mempty) [(Nothing, mempty, x), (yext, mempty, y)] $ AppRes (toStruct ret') [] + body = mkApply (Var op ty mempty) [(Nothing, mempty, x), (yext, yam, y)] $ AppRes (toStruct ret') [] nameExp final $ Lambda [Id xn (Info xty) mempty] body Nothing (Info (RetType dims ret')) loc where pSubst x y vn diff --git a/src/Futhark/Internalise/Monomorphise.hs b/src/Futhark/Internalise/Monomorphise.hs index ee352f67c6..aada3924c0 100644 --- a/src/Futhark/Internalise/Monomorphise.hs +++ b/src/Futhark/Internalise/Monomorphise.hs @@ -664,27 +664,27 @@ transformExp (Lambda params e0 decl tp loc) = do transformExp (OpSection qn t loc) = transformExp $ Var qn t loc transformExp (OpSectionLeft fname (Info t) e arg (Info rettype, Info retext) loc) = do - let (Info (xp, xtype, xargext), Info (yp, ytype)) = arg + let (Info (xp, xtype, xargext, xam), Info (yp, ytype)) = arg e' <- transformExp e desugarBinOpSection fname (Just e') Nothing t - (xp, xtype, xargext) - (yp, ytype, Nothing) + (xp, xtype, xargext, xam) + (yp, ytype, Nothing, mempty) (rettype, retext) loc transformExp (OpSectionRight fname (Info t) e arg (Info rettype) loc) = do - let (Info (xp, xtype), Info (yp, ytype, yargext)) = arg + let (Info (xp, xtype), Info (yp, ytype, yargext, yam)) = arg e' <- transformExp e desugarBinOpSection fname Nothing (Just e') t - (xp, xtype, Nothing) - (yp, ytype, yargext) + (xp, xtype, Nothing, mempty) + (yp, ytype, yargext, yam) (rettype, []) loc transformExp (ProjectSection fields (Info t) loc) = do @@ -735,12 +735,12 @@ desugarBinOpSection :: Maybe Exp -> Maybe Exp -> StructType -> - (PName, ParamType, Maybe VName) -> - (PName, ParamType, Maybe VName) -> + (PName, ParamType, Maybe VName, AutoMap) -> + (PName, ParamType, Maybe VName, AutoMap) -> (ResRetType, [VName]) -> SrcLoc -> MonoM Exp -desugarBinOpSection fname e_left e_right t (xp, xtype, xext) (yp, ytype, yext) (RetType dims rettype, retext) loc = do +desugarBinOpSection fname e_left e_right t (xp, xtype, xext, xam) (yp, ytype, yext, yam) (RetType dims rettype, retext) loc = do t' <- transformType t op <- transformFName loc fname $ toStruct t (v1, wrap_left, e1, p1) <- makeVarParam e_left =<< transformType xtype @@ -748,7 +748,7 @@ desugarBinOpSection fname e_left e_right t (xp, xtype, xext) (yp, ytype, yext) ( let apply_left = mkApply op - [(xext, mempty, e1)] + [(xext, xam, e1)] (AppRes (Scalar $ Arrow mempty yp (diet ytype) (toStruct ytype) (RetType [] $ toRes Nonunique t')) []) onDim (Var d typ _) | Named p <- xp, qualLeaf d == p = Var (qualName v1) typ loc @@ -757,7 +757,7 @@ desugarBinOpSection fname e_left e_right t (xp, xtype, xext) (yp, ytype, yext) ( rettype' = first onDim rettype body <- scoping (S.fromList [v1, v2]) $ - mkApply apply_left [(yext, mempty, e2)] + mkApply apply_left [(yext, yam, e2)] <$> transformAppRes (AppRes (toStruct rettype') retext) rettype'' <- transformRetTypeSizes (S.fromList [v1, v2]) $ RetType dims rettype' pure . wrap_left . wrap_right $ diff --git a/src/Language/Futhark/Interpreter.hs b/src/Language/Futhark/Interpreter.hs index 01e4bddb21..527c734cbf 100644 --- a/src/Language/Futhark/Interpreter.hs +++ b/src/Language/Futhark/Interpreter.hs @@ -1047,11 +1047,11 @@ eval env (Lambda ps body _ (Info (RetType _ rt)) _) = evalFunction env [] ps body rt eval env (OpSection qv (Info t) _) = evalTermVar env qv $ toStruct t -eval env (OpSectionLeft qv _ e (Info (_, _, argext), _) (Info (RetType _ t), _) loc) = do +eval env (OpSectionLeft qv _ e (Info (_, _, argext, _), _) (Info (RetType _ t), _) loc) = do v <- evalArg env e argext f <- evalTermVar env qv (toStruct t) apply loc env f v -eval env (OpSectionRight qv _ e (Info _, Info (_, _, argext)) (Info (RetType _ t)) loc) = do +eval env (OpSectionRight qv _ e (Info _, Info (_, _, argext, _)) (Info (RetType _ t)) loc) = do y <- evalArg env e argext pure $ ValueFun $ \x -> do diff --git a/src/Language/Futhark/Syntax.hs b/src/Language/Futhark/Syntax.hs index bd2133f017..ef7afa4d30 100644 --- a/src/Language/Futhark/Syntax.hs +++ b/src/Language/Futhark/Syntax.hs @@ -848,7 +848,7 @@ data ExpBase f vn (QualName vn) (f StructType) (ExpBase f vn) - (f (PName, ParamType, Maybe VName), f (PName, ParamType)) + (f (PName, ParamType, Maybe VName, AutoMap), f (PName, ParamType)) (f ResRetType, f [VName]) SrcLoc | -- | @+2@; first type is operand, second is result. @@ -856,7 +856,7 @@ data ExpBase f vn (QualName vn) (f StructType) (ExpBase f vn) - (f (PName, ParamType), f (PName, ParamType, Maybe VName)) + (f (PName, ParamType), f (PName, ParamType, Maybe VName, AutoMap)) (f ResRetType) SrcLoc | -- | Field projection as a section: @(.x.y.z)@. diff --git a/src/Language/Futhark/Traversals.hs b/src/Language/Futhark/Traversals.hs index 798edae981..fc20935c24 100644 --- a/src/Language/Futhark/Traversals.hs +++ b/src/Language/Futhark/Traversals.hs @@ -184,25 +184,25 @@ instance ASTMappable (ExpBase Info VName) where <$> mapOnName tv name <*> traverse (mapOnStructType tv) t <*> pure loc - astMap tv (OpSectionLeft name t arg (Info (pa, t1a, argext), Info (pb, t1b)) (ret, retext) loc) = + astMap tv (OpSectionLeft name t arg (Info (pa, t1a, argext, am), Info (pb, t1b)) (ret, retext) loc) = OpSectionLeft <$> mapOnName tv name <*> traverse (mapOnStructType tv) t <*> mapOnExp tv arg <*> ( (,) - <$> (Info <$> ((pa,,) <$> mapOnParamType tv t1a <*> pure argext)) + <$> (Info <$> ((pa,,,) <$> mapOnParamType tv t1a <*> pure argext <*> pure am)) <*> (Info <$> ((pb,) <$> mapOnParamType tv t1b)) ) <*> ((,) <$> traverse (mapOnResRetType tv) ret <*> pure retext) <*> pure loc - astMap tv (OpSectionRight name t arg (Info (pa, t1a), Info (pb, t1b, argext)) t2 loc) = + astMap tv (OpSectionRight name t arg (Info (pa, t1a), Info (pb, t1b, argext, am)) t2 loc) = OpSectionRight <$> mapOnName tv name <*> traverse (mapOnStructType tv) t <*> mapOnExp tv arg <*> ( (,) <$> (Info <$> ((pa,) <$> mapOnParamType tv t1a)) - <*> (Info <$> ((pb,,) <$> mapOnParamType tv t1b <*> pure argext)) + <*> (Info <$> ((pb,,,) <$> mapOnParamType tv t1b <*> pure argext <*> pure am)) ) <*> traverse (mapOnResRetType tv) t2 <*> pure loc diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 7c97574b15..d3b59381a8 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -725,16 +725,16 @@ checkExp (OpSection op (Info op_t) loc) = do checkExp (OpSectionLeft op (Info op_t) e _ _ loc) = do ftype <- lookupVar loc op op_t e' <- checkExp e - (t1, rt, argext, retext, _) <- checkApply loc (Just op, 0) ftype e' + (t1, rt, argext, retext, am) <- checkApply loc (Just op, 0) ftype e' case (ftype, rt) of - (Scalar (Arrow _ m1 d1 _ _), Scalar (Arrow _ m2 d2 t2 rettype)) -> + (Scalar (Arrow _ m1 d1 _ _), Scalar (Arrow _ m2 d2 t2 (RetType ds rt2))) -> pure $ OpSectionLeft op (Info ftype) e' - (Info (m1, toParam d1 t1, argext), Info (m2, toParam d2 t2)) - (Info rettype, Info retext) + (Info (m1, toParam d1 t1, argext, am), Info (m2, toParam d2 t2)) + (Info $ RetType ds $ arrayOfWithAliases (uniqueness rt2) (autoFrame am) rt2, Info retext) loc _ -> typeError loc mempty $ @@ -744,7 +744,7 @@ checkExp (OpSectionRight op (Info op_t) e _ _ loc) = do e' <- checkExp e case ftype of Scalar (Arrow _ m1 d1 t1 (RetType [] (Scalar (Arrow _ m2 d2 t2 (RetType dims2 ret))))) -> do - (t2', arrow', argext, _, _) <- + (t2', arrow', argext, _, am) <- checkApply loc (Just op, 1) @@ -757,8 +757,8 @@ checkExp (OpSectionRight op (Info op_t) e _ _ loc) = do op (Info ftype) e' - (Info (m1, toParam d1 t1'), Info (m2, toParam d2 t2', argext)) - (Info $ RetType dims2' ret') + (Info (m1, toParam d1 t1'), Info (m2, toParam d2 t2', argext, am)) + (Info $ RetType dims2' $ arrayOfWithAliases (uniqueness ret') (autoFrame am) ret') loc _ -> error $ "OpSectionRight: impossible type\n" <> prettyString arrow' _ -> diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 52933dd1a3..335c532cea 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -599,8 +599,59 @@ bindParams tps orig_ps m = bindTypeParams tps $ do incLevel $ descend [] orig_ps -checkApply :: SrcLoc -> (Maybe (QualName VName), Int) -> Type -> Shape Size -> Exp -> TermM (Type, AutoMap) -checkApply loc _ ftype fframe arg = do +checkApply :: SrcLoc -> Maybe (QualName VName) -> (Shape Size, Type) -> [(Shape Size, Type)] -> TermM (StructType, [AutoMap]) +checkApply loc fname (fframe, ftype) args = do + ((_, _, rt), argts) <- mapAccumLM onArg (0, fframe, ftype) args + rt' <- asStructType loc rt + pure (rt', argts) + where + -- pure (asStructType loc rt, argts) + + onArg (i, f_f, f_t) (argframe, argtype) = do + (rt, am) <- checkApplyOne loc (fname, i) (f_f, f_t) (argframe, argtype) + pure + ( (i + 1, autoFrame am, rt), + am + ) + +checkApplyOne :: SrcLoc -> (Maybe (QualName VName), Int) -> (Shape Size, Type) -> (Shape Size, Type) -> TermM (Type, AutoMap) +checkApplyOne loc _ (fframe, ftype) (argframe, argtype) = do + (a, b) <- split $ stripFrame fframe ftype + r <- newSVar loc "R" + m <- newSVar loc "M" + let unit_info = Info $ Scalar $ Prim Bool + r_var = Var (QualName [] r) unit_info mempty + m_var = Var (QualName [] m) unit_info mempty + lhs = arrayOf (toShape (SVar r) <> (toSComp <$> argframe)) argtype + rhs = arrayOf (toShape (SVar m) <> (toSComp <$> fframe)) a + ctAM r m + ctEq lhs rhs + pure + ( arrayOf (toShape (SVar m) <> (toSComp <$> fframe)) b, + AutoMap {autoRep = toShape r_var, autoMap = toShape m_var, autoFrame = toShape m_var <> fframe} + ) + where + stripFrame :: Shape Size -> Type -> Type + stripFrame frame (Array u ds t) = + let mnew_shape = Shape <$> L.stripPrefix (toSComp <$> shapeDims frame) (shapeDims ds) + in case mnew_shape of + Nothing -> Scalar t + Just new_shape -> arrayOfWithAliases u new_shape $ Scalar t + stripFrame _ t = t + toSComp (Var (QualName [] x) _ _) = SVar x + toSComp _ = error "" + toShape = Shape . pure + split (Scalar (Arrow _ _ _ a (RetType _ b))) = + pure (a, b `setUniqueness` NoUniqueness) + split ftype' = do + a <- newType loc "arg" + b <- newTyVar loc "res" + ctEq ftype' $ Scalar $ Arrow NoUniqueness Unnamed Observe a $ RetType [] $ tyVarType b + pure (a, tyVarType b) + +-- To be removed (probably) +checkApply_ :: SrcLoc -> (Maybe (QualName VName), Int) -> Type -> Shape Size -> Exp -> TermM (Type, AutoMap) +checkApply_ loc _ ftype fframe arg = do (a, b) <- split $ stripFrame fframe ftype r <- newSVar loc "R" m <- newSVar loc "M" @@ -828,7 +879,7 @@ checkExp (AppExp (Apply fe args loc) NoInfo) = do onArg (i, f_t, f_f) (_, arg) = do arg' <- checkExp arg - (rt, am) <- checkApply loc (fname, i) f_t f_f arg' + (rt, am) <- checkApply_ loc (fname, i) f_t f_f arg' pure ( (i + 1, rt, autoFrame am), (Info (Nothing, am), arg') @@ -838,52 +889,55 @@ checkExp (AppExp (BinOp (op, oploc) NoInfo (e1, _) (e2, _) loc) NoInfo) = do ftype <- lookupVar oploc op e1' <- checkExp e1 e2' <- checkExp e2 - - (rt1, am1) <- checkApply loc (Just op, 0) (toType ftype) mempty e1' - (rt2, am2) <- checkApply loc (Just op, 1) rt1 (autoFrame am1) e2' - rt2' <- asStructType loc rt2 + (rt, ams) <- + checkApply + loc + (Just op) + (mempty, toType ftype) + [(frameOf e1', toType $ typeOf e1'), (frameOf e2', toType $ typeOf e2')] + let [am1, am2] = ams pure $ AppExp (BinOp (op, oploc) (Info ftype) (e1', Info (Nothing, am1)) (e2', Info (Nothing, am2)) loc) - (Info (AppRes rt2' [])) + (Info (AppRes rt [])) -- checkExp (OpSectionLeft op _ e _ _ loc) = do optype <- lookupVar loc op e' <- checkExp e - void $ checkApply loc (Just op, 0) (toType optype) mempty e' + t2 <- newType loc "t" + t2' <- asStructType loc t2 let t1 = typeOf e' - t2 <- newType loc "t2" - rt <- newType loc "rt" - ctEq (toType optype) $ toType $ foldFunType [toParam Observe t1, t2] $ RetType [] $ rt `setUniqueness` Nonunique + f1 = frameOf e' + (rt, ams) <- checkApply loc (Just op) (mempty, toType optype) [(f1, toType t1), (mempty, t2)] pure $ OpSectionLeft op (Info optype) e' - ( Info (Unnamed, toParam Observe t1, Nothing), - Info (Unnamed, t2) + ( Info (Unnamed, toParam Observe t1, Nothing, ams !! 0), -- fix + Info (Unnamed, toParam Observe t2') ) - (Info (RetType [] rt), Info []) + (Info (RetType [] (rt `setUniqueness` Nonunique)), Info []) loc --- checkExp (OpSectionRight op _ e _ NoInfo loc) = do optype <- lookupVar loc op e' <- checkExp e t1 <- newType loc "t" + t1' <- asStructType loc t1 let t2 = typeOf e' - rt <- newType loc "rt" - ctEq (toType optype) $ toType $ foldFunType [t1, toParam Observe t2] $ RetType [] $ rt `setUniqueness` Nonunique + f2 = frameOf e' + (rt, ams) <- checkApply loc (Just op) (mempty, toType optype) [(mempty, t1), (f2, toType t2)] pure $ OpSectionRight op (Info optype) e' -- Dummy types. - ( Info (Unnamed, toParam Observe t1), - Info (Unnamed, toParam Observe t2, Nothing) + ( Info (Unnamed, toParam Observe t1'), + Info (Unnamed, toParam Observe t2, Nothing, ams !! 1) -- fix ) - (Info $ RetType [] rt) + (Info $ RetType [] (rt `setUniqueness` Nonunique)) loc -- checkExp (ProjectSection fields NoInfo loc) = do From 14c5544d5816b208453483c98d1d8ed6f91c33fe Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 19 Feb 2024 16:13:06 +0100 Subject: [PATCH 077/258] Flail at the constraint solver. Now explicitly returns variables to be generalised. --- .../Futhark/TypeChecker/Constraints.hs | 23 +++++++++++++------ src/Language/Futhark/TypeChecker/Terms.hs | 15 ++++-------- src/Language/Futhark/TypeChecker/Terms2.hs | 20 ++++++++++------ 3 files changed, 33 insertions(+), 25 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 2548c74008..ddaedd4025 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -117,19 +117,28 @@ substTyVars m (Scalar (Arrow u pname d t1 (RetType ext t2))) = substTyVars m (Array u shape elemt) = arrayOfWithAliases u shape $ substTyVars m $ Scalar elemt --- | A solution maps a type variable to its substitution. This substitution is complete, in the sense there are no right-hand sides that contain a type variable. +-- | A solution maps a type variable to its substitution. This +-- substitution is complete, in the sense there are no right-hand +-- sides that contain a type variable. type Solution = M.Map TyVar (Either [PrimType] (TypeBase () NoUniqueness)) -solution :: SolverState -> Solution +solution :: SolverState -> ([VName], Solution) solution s = - M.mapMaybe mkSubst $ - solverTyVars s + ( mapMaybe unconstrained $ M.toList $ solverTyVars s, + M.mapMaybe mkSubst $ solverTyVars s + ) where - mkSubst (TyVarSol _lvl t) = Just $ Right $ first (const ()) $ substTyVars (solverTyVars s) t - mkSubst (TyVarLink v') = mkSubst =<< M.lookup v' (solverTyVars s) + mkSubst (TyVarSol _lvl t) = + Just $ Right $ first (const ()) $ substTyVars (solverTyVars s) t + mkSubst (TyVarLink v') = + Just . fromMaybe (Right $ Scalar $ TypeVar mempty (qualName v') []) $ + mkSubst =<< M.lookup v' (solverTyVars s) mkSubst (TyVarUnsol _ (TyVarPrim pts)) = Just $ Left pts mkSubst _ = Nothing + unconstrained (v, TyVarUnsol _ TyVarFree) = Just v + unconstrained _ = Nothing + newtype SolveM a = SolveM {runSolveM :: StateT SolverState (Except T.Text) a} deriving (Functor, Applicative, Monad, MonadState SolverState, MonadError T.Text) @@ -213,7 +222,7 @@ solveCt ct = Nothing -> bad Just eqs -> mapM_ solveCt' eqs -solve :: Constraints -> TyVars -> Either T.Text Solution +solve :: Constraints -> TyVars -> Either T.Text ([VName], Solution) solve constraints tyvars = second solution . runExcept diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index d3b59381a8..eacb5b6964 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -1015,7 +1015,7 @@ checkSizeExp e = do (maybe_tysubsts, e') <- Terms2.checkSingleExp e case maybe_tysubsts of Left err -> typeError e' mempty $ pretty err - Right tysubsts -> runTermTypeM checkExp tysubsts $ do + Right (unconstrained, tysubsts) -> runTermTypeM checkExp tysubsts $ do e'' <- checkExp e' let t = typeOf e'' when (hasBinding e'') $ @@ -1216,14 +1216,6 @@ localChecks = void . check e <$ case ty of Info (Scalar (Prim t)) -> errorBounds (inBoundsI (-x) t) (-x) t (loc1 <> loc2) _ -> error "Inferred type of int literal is not a number" - check e@(AppExp (BinOp (QualName [] v, _) _ (x, _) _ loc) _) - | baseName v == "==", - Array {} <- typeOf x, - baseTag v <= maxIntrinsicTag = do - warn loc $ - textwrap - "Comparing arrays with \"==\" is deprecated and will stop working in a future revision of the language." - recurse e check e = recurse e recurse = astMap identityMapper {mapOnExp = check} @@ -1597,9 +1589,10 @@ checkFunDef (fname, retdecl, tparams, params, body, loc) = do case maybe_tysubsts of Left err -> typeError loc mempty $ pretty err - Right tysubsts -> runTermTypeM checkExp tysubsts $ do + Right (unconstrained, tysubsts) -> runTermTypeM checkExp tysubsts $ do + let unconstrained_tparams = map (\v -> TypeParamType Unlifted v mempty) unconstrained (tparams', params'', retdecl'', RetType dims rettype', body'') <- - checkBinding (fname, retdecl', tparams, params', body', loc) + checkBinding (fname, retdecl', unconstrained_tparams <> tparams, params', body', loc) -- Since this is a top-level function, we also resolve overloaded -- types, using either defaults or complaining about ambiguities. diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 335c532cea..594e9af0ac 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -46,6 +46,7 @@ import Control.Monad.Except import Control.Monad.Reader import Control.Monad.State import Data.Bifunctor +import Data.Bitraversable import Data.Char (isAscii) import Data.List qualified as L import Data.List.NonEmpty qualified as NE @@ -396,7 +397,7 @@ lookupVar loc qn@(QualName qs name) = do outer_env <- asks termOuterEnv asStructType loc $ first (const SDim) $ qualifyTypeVars outer_env tnames qs t' Just EqualityF -> do - argtype <- newType loc "t" + argtype <- newTypeOverloaded loc "t" anyPrimType pure $ foldFunType [argtype, argtype] $ RetType [] $ Scalar $ Prim Bool Just (OverloadedF ts pts rt) -> do argtype <- newTypeOverloaded loc "t" ts @@ -915,7 +916,7 @@ checkExp (OpSectionLeft op _ e _ _ loc) = do op (Info optype) e' - ( Info (Unnamed, toParam Observe t1, Nothing, ams !! 0), -- fix + ( Info (Unnamed, toParam Observe t1, Nothing, head ams), -- fix Info (Unnamed, toParam Observe t2') ) (Info (RetType [] (rt `setUniqueness` Nonunique)), Info []) @@ -1149,7 +1150,7 @@ checkValDef :: SrcLoc ) -> TypeM - ( Either T.Text (M.Map TyVar (TypeBase () NoUniqueness)), + ( Either T.Text ([VName], M.Map TyVar (TypeBase () NoUniqueness)), [Pat ParamType], Maybe (TypeExp Exp VName), Exp @@ -1183,7 +1184,9 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do Just (cts', tyvars', vns', counter') -> do modify $ \s -> s {termCounter = counter', termNameSource = vns'} - solution <- traverse (M.traverseWithKey (doDefaults mempty)) $ solve cts' tyvars' + solution <- + bitraverse pure (traverse (M.traverseWithKey (doDefaults mempty))) $ + solve cts' tyvars' traceM $ unlines @@ -1193,17 +1196,20 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars', "## solution:", let p (v, t) = prettyNameString v <> " => " <> prettyString t - in either T.unpack (unlines . map p . M.toList) solution + in either T.unpack (unlines . map p . M.toList . snd) solution, + either (const mempty) (unlines . ("## unconstrained:" :) . map prettyNameString . fst) solution ] pure (solution, params', retdecl', body') checkSingleExp :: ExpBase NoInfo VName -> - TypeM (Either T.Text (M.Map TyVar (TypeBase () NoUniqueness)), Exp) + TypeM (Either T.Text ([VName], M.Map TyVar (TypeBase () NoUniqueness)), Exp) checkSingleExp e = runTermM $ do e' <- checkExp e cts <- gets termConstraints tyvars <- gets termTyVars - solution <- traverse (M.traverseWithKey (doDefaults mempty)) $ solve cts tyvars + solution <- + bitraverse pure (traverse (M.traverseWithKey (doDefaults mempty))) $ + solve cts tyvars pure (solution, e') From 52beb8c7c22ac3fd69d33fa906d349db631de5c7 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 19 Feb 2024 16:23:53 +0100 Subject: [PATCH 078/258] Print this too. --- src/Language/Futhark/TypeChecker.hs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/Language/Futhark/TypeChecker.hs b/src/Language/Futhark/TypeChecker.hs index 879a27afc4..9d07bf5b53 100644 --- a/src/Language/Futhark/TypeChecker.hs +++ b/src/Language/Futhark/TypeChecker.hs @@ -26,6 +26,7 @@ import Data.Map.Strict qualified as M import Data.Maybe import Data.Ord import Data.Set qualified as S +import Debug.Trace import Futhark.FreshNames hiding (newName) import Futhark.Util.Pretty hiding (space) import Language.Futhark @@ -704,6 +705,9 @@ checkValBind vb = do _ -> pure () let vb' = ValBind entry' fname maybe_tdecl' (Info rettype) tparams' params' body' doc attrs' loc + + traceM $ unlines ["Inferred:", prettyString vb'] + pure ( mempty { envVtable = From 8f4cb684e9f3ac8d53135cbf8bd4b316a8c1486a Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Mon, 19 Feb 2024 11:38:07 -0800 Subject: [PATCH 079/258] Add ambiguity checking. --- src/Language/Futhark/TypeChecker/Rank.hs | 89 +++++++++++++++--------- 1 file changed, 57 insertions(+), 32 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 23d295f8ee..bff7a08b84 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -92,7 +92,7 @@ binVar sv = do modify $ \s -> s { rankBinVars = M.insert sv bv $ rankBinVars s, - rankConstraints = rankConstraints s ++ [bin bv] + rankConstraints = rankConstraints s ++ [bin bv, var bv ~<=~ var sv] } pure bv Just bv -> pure bv @@ -141,36 +141,66 @@ mkLinearProg counter cs tyVars = mapM_ (uncurry addTyVarInfo) $ M.toList tyVars finalState = flip execState initState $ runRankM buildLP +ambigCheckLinearProg :: LinearProg -> (Double, Map VName Int) -> LinearProg +ambigCheckLinearProg prog (opt, ranks) = + prog + { constraints = + constraints prog + ++ [ lsum (var <$> M.keys one_bins) + ~-~ lsum (var <$> M.keys zero_bins) + ~<=~ constant (fromIntegral $ length one_bins) + ~-~ constant 1, + objective prog ~==~ constant opt + ] + } + where + -- We really need to track which variables are binary in the LinearProg + is_bin_var = ("b_" `L.isPrefixOf`) . baseString + one_bins = M.filterWithKey (\k v -> is_bin_var k && v == 1) ranks + zero_bins = M.filterWithKey (\k v -> is_bin_var k && v == 0) ranks + lsum = foldr (~+~) (constant 0) + rankAnalysis :: Bool -> VNameSource -> Int -> [Ct] -> TyVars -> Maybe ([Ct], TyVars, VNameSource, Int) rankAnalysis _ vns counter [] tyVars = Just ([], tyVars, vns, counter) rankAnalysis use_glpk vns counter cs tyVars = do traceM $ unlines ["## rankAnalysis prog", prettyString prog] - rank_map <- - if use_glpk - then snd <$> (unsafePerformIO $ glpk prog) - else do - (_size, ranks) <- branchAndBound lp - pure $ (fromJust . (ranks V.!?)) <$> inv_var_map - traceM $ unlines $ "## rank map" : map prettyString (M.toList rank_map) - let initEnv = - SubstEnv - { envTyVars = tyVars, - envRanks = rank_map - } - - initState = - SubstState - { substTyVars = mempty, - substNewVars = mempty, - substNameSource = vns, - substCounter = counter, - substNewCts = mempty - } - (cs', state') = - runSubstM initEnv initState $ - substRanks $ - filter (not . isCtAM) cs - pure (cs' <> substNewCts state', substTyVars state' <> tyVars, substNameSource state', substCounter state') + -- rank_map <- + -- if use_glpk + -- then snd <$> (unsafePerformIO $ glpk prog) + -- else do + -- (_size, ranks) <- branchAndBound lp + -- pure $ (fromJust . (ranks V.!?)) <$> inv_var_map + (size, rank_map) <- unsafePerformIO $ glpk prog + case unsafePerformIO $ glpk $ ambigCheckLinearProg prog (fromIntegral size, rank_map) of + Just (size', rank_map') -> do + traceM $ + unlines $ + "## rank map" + : map prettyString (M.toList rank_map) + ++ "## ambig rank map" + : map prettyString (M.toList rank_map') + error "ambiguous" + Nothing -> do + traceM $ unlines $ "## rank map" : map prettyString (M.toList rank_map) + let initEnv = + SubstEnv + { envTyVars = tyVars, + envRanks = rank_map + } + + initState = + SubstState + { substTyVars = mempty, + substNewVars = mempty, + substNameSource = vns, + substCounter = counter, + substNewCts = mempty + } + (cs', state') = + runSubstM initEnv initState $ + substRanks $ + filter (not . isCtAM) cs + pure (cs' <> substNewCts state', substTyVars state' <> tyVars, substNameSource state', substCounter state') where isCtAM (CtAM {}) = True isCtAM _ = False @@ -189,11 +219,6 @@ rankAnalysis use_glpk vns counter cs tyVars = do (lp, var_map) = linearProgToLP prog inv_var_map = M.fromListWith (error "oh no!") [(v, k) | (k, v) <- M.toList var_map] - rm_subscript x = fromMaybe x $ lookup x $ zip "₀₁₂₃₄₅₆₇₈₉" "0123456789" - vname_to_pulp_var = M.mapWithKey (\k _ -> map rm_subscript $ show $ prettyName k) inv_var_map - pulp_var_to_vname = - M.fromListWith (error "oh no!") [(v, k) | (k, v) <- M.toList vname_to_pulp_var] - newtype SubstM a = SubstM (StateT SubstState (Reader SubstEnv) a) deriving (Functor, Applicative, Monad, MonadState SubstState, MonadReader SubstEnv) From 6d91328e18839cfd68d5010071613ce86b5f6c48 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Mon, 19 Feb 2024 11:41:23 -0800 Subject: [PATCH 080/258] Forgot the source. --- src/Language/Futhark/TypeChecker/Rank.hs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index bff7a08b84..d67032446e 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -146,6 +146,7 @@ ambigCheckLinearProg prog (opt, ranks) = prog { constraints = constraints prog + -- https://yetanothermathprogrammingconsultant.blogspot.com/2011/10/integer-cuts.html ++ [ lsum (var <$> M.keys one_bins) ~-~ lsum (var <$> M.keys zero_bins) ~<=~ constant (fromIntegral $ length one_bins) From 2c19d786bc18216657b5aec61ff70b9d74a84966 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 19 Feb 2024 21:09:33 +0100 Subject: [PATCH 081/258] Add rep shapes here. --- src/Language/Futhark/TypeChecker/Terms.hs | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index eacb5b6964..6b6bdcb16e 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -492,7 +492,7 @@ checkExp (AppExp (BinOp (op, oploc) (Info op_t) (e1, _) (e2, _) loc) _) = do -- Note that the application to the first operand cannot fix any -- existential sizes, because it must by necessity be a function. (_, rt, p1_ext, _, am1) <- checkApply loc (Just op, 0) ftype e1' - (_, rt', p2_ext, retext, am2) <- checkApply loc (Just op, 1) rt e2' + (_, rt', p2_ext, retext, am2) <- checkApply loc (Just op, 1) (arrayOf (autoFrame am1) rt) e2' pure $ AppExp @@ -987,9 +987,15 @@ checkApply loc (fname, _) (Scalar (Arrow _ pname _ tp1 tp2)) argexp = do } pure (tp1, tp2'', argext, ext, am) -checkApply loc fname (Array _ _ t) arg = +checkApply loc fname (Array _ shape t) arg = do -- This implies the function is the result of an automap. - checkApply loc fname (Scalar t) arg + (t1, rt, argext, retext, am) <- checkApply loc fname (Scalar t) arg + let am' = + am + { autoRep = shape <> autoRep am, + autoFrame = shape <> autoFrame am + } + pure (t1, rt, argext, retext, am') checkApply _ _ _ _ = error "checkApply: impossible case" From 8dd13d4632b02dc3157ed9af9290ed3a04304d1d Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Mon, 19 Feb 2024 19:45:42 -0800 Subject: [PATCH 082/258] Make the design of `Rank.hs` less dumb. --- src/Language/Futhark/TypeChecker/Rank.hs | 174 ++++++++++----------- src/Language/Futhark/TypeChecker/Terms2.hs | 47 +++--- 2 files changed, 105 insertions(+), 116 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index d67032446e..eca628fdc1 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -6,20 +6,14 @@ import Data.List qualified as L import Data.Map (Map) import Data.Map qualified as M import Data.Maybe -import Data.Vector.Unboxed qualified as V import Debug.Trace -import Futhark.FreshNames qualified as FreshNames -import Futhark.MonadFreshNames hiding (newName) -import Futhark.Solve.BranchAndBound import Futhark.Solve.GLPK import Futhark.Solve.LP hiding (Constraint, LSum, LinearProg) import Futhark.Solve.LP qualified as LP -import Futhark.Solve.Simplex import Language.Futhark hiding (ScalarType) import Language.Futhark.TypeChecker.Constraints -import Language.Futhark.TypeChecker.Monad (mkTypeVarName) +import Language.Futhark.TypeChecker.Monad import System.IO.Unsafe -import System.Process type LSum = LP.LSum VName Double @@ -61,7 +55,7 @@ instance Distribute Type where distribute = distributeOne where distributeOne (Array _ s (Arrow _ _ _ ta (RetType rd tr))) = - Scalar $ Arrow NoUniqueness Unnamed mempty (arrayOf s ta) (RetType rd $ arrayOfWithAliases Nonunique s $ tr) + Scalar $ Arrow NoUniqueness Unnamed mempty (arrayOf s ta) (RetType rd $ arrayOfWithAliases Nonunique s tr) distributeOne t = t instance Distribute Ct where @@ -83,9 +77,9 @@ incCounter = do put s {rankCounter = rankCounter s + 1} pure $ rankCounter s -binVar :: VName -> RankM (VName) +binVar :: VName -> RankM VName binVar sv = do - mbv <- (M.!? sv) <$> gets rankBinVars + mbv <- gets ((M.!? sv) . rankBinVars) case mbv of Nothing -> do bv <- VName ("b_" <> baseName sv) <$> incCounter @@ -112,7 +106,7 @@ addCt (CtAM r m) = do addConstraints $ oneIsZero (b_r, r) (b_m, m) addTyVarInfo :: TyVar -> (Int, TyVarInfo) -> RankM () -addTyVarInfo tv (_, TyVarFree) = pure () +addTyVarInfo _ (_, TyVarFree) = pure () addTyVarInfo tv (_, TyVarPrim _) = addConstraint $ rank tv ~==~ constant 0 addTyVarInfo tv (_, TyVarRecord _) = @@ -120,8 +114,8 @@ addTyVarInfo tv (_, TyVarRecord _) = addTyVarInfo tv (_, TyVarSum _) = addConstraint $ rank tv ~==~ constant 0 -mkLinearProg :: Int -> [Ct] -> TyVars -> LinearProg -mkLinearProg counter cs tyVars = +mkLinearProg :: [Ct] -> TyVars -> LinearProg +mkLinearProg cs tyVars = LP.LinearProg { optType = Minimize, objective = @@ -133,7 +127,7 @@ mkLinearProg counter cs tyVars = initState = RankState { rankBinVars = mempty, - rankCounter = counter, + rankCounter = 0, rankConstraints = mempty } buildLP = do @@ -161,50 +155,41 @@ ambigCheckLinearProg prog (opt, ranks) = zero_bins = M.filterWithKey (\k v -> is_bin_var k && v == 0) ranks lsum = foldr (~+~) (constant 0) -rankAnalysis :: Bool -> VNameSource -> Int -> [Ct] -> TyVars -> Maybe ([Ct], TyVars, VNameSource, Int) -rankAnalysis _ vns counter [] tyVars = Just ([], tyVars, vns, counter) -rankAnalysis use_glpk vns counter cs tyVars = do - traceM $ unlines ["## rankAnalysis prog", prettyString prog] - -- rank_map <- - -- if use_glpk - -- then snd <$> (unsafePerformIO $ glpk prog) - -- else do - -- (_size, ranks) <- branchAndBound lp - -- pure $ (fromJust . (ranks V.!?)) <$> inv_var_map - (size, rank_map) <- unsafePerformIO $ glpk prog - case unsafePerformIO $ glpk $ ambigCheckLinearProg prog (fromIntegral size, rank_map) of - Just (size', rank_map') -> do - traceM $ - unlines $ - "## rank map" - : map prettyString (M.toList rank_map) - ++ "## ambig rank map" - : map prettyString (M.toList rank_map') - error "ambiguous" - Nothing -> do - traceM $ unlines $ "## rank map" : map prettyString (M.toList rank_map) - let initEnv = - SubstEnv - { envTyVars = tyVars, - envRanks = rank_map - } - - initState = - SubstState - { substTyVars = mempty, - substNewVars = mempty, - substNameSource = vns, - substCounter = counter, - substNewCts = mempty - } - (cs', state') = - runSubstM initEnv initState $ - substRanks $ - filter (not . isCtAM) cs - pure (cs' <> substNewCts state', substTyVars state' <> tyVars, substNameSource state', substCounter state') +checkProg :: (MonadTypeChecker m, Located loc) => loc -> LinearProg -> m (Map VName Int) +checkProg loc prog = do + traceM $ + unlines + [ "## checkProg", + prettyString prog + ] + case run_glpk prog of + Nothing -> typeError loc mempty "Rank ILP cannot be solved." + Just sol@(_size, rank_map) -> + case check_ambig sol of + Nothing -> do + traceM $ + unlines $ + "## rank map" : map prettyString (M.toList rank_map) + pure rank_map + Just (_, rank_map') -> do + traceM $ + unlines $ + "## rank map" + : map prettyString (M.toList rank_map) + ++ "## ambig rank map" + : map prettyString (M.toList rank_map') + typeError loc mempty "Rank ILP is ambiguous." + where + run_glpk = unsafePerformIO . glpk + check_ambig (size, rank_map) = + run_glpk $ ambigCheckLinearProg prog (fromIntegral size, rank_map) + +rankAnalysis :: (MonadTypeChecker m, Located loc) => loc -> [Ct] -> TyVars -> m ([Ct], TyVars) +rankAnalysis _ [] tyVars = pure ([], tyVars) +rankAnalysis loc cs tyVars = do + checkProg loc (mkLinearProg (foldMap splitFuncs cs) tyVars) + >>= substRankInfo cs tyVars where - isCtAM (CtAM {}) = True - isCtAM _ = False splitFuncs ( CtEq (Scalar (Arrow _ _ _ t1a (RetType _ t1r))) @@ -215,17 +200,43 @@ rankAnalysis use_glpk vns counter cs tyVars = do t1r' = t1r `setUniqueness` NoUniqueness t2r' = t2r `setUniqueness` NoUniqueness splitFuncs c = [c] - cs' = foldMap splitFuncs cs - prog = mkLinearProg counter cs' tyVars - (lp, var_map) = linearProgToLP prog - inv_var_map = M.fromListWith (error "oh no!") [(v, k) | (k, v) <- M.toList var_map] -newtype SubstM a = SubstM (StateT SubstState (Reader SubstEnv) a) - deriving (Functor, Applicative, Monad, MonadState SubstState, MonadReader SubstEnv) +substRankInfo :: (MonadTypeChecker m) => [Ct] -> TyVars -> Map VName Int -> m ([Ct], TyVars) +substRankInfo cs tyVars rankmap = do + (cs', new_cs, new_tyVars) <- + runSubstT tyVars rankmap $ + substRanks $ + filter (not . isCtAM) cs + pure (cs' <> new_cs, new_tyVars <> tyVars) + where + isCtAM (CtAM {}) = True + isCtAM _ = False + +runSubstT :: (MonadTypeChecker m) => TyVars -> Map VName Int -> SubstT m a -> m (a, [Ct], TyVars) +runSubstT tyVars rankmap (SubstT m) = do + let env = + SubstEnv + { envTyVars = tyVars, + envRanks = rankmap + } -runSubstM :: SubstEnv -> SubstState -> SubstM a -> (a, SubstState) -runSubstM initEnv initState (SubstM m) = - runReader (runStateT m initState) initEnv + s = + SubstState + { substTyVars = mempty, + substNewVars = mempty, + substNewCts = mempty + } + (a, s') <- runReaderT (runStateT m s) env + pure (a, substNewCts s', substTyVars s') + +newtype SubstT m a = SubstT (StateT SubstState (ReaderT SubstEnv m) a) + deriving + ( Functor, + Applicative, + Monad, + MonadState SubstState, + MonadReader SubstEnv + ) data SubstEnv = SubstEnv { envTyVars :: TyVars, @@ -235,21 +246,15 @@ data SubstEnv = SubstEnv data SubstState = SubstState { substTyVars :: TyVars, substNewVars :: Map TyVar TyVar, - substNameSource :: VNameSource, - substCounter :: !Int, substNewCts :: [Ct] } -substIncCounter :: SubstM Int -substIncCounter = do - s <- get - put s {substCounter = substCounter s + 1} - pure $ substCounter s +instance MonadTrans SubstT where + lift = SubstT . lift . lift -newTyVar :: TyVar -> SubstM TyVar +newTyVar :: (MonadTypeChecker m) => TyVar -> SubstT m TyVar newTyVar t = do - i <- substIncCounter - t' <- newID $ mkTypeVarName (baseName t) i + t' <- lift $ newTypeName (baseName t) shape <- rankToShape t modify $ \s -> s @@ -262,22 +267,15 @@ newTyVar t = do ] } pure t' - where - newID x = do - s <- get - let (v', src') = FreshNames.newName (substNameSource s) $ VName x 0 - put $ s {substNameSource = src'} - pure v' -rankToShape :: VName -> SubstM (Shape SComp) +rankToShape :: (Monad m) => VName -> SubstT m (Shape SComp) rankToShape x = do rs <- asks envRanks pure $ Shape $ replicate (fromJust $ rs M.!? x) SDim -addRankInfo :: TyVar -> SubstM () +addRankInfo :: (MonadTypeChecker m) => TyVar -> SubstT m () addRankInfo t = do rs <- asks envRanks - -- unless (fromMaybe (error $ prettyString t) (rs M.!? t) == 0) $ do unless (fromMaybe 0 (rs M.!? t) == 0) $ do new_vars <- gets substNewVars maybe new_var (const $ pure ()) $ new_vars M.!? t @@ -290,7 +288,7 @@ addRankInfo t = do modify $ \s -> s {substTyVars = M.insert t (fst info, TyVarFree) $ substTyVars s} class SubstRanks a where - substRanks :: a -> SubstM a + substRanks :: (MonadTypeChecker m) => a -> SubstT m a instance (SubstRanks a) => SubstRanks [a] where substRanks = mapM substRanks @@ -298,11 +296,11 @@ instance (SubstRanks a) => SubstRanks [a] where instance SubstRanks (Shape SComp) where substRanks = foldM (\s d -> (s <>) <$> instDim d) mempty where - instDim (SDim) = pure $ Shape $ pure SDim + instDim SDim = pure $ Shape $ pure SDim instDim (SVar x) = rankToShape x instance SubstRanks (TypeBase SComp u) where - substRanks t@(Scalar (TypeVar u (QualName [] x) [])) = + substRanks t@(Scalar (TypeVar _ (QualName [] x) [])) = addRankInfo x >> pure t substRanks (Scalar (Arrow u p d ta (RetType retdims tr))) = do ta' <- substRanks ta diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 594e9af0ac..e5a25df58c 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -1163,44 +1163,35 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do cts <- gets termConstraints - counter <- gets termCounter - tyvars <- gets termTyVars traceM $ "\n# function " <> prettyNameString fname <> "\n# " <> locStr loc <> "\n" - vns <- gets termNameSource - - let use_glpk = True - traceM $ unlines [ "## cts:", unlines $ map prettyString cts ] - case rankAnalysis use_glpk vns counter cts tyvars of - Nothing -> error "" - Just (cts', tyvars', vns', counter') -> do - modify $ \s -> s {termCounter = counter', termNameSource = vns'} - - solution <- - bitraverse pure (traverse (M.traverseWithKey (doDefaults mempty))) $ - solve cts' tyvars' - - traceM $ - unlines - [ "## constraints:", - unlines $ map prettyString cts', - "## tyvars:", - unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars', - "## solution:", - let p (v, t) = prettyNameString v <> " => " <> prettyString t - in either T.unpack (unlines . map p . M.toList . snd) solution, - either (const mempty) (unlines . ("## unconstrained:" :) . map prettyNameString . fst) solution - ] - - pure (solution, params', retdecl', body') + (cts', tyvars') <- rankAnalysis loc cts tyvars + + solution <- + bitraverse pure (traverse (M.traverseWithKey (doDefaults mempty))) $ + solve cts' tyvars' + + traceM $ + unlines + [ "## constraints:", + unlines $ map prettyString cts', + "## tyvars:", + unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars', + "## solution:", + let p (v, t) = prettyNameString v <> " => " <> prettyString t + in either T.unpack (unlines . map p . M.toList . snd) solution, + either (const mempty) (unlines . ("## unconstrained:" :) . map prettyNameString . fst) solution + ] + + pure (solution, params', retdecl', body') checkSingleExp :: ExpBase NoInfo VName -> From af0b0bb45e9be1c75d4bb5f20295a688fdda0d35 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Mon, 19 Feb 2024 20:50:35 -0800 Subject: [PATCH 083/258] Pass a list of possible solutions around. --- src/Language/Futhark/TypeChecker/Rank.hs | 59 +++++++-------- src/Language/Futhark/TypeChecker/Terms.hs | 88 ++++++++++++---------- src/Language/Futhark/TypeChecker/Terms2.hs | 42 ++++++----- 3 files changed, 99 insertions(+), 90 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index eca628fdc1..729f58848d 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -135,7 +135,7 @@ mkLinearProg cs tyVars = mapM_ (uncurry addTyVarInfo) $ M.toList tyVars finalState = flip execState initState $ runRankM buildLP -ambigCheckLinearProg :: LinearProg -> (Double, Map VName Int) -> LinearProg +ambigCheckLinearProg :: LinearProg -> (Int, Map VName Int) -> LinearProg ambigCheckLinearProg prog (opt, ranks) = prog { constraints = @@ -145,7 +145,7 @@ ambigCheckLinearProg prog (opt, ranks) = ~-~ lsum (var <$> M.keys zero_bins) ~<=~ constant (fromIntegral $ length one_bins) ~-~ constant 1, - objective prog ~==~ constant opt + objective prog ~==~ constant (fromIntegral opt) ] } where @@ -155,40 +155,39 @@ ambigCheckLinearProg prog (opt, ranks) = zero_bins = M.filterWithKey (\k v -> is_bin_var k && v == 0) ranks lsum = foldr (~+~) (constant 0) -checkProg :: (MonadTypeChecker m, Located loc) => loc -> LinearProg -> m (Map VName Int) -checkProg loc prog = do +-- We should probably cap the iteration on this +enumerateRankSols :: LinearProg -> [Map VName Int] +enumerateRankSols prog = + takeSolns $ + iterate next_sol $ + (prog,) <$> run_glpk prog + where + run_glpk = unsafePerformIO . glpk + next_sol m = do + (prog', sol') <- m + let prog'' = ambigCheckLinearProg prog' sol' + sol'' <- run_glpk prog'' + pure (prog'', sol'') + takeSolns [] = [] + takeSolns (Nothing : _) = [] + takeSolns (Just (_, (_, r)) : xs) = r : takeSolns xs + +solveRankILP :: (MonadTypeChecker m) => SrcLoc -> LinearProg -> m [Map VName Int] +solveRankILP loc prog = do traceM $ unlines - [ "## checkProg", + [ "## solveRankILP", prettyString prog ] - case run_glpk prog of - Nothing -> typeError loc mempty "Rank ILP cannot be solved." - Just sol@(_size, rank_map) -> - case check_ambig sol of - Nothing -> do - traceM $ - unlines $ - "## rank map" : map prettyString (M.toList rank_map) - pure rank_map - Just (_, rank_map') -> do - traceM $ - unlines $ - "## rank map" - : map prettyString (M.toList rank_map) - ++ "## ambig rank map" - : map prettyString (M.toList rank_map') - typeError loc mempty "Rank ILP is ambiguous." - where - run_glpk = unsafePerformIO . glpk - check_ambig (size, rank_map) = - run_glpk $ ambigCheckLinearProg prog (fromIntegral size, rank_map) + case enumerateRankSols prog of + [] -> typeError loc mempty "Rank ILP cannot be solved." + rs -> pure rs -rankAnalysis :: (MonadTypeChecker m, Located loc) => loc -> [Ct] -> TyVars -> m ([Ct], TyVars) -rankAnalysis _ [] tyVars = pure ([], tyVars) +rankAnalysis :: (MonadTypeChecker m) => SrcLoc -> [Ct] -> TyVars -> m [([Ct], TyVars)] +rankAnalysis _ [] tyVars = pure [([], tyVars)] rankAnalysis loc cs tyVars = do - checkProg loc (mkLinearProg (foldMap splitFuncs cs) tyVars) - >>= substRankInfo cs tyVars + solveRankILP loc (mkLinearProg (foldMap splitFuncs cs) tyVars) + >>= mapM (substRankInfo cs tyVars) where splitFuncs ( CtEq diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 6b6bdcb16e..f80407b378 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -1590,44 +1590,52 @@ checkFunDef :: Exp ) checkFunDef (fname, retdecl, tparams, params, body, loc) = do - (maybe_tysubsts, params', retdecl', body') <- + (maybe_tysubstss, params', retdecl', body') <- Terms2.checkValDef (fname, retdecl, tparams, params, body, loc) - - case maybe_tysubsts of - Left err -> typeError loc mempty $ pretty err - Right (unconstrained, tysubsts) -> runTermTypeM checkExp tysubsts $ do - let unconstrained_tparams = map (\v -> TypeParamType Unlifted v mempty) unconstrained - (tparams', params'', retdecl'', RetType dims rettype', body'') <- - checkBinding (fname, retdecl', unconstrained_tparams <> tparams, params', body', loc) - - -- Since this is a top-level function, we also resolve overloaded - -- types, using either defaults or complaining about ambiguities. - fixOverloadedTypes $ - typeVars rettype' <> foldMap (typeVars . patternType) params'' - - -- Then replace all inferred types in the body and parameters. - body''' <- updateTypes body'' - params''' <- updateTypes params'' - retdecl''' <- traverse updateTypes retdecl'' - rettype'' <- normTypeFully rettype' - - -- Check if the function body can actually be evaluated. - causalityCheck body''' - - -- Check for various problems. - mapM_ (mustBeIrrefutable . fmap toStruct) params'' - localChecks body''' - - let ((body'''', updated_ret), errors) = - Consumption.checkValDef - ( fname, - params''', - body''', - RetType dims rettype'', - retdecl''', - loc - ) - - mapM_ throwError errors - - pure (tparams', params''', retdecl''', updated_ret, body'''') + case maybe_tysubstss of + [] -> error "impossible" + [maybe_tysubsts] -> doChecks (maybe_tysubsts, params', retdecl', body') + _ -> typeError loc mempty "Rank ILP is ambiguous" + where + -- TODO: Print out the possibilities. (And also potentially eliminate + --- some of the possibilities to disambiguate). + + doChecks (maybe_tysubsts, params', retdecl', body') = + case maybe_tysubsts of + Left err -> typeError loc mempty $ pretty err + Right (unconstrained, tysubsts) -> runTermTypeM checkExp tysubsts $ do + let unconstrained_tparams = map (\v -> TypeParamType Unlifted v mempty) unconstrained + (tparams', params'', retdecl'', RetType dims rettype', body'') <- + checkBinding (fname, retdecl', unconstrained_tparams <> tparams, params', body', loc) + + -- Since this is a top-level function, we also resolve overloaded + -- types, using either defaults or complaining about ambiguities. + fixOverloadedTypes $ + typeVars rettype' <> foldMap (typeVars . patternType) params'' + + -- Then replace all inferred types in the body and parameters. + body''' <- updateTypes body'' + params''' <- updateTypes params'' + retdecl''' <- traverse updateTypes retdecl'' + rettype'' <- normTypeFully rettype' + + -- Check if the function body can actually be evaluated. + causalityCheck body''' + + -- Check for various problems. + mapM_ (mustBeIrrefutable . fmap toStruct) params'' + localChecks body''' + + let ((body'''', updated_ret), errors) = + Consumption.checkValDef + ( fname, + params''', + body''', + RetType dims rettype'', + retdecl''', + loc + ) + + mapM_ throwError errors + + pure (tparams', params''', retdecl''', updated_ret, body'''') diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index e5a25df58c..467bde81bf 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -1150,7 +1150,7 @@ checkValDef :: SrcLoc ) -> TypeM - ( Either T.Text ([VName], M.Map TyVar (TypeBase () NoUniqueness)), + ( [Either T.Text ([VName], M.Map TyVar (TypeBase () NoUniqueness))], [Pat ParamType], Maybe (TypeExp Exp VName), Exp @@ -1173,25 +1173,27 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do unlines $ map prettyString cts ] - (cts', tyvars') <- rankAnalysis loc cts tyvars - - solution <- - bitraverse pure (traverse (M.traverseWithKey (doDefaults mempty))) $ - solve cts' tyvars' - - traceM $ - unlines - [ "## constraints:", - unlines $ map prettyString cts', - "## tyvars:", - unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars', - "## solution:", - let p (v, t) = prettyNameString v <> " => " <> prettyString t - in either T.unpack (unlines . map p . M.toList . snd) solution, - either (const mempty) (unlines . ("## unconstrained:" :) . map prettyNameString . fst) solution - ] - - pure (solution, params', retdecl', body') + cts_tyvars' <- rankAnalysis loc cts tyvars + + solutions <- + forM cts_tyvars' $ + bitraverse pure (traverse (M.traverseWithKey (doDefaults mempty))) + . uncurry solve + + forM (zip solutions cts_tyvars') $ \(solution, (cts', tyvars')) -> + traceM $ + unlines + [ "## constraints:", + unlines $ map prettyString cts', + "## tyvars:", + unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars', + "## solution:", + let p (v, t) = prettyNameString v <> " => " <> prettyString t + in either T.unpack (unlines . map p . M.toList . snd) solution, + either (const mempty) (unlines . ("## unconstrained:" :) . map prettyNameString . fst) solution + ] + + pure (solutions, params', retdecl', body') checkSingleExp :: ExpBase NoInfo VName -> From 012680d2cb46c96b33555ce5e3b073b643bff917 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Mon, 19 Feb 2024 20:58:55 -0800 Subject: [PATCH 084/258] Cap the number of solutions. --- src/Language/Futhark/TypeChecker/Rank.hs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 729f58848d..247f0555f7 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -155,12 +155,12 @@ ambigCheckLinearProg prog (opt, ranks) = zero_bins = M.filterWithKey (\k v -> is_bin_var k && v == 0) ranks lsum = foldr (~+~) (constant 0) --- We should probably cap the iteration on this enumerateRankSols :: LinearProg -> [Map VName Int] enumerateRankSols prog = - takeSolns $ - iterate next_sol $ - (prog,) <$> run_glpk prog + take 5 $ + takeSolns $ + iterate next_sol $ + (prog,) <$> run_glpk prog where run_glpk = unsafePerformIO . glpk next_sol m = do From 9ab28c07195c8e1ebc275a8769f1e172e6d72e4d Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Tue, 20 Feb 2024 05:37:42 -0800 Subject: [PATCH 085/258] Make big `M` actually big(ish). --- src/Futhark/Solve/LP.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Futhark/Solve/LP.hs b/src/Futhark/Solve/LP.hs index 7623033e7c..ca2fbe73f0 100644 --- a/src/Futhark/Solve/LP.hs +++ b/src/Futhark/Solve/LP.hs @@ -201,7 +201,7 @@ linearProgToPulp prog = rm_subscript x = fromMaybe x $ lookup x $ zip "₀₁₂₃₄₅₆₇₈₉" "0123456789" bigM :: (Num a) => a -bigM = 10 ^ 3 +bigM = 10 ^ 6 oneIsZero :: (Eq a, Num a, Ord v) => (v, v) -> (v, v) -> [Constraint v a] oneIsZero (b1, x1) (b2, x2) = From 7616bcce1174243c3560e27f439394cbddf9220c Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Tue, 20 Feb 2024 05:42:00 -0800 Subject: [PATCH 086/258] AUTOMAP for short-circuiting operators. --- src/Futhark/Internalise/Exps.hs | 44 +++++++++++---------------------- 1 file changed, 14 insertions(+), 30 deletions(-) diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index 2e3cabd4f5..f8d20ee005 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -363,28 +363,26 @@ internaliseAppExp desc (E.AppRes et ext) e@E.Apply {} = () -- Short-circuiting operators are magical. | baseTag (qualLeaf qfname) <= maxIntrinsicTag, - baseString (qualLeaf qfname) == "&&", - [(x, _), (y, _)] <- args -> - internaliseExp desc $ - E.AppExp - (E.If x y (E.Literal (E.BoolValue False) mempty) mempty) - (Info $ AppRes (E.Scalar $ E.Prim E.Bool) []) + baseString (qualLeaf qfname) == "&&" -> + withAutoMap ams arg_desc res_t args $ \[([x], x_stms), ([y], y_stms)] -> do + letValExp' desc + =<< eIf + (addStms x_stms >> pure (BasicOp $ SubExp x)) + (addStms y_stms >> eBody [pure $ BasicOp $ SubExp y]) + (eBody [pure $ BasicOp $ SubExp $ Constant $ I.BoolValue False]) | baseTag (qualLeaf qfname) <= maxIntrinsicTag, - baseString (qualLeaf qfname) == "||", - [(x, _), (y, _)] <- args -> - internaliseExp desc $ - E.AppExp - (E.If x (E.Literal (E.BoolValue True) mempty) y mempty) - (Info $ AppRes (E.Scalar $ E.Prim E.Bool) []) + baseString (qualLeaf qfname) == "||" -> + withAutoMap ams arg_desc res_t args $ \[([x], x_stms), ([y], y_stms)] -> do + letValExp' desc + =<< eIf + (addStms x_stms >> pure (BasicOp $ SubExp x)) + (eBody [pure $ BasicOp $ SubExp $ Constant $ I.BoolValue True]) + (addStms y_stms >> eBody [pure $ BasicOp $ SubExp y]) -- Overloaded and intrinsic functions never take array -- arguments (except equality, but those cannot be -- existential), so we can safely ignore the existential -- dimensions. | Just internalise <- isOverloadedFunction qfname desc loc -> do - -- let prepareArg (arg, _) = - -- (E.toStruct (E.typeOf arg),) <$> internaliseExp "arg" arg - -- internalise =<< mapM prepareArg args - -- withAutoMap_ ams arg_desc res_t args $ \args' -> do let prepareArg (arg, _, am) arg' = (E.toStruct $ E.stripArray (E.shapeRank $ autoMap am) (E.typeOf arg), arg') @@ -392,26 +390,12 @@ internaliseAppExp desc (E.AppRes et ext) e@E.Apply {} = | Just internalise <- isIntrinsicFunction qfname (map fst args) loc -> internalise desc | baseTag (qualLeaf qfname) <= maxIntrinsicTag, - -- Just (rettype, _) <- M.lookup fname I.builtInFunctions -> do - -- let tag ses = [(se, I.Observe) | se <- ses] - -- args' <- reverse <$> mapM (internaliseArg arg_desc) (reverse args) - -- let args'' = concatMap tag args' - -- letValExp' desc $ I.Apply fname args'' [(I.Prim rettype, mempty)] (Safe, loc, []) - -- Just (rettype, _) <- M.lookup fname I.builtInFunctions -> withAutoMap_ ams arg_desc res_t args $ \args' -> do let tag ses = [(se, I.Observe) | se <- ses] let args'' = concatMap tag args' letValExp' desc $ I.Apply fname args'' [(I.Prim rettype, mempty)] (Safe, loc, []) | otherwise -> do - traceM $ - unlines - [ "## qfname", - prettyString qfname - ] - -- args' <- concat . reverse <$> mapM (internaliseArg arg_desc) (reverse args) - -- funcall desc qfname args' loc - withAutoMap_ ams arg_desc res_t args $ \args' -> funcall desc qfname (concat args') loc internaliseAppExp desc _ (E.LetPat sizes pat e body _) = From 870f97f4acafc16212fc1b2974b5c3c36bc1b4fa Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Tue, 20 Feb 2024 06:08:56 -0800 Subject: [PATCH 087/258] Use `Int`s instead of `Double`s and print out rank maps. --- src/Language/Futhark/TypeChecker/Rank.hs | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 247f0555f7..5989c27593 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -15,11 +15,11 @@ import Language.Futhark.TypeChecker.Constraints import Language.Futhark.TypeChecker.Monad import System.IO.Unsafe -type LSum = LP.LSum VName Double +type LSum = LP.LSum VName Int -type Constraint = LP.Constraint VName Double +type Constraint = LP.Constraint VName Int -type LinearProg = LP.LinearProg VName Double +type LinearProg = LP.LinearProg VName Int type ScalarType = ScalarTypeBase SComp NoUniqueness @@ -181,7 +181,14 @@ solveRankILP loc prog = do ] case enumerateRankSols prog of [] -> typeError loc mempty "Rank ILP cannot be solved." - rs -> pure rs + rs -> do + traceM "## rank maps" + forM (zip [0 :: Int ..] rs) $ \(i, r) -> + traceM $ + unlines $ + "\n## rank map " <> prettyString i + : map prettyString (M.toList r) + pure rs rankAnalysis :: (MonadTypeChecker m) => SrcLoc -> [Ct] -> TyVars -> m [([Ct], TyVars)] rankAnalysis _ [] tyVars = pure [([], tyVars)] From 87da34a300b0acf49e951a75f7215d72f56a7b35 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Tue, 20 Feb 2024 06:14:59 -0800 Subject: [PATCH 088/258] Apparently powers of 2 are better. 10^6 also somehow gives incorrect results. --- src/Futhark/Solve/LP.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Futhark/Solve/LP.hs b/src/Futhark/Solve/LP.hs index ca2fbe73f0..044f6efe63 100644 --- a/src/Futhark/Solve/LP.hs +++ b/src/Futhark/Solve/LP.hs @@ -201,7 +201,7 @@ linearProgToPulp prog = rm_subscript x = fromMaybe x $ lookup x $ zip "₀₁₂₃₄₅₆₇₈₉" "0123456789" bigM :: (Num a) => a -bigM = 10 ^ 6 +bigM = 2 ^ 10 oneIsZero :: (Eq a, Num a, Ord v) => (v, v) -> (v, v) -> [Constraint v a] oneIsZero (b1, x1) (b2, x2) = From 043ff227d4692f18b09465763c12e21521a68ecb Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Tue, 20 Feb 2024 08:11:27 -0800 Subject: [PATCH 089/258] Support auto replicates in internalization. --- src/Futhark/Internalise/Exps.hs | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index f8d20ee005..6f25174a91 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -938,8 +938,24 @@ withAutoMap ams arg_desc res_t args_e innerM = do ts | otherwise = pure $ Right $ zip ses ts + internaliseShape = + fmap I.Shape . mapM (internaliseExp1 "") . E.shapeDims + + addReplicates = + zipWithM + ( \am arg -> do + rep_shape <- + internaliseShape $ + autoRep am `E.shapePrefix` autoFrame am + if I.shapeRank rep_shape > 0 + then concat <$> mapM (letValExp' "autoRep" . BasicOp . Replicate rep_shape) arg + else pure arg + ) + expand args stms argts ams' level - | level <= 0 = innerM $ zip args stms + | level <= 0 = do + args' <- addReplicates ams' args + innerM $ zip args' stms | otherwise = do let ds' = map autoMapRank ams' arg_params <- mapM (mkLambdaParams level) $ zip4 args argts stms ds' From 674bf01aa19e9727fea49b2adb51046f8b41770c Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 21 Feb 2024 15:26:54 +0100 Subject: [PATCH 090/258] Fiddle with Ident type checking. --- src/Language/Futhark/TypeChecker/Terms.hs | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index f80407b378..45f6c820fb 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -343,6 +343,10 @@ unscopeType :: unscopeType tloc unscoped = sizeFree tloc $ find (`elem` unscoped) . fvVars . freeInExp +checkIdent :: Ident StructType -> TermTypeM (Ident StructType) +checkIdent (Ident v t loc) = + Ident v <$> traverse (replaceTyVars loc) t <*> pure loc + checkExp :: Exp -> TermTypeM Exp checkExp (Literal val loc) = pure $ Literal val loc @@ -601,18 +605,20 @@ checkExp (AppExp (LetFun name (tparams, params, maybe_retdecl, _, e) body loc) _ ) (Info $ AppRes body_t ext) checkExp (AppExp (LetWith dest src slice ve body loc) _) = do + dest' <- checkIdent dest + src' <- checkIdent src slice' <- checkSlice slice - (t, _) <- newArrayType (mkUsage src "type of source array") "src" $ sliceDims slice' - unify (mkUsage loc "type of target array") t $ unInfo $ identType src + (t, _) <- newArrayType (mkUsage src' "type of source array") "src" $ sliceDims slice' + unify (mkUsage loc "type of target array") t $ unInfo $ identType src' (elemt, _) <- sliceShape (Just (loc, Nonrigid)) slice' =<< normTypeFully t ve' <- unifies "type of target array" elemt =<< checkExp ve - bindingIdent dest $ do + bindingIdent dest' $ do body' <- checkExp body - (body_t, ext) <- unscopeType loc [identName dest] =<< expTypeFully body' - pure $ AppExp (LetWith dest src slice' ve' body' loc) (Info $ AppRes body_t ext) + (body_t, ext) <- unscopeType loc [identName dest'] =<< expTypeFully body' + pure $ AppExp (LetWith dest' src' slice' ve' body' loc) (Info $ AppRes body_t ext) checkExp (Update src slice ve loc) = do slice' <- checkSlice slice (t, _) <- newArrayType (mkUsage' src) "src" $ sliceDims slice' From a10bd3d30615b30fa0f31f8b77f3ba26e18d5107 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 21 Feb 2024 19:52:30 +0100 Subject: [PATCH 091/258] Fix type checking of LetWith. --- src/Language/Futhark/TypeChecker/Terms.hs | 9 +++------ src/Language/Futhark/TypeChecker/Terms/Monad.hs | 2 +- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 45f6c820fb..9c8ff1ddcb 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -343,10 +343,6 @@ unscopeType :: unscopeType tloc unscoped = sizeFree tloc $ find (`elem` unscoped) . fvVars . freeInExp -checkIdent :: Ident StructType -> TermTypeM (Ident StructType) -checkIdent (Ident v t loc) = - Ident v <$> traverse (replaceTyVars loc) t <*> pure loc - checkExp :: Exp -> TermTypeM Exp checkExp (Literal val loc) = pure $ Literal val loc @@ -605,8 +601,9 @@ checkExp (AppExp (LetFun name (tparams, params, maybe_retdecl, _, e) body loc) _ ) (Info $ AppRes body_t ext) checkExp (AppExp (LetWith dest src slice ve body loc) _) = do - dest' <- checkIdent dest - src' <- checkIdent src + src_t <- lookupVar loc (qualName (identName src)) (unInfo (identType src)) + let src' = src {identType = Info src_t} + dest' = dest {identType = Info src_t} slice' <- checkSlice slice (t, _) <- newArrayType (mkUsage src' "type of source array") "src" $ sliceDims slice' unify (mkUsage loc "type of target array") t $ unInfo $ identType src' diff --git a/src/Language/Futhark/TypeChecker/Terms/Monad.hs b/src/Language/Futhark/TypeChecker/Terms/Monad.hs index 8bdbb81daf..094f4f8b62 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Monad.hs @@ -358,7 +358,7 @@ replaceTyVars loc orig_t = do f (Scalar (TypeVar u (QualName [] v) [])) | Just t <- M.lookup v tyvars = - fst <$> allDimsFreshInType (mkUsage loc "instantiation") Nonrigid "dv" (second (const u) t) + fst <$> allDimsFreshInType (mkUsage loc "replaceTyVars") Nonrigid "dv" (second (const u) t) | otherwise = pure $ Scalar (TypeVar u (QualName [] v) []) f (Scalar (TypeVar u qn targs)) = From 2bb4085cf4ea602ebce6b9d739c64839a68f3682 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 21 Feb 2024 21:43:56 +0100 Subject: [PATCH 092/258] OptionPricing now type checks. --- src/Language/Futhark/TypeChecker/Terms.hs | 4 +++- src/Language/Futhark/TypeChecker/Terms/Pat.hs | 2 +- src/Language/Futhark/TypeChecker/Terms2.hs | 12 ++++++++---- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 9c8ff1ddcb..f4fc11ebdd 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -1333,7 +1333,8 @@ checkBinding (fname, maybe_retdecl, tparams, params, body, loc) = verifyFunctionParams (Just fname) params'' (tparams', params''', rettype') <- - letGeneralise (baseName fname) loc tparams params'' =<< unscopeUnknown rettype + letGeneralise (baseName fname) loc tparams params'' + =<< unscopeUnknown rettype when ( null params @@ -1449,6 +1450,7 @@ closeOverTypes defname defloc tparams paramts ret substs = do case M.lookup v substs of Just (_, UnknownSize {}) -> Just v _ -> Nothing + pure ( tparams ++ more_tparams, injectExt (nubOrd $ retext ++ mapMaybe mkExt (S.toList $ fvVars $ freeInType ret)) ret diff --git a/src/Language/Futhark/TypeChecker/Terms/Pat.hs b/src/Language/Futhark/TypeChecker/Terms/Pat.hs index b1a2f59a8d..dbd1d019f0 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Pat.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Pat.hs @@ -127,7 +127,7 @@ checkPat' _ (Id name (Info t) loc) NoneInferred = do pure $ Id name (Info t') loc checkPat' _ (Id name (Info t1) loc) (Ascribed t2) = do t <- replaceTyVars loc t1 - unify (mkUsage loc "wildcard") (toStruct t) (toStruct t2) + unify (mkUsage loc "id") (toStruct t) (toStruct t2) pure $ Id name (Info t) loc checkPat' _ (Wildcard (Info t) loc) NoneInferred = do t' <- replaceTyVars loc t diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 467bde81bf..6c0a55ab60 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -488,16 +488,20 @@ checkPat' (RecordPat fs loc) NoneInferred = checkPat' (PatAscription p t loc) maybe_outer_t = do (t', _, RetType _ st, _) <- checkTypeExp checkSizeExp t + -- Uniqueness kung fu to make the Monoid(mempty) instance give what + -- we expect. We should perhaps stop being so implicit. + st' <- asStructType loc $ toType $ resToParam st + case maybe_outer_t of Ascribed outer_t -> do - ctEq (toType st) (toType outer_t) + ctEq (toType st') (toType outer_t) PatAscription - <$> checkPat' p (Ascribed (resToParam st)) + <$> checkPat' p (Ascribed st') <*> pure t' <*> pure loc NoneInferred -> PatAscription - <$> checkPat' p (Ascribed (resToParam st)) + <$> checkPat' p (Ascribed st') <*> pure t' <*> pure loc checkPat' (PatLit l NoInfo loc) (Ascribed t) = do @@ -1180,7 +1184,7 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do bitraverse pure (traverse (M.traverseWithKey (doDefaults mempty))) . uncurry solve - forM (zip solutions cts_tyvars') $ \(solution, (cts', tyvars')) -> + forM_ (zip solutions cts_tyvars') $ \(solution, (cts', tyvars')) -> traceM $ unlines [ "## constraints:", From 4af653a645d18612fd89867933ab23fe0592da15 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Wed, 21 Feb 2024 23:34:23 -0800 Subject: [PATCH 093/258] - Introduce rank representation for `AutoMap` annotations to carry rank information across type checking phases. - Fix AUTOMAP in `Terms.hs`. --- src/Language/Futhark/Syntax.hs | 27 ++++++-- src/Language/Futhark/TypeChecker/Rank.hs | 42 +++++++++++-- src/Language/Futhark/TypeChecker/Terms.hs | 72 +++++++++++----------- src/Language/Futhark/TypeChecker/Terms2.hs | 6 +- 4 files changed, 99 insertions(+), 48 deletions(-) diff --git a/src/Language/Futhark/Syntax.hs b/src/Language/Futhark/Syntax.hs index ef7afa4d30..fae5741c8e 100644 --- a/src/Language/Futhark/Syntax.hs +++ b/src/Language/Futhark/Syntax.hs @@ -24,6 +24,9 @@ module Language.Futhark.Syntax shapeRank, stripDims, AutoMap (..), + autoRepRank, + autoMapRank, + autoFrameRank, TypeBase (..), TypeArg (..), SizeExp (..), @@ -261,13 +264,27 @@ stripDims i (Shape l) | i < length l = Just $ Shape $ drop i l | otherwise = Nothing -data AutoMap = AutoMap - { autoRep :: Shape Size, - autoMap :: Shape Size, - autoFrame :: Shape Size - } +data AutoMap + = AutoMap + { autoRep :: Shape Size, + autoMap :: Shape Size, + autoFrame :: Shape Size + } + | AutoMapRank Int Int Int deriving (Eq, Show, Ord) +autoRepRank :: AutoMap -> Int +autoRepRank (AutoMapRank r _ _) = r +autoRepRank _ = 0 + +autoMapRank :: AutoMap -> Int +autoMapRank (AutoMapRank _ m _) = m +autoMapRank _ = 0 + +autoFrameRank :: AutoMap -> Int +autoFrameRank (AutoMapRank _ _ f) = f +autoFrameRank _ = 0 + instance Semigroup AutoMap where (AutoMap r1 m1 f1) <> (AutoMap r2 m2 f2) = AutoMap (r1 <> r2) (m1 <> m2) (f1 <> f2) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 5989c27593..554eba77dd 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -2,6 +2,8 @@ module Language.Futhark.TypeChecker.Rank (rankAnalysis) where import Control.Monad.Reader import Control.Monad.State +import Data.Bifunctor +import Data.Functor.Identity import Data.List qualified as L import Data.Map (Map) import Data.Map qualified as M @@ -11,6 +13,7 @@ import Futhark.Solve.GLPK import Futhark.Solve.LP hiding (Constraint, LSum, LinearProg) import Futhark.Solve.LP qualified as LP import Language.Futhark hiding (ScalarType) +import Language.Futhark.Traversals import Language.Futhark.TypeChecker.Constraints import Language.Futhark.TypeChecker.Monad import System.IO.Unsafe @@ -190,11 +193,13 @@ solveRankILP loc prog = do : map prettyString (M.toList r) pure rs -rankAnalysis :: (MonadTypeChecker m) => SrcLoc -> [Ct] -> TyVars -> m [([Ct], TyVars)] -rankAnalysis _ [] tyVars = pure [([], tyVars)] -rankAnalysis loc cs tyVars = do - solveRankILP loc (mkLinearProg (foldMap splitFuncs cs) tyVars) - >>= mapM (substRankInfo cs tyVars) +rankAnalysis :: (MonadTypeChecker m) => SrcLoc -> [Ct] -> TyVars -> Exp -> m [(([Ct], TyVars), Exp)] +rankAnalysis _ [] tyVars body = pure [(([], tyVars), body)] +rankAnalysis loc cs tyVars body = do + rank_maps <- solveRankILP loc (mkLinearProg (foldMap splitFuncs cs) tyVars) + cts_tyvars' <- mapM (substRankInfo cs tyVars) rank_maps + let bodys = map (flip updAM body) rank_maps + pure $ zip cts_tyvars' bodys where splitFuncs ( CtEq @@ -321,3 +326,30 @@ instance SubstRanks (TypeBase SComp u) where instance SubstRanks Ct where substRanks (CtEq t1 t2) = CtEq <$> substRanks t1 <*> substRanks t2 substRanks _ = error "" + +updAM :: Map VName Int -> Exp -> Exp +updAM rank_map e = + case e of + AppExp (Apply f args loc) res -> + let f' = updAM rank_map f + args' = + fmap + ( bimap + (fmap $ bimap id upd) + (updAM rank_map) + ) + args + in AppExp (Apply f' args' loc) res + AppExp (BinOp op t (x, Info (xv, xam)) (y, Info (yv, yam)) loc) res -> + AppExp (BinOp op t (updAM rank_map x, Info (xv, upd xam)) (updAM rank_map y, Info (yv, upd yam)) loc) res + _ -> runIdentity $ astMap m e + where + dimToRank (Var (QualName [] x) _ _) = rank_map M.! x + dimToRank e = error $ prettyString e + shapeToRank = sum . fmap dimToRank + upd (AutoMap r m f) = + AutoMapRank (shapeToRank r) (shapeToRank m) (shapeToRank f) + m = + identityMapper + { mapOnExp = pure . updAM rank_map + } diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index f4fc11ebdd..8c48591696 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -466,33 +466,33 @@ checkExp (Coerce e te _ loc) = do t <- expTypeFully e' t' <- matchDims (const . const pure) t te_t pure $ Coerce e' te' (Info t') loc -checkExp (AppExp (Apply fe args loc) _) = do +checkExp e@(AppExp (Apply fe args loc) _) = do fe' <- checkExp fe + let ams = fmap (snd . unInfo . fst) args args' <- mapM (checkExp . snd) args t <- expType fe' let fname = case fe' of Var v _ _ -> Just v _ -> Nothing - ((_, exts, rt), args'') <- mapAccumLM (onArg fname) (0, [], t) args' + ((_, exts, rt), args'') <- mapAccumLM (onArg fname) (0, [], t) (NE.zip args' ams) pure $ AppExp (Apply fe' args'' loc) $ Info $ AppRes rt exts where - onArg fname (i, all_exts, t) arg' = do - (_, rt, argext, exts, am) <- checkApply loc (fname, i) t arg' + onArg fname (i, all_exts, t) (arg', am) = do + (_, rt, argext, exts, am') <- checkApply loc (fname, i) t arg' am pure - ( (i + 1, all_exts <> exts, arrayOf (autoFrame am) rt), - (Info (argext, am), arg') + ( (i + 1, all_exts <> exts, rt), + (Info (argext, am'), arg') ) -checkExp (AppExp (BinOp (op, oploc) (Info op_t) (e1, _) (e2, _) loc) _) = do +checkExp (AppExp (BinOp (op, oploc) (Info op_t) (e1, Info (_, xam)) (e2, Info (_, yam)) loc) _) = do ftype <- lookupVar oploc op op_t e1' <- checkExp e1 e2' <- checkExp e2 - -- Note that the application to the first operand cannot fix any -- existential sizes, because it must by necessity be a function. - (_, rt, p1_ext, _, am1) <- checkApply loc (Just op, 0) ftype e1' - (_, rt', p2_ext, retext, am2) <- checkApply loc (Just op, 1) (arrayOf (autoFrame am1) rt) e2' + (_, rt, p1_ext, _, am1) <- checkApply loc (Just op, 0) ftype e1' xam + (_, rt', p2_ext, retext, am2) <- checkApply loc (Just op, 1) rt e2' yam pure $ AppExp @@ -503,7 +503,7 @@ checkExp (AppExp (BinOp (op, oploc) (Info op_t) (e1, _) (e2, _) loc) _) = do (e2', Info (p2_ext, am2)) loc ) - (Info (AppRes (arrayOf (autoFrame am2) rt') retext)) + (Info (AppRes rt' retext)) checkExp (Project k e _ loc) = do e' <- checkExp e t <- expType e' @@ -725,10 +725,10 @@ checkExp (Lambda params body rettype_te (Info (RetType _ rt)) loc) = do checkExp (OpSection op (Info op_t) loc) = do ftype <- lookupVar loc op op_t pure $ OpSection op (Info ftype) loc -checkExp (OpSectionLeft op (Info op_t) e _ _ loc) = do +checkExp (OpSectionLeft op (Info op_t) e (Info (_, _, _, am), _) _ loc) = do ftype <- lookupVar loc op op_t e' <- checkExp e - (t1, rt, argext, retext, am) <- checkApply loc (Just op, 0) ftype e' + (t1, rt, argext, retext, am) <- checkApply loc (Just op, 0) ftype e' am case (ftype, rt) of (Scalar (Arrow _ m1 d1 _ _), Scalar (Arrow _ m2 d2 t2 (RetType ds rt2))) -> pure $ @@ -742,7 +742,7 @@ checkExp (OpSectionLeft op (Info op_t) e _ _ loc) = do _ -> typeError loc mempty $ "Operator section with invalid operator of type" <+> pretty ftype -checkExp (OpSectionRight op (Info op_t) e _ _ loc) = do +checkExp (OpSectionRight op (Info op_t) e (_, Info (_, _, _, am)) _ loc) = do ftype <- lookupVar loc op op_t e' <- checkExp e case ftype of @@ -753,6 +753,7 @@ checkExp (OpSectionRight op (Info op_t) e _ _ loc) = do (Just op, 1) (Scalar $ Arrow mempty m2 d2 t2 $ RetType [] $ Scalar $ Arrow Nonunique m1 d1 t1 $ RetType dims2 ret) e' + am case arrow' of Scalar (Arrow _ _ _ t1' (RetType dims2' ret')) -> pure $ @@ -930,19 +931,25 @@ stripToMatch paramt (Array _ (Shape (d : ds)) argt) = first (Shape [d] <>) $ stripToMatch paramt $ arrayOf (Shape ds) (Scalar argt) stripToMatch _ argt = (mempty, argt) +splitArrayAt :: Int -> StructType -> (Shape Size, StructType) +splitArrayAt x t = + (Shape $ take x $ shapeDims $ arrayShape t, stripArray x t) + checkApply :: SrcLoc -> ApplyOp -> StructType -> Exp -> + AutoMap -> TermTypeM (StructType, StructType, Maybe VName, [VName], AutoMap) -checkApply loc (fname, _) (Scalar (Arrow _ pname _ tp1 tp2)) argexp = do +checkApply loc (fname, _) ft@(Scalar (Arrow _ pname _ tp1 tp2)) argexp am = do let argtype = typeOf argexp onFailure (CheckingApply fname argexp tp1 argtype) $ do - (am_map_shape, argtype_automap) <- - stripToMatch <$> normTypeFully tp1 <*> normTypeFully argtype + (am_map_shape, argtype_with_frame) <- splitArrayAt (autoMapRank am) <$> normTypeFully argtype + (am_rep_shape, tp1_with_frame) <- splitArrayAt (autoRepRank am) <$> normTypeFully tp1 + let (am_frame_shape, argtype_automap) = splitArrayAt (autoFrameRank am) argtype_with_frame - unify (mkUsage argexp "use as function argument") tp1 argtype_automap + unify (mkUsage argexp "use as function argument") tp1_with_frame argtype_with_frame -- Perform substitutions of instantiated variables in the types. (tp2', ext) <- instantiateDimsInReturnType loc fname =<< normTypeFully tp2 @@ -986,21 +993,16 @@ checkApply loc (fname, _) (Scalar (Arrow _ pname _ tp1 tp2)) argexp = do AutoMap { autoMap = am_map_shape, autoRep = mempty, - autoFrame = am_map_shape + autoFrame = am_map_shape <> am_frame_shape } - pure (tp1, tp2'', argext, ext, am) -checkApply loc fname (Array _ shape t) arg = do - -- This implies the function is the result of an automap. - (t1, rt, argext, retext, am) <- checkApply loc fname (Scalar t) arg - let am' = - am - { autoRep = shape <> autoRep am, - autoFrame = shape <> autoFrame am - } - pure (t1, rt, argext, retext, am') -checkApply _ _ _ _ = - error "checkApply: impossible case" + pure (tp1, distributeFrame (autoMap am) tp2'', argext, ext, am) + where + distributeFrame frame (Scalar (Arrow u p d a (RetType ds b))) = + Scalar $ Arrow u p d (arrayOf frame a) (RetType ds (arrayOfWithAliases (uniqueness b) frame b)) + distributeFrame frame t = arrayOf frame t +checkApply _ _ _ _ _ = + error "checkApply: array" -- | Type-check a single expression in isolation. This expression may -- turn out to be polymorphic, in which case the list of type @@ -1595,11 +1597,11 @@ checkFunDef :: Exp ) checkFunDef (fname, retdecl, tparams, params, body, loc) = do - (maybe_tysubstss, params', retdecl', body') <- + (maybe_tysubstss, params', retdecl', bodys') <- Terms2.checkValDef (fname, retdecl, tparams, params, body, loc) - case maybe_tysubstss of - [] -> error "impossible" - [maybe_tysubsts] -> doChecks (maybe_tysubsts, params', retdecl', body') + case (maybe_tysubstss, bodys') of + ([], _) -> error "impossible" + ([maybe_tysubsts], [body']) -> doChecks (maybe_tysubsts, params', retdecl', body') _ -> typeError loc mempty "Rank ILP is ambiguous" where -- TODO: Print out the possibilities. (And also potentially eliminate diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 6c0a55ab60..6b1f9cb763 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -1157,7 +1157,7 @@ checkValDef :: ( [Either T.Text ([VName], M.Map TyVar (TypeBase () NoUniqueness))], [Pat ParamType], Maybe (TypeExp Exp VName), - Exp + [Exp] ) checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do bindParams tparams params $ \params' -> do @@ -1177,7 +1177,7 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do unlines $ map prettyString cts ] - cts_tyvars' <- rankAnalysis loc cts tyvars + (cts_tyvars', bodys') <- unzip <$> rankAnalysis loc cts tyvars body' solutions <- forM cts_tyvars' $ @@ -1197,7 +1197,7 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do either (const mempty) (unlines . ("## unconstrained:" :) . map prettyNameString . fst) solution ] - pure (solutions, params', retdecl', body') + pure (solutions, params', retdecl', bodys') checkSingleExp :: ExpBase NoInfo VName -> From c336b16008f1fdba4ae61dba508ad0301db4c0f5 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 22 Feb 2024 09:13:16 +0100 Subject: [PATCH 094/258] Do not rewrite automapped short-circuiting ops. --- src/Futhark/Internalise/FullNormalise.hs | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/src/Futhark/Internalise/FullNormalise.hs b/src/Futhark/Internalise/FullNormalise.hs index 6ee354ea4f..509d41c7ca 100644 --- a/src/Futhark/Internalise/FullNormalise.hs +++ b/src/Futhark/Internalise/FullNormalise.hs @@ -299,16 +299,22 @@ getOrdering final (AppExp (Loop sizes pat einit form body loc) resT) = do body' <- transformBody body nameExp final $ AppExp (Loop sizes pat einit' form' body' loc) resT getOrdering final (AppExp (BinOp (op, oloc) opT (el, Info (elp, elam)) (er, Info (erp, eram)) loc) (Info resT)) = do + -- Rewrite short-circuiting boolean operators on scalars to explicit + -- if-then-else. expr' <- case (isOr, isAnd) of - (True, _) -> do - el' <- naming "or_lhs" $ getOrdering True el - er' <- naming "or_rhs" $ transformBody er - pure $ AppExp (If el' (Literal (BoolValue True) mempty) er' loc) (Info resT) - (_, True) -> do - el' <- naming "and_lhs" $ getOrdering True el - er' <- naming "and_rhs" $ transformBody er - pure $ AppExp (If el' er' (Literal (BoolValue False) mempty) loc) (Info resT) - (False, False) -> do + (True, _) + | elam == mempty, + eram == mempty -> do + el' <- naming "or_lhs" $ getOrdering True el + er' <- naming "or_rhs" $ transformBody er + pure $ AppExp (If el' (Literal (BoolValue True) mempty) er' loc) (Info resT) + (_, True) + | elam == mempty, + eram == mempty -> do + el' <- naming "and_lhs" $ getOrdering True el + er' <- naming "and_rhs" $ transformBody er + pure $ AppExp (If el' er' (Literal (BoolValue False) mempty) loc) (Info resT) + _ -> do el' <- naming (prettyString op <> "_lhs") $ getOrdering False el er' <- naming (prettyString op <> "_rhs") $ getOrdering False er pure $ mkApply (Var op opT oloc) [(elp, elam, el'), (erp, eram, er')] resT From ff353e57ba5924dffd1d6afe2451a6063db95dcf Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 22 Feb 2024 09:15:57 +0100 Subject: [PATCH 095/258] We do not need this. --- src/Futhark/Internalise/Exps.hs | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index 6f25174a91..6eca179e3c 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -361,23 +361,6 @@ internaliseAppExp desc (E.AppRes et ext) e@E.Apply {} = -- Some functions are magical (overloaded) and we handle that here. case () of () - -- Short-circuiting operators are magical. - | baseTag (qualLeaf qfname) <= maxIntrinsicTag, - baseString (qualLeaf qfname) == "&&" -> - withAutoMap ams arg_desc res_t args $ \[([x], x_stms), ([y], y_stms)] -> do - letValExp' desc - =<< eIf - (addStms x_stms >> pure (BasicOp $ SubExp x)) - (addStms y_stms >> eBody [pure $ BasicOp $ SubExp y]) - (eBody [pure $ BasicOp $ SubExp $ Constant $ I.BoolValue False]) - | baseTag (qualLeaf qfname) <= maxIntrinsicTag, - baseString (qualLeaf qfname) == "||" -> - withAutoMap ams arg_desc res_t args $ \[([x], x_stms), ([y], y_stms)] -> do - letValExp' desc - =<< eIf - (addStms x_stms >> pure (BasicOp $ SubExp x)) - (eBody [pure $ BasicOp $ SubExp $ Constant $ I.BoolValue True]) - (addStms y_stms >> eBody [pure $ BasicOp $ SubExp y]) -- Overloaded and intrinsic functions never take array -- arguments (except equality, but those cannot be -- existential), so we can safely ignore the existential From 986856d32bf2726c66b42a731d46bcd044155139 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 22 Feb 2024 09:40:05 +0100 Subject: [PATCH 096/258] Try to handle logical operators. --- src/Futhark/Internalise/Exps.hs | 5 ++++- src/Futhark/Internalise/FullNormalise.hs | 15 ++++++++++++++- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index 6eca179e3c..c17db733ba 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -1680,12 +1680,15 @@ isOverloadedFunction qname desc loc = do handle name | Just bop <- find ((name ==) . prettyString) [minBound .. maxBound :: E.BinOp] = Just $ \[(x_t, [x']), (y_t, [y'])] -> - case (x_t, y_t) of + case (arrayElem x_t, arrayElem y_t) of (E.Scalar (E.Prim t1), E.Scalar (E.Prim t2)) -> internaliseBinOp loc desc bop x' y' t1 t2 _ -> error "Futhark.Internalise.internaliseExp: non-primitive type in BinOp." handle _ = Nothing + arrayElem (E.Array _ _ t) = E.Scalar t + arrayElem t = t + -- | Handle intrinsic functions. These are only allowed to be called -- in the prelude, and their internalisation may involve inspecting -- the AST. diff --git a/src/Futhark/Internalise/FullNormalise.hs b/src/Futhark/Internalise/FullNormalise.hs index 509d41c7ca..17841c9f53 100644 --- a/src/Futhark/Internalise/FullNormalise.hs +++ b/src/Futhark/Internalise/FullNormalise.hs @@ -300,7 +300,8 @@ getOrdering final (AppExp (Loop sizes pat einit form body loc) resT) = do nameExp final $ AppExp (Loop sizes pat einit' form' body' loc) resT getOrdering final (AppExp (BinOp (op, oloc) opT (el, Info (elp, elam)) (er, Info (erp, eram)) loc) (Info resT)) = do -- Rewrite short-circuiting boolean operators on scalars to explicit - -- if-then-else. + -- if-then-else. Automapped cases are turned into applications of + -- intrinsic functions. expr' <- case (isOr, isAnd) of (True, _) | elam == mempty, @@ -308,18 +309,30 @@ getOrdering final (AppExp (BinOp (op, oloc) opT (el, Info (elp, elam)) (er, Info el' <- naming "or_lhs" $ getOrdering True el er' <- naming "or_rhs" $ transformBody er pure $ AppExp (If el' (Literal (BoolValue True) mempty) er' loc) (Info resT) + | otherwise -> do + el' <- naming "or_lhs" $ getOrdering False el + er' <- naming "or_rhs" $ getOrdering False er + pure $ mkApply orop [(elp, elam, el'), (erp, eram, er')] resT (_, True) | elam == mempty, eram == mempty -> do el' <- naming "and_lhs" $ getOrdering True el er' <- naming "and_rhs" $ transformBody er pure $ AppExp (If el' er' (Literal (BoolValue False) mempty) loc) (Info resT) + | otherwise -> do + el' <- naming "and_lhs" $ getOrdering False el + er' <- naming "and_rhs" $ getOrdering False er + pure $ mkApply andop [(elp, elam, el'), (erp, eram, er')] resT _ -> do el' <- naming (prettyString op <> "_lhs") $ getOrdering False el er' <- naming (prettyString op <> "_rhs") $ getOrdering False er pure $ mkApply (Var op opT oloc) [(elp, elam, el'), (erp, eram, er')] resT nameExp final expr' where + bool = Scalar $ Prim Bool + opt = foldFunType [bool, bool] $ RetType [] bool + andop = Var (qualName (intrinsicVar "&&")) (Info opt) mempty + orop = Var (qualName (intrinsicVar "||")) (Info opt) mempty isOr = baseName (qualLeaf op) == "||" isAnd = baseName (qualLeaf op) == "&&" getOrdering final (AppExp (LetWith (Ident dest dty dloc) (Ident src sty sloc) slice e body loc) _) = do From 263ec737c8dbbac6a3d88e227592926d6d6a88f9 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Thu, 22 Feb 2024 00:47:16 -0800 Subject: [PATCH 097/258] Remove `autoMapRank`. --- src/Language/Futhark/Syntax.hs | 21 ++++++++------------- src/Language/Futhark/TypeChecker/Rank.hs | 7 ++++--- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/src/Language/Futhark/Syntax.hs b/src/Language/Futhark/Syntax.hs index fae5741c8e..b39a82cd0b 100644 --- a/src/Language/Futhark/Syntax.hs +++ b/src/Language/Futhark/Syntax.hs @@ -264,26 +264,21 @@ stripDims i (Shape l) | i < length l = Just $ Shape $ drop i l | otherwise = Nothing -data AutoMap - = AutoMap - { autoRep :: Shape Size, - autoMap :: Shape Size, - autoFrame :: Shape Size - } - | AutoMapRank Int Int Int +data AutoMap = AutoMap + { autoRep :: Shape Size, + autoMap :: Shape Size, + autoFrame :: Shape Size + } deriving (Eq, Show, Ord) autoRepRank :: AutoMap -> Int -autoRepRank (AutoMapRank r _ _) = r -autoRepRank _ = 0 +autoRepRank = shapeRank . autoRep autoMapRank :: AutoMap -> Int -autoMapRank (AutoMapRank _ m _) = m -autoMapRank _ = 0 +autoMapRank = shapeRank . autoMap autoFrameRank :: AutoMap -> Int -autoFrameRank (AutoMapRank _ _ f) = f -autoFrameRank _ = 0 +autoFrameRank = shapeRank . autoFrame instance Semigroup AutoMap where (AutoMap r1 m1 f1) <> (AutoMap r2 m2 f2) = AutoMap (r1 <> r2) (m1 <> m2) (f1 <> f2) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 554eba77dd..1eef479dcf 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -344,11 +344,12 @@ updAM rank_map e = AppExp (BinOp op t (updAM rank_map x, Info (xv, upd xam)) (updAM rank_map y, Info (yv, upd yam)) loc) res _ -> runIdentity $ astMap m e where - dimToRank (Var (QualName [] x) _ _) = rank_map M.! x + dimToRank (Var (QualName [] x) _ _) = + replicate (rank_map M.! x) (TupLit mempty mempty) dimToRank e = error $ prettyString e - shapeToRank = sum . fmap dimToRank + shapeToRank = Shape . foldMap dimToRank upd (AutoMap r m f) = - AutoMapRank (shapeToRank r) (shapeToRank m) (shapeToRank f) + AutoMap (shapeToRank r) (shapeToRank m) (shapeToRank f) m = identityMapper { mapOnExp = pure . updAM rank_map From 285604e1a1dca422fc7f20fc25e682a9a1c4ba4c Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 22 Feb 2024 14:28:23 +0100 Subject: [PATCH 098/258] Another hash. --- src/Language/Futhark/TypeChecker.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker.hs b/src/Language/Futhark/TypeChecker.hs index 6ce5977c9e..d6b5baf8fc 100644 --- a/src/Language/Futhark/TypeChecker.hs +++ b/src/Language/Futhark/TypeChecker.hs @@ -707,7 +707,7 @@ checkValBind vb = do let vb' = ValBind entry' fname maybe_tdecl' (Info rettype) tparams' params' body' doc attrs' loc - traceM $ unlines ["Inferred:", prettyString vb'] + traceM $ unlines ["# Inferred:", prettyString vb'] pure ( mempty From 2334dfd11e1209a4cc2d409a494a9c801d69f087 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 22 Feb 2024 14:28:56 +0100 Subject: [PATCH 099/258] Make uniqueness explicit. --- src/Language/Futhark/TypeChecker/Terms2.hs | 122 +++++++++++---------- 1 file changed, 65 insertions(+), 57 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 6b1f9cb763..d9e5930b71 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -216,8 +216,8 @@ incCounter = do put s {termCounter = termCounter s + 1} pure $ termCounter s -tyVarType :: (Monoid u) => TyVar -> TypeBase dim u -tyVarType v = Scalar $ TypeVar mempty (qualName v) [] +tyVarType :: u -> TyVar -> TypeBase dim u +tyVarType u v = Scalar $ TypeVar u (qualName v) [] newTyVarWith :: (Located loc) => loc -> Name -> TyVarInfo -> TermM TyVar newTyVarWith _loc desc info = do @@ -230,38 +230,46 @@ newTyVarWith _loc desc info = do newTyVar :: (Located loc) => loc -> Name -> TermM TyVar newTyVar loc desc = newTyVarWith loc desc TyVarFree -newType :: (Located loc, Monoid u) => loc -> Name -> TermM (TypeBase dim u) -newType loc desc = tyVarType <$> newTyVar loc desc +newType :: (Located loc) => loc -> Name -> u -> TermM (TypeBase dim u) +newType loc desc u = tyVarType u <$> newTyVar loc desc -newTypeWithField :: (Monoid u) => SrcLoc -> Name -> Name -> Type -> TermM (TypeBase d u) +newTypeWithField :: SrcLoc -> Name -> Name -> Type -> TermM Type newTypeWithField loc desc k t = - tyVarType <$> newTyVarWith loc desc (TyVarRecord $ M.singleton k t) + tyVarType NoUniqueness <$> newTyVarWith loc desc (TyVarRecord $ M.singleton k t) -newTypeWithConstr :: (Monoid u) => SrcLoc -> Name -> Name -> [TypeBase SComp u] -> TermM (TypeBase d u) -newTypeWithConstr loc desc k ts = - tyVarType <$> newTyVarWith loc desc (TyVarSum $ M.singleton k ts') +newTypeWithConstr :: SrcLoc -> Name -> u -> Name -> [TypeBase SComp u] -> TermM (TypeBase d u) +newTypeWithConstr loc desc u k ts = + tyVarType u <$> newTyVarWith loc desc (TyVarSum $ M.singleton k ts') where ts' = map (`setUniqueness` NoUniqueness) ts -newTypeOverloaded :: (Monoid u) => SrcLoc -> Name -> [PrimType] -> TermM (TypeBase d u) +newTypeOverloaded :: SrcLoc -> Name -> [PrimType] -> TermM (TypeBase d NoUniqueness) newTypeOverloaded loc name pts = - tyVarType <$> newTyVarWith loc name (TyVarPrim pts) + tyVarType NoUniqueness <$> newTyVarWith loc name (TyVarPrim pts) newSVar :: (Located loc) => loc -> Name -> TermM SVar newSVar _loc desc = do i <- incCounter newID $ mkTypeVarName desc i -asStructType :: (Monoid u) => SrcLoc -> TypeBase SComp u -> TermM (TypeBase Size u) +asStructType :: SrcLoc -> TypeBase SComp u -> TermM (TypeBase Size u) asStructType _ (Scalar (Prim pt)) = pure $ Scalar $ Prim pt asStructType _ (Scalar (TypeVar u v [])) = pure $ Scalar $ TypeVar u v [] asStructType loc (Scalar (Arrow u pname d t1 (RetType ext t2))) = do t1' <- asStructType loc t1 t2' <- asStructType loc t2 pure $ Scalar $ Arrow u pname d t1' $ RetType ext t2' -asStructType loc t = do - t' <- newType loc "artificial" - ctEq (toType t' `setUniqueness` NoUniqueness) (t `setUniqueness` NoUniqueness) +asStructType loc (Scalar (Record fs)) = + Scalar . Record <$> traverse (asStructType loc) fs +asStructType loc (Scalar (Sum cs)) = + Scalar . Sum <$> traverse (mapM (asStructType loc)) cs +asStructType loc t@(Scalar (TypeVar u _ _)) = do + t' <- newType loc "artificial" u + ctEq (toType t') t + pure t' +asStructType loc t@(Array u _ _) = do + t' <- newType loc "artificial" u + ctEq (toType t') t pure t' addCt :: Ct -> TermM () @@ -370,7 +378,7 @@ instTypeScheme _qn loc tparams t = do case tparam of TypeParamType _ v _ -> do v' <- newTyVar loc $ nameFromString $ takeWhile isAscii $ baseString v - pure $ Just (v, (typeParamName tparam, Subst [] $ RetType [] $ tyVarType v')) + pure $ Just (v, (typeParamName tparam, Subst [] $ RetType [] $ tyVarType NoUniqueness v')) TypeParamDim {} -> pure Nothing let t' = applySubst (`lookup` substs) t @@ -398,11 +406,11 @@ lookupVar loc qn@(QualName qs name) = do asStructType loc $ first (const SDim) $ qualifyTypeVars outer_env tnames qs t' Just EqualityF -> do argtype <- newTypeOverloaded loc "t" anyPrimType - pure $ foldFunType [argtype, argtype] $ RetType [] $ Scalar $ Prim Bool + pure $ foldFunType [toParam Observe argtype, toParam Observe argtype] $ RetType [] $ Scalar $ Prim Bool Just (OverloadedF ts pts rt) -> do argtype <- newTypeOverloaded loc "t" ts let (pts', rt') = instOverloaded argtype pts rt - pure $ foldFunType pts' $ RetType [] $ toRes Nonunique rt' + pure $ foldFunType (map (toParam Observe) pts') $ RetType [] $ toRes Nonunique rt' where instOverloaded argtype pts rt = ( map (maybe argtype (Scalar . Prim)) pts, @@ -426,9 +434,9 @@ bind idents = localScope (`bindVars` idents) -- literals in patterns. patLitMkType :: PatLit -> SrcLoc -> TermM ParamType patLitMkType (PatLitInt _) loc = - newTypeOverloaded loc "t" anyNumberType + toParam Observe <$> newTypeOverloaded loc "t" anyNumberType patLitMkType (PatLitFloat _) loc = - newTypeOverloaded loc "t" anyFloatType + toParam Observe <$> newTypeOverloaded loc "t" anyFloatType patLitMkType (PatLitPrim v) _ = pure $ Scalar $ Prim $ primValueType v @@ -449,12 +457,12 @@ checkPat' (PatAttr attr p loc) t = checkPat' (Id name NoInfo loc) (Ascribed t) = pure $ Id name (Info t) loc checkPat' (Id name NoInfo loc) NoneInferred = do - t <- newType loc "t" + t <- newType loc "t" Observe pure $ Id name (Info t) loc checkPat' (Wildcard _ loc) (Ascribed t) = pure $ Wildcard (Info t) loc checkPat' (Wildcard NoInfo loc) NoneInferred = do - t <- newType loc "t" + t <- newType loc "t" Observe pure $ Wildcard (Info t) loc checkPat' (TuplePat ps loc) (Ascribed t) | Just ts <- isTupleRecord t, @@ -463,7 +471,7 @@ checkPat' (TuplePat ps loc) (Ascribed t) <$> zipWithM checkPat' ps (map Ascribed ts) <*> pure loc | otherwise = do - ps_t :: [ParamType] <- replicateM (length ps) (newType loc "t") + ps_t <- replicateM (length ps) (newType loc "t" Observe) ctEq (toType (Scalar (tupleRecord ps_t))) (toType t) TuplePat <$> zipWithM checkPat' ps (map Ascribed ps_t) <*> pure loc checkPat' (TuplePat ps loc) NoneInferred = @@ -473,7 +481,7 @@ checkPat' p@(RecordPat p_fs loc) (Ascribed t) L.sort (map fst p_fs) == L.sort (M.keys t_fs) = RecordPat . M.toList <$> check t_fs <*> pure loc | otherwise = do - p_fs' :: M.Map Name Type <- traverse (const $ newType loc "t") $ M.fromList p_fs + p_fs' <- traverse (const $ newType loc "t" NoUniqueness) $ M.fromList p_fs ctEq (Scalar (Record p_fs')) $ toType t st <- asStructType loc $ Scalar (Record p_fs') checkPat' p $ Ascribed $ toParam Observe st @@ -526,15 +534,15 @@ checkPat' (PatConstr n NoInfo ps loc) (Ascribed (Scalar (Sum cs))) pure $ PatConstr n (Info (Scalar (Sum cs))) ps' loc checkPat' (PatConstr n NoInfo ps loc) (Ascribed t) = do ps' <- forM ps $ \p -> do - p_t <- newType (srclocOf p) "t" + p_t <- newType (srclocOf p) "t" Observe checkPat' p $ Ascribed p_t - t' <- newTypeWithConstr loc "t" n $ map (toType . patternType) ps' + t' <- newTypeWithConstr loc "t" Observe n $ map (toType . patternType) ps' ctEq t' (toType t) t'' <- asStructType loc t' pure $ PatConstr n (Info $ toParam Observe t'') ps' loc checkPat' (PatConstr n NoInfo ps loc) NoneInferred = do ps' <- mapM (`checkPat'` NoneInferred) ps - t <- newTypeWithConstr loc "t" n $ map (toType . patternType) ps' + t <- newTypeWithConstr loc "t" Observe n $ map (toType . patternType) ps' t' <- asStructType loc t pure $ PatConstr n (Info $ toParam Observe t') ps' loc @@ -649,10 +657,10 @@ checkApplyOne loc _ (fframe, ftype) (argframe, argtype) = do split (Scalar (Arrow _ _ _ a (RetType _ b))) = pure (a, b `setUniqueness` NoUniqueness) split ftype' = do - a <- newType loc "arg" - b <- newTyVar loc "res" - ctEq ftype' $ Scalar $ Arrow NoUniqueness Unnamed Observe a $ RetType [] $ tyVarType b - pure (a, tyVarType b) + a <- newType loc "arg" NoUniqueness + b <- newType loc "res" Nonunique + ctEq ftype' $ Scalar $ Arrow NoUniqueness Unnamed Observe a $ RetType [] b + pure (a, b `setUniqueness` NoUniqueness) -- To be removed (probably) checkApply_ :: SrcLoc -> (Maybe (QualName VName), Int) -> Type -> Shape Size -> Exp -> TermM (Type, AutoMap) @@ -685,10 +693,10 @@ checkApply_ loc _ ftype fframe arg = do split (Scalar (Arrow _ _ _ a (RetType _ b))) = pure (a, b `setUniqueness` NoUniqueness) split ftype' = do - a <- newType loc "arg" - b <- newTyVar loc "res" - ctEq ftype' $ Scalar $ Arrow NoUniqueness Unnamed Observe a $ RetType [] $ tyVarType b - pure (a, tyVarType b) + a <- newType loc "arg" NoUniqueness + b <- newType loc "res" Nonunique + ctEq ftype' $ Scalar $ Arrow NoUniqueness Unnamed Observe a $ RetType [] b + pure (a, b `setUniqueness` NoUniqueness) checkSlice :: SliceBase NoInfo VName -> TermM [DimIndex] checkSlice = mapM checkDimIndex @@ -714,8 +722,8 @@ mustHaveFields loc t [f] ve_t = do rt :: Type <- newTypeWithField loc "ft" f ve_t ctEq t rt mustHaveFields loc t (f : fs) ve_t = do - ft :: Type <- newType loc "ft" - rt :: Type <- newTypeWithField loc "rt" f ft + ft <- newType loc "ft" NoUniqueness + rt <- newTypeWithField loc "rt" f ft mustHaveFields loc ft fs ve_t ctEq t rt @@ -796,7 +804,7 @@ checkExp (Not arg loc) = do arg' <- require "logical negation" (Bool : anyIntType) =<< checkExp arg pure $ Not arg' loc checkExp (Hole NoInfo loc) = - Hole <$> (Info <$> newType loc "hole") <*> pure loc + Hole <$> (Info <$> newType loc "hole" NoUniqueness) <*> pure loc checkExp (Parens e loc) = Parens <$> checkExp e <*> pure loc checkExp (TupLit es loc) = @@ -826,7 +834,7 @@ checkExp (ArrayLit es _ loc) = do -- type variables for pathologically large arrays with -- type-unsuffixed integers. Add some special case that handles that -- more efficiently. - et <- newType loc "et" + et <- newType loc "et" NoUniqueness es' <- forM es $ \e -> do e' <- checkExp e ctEq (expType e') (toType et) @@ -868,7 +876,7 @@ checkExp (Assert e1 e2 NoInfo loc) = do -- checkExp (Constr name es NoInfo loc) = do es' <- mapM checkExp es - t <- newTypeWithConstr loc "t" name $ map expType es' + t <- newTypeWithConstr loc "t" NoUniqueness name $ map expType es' pure $ Constr name es' (Info t) loc -- checkExp (AppExp (Apply fe args loc) NoInfo) = do @@ -910,7 +918,7 @@ checkExp (AppExp (BinOp (op, oploc) NoInfo (e1, _) (e2, _) loc) NoInfo) = do checkExp (OpSectionLeft op _ e _ _ loc) = do optype <- lookupVar loc op e' <- checkExp e - t2 <- newType loc "t" + t2 <- newType loc "t" NoUniqueness t2' <- asStructType loc t2 let t1 = typeOf e' f1 = frameOf e' @@ -928,7 +936,7 @@ checkExp (OpSectionLeft op _ e _ _ loc) = do checkExp (OpSectionRight op _ e _ NoInfo loc) = do optype <- lookupVar loc op e' <- checkExp e - t1 <- newType loc "t" + t1 <- newType loc "t" NoUniqueness t1' <- asStructType loc t1 let t2 = typeOf e' f2 = frameOf e' @@ -946,8 +954,8 @@ checkExp (OpSectionRight op _ e _ NoInfo loc) = do loc -- checkExp (ProjectSection fields NoInfo loc) = do - a <- newType loc "a" - b <- newType loc "b" + a <- newType loc "a" NoUniqueness + b <- newType loc "b" NoUniqueness mustHaveFields loc (toType a) fields (toType b) let ft = Scalar $ Arrow mempty Unnamed Observe a $ RetType [] $ b `setUniqueness` Nonunique pure $ ProjectSection fields (Info ft) loc @@ -1002,14 +1010,14 @@ checkExp (AppExp (Range start maybe_step end loc) _) = do pure e' maybe_step' <- traverse check maybe_step end' <- traverse check end - range_t <- newTyVar loc "range" - ctEq (tyVarType range_t :: Type) (arrayOfRank 1 (expType start')) - pure $ AppExp (Range start' maybe_step' end' loc) $ Info $ AppRes (tyVarType range_t) [] + range_t <- newType loc "range" NoUniqueness + ctEq (toType range_t) (arrayOfRank 1 (expType start')) + pure $ AppExp (Range start' maybe_step' end' loc) $ Info $ AppRes range_t [] -- checkExp (Project k e NoInfo loc) = do e' <- checkExp e - kt <- newType loc "kt" - t :: Type <- newTypeWithField loc "t" k kt + kt <- newType loc "kt" NoUniqueness + t <- newTypeWithField loc "t" k kt ctEq (expType e') t kt' <- asStructType loc kt pure $ Project k e' (Info kt') loc @@ -1022,9 +1030,9 @@ checkExp (RecordUpdate src fields ve NoInfo loc) = do -- checkExp (IndexSection slice NoInfo loc) = do slice' <- checkSlice slice - index_arg_t <- newType loc "index" - index_elem_t <- newType loc "index_elem" - index_res_t :: Type <- newType loc "index_res" + index_arg_t <- newType loc "index" NoUniqueness + index_elem_t <- newType loc "index_elem" NoUniqueness + index_res_t <- newType loc "index_res" NoUniqueness let num_slices = length $ filter isSlice slice ctEq (toType index_arg_t) $ arrayOfRank num_slices index_elem_t ctEq index_res_t $ arrayOfRank (length slice) index_elem_t @@ -1035,8 +1043,8 @@ checkExp (IndexSection slice NoInfo loc) = do checkExp (AppExp (Index e slice loc) _) = do e' <- checkExp e slice' <- checkSlice slice - index_t <- newType loc "index" - index_elem_t <- newType loc "index_elem" + index_t <- newType loc "index" NoUniqueness + index_elem_t <- newType loc "index_elem" NoUniqueness let num_slices = length $ filter isSlice slice ctEq (toType index_t) $ arrayOfRank num_slices index_elem_t ctEq (expType e') $ arrayOfRank (length slice) index_elem_t @@ -1047,7 +1055,7 @@ checkExp (Update src slice ve loc) = do slice' <- checkSlice slice ve' <- checkExp ve let num_slices = length $ filter isSlice slice - update_elem_t <- newType loc "update_elem" + update_elem_t <- newType loc "update_elem" NoUniqueness ctEq (expType src') $ arrayOfRank (length slice) update_elem_t ctEq (expType ve') $ arrayOfRank num_slices update_elem_t pure $ Update src' slice' ve' loc @@ -1059,7 +1067,7 @@ checkExp (AppExp (LetWith dest src slice ve body loc) _) = do slice' <- checkSlice slice ve' <- checkExp ve let num_slices = length $ filter isSlice slice - update_elem_t <- newType loc "update_elem" + update_elem_t <- newType loc "update_elem" NoUniqueness ctEq (toType src_t) $ arrayOfRank (length slice) update_elem_t ctEq (expType ve') $ arrayOfRank num_slices update_elem_t bind [dest'] $ do @@ -1100,7 +1108,7 @@ checkExp (AppExp (Loop _ pat arg form body loc) _) = do pure (While cond', body') ForIn elemp arr -> do arr' <- checkExp arr - elem_t <- newType elemp "elem" + elem_t <- newType elemp "elem" NoUniqueness ctEq (expType arr') $ arrayOfRank 1 (toType elem_t) bindLetPat elemp elem_t $ \elemp' -> do body' <- checkExp body From 3212b972d4dd569cce24cec3e732f226569a456d Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Thu, 22 Feb 2024 14:35:10 -0800 Subject: [PATCH 100/258] Hack to detect no integer solutions from `glpk`. --- futhark.cabal | 1 + src/Futhark/Solve/GLPK.hs | 18 ++++++++++++------ 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/futhark.cabal b/futhark.cabal index 8cf87483f3..8c00ce32bc 100644 --- a/futhark.cabal +++ b/futhark.cabal @@ -497,6 +497,7 @@ library -- remove me later , process , glpk-hs + , silently executable futhark import: common diff --git a/src/Futhark/Solve/GLPK.hs b/src/Futhark/Solve/GLPK.hs index fe7ac5d129..b2d340d683 100644 --- a/src/Futhark/Solve/GLPK.hs +++ b/src/Futhark/Solve/GLPK.hs @@ -1,12 +1,14 @@ module Futhark.Solve.GLPK (glpk) where +import Control.Monad import Data.LinearProgram import Data.Map qualified as M import Data.Maybe import Data.Set qualified as S import Futhark.Solve.LP qualified as F +import System.IO.Silently -linearProgToGLPK :: (Show v, Ord v, Eq a, Num a, Group a) => F.LinearProg v a -> (LP v a) +linearProgToGLPK :: (Ord v, Eq a, Num a) => F.LinearProg v a -> (LP v a) linearProgToGLPK prog = LP { direction = cOptType $ F.optType prog, @@ -38,12 +40,16 @@ linearProgToGLPK prog = varList = S.toList $ F.vars prog -glpk :: - (Show v, Ord v, Show a, Eq a, Real a, Group a) => - F.LinearProg v a -> - IO (Maybe (Int, M.Map v Int)) +glpk :: (Ord v, Real a) => F.LinearProg v a -> IO (Maybe (Int, M.Map v Int)) glpk lp = do + (output, res) <- capture $ glpk' lp + pure $ do + guard $ "PROBLEM HAS NO INTEGER FEASIBLE SOLUTION" `notElem` lines output + res + +glpk' :: (Ord v, Real a) => F.LinearProg v a -> IO (Maybe (Int, M.Map v Int)) +glpk' lp = do (_, mres) <- glpSolveVars opts $ linearProgToGLPK lp pure $ (\(opt, vs) -> (truncate opt, fmap truncate vs)) <$> mres where - opts = mipDefaults {msgLev = MsgOff} + opts = mipDefaults {msgLev = MsgAll} From eb49178500a27cf98e5c9cab4707d6a7ec67eaa9 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Thu, 22 Feb 2024 16:45:17 -0800 Subject: [PATCH 101/258] Add some notes to clarify how this will actually work. --- src/Futhark/Internalise/Exps.hs | 92 +++++++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index c17db733ba..09285aacb2 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -888,6 +888,98 @@ withAutoMap_ ams arg_desc res_t args_e innerM = mapM_ addStms $ reverse stms innerM args +-- | Internalization of 'AutoMap'-annotated applications. +-- +-- Each application @f x@ has an annotation with @AutoMap R M F@ where +-- @R, M, F@ are the autorep, automap, and frame shapes, +-- respectively. +-- +-- The application @f x@ will have type @F t@ for some @t@, i.e. @(f +-- x) : F t@. The frame @F@ is a prefix of the type of @f x@; namely +-- it is the total accumulated shape that is due to implicit maps. +-- Another way of thinking about that is that @|F|@ is is the level +-- of the automap-nest that @f x@ is in. For example, if @|F| = 2@ +-- then we know that @f x@ implicitly stands for +-- +-- > map (\x' -> map (\x'' -> f x'') x') x +-- +-- For an application with a non-empty autorep annotation, the frame +-- tells about how many dimensions of the replicate can be eliminated. +-- For example, @[[1,2],[3,4]] + 5@ will yield the following annotations: +-- +-- > ([[1,2],[3,4]] +) -- AutoMap {R = mempty, M = [2][2], F = [2][2]} +-- > (([[1,2],[3,4]] +) 5) -- AutoMap {R = [2][2], M = mempty, F = [2][2]} +-- +-- All replicated arguments are pushed down the auto-map nest. Each +-- time a replicated argument is pushed down a level of an +-- automap-nest, one fewer replicates is needed (i.e., the outermost +-- dimension of @R@ can be dropped). Replicated arguments are pushed +-- down the nest until either 1) the bottom of the nest is encountered +-- or 2) no replicate dimensions remain. For example, in the second +-- application above @R@ = @F@, so we can push the replicated argument +-- down two levels. Since each level effectively removes a dimension +-- of the replicate, no replicates will be required: +-- +-- > map (\xs -> map (\x -> f x'' 5) xs) [[1,2],[3,4]] +-- +-- The number of replicates that are actually required is given by +-- max(|R| - |F|, 0). +-- +-- An expression's "true level" is the level at which that expression +-- will appear in the automap-nest. The bottom of a mapnest is level 0. +-- +-- * For annotations with @R = mempty@, the true level is @|F|@. +-- * For annotations with @M = mempty@, the true level is @|F| - |R|@. +-- +-- If @|R| > |F|@ then actual replicates (namely @|R| - |F|@ of them) +-- will be required at the bottom of the mapnest. +-- +-- Note that replicates can only appear at the bottom of a mapnest; any +-- expression of the form +-- +-- > map (\ls x' rs -> e) (replicate x) +-- +-- can always be written as +-- +-- > map (\ls rs -> e[x' -> x]) +-- +-- Let's look at another example. Consider (with exact sizes omitted for brevity) +-- +-- > f : a -> a -> a -> []a -> [][][]a -> a +-- > xss : [][]a +-- > ys : []a +-- > zsss : [][][]a +-- > w : a +-- > vss : [][]a +-- +-- and the application +-- +-- > f xss ys zsss w vss +-- +-- which will have the following annotations +-- +-- > (f xss) -- AutoMap {R = mempty, M = [][], F = [][]} (1) +-- > ((f xss) ys) -- AutoMap {R = [], M = mempty, F = [][]} (2) +-- > (((f xss) ys) zsss) -- AutoMap {R = mempty, M = [], F = [][][]} (3) +-- > ((((f xss) ys) zsss) w) -- AutoMap {R = [][][][], M = mempty, F = [][][]} (4) +-- > (((((f xss) ys) zsss) w) vss) -- AutoMap {R = [], M = mempty, F = [][][]} (5) +-- +-- This will yield the following mapnest. +-- +-- > map (\zss -> +-- > map (\xs zs vs -> +-- > map (\x y z v -> f x y z (replicate w) v) xs ys zs v) xss zss vss) zsss +-- +-- Let's see how we'd construct this mapnest from the annotations. We construct +-- the nest bottom-up. We have: +-- +-- Application | True level +-- --------------------------- +-- (1) | |[][]| = 2 +-- (2) | |[][]| - |[]| = 1 +-- (3) | |[][][]| = 3 +-- (4) | |[][][]| - |[][][][]| = -1 +-- (5) | |[][][]| - |[]| = 2 withAutoMap :: [AutoMap] -> String -> StructType -> [(E.Exp, Maybe VName)] -> ([([SubExp], Stms SOACS)] -> InternaliseM [SubExp]) -> InternaliseM [SubExp] withAutoMap ams arg_desc res_t args_e innerM = do (args, stms) <- From 94af0909b06bbe9bb324154aa59ee8f5e06dc90d Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Thu, 22 Feb 2024 20:28:09 -0800 Subject: [PATCH 102/258] Bit more. --- src/Futhark/Internalise/Exps.hs | 38 +++++++++++++++++++++++++++++---- 1 file changed, 34 insertions(+), 4 deletions(-) diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index 09285aacb2..21620aa993 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -975,11 +975,41 @@ withAutoMap_ ams arg_desc res_t args_e innerM = -- -- Application | True level -- --------------------------- --- (1) | |[][]| = 2 --- (2) | |[][]| - |[]| = 1 --- (3) | |[][][]| = 3 +-- (1) | |[][]| = 2 +-- (2) | |[][]| - |[]| = 1 +-- (3) | |[][][]| = 3 -- (4) | |[][][]| - |[][][][]| = -1 --- (5) | |[][][]| - |[]| = 2 +-- (5) | |[][][]| - |[]| = 2 +-- +-- We start at level 0. +-- * Any argument with a negative true level of @-n@ will be replicated @n@ times; +-- the exact shapes can be found by removing the @F@ postfix from @R@, +-- i.e. @R = shapes_to_rep_by <> F@. +-- * Any argument with a 0 true level will be included. +-- * For any argument @arg@ with a positive true level, we construct a new parameter +-- whose type is @arg@ with the leading @n@ dimensions (where @n@ is the true level) +-- removed. +-- +-- Following the rules above, @w@ will be replicated once. For the remaining arguments, +-- we create new parameters @x : a, y : a, z : a , v : a@. Hence, level 0 becomes +-- +-- > f x y z (replicate w) v +-- +-- At level l > 0: +-- * There are no replicates. +-- * Any argument with l true level will be included verbatim. +-- * Any argument with true level > l will have a new parameter constructed for it, +-- whose type has the leading @n - l@ dimensions (where @n@ is the true level) removed. +-- * We surround the previous level with a map that binds that levels' new parameters +-- and is passed the current levels' arguments. +-- +-- Following the above recipe for level 1, we create parameters +-- @xs : []a, zs : []a, vs :[]a@ and obtain +-- +-- > map (\x y z v -> f x y z (replicate w) v) xs ys zs vs +-- +-- This process continues until the level is greater than the maximum +-- true level of any application, at which we terminate. withAutoMap :: [AutoMap] -> String -> StructType -> [(E.Exp, Maybe VName)] -> ([([SubExp], Stms SOACS)] -> InternaliseM [SubExp]) -> InternaliseM [SubExp] withAutoMap ams arg_desc res_t args_e innerM = do (args, stms) <- From a7c8dd9fff8f00e63125b6c40d9de60f8064e80e Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 23 Feb 2024 11:34:44 +0100 Subject: [PATCH 103/258] Unnecessary warnings. --- src/Language/Futhark/TypeChecker/Rank.hs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 1eef479dcf..6c2504ec3e 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -186,7 +186,7 @@ solveRankILP loc prog = do [] -> typeError loc mempty "Rank ILP cannot be solved." rs -> do traceM "## rank maps" - forM (zip [0 :: Int ..] rs) $ \(i, r) -> + forM_ (zip [0 :: Int ..] rs) $ \(i, r) -> traceM $ unlines $ "\n## rank map " <> prettyString i @@ -198,7 +198,7 @@ rankAnalysis _ [] tyVars body = pure [(([], tyVars), body)] rankAnalysis loc cs tyVars body = do rank_maps <- solveRankILP loc (mkLinearProg (foldMap splitFuncs cs) tyVars) cts_tyvars' <- mapM (substRankInfo cs tyVars) rank_maps - let bodys = map (flip updAM body) rank_maps + let bodys = map (`updAM` body) rank_maps pure $ zip cts_tyvars' bodys where splitFuncs @@ -335,7 +335,7 @@ updAM rank_map e = args' = fmap ( bimap - (fmap $ bimap id upd) + (fmap $ second upd) (updAM rank_map) ) args From 86acaed8db1f914d6318a5c79b4f37326248fd0b Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 23 Feb 2024 11:38:46 +0100 Subject: [PATCH 104/258] Slices better be i64. --- src/Language/Futhark/TypeChecker/Terms2.hs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index d9e5930b71..024975b100 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -358,6 +358,9 @@ arrayOfRank :: Int -> Type -> Type arrayOfRank n = arrayOf $ Shape $ replicate n SDim require :: T.Text -> [PrimType] -> Exp -> TermM Exp +require _why [pt] e = do + ctEq (Scalar $ Prim pt) (expType e) + pure e require _why pts e = do t :: Type <- newTypeOverloaded (srclocOf e) "t" pts ctEq t $ expType e @@ -702,11 +705,11 @@ checkSlice :: SliceBase NoInfo VName -> TermM [DimIndex] checkSlice = mapM checkDimIndex where checkDimIndex (DimFix i) = - DimFix <$> check i + DimFix <$> (require "use as index" anySignedType =<< checkExp i) checkDimIndex (DimSlice i j s) = DimSlice <$> traverse check i <*> traverse check j <*> traverse check s - check = require "use as index" anySignedType <=< checkExp + check = require "use in slice" [Signed Int64] <=< checkExp isSlice :: DimIndexBase f vn -> Bool isSlice DimSlice {} = True From 92be29dca09a2dd295e084e9b0b9ca369d5c5a6e Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 23 Feb 2024 14:28:12 +0100 Subject: [PATCH 105/258] Implement occurs check. --- .../Futhark/TypeChecker/Constraints.hs | 43 ++++++++++++------- src/Language/Futhark/TypeChecker/Rank.hs | 1 + tests/{issue1599.fut => types/occurs.fut} | 1 + 3 files changed, 30 insertions(+), 15 deletions(-) rename tests/{issue1599.fut => types/occurs.fut} (53%) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index ddaedd4025..ddb8e26e2c 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -16,12 +16,11 @@ where import Control.Monad.Except import Control.Monad.State import Data.Bifunctor -import Data.List qualified as L import Data.Map qualified as M import Data.Maybe +import Data.Set qualified as S import Data.Text qualified as T import Debug.Trace -import Futhark.IR.Pretty import Futhark.Util.Pretty import Language.Futhark @@ -38,7 +37,7 @@ data SComp instance Pretty SComp where pretty SDim = "[]" - pretty (SVar x) = brackets $ pretty x + pretty (SVar x) = brackets $ prettyName x instance Pretty (Shape SComp) where pretty = mconcat . map pretty . shapeDims @@ -59,7 +58,7 @@ data Ct instance Pretty Ct where pretty (CtEq t1 t2) = pretty t1 <+> "~" <+> pretty t2 - pretty (CtAM r m) = pretty r <+> "=" <+> "•" <+> "∨" <+> pretty m <+> "=" <+> "•" + pretty (CtAM r m) = prettyName r <+> "=" <+> "•" <+> "∨" <+> prettyName m <+> "=" <+> "•" type Constraints = [Ct] @@ -142,12 +141,25 @@ solution s = newtype SolveM a = SolveM {runSolveM :: StateT SolverState (Except T.Text) a} deriving (Functor, Applicative, Monad, MonadState SolverState, MonadError T.Text) +occursCheck :: VName -> Type -> SolveM () +occursCheck v tp = do + vars <- gets solverTyVars + let tp' = substTyVars vars tp + when (v `S.member` typeVars tp') . throwError . docText $ + "Occurs check: cannot instantiate" + <+> prettyName v + <+> "with" + <+> pretty tp + <> "." + subTyVar :: VName -> Int -> Type -> SolveM () -subTyVar v lvl t = +subTyVar v lvl t = do + occursCheck v t modify $ \s -> s {solverTyVars = M.insert v (TyVarSol lvl t) $ solverTyVars s} linkTyVar :: VName -> VName -> SolveM () -linkTyVar v t = +linkTyVar v t = do + occursCheck v $ Scalar $ TypeVar NoUniqueness (qualName t) [] modify $ \s -> s {solverTyVars = M.insert v (TyVarLink t) $ solverTyVars s} -- Unify at the root, emitting new equalities that must hold. @@ -164,11 +176,11 @@ unify (Scalar (Record fs1)) (Scalar (Record fs2)) Just $ M.elems $ M.intersectionWith (,) fs1 fs2 unify (Scalar (Sum cs1)) (Scalar (Sum cs2)) | M.keys cs1 == M.keys cs2 = - fmap concat - . forM (M.elems $ M.intersectionWith (,) cs1 cs2) - $ \(ts1, ts2) -> do - guard $ length ts1 == length ts2 - Just $ zip ts1 ts2 + fmap concat . forM cs' $ \(ts1, ts2) -> do + guard $ length ts1 == length ts2 + Just $ zip ts1 ts2 + where + cs' = M.elems $ M.intersectionWith (,) cs1 cs2 unify t1 t2 | Just t1' <- peelArray 1 t1, Just t2' <- peelArray 1 t2 = @@ -224,9 +236,10 @@ solveCt ct = solve :: Constraints -> TyVars -> Either T.Text ([VName], Solution) solve constraints tyvars = - second solution - . runExcept - . flip execStateT (initialState tyvars) - . runSolveM + trace (unlines (map prettyString constraints)) + $ second solution + . runExcept + . flip execStateT (initialState tyvars) + . runSolveM $ mapM solveCt constraints {-# NOINLINE solve #-} diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 6c2504ec3e..fd9904998b 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -9,6 +9,7 @@ import Data.Map (Map) import Data.Map qualified as M import Data.Maybe import Debug.Trace +import Futhark.IR.Pretty () import Futhark.Solve.GLPK import Futhark.Solve.LP hiding (Constraint, LSum, LinearProg) import Futhark.Solve.LP qualified as LP diff --git a/tests/issue1599.fut b/tests/types/occurs.fut similarity index 53% rename from tests/issue1599.fut rename to tests/types/occurs.fut index 3ce47c38b1..c37b1448c4 100644 --- a/tests/issue1599.fut +++ b/tests/types/occurs.fut @@ -1,3 +1,4 @@ +-- Simple instance of an occurs check. -- == -- error: Occurs From aec83f333274151c9ee15b10b5948eeda802a1d3 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 23 Feb 2024 14:28:34 +0100 Subject: [PATCH 106/258] This is too much. --- src/Language/Futhark/TypeChecker/Constraints.hs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index ddb8e26e2c..9565804534 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -236,10 +236,9 @@ solveCt ct = solve :: Constraints -> TyVars -> Either T.Text ([VName], Solution) solve constraints tyvars = - trace (unlines (map prettyString constraints)) - $ second solution - . runExcept - . flip execStateT (initialState tyvars) - . runSolveM + second solution + . runExcept + . flip execStateT (initialState tyvars) + . runSolveM $ mapM solveCt constraints {-# NOINLINE solve #-} From df35808a604bbb524a3e2f7b8fea0aef826f6c7c Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 23 Feb 2024 14:39:47 +0100 Subject: [PATCH 107/258] Add Futhark.Util.debugTraceM. --- src/Futhark/Internalise/Defunctionalise.hs | 11 +++++------ src/Futhark/Internalise/Exps.hs | 1 - src/Futhark/Util.hs | 8 ++++++++ src/Language/Futhark/TypeChecker.hs | 4 ++-- src/Language/Futhark/TypeChecker/Constraints.hs | 1 - src/Language/Futhark/TypeChecker/Rank.hs | 8 ++++---- src/Language/Futhark/TypeChecker/Terms.hs | 1 - src/Language/Futhark/TypeChecker/Terms2.hs | 9 ++++----- 8 files changed, 23 insertions(+), 20 deletions(-) diff --git a/src/Futhark/Internalise/Defunctionalise.hs b/src/Futhark/Internalise/Defunctionalise.hs index 249616ef67..98b698aeeb 100644 --- a/src/Futhark/Internalise/Defunctionalise.hs +++ b/src/Futhark/Internalise/Defunctionalise.hs @@ -14,10 +14,9 @@ import Data.List.NonEmpty qualified as NE import Data.Map.Strict qualified as M import Data.Maybe import Data.Set qualified as S -import Debug.Trace import Futhark.IR.Pretty () import Futhark.MonadFreshNames -import Futhark.Util (mapAccumLM, nubOrd) +import Futhark.Util (debugTraceM, mapAccumLM, nubOrd) import Language.Futhark import Language.Futhark.Traversals import Language.Futhark.TypeChecker.Types (Subst (..), applySubst) @@ -956,7 +955,7 @@ defuncApplyArg fname_s (f', LambdaSV pat lam_e_t lam_e closure_env) (((argext, a fname' = Var (qualName fname) (Info fname_t) (srclocOf arg) callret <- unRetType lifted_rettype - traceM $ + debugTraceM $ unlines [ "##defuncApplyArg LambdaSV", "## fname", @@ -985,7 +984,7 @@ defuncApplyArg _ (f', DynamicFun _ sv) (((argext, am), arg), argtypes) = do restype = foldFunType argtypes' (RetType [] rettype) callret = AppRes restype [] apply_e = mkApply f' [(argext, am, arg')] callret - traceM $ + debugTraceM $ unlines [ "##defuncApplyArg DynamicFun", "## f'", @@ -1046,7 +1045,7 @@ defuncApply f args appres loc = do -- ret_am = maximumBy (\x y -> shapeRank x `compare` shapeRank y) am_dims ams = NE.toList $ autoMap . snd . fst <$> args ret_am = maximumBy (\x y -> shapeRank x `compare` shapeRank y) ams - traceM $ + debugTraceM $ unlines [ "## defuncApply", "## f", @@ -1062,7 +1061,7 @@ defuncApply f args appres loc = do "## f type", prettyString $ typeOf f, "## arg types", - prettyString $ (typeOf . snd) <$> args, + prettyString $ typeOf . snd <$> args, "## ret_am", prettyString ret_am ] diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index 21620aa993..5f933e659c 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -16,7 +16,6 @@ import Data.List.NonEmpty qualified as NE import Data.Map.Strict qualified as M import Data.Set qualified as S import Data.Text qualified as T -import Debug.Trace import Futhark.IR.SOACS as I hiding (stmPat) import Futhark.Internalise.AccurateSizes import Futhark.Internalise.Bindings diff --git a/src/Futhark/Util.hs b/src/Futhark/Util.hs index 76fce3b4af..6d41b1c7f2 100644 --- a/src/Futhark/Util.hs +++ b/src/Futhark/Util.hs @@ -51,6 +51,7 @@ module Futhark.Util fixPoint, concatMapM, topologicalSort, + debugTraceM, ) where @@ -77,6 +78,7 @@ import Data.Text.Encoding qualified as T import Data.Text.Encoding.Error qualified as T import Data.Time.Clock (UTCTime, getCurrentTime) import Data.Tuple (swap) +import Debug.Trace import Numeric import System.Directory.Tree qualified as Dir import System.Environment @@ -507,3 +509,9 @@ topologicalSort dep nodes = modify $ second $ IM.insert i True mapM_ sorting $ mapMaybe (depends_of node) nodes_idx modify $ bimap (node :) (IM.insert i False) + +-- | 'traceM', but only if @FUTHARK_COMPILER_DEBUGGING@ is set to at least 1. +debugTraceM :: (Monad m) => String -> m () +debugTraceM + | isEnvVarAtLeast "FUTHARK_COMPILER_DEBUGGING" 1 = traceM + | otherwise = const $ pure () diff --git a/src/Language/Futhark/TypeChecker.hs b/src/Language/Futhark/TypeChecker.hs index 7ce81ecca5..2f3371d02b 100644 --- a/src/Language/Futhark/TypeChecker.hs +++ b/src/Language/Futhark/TypeChecker.hs @@ -26,8 +26,8 @@ import Data.Map.Strict qualified as M import Data.Maybe import Data.Ord import Data.Set qualified as S -import Debug.Trace import Futhark.FreshNames hiding (newName) +import Futhark.Util (debugTraceM) import Futhark.Util.Pretty hiding (space) import Language.Futhark import Language.Futhark.Semantic @@ -715,7 +715,7 @@ checkValBind vb = do let vb' = ValBind entry' fname maybe_tdecl' (Info rettype) tparams' params' body' doc attrs' loc - traceM $ unlines ["# Inferred:", prettyString vb'] + debugTraceM $ unlines ["# Inferred:", prettyString vb'] pure ( mempty diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 9565804534..cf5a4f2dbc 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -20,7 +20,6 @@ import Data.Map qualified as M import Data.Maybe import Data.Set qualified as S import Data.Text qualified as T -import Debug.Trace import Futhark.Util.Pretty import Language.Futhark diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index fd9904998b..83815ec70e 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -8,11 +8,11 @@ import Data.List qualified as L import Data.Map (Map) import Data.Map qualified as M import Data.Maybe -import Debug.Trace import Futhark.IR.Pretty () import Futhark.Solve.GLPK import Futhark.Solve.LP hiding (Constraint, LSum, LinearProg) import Futhark.Solve.LP qualified as LP +import Futhark.Util (debugTraceM) import Language.Futhark hiding (ScalarType) import Language.Futhark.Traversals import Language.Futhark.TypeChecker.Constraints @@ -178,7 +178,7 @@ enumerateRankSols prog = solveRankILP :: (MonadTypeChecker m) => SrcLoc -> LinearProg -> m [Map VName Int] solveRankILP loc prog = do - traceM $ + debugTraceM $ unlines [ "## solveRankILP", prettyString prog @@ -186,9 +186,9 @@ solveRankILP loc prog = do case enumerateRankSols prog of [] -> typeError loc mempty "Rank ILP cannot be solved." rs -> do - traceM "## rank maps" + debugTraceM "## rank maps" forM_ (zip [0 :: Int ..] rs) $ \(i, r) -> - traceM $ + debugTraceM $ unlines $ "\n## rank map " <> prettyString i : map prettyString (M.toList r) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 8c48591696..76578151cb 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -27,7 +27,6 @@ import Data.List.NonEmpty qualified as NE import Data.Map.Strict qualified as M import Data.Maybe import Data.Set qualified as S -import Debug.Trace import Futhark.Util (mapAccumLM, nubOrd, topologicalSort) import Futhark.Util.Pretty hiding (space) import Language.Futhark diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 024975b100..43ef74251c 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -55,10 +55,9 @@ import Data.Map qualified as M import Data.Maybe import Data.Set qualified as S import Data.Text qualified as T -import Debug.Trace import Futhark.FreshNames qualified as FreshNames import Futhark.MonadFreshNames hiding (newName) -import Futhark.Util (mapAccumLM) +import Futhark.Util (debugTraceM, mapAccumLM) import Futhark.Util.Pretty import Language.Futhark import Language.Futhark.TypeChecker.Constraints @@ -1180,9 +1179,9 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do tyvars <- gets termTyVars - traceM $ "\n# function " <> prettyNameString fname <> "\n# " <> locStr loc <> "\n" + debugTraceM $ "\n# function " <> prettyNameString fname <> "\n# " <> locStr loc <> "\n" - traceM $ + debugTraceM $ unlines [ "## cts:", unlines $ map prettyString cts @@ -1196,7 +1195,7 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do . uncurry solve forM_ (zip solutions cts_tyvars') $ \(solution, (cts', tyvars')) -> - traceM $ + debugTraceM $ unlines [ "## constraints:", unlines $ map prettyString cts', From 02c71b7c9f500ffac7e5fc2cbaa2f9a6e5499f03 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Fri, 23 Feb 2024 06:04:13 -0800 Subject: [PATCH 108/258] Prettyprint AUTOMAP annotations. --- src/Language/Futhark/Pretty.hs | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/Language/Futhark/Pretty.hs b/src/Language/Futhark/Pretty.hs index 81ca4a152f..886f69b066 100644 --- a/src/Language/Futhark/Pretty.hs +++ b/src/Language/Futhark/Pretty.hs @@ -311,11 +311,21 @@ prettyAppExp _ (If c t f _) = prettyAppExp p (Apply f args _) = parensIf (p >= 10) $ prettyExp 0 f - <+> hsep (map (prettyExp 10 . snd) $ NE.toList args) + <+> hsep (map prettyArg $ NE.toList args) + where + prettyArg (i, e) = + case unAnnot i of + Just (_, am) + | isEnvVarAtLeast "FUTHARK_COMPILER_DEBUGGING" 1 -> + parens (prettyExp 10 e <+> "Δ" <+> pretty am) + _ -> prettyExp 10 e instance (Eq vn, IsName vn, Annot f) => Pretty (AppExpBase f vn) where pretty = prettyAppExp (-1) +instance Pretty AutoMap where + pretty (AutoMap r m f) = encloseSep lparen rparen comma $ map pretty [r, m, f] + prettyInst :: (Annot f, Pretty t) => f t -> Doc a prettyInst t = case unAnnot t of From 2a64af32824306c56e3027837edbb50024c8c6db Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Fri, 23 Feb 2024 06:15:16 -0800 Subject: [PATCH 109/258] Also prettyprint binops. --- src/Language/Futhark/Pretty.hs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/Language/Futhark/Pretty.hs b/src/Language/Futhark/Pretty.hs index 886f69b066..8519b22060 100644 --- a/src/Language/Futhark/Pretty.hs +++ b/src/Language/Futhark/Pretty.hs @@ -234,7 +234,13 @@ letBody body@(AppExp LetFun {} _) = pretty body letBody body = "in" <+> align (pretty body) prettyAppExp :: (Eq vn, IsName vn, Annot f) => Int -> AppExpBase f vn -> Doc a -prettyAppExp p (BinOp (bop, _) _ (x, _) (y, _) _) = prettyBinOp p bop x y +prettyAppExp p (BinOp (bop, _) _ (x, xi) (y, yi) _) = + case (unAnnot xi, unAnnot yi) of + (Just (_, xam), Just (_, yam)) + | isEnvVarAtLeast "FUTHARK_COMPILER_DEBUGGING" 1 -> + -- fix + parens (prettyBinOp p bop x y <+> "Δ" <+> pretty xam <+> "Δ" <+> pretty yam) + _ -> prettyBinOp p bop x y prettyAppExp _ (Match e cs _) = "match" <+> pretty e (stack . map pretty) (NE.toList cs) prettyAppExp _ (Loop sizeparams pat initexp form loopbody _) = "loop" From d201507b14c3245500d0ec0a6b1a357e0e338e03 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 23 Feb 2024 16:00:18 +0100 Subject: [PATCH 110/258] Fix checkOneExp. --- src/Language/Futhark/TypeChecker/Terms.hs | 24 +++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 76578151cb..3a1e9fc663 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -1007,16 +1007,20 @@ checkApply _ _ _ _ _ = -- turn out to be polymorphic, in which case the list of type -- parameters will be non-empty. checkOneExp :: ExpBase NoInfo VName -> TypeM ([TypeParam], Exp) -checkOneExp e = runTermTypeM checkExp mempty $ do - e' <- checkExp $ undefined e - let t = typeOf e' - (tparams, _, _) <- - letGeneralise (nameFromString "") (srclocOf e) [] [] $ toRes Nonunique t - fixOverloadedTypes $ typeVars t - e'' <- updateTypes e' - localChecks e'' - causalityCheck e'' - pure (tparams, e'') +checkOneExp e = do + (maybe_tysubsts, e') <- Terms2.checkSingleExp e + case maybe_tysubsts of + Left err -> typeError e' mempty $ pretty err + Right (unconstrained, tysubsts) -> runTermTypeM checkExp tysubsts $ do + e'' <- checkExp e' + let t = typeOf e'' + (tparams, _, _) <- + letGeneralise (nameFromString "") (srclocOf e) [] [] $ toRes Nonunique t + fixOverloadedTypes $ typeVars t + e''' <- updateTypes e'' + localChecks e''' + causalityCheck e''' + pure (tparams, e''') -- | Type-check a single size expression in isolation. This expression may -- turn out to be polymorphic, in which case it is unified with i64. From 2119f2971270ac34511ec505006544caac4e2e3d Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Fri, 23 Feb 2024 07:01:42 -0800 Subject: [PATCH 111/258] Fix frame duplication. --- src/Language/Futhark/TypeChecker/Terms2.hs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 43ef74251c..40e3ae83a1 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -637,8 +637,9 @@ checkApplyOne loc _ (fframe, ftype) (argframe, argtype) = do let unit_info = Info $ Scalar $ Prim Bool r_var = Var (QualName [] r) unit_info mempty m_var = Var (QualName [] m) unit_info mempty - lhs = arrayOf (toShape (SVar r) <> (toSComp <$> argframe)) argtype + lhs = arrayOf (toShape (SVar r)) argtype rhs = arrayOf (toShape (SVar m) <> (toSComp <$> fframe)) a + ctAM r m ctEq lhs rhs pure @@ -673,7 +674,7 @@ checkApply_ loc _ ftype fframe arg = do let unit_info = Info $ Scalar $ Prim Bool r_var = Var (QualName [] r) unit_info mempty m_var = Var (QualName [] m) unit_info mempty - lhs = arrayOf (toShape (SVar r) <> (toSComp <$> frameOf arg)) $ toType $ typeOf arg + lhs = arrayOf (toShape (SVar r)) $ toType $ typeOf arg rhs = arrayOf (toShape (SVar m) <> (toSComp <$> fframe)) a ctAM r m ctEq lhs rhs From 81d53e5a0b7d768742acab432db6ccfc31a59d46 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Fri, 23 Feb 2024 07:46:47 -0800 Subject: [PATCH 112/258] Remove `checkApply_`. --- src/Language/Futhark/TypeChecker/Terms2.hs | 96 +++++++++------------- 1 file changed, 40 insertions(+), 56 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 40e3ae83a1..3f8b0df53b 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -614,14 +614,12 @@ bindParams tps orig_ps m = bindTypeParams tps $ do incLevel $ descend [] orig_ps -checkApply :: SrcLoc -> Maybe (QualName VName) -> (Shape Size, Type) -> [(Shape Size, Type)] -> TermM (StructType, [AutoMap]) +checkApply :: SrcLoc -> Maybe (QualName VName) -> (Shape Size, Type) -> NE.NonEmpty (Shape Size, Type) -> TermM (StructType, NE.NonEmpty AutoMap) checkApply loc fname (fframe, ftype) args = do ((_, _, rt), argts) <- mapAccumLM onArg (0, fframe, ftype) args rt' <- asStructType loc rt pure (rt', argts) where - -- pure (asStructType loc rt, argts) - onArg (i, f_f, f_t) (argframe, argtype) = do (rt, am) <- checkApplyOne loc (fname, i) (f_f, f_t) (argframe, argtype) pure @@ -665,42 +663,6 @@ checkApplyOne loc _ (fframe, ftype) (argframe, argtype) = do ctEq ftype' $ Scalar $ Arrow NoUniqueness Unnamed Observe a $ RetType [] b pure (a, b `setUniqueness` NoUniqueness) --- To be removed (probably) -checkApply_ :: SrcLoc -> (Maybe (QualName VName), Int) -> Type -> Shape Size -> Exp -> TermM (Type, AutoMap) -checkApply_ loc _ ftype fframe arg = do - (a, b) <- split $ stripFrame fframe ftype - r <- newSVar loc "R" - m <- newSVar loc "M" - let unit_info = Info $ Scalar $ Prim Bool - r_var = Var (QualName [] r) unit_info mempty - m_var = Var (QualName [] m) unit_info mempty - lhs = arrayOf (toShape (SVar r)) $ toType $ typeOf arg - rhs = arrayOf (toShape (SVar m) <> (toSComp <$> fframe)) a - ctAM r m - ctEq lhs rhs - pure - ( arrayOf (toShape (SVar m) <> (toSComp <$> fframe)) b, - AutoMap {autoRep = toShape r_var, autoMap = toShape m_var, autoFrame = toShape m_var <> fframe} - ) - where - stripFrame :: Shape Size -> Type -> Type - stripFrame frame (Array u ds t) = - let mnew_shape = Shape <$> L.stripPrefix (toSComp <$> shapeDims frame) (shapeDims ds) - in case mnew_shape of - Nothing -> Scalar t - Just new_shape -> arrayOfWithAliases u new_shape $ Scalar t - stripFrame _ t = t - toSComp (Var (QualName [] x) _ _) = SVar x - toSComp _ = error "" - toShape = Shape . pure - split (Scalar (Arrow _ _ _ a (RetType _ b))) = - pure (a, b `setUniqueness` NoUniqueness) - split ftype' = do - a <- newType loc "arg" NoUniqueness - b <- newType loc "res" Nonunique - ctEq ftype' $ Scalar $ Arrow NoUniqueness Unnamed Observe a $ RetType [] b - pure (a, b `setUniqueness` NoUniqueness) - checkSlice :: SliceBase NoInfo VName -> TermM [DimIndex] checkSlice = mapM checkDimIndex where @@ -884,22 +846,41 @@ checkExp (Constr name es NoInfo loc) = do -- checkExp (AppExp (Apply fe args loc) NoInfo) = do fe' <- checkExp fe - ((_, rt, _), args') <- mapAccumLM onArg (0, expType fe', frameOf fe') args - rt' <- asStructType loc rt - pure $ AppExp (Apply fe' args' loc) $ Info $ AppRes rt' [] + (args', argts') <- + NE.unzip + <$> forM + args + ( \(_, arg) -> do + arg' <- checkExp arg + pure (arg', (frameOf arg', expType arg')) + ) + (rt, ams) <- checkApply loc fname (frameOf fe', expType fe') argts' + pure $ + AppExp (Apply fe' (NE.zipWith (\am arg -> (Info (Nothing, am), arg)) ams args') loc) $ + Info $ + AppRes rt [] where fname = case fe of Var v _ _ -> Just v _ -> Nothing - - onArg (i, f_t, f_f) (_, arg) = do - arg' <- checkExp arg - (rt, am) <- checkApply_ loc (fname, i) f_t f_f arg' - pure - ( (i + 1, rt, autoFrame am), - (Info (Nothing, am), arg') - ) +-- fe' <- checkExp fe +-- ((_, rt, _), args') <- mapAccumLM onArg (0, expType fe', frameOf fe') args +-- rt' <- asStructType loc rt +-- pure $ AppExp (Apply fe' args' loc) $ Info $ AppRes rt' [] +-- where +-- fname = +-- case fe of +-- Var v _ _ -> Just v +-- _ -> Nothing + +-- onArg (i, f_t, f_f) (_, arg) = do +-- arg' <- checkExp arg +-- (rt, am) <- checkApply_ loc (fname, i) f_t f_f arg' +-- pure +-- ( (i + 1, rt, autoFrame am), +-- (Info (Nothing, am), arg') +-- ) -- checkExp (AppExp (BinOp (op, oploc) NoInfo (e1, _) (e2, _) loc) NoInfo) = do ftype <- lookupVar oploc op @@ -910,8 +891,8 @@ checkExp (AppExp (BinOp (op, oploc) NoInfo (e1, _) (e2, _) loc) NoInfo) = do loc (Just op) (mempty, toType ftype) - [(frameOf e1', toType $ typeOf e1'), (frameOf e2', toType $ typeOf e2')] - let [am1, am2] = ams + ((frameOf e1', toType $ typeOf e1') NE.:| [(frameOf e2', toType $ typeOf e2')]) + let (am1 NE.:| [am2]) = ams pure $ AppExp @@ -925,13 +906,15 @@ checkExp (OpSectionLeft op _ e _ _ loc) = do t2' <- asStructType loc t2 let t1 = typeOf e' f1 = frameOf e' - (rt, ams) <- checkApply loc (Just op) (mempty, toType optype) [(f1, toType t1), (mempty, t2)] + (rt, ams) <- checkApply loc (Just op) (mempty, toType optype) ((f1, toType t1) NE.:| [(mempty, t2)]) + + let (am1 NE.:| _) = ams pure $ OpSectionLeft op (Info optype) e' - ( Info (Unnamed, toParam Observe t1, Nothing, head ams), -- fix + ( Info (Unnamed, toParam Observe t1, Nothing, am1), Info (Unnamed, toParam Observe t2') ) (Info (RetType [] (rt `setUniqueness` Nonunique)), Info []) @@ -943,7 +926,8 @@ checkExp (OpSectionRight op _ e _ NoInfo loc) = do t1' <- asStructType loc t1 let t2 = typeOf e' f2 = frameOf e' - (rt, ams) <- checkApply loc (Just op) (mempty, toType optype) [(mempty, t1), (f2, toType t2)] + (rt, ams) <- checkApply loc (Just op) (mempty, toType optype) ((mempty, t1) NE.:| [(f2, toType t2)]) + let (_ NE.:| [am2]) = ams pure $ OpSectionRight op @@ -951,7 +935,7 @@ checkExp (OpSectionRight op _ e _ NoInfo loc) = do e' -- Dummy types. ( Info (Unnamed, toParam Observe t1'), - Info (Unnamed, toParam Observe t2, Nothing, ams !! 1) -- fix + Info (Unnamed, toParam Observe t2, Nothing, am2) ) (Info $ RetType [] (rt `setUniqueness` Nonunique)) loc From 6b171951cdabb1f2026f6b8a313b9f2f9ab59a9e Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Fri, 23 Feb 2024 07:48:20 -0800 Subject: [PATCH 113/258] Forgot to remove this too. --- src/Language/Futhark/TypeChecker/Terms2.hs | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 3f8b0df53b..b3a3106435 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -864,24 +864,6 @@ checkExp (AppExp (Apply fe args loc) NoInfo) = do case fe of Var v _ _ -> Just v _ -> Nothing --- fe' <- checkExp fe --- ((_, rt, _), args') <- mapAccumLM onArg (0, expType fe', frameOf fe') args --- rt' <- asStructType loc rt --- pure $ AppExp (Apply fe' args' loc) $ Info $ AppRes rt' [] --- where --- fname = --- case fe of --- Var v _ _ -> Just v --- _ -> Nothing - --- onArg (i, f_t, f_f) (_, arg) = do --- arg' <- checkExp arg --- (rt, am) <- checkApply_ loc (fname, i) f_t f_f arg' --- pure --- ( (i + 1, rt, autoFrame am), --- (Info (Nothing, am), arg') --- ) --- checkExp (AppExp (BinOp (op, oploc) NoInfo (e1, _) (e2, _) loc) NoInfo) = do ftype <- lookupVar oploc op e1' <- checkExp e1 From e09ad671ddaa2e0b7e735c0f1f614c3f2ae45bed Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Fri, 23 Feb 2024 17:41:59 -0800 Subject: [PATCH 114/258] Rank fixes. --- .../Futhark/TypeChecker/Constraints.hs | 2 +- src/Language/Futhark/TypeChecker/Rank.hs | 39 +++++++++++++------ 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index cf5a4f2dbc..63e2320a47 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -71,7 +71,7 @@ data TyVarInfo TyVarRecord (M.Map Name Type) | -- | Must be a sum type with these fields. TyVarSum (M.Map Name [Type]) - deriving (Show) + deriving (Show, Eq) instance Pretty TyVarInfo where pretty TyVarFree = "free" diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 83815ec70e..7fffe64107 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -55,12 +55,16 @@ instance Rank Type where class Distribute a where distribute :: a -> a -instance Distribute Type where - distribute = distributeOne - where - distributeOne (Array _ s (Arrow _ _ _ ta (RetType rd tr))) = - Scalar $ Arrow NoUniqueness Unnamed mempty (arrayOf s ta) (RetType rd $ arrayOfWithAliases Nonunique s tr) - distributeOne t = t +instance Distribute (TypeBase dim u) where + distribute (Array u s (Arrow _ _ _ ta (RetType rd tr))) = + Scalar $ + Arrow + u + Unnamed + mempty + (arrayOf s ta) + (RetType rd $ distribute (arrayOfWithAliases Nonunique s tr)) + distribute t = t instance Distribute Ct where distribute (CtEq t1 t2) = distribute t1 `CtEq` distribute t2 @@ -169,6 +173,7 @@ enumerateRankSols prog = run_glpk = unsafePerformIO . glpk next_sol m = do (prog', sol') <- m + guard (fst sol' /= 0) let prog'' = ambigCheckLinearProg prog' sol' sol'' <- run_glpk prog'' pure (prog'', sol'') @@ -197,11 +202,12 @@ solveRankILP loc prog = do rankAnalysis :: (MonadTypeChecker m) => SrcLoc -> [Ct] -> TyVars -> Exp -> m [(([Ct], TyVars), Exp)] rankAnalysis _ [] tyVars body = pure [(([], tyVars), body)] rankAnalysis loc cs tyVars body = do - rank_maps <- solveRankILP loc (mkLinearProg (foldMap splitFuncs cs) tyVars) + rank_maps <- solveRankILP loc (mkLinearProg cs' tyVars) cts_tyvars' <- mapM (substRankInfo cs tyVars) rank_maps let bodys = map (`updAM` body) rank_maps pure $ zip cts_tyvars' bodys where + cs' = foldMap (splitFuncs . distribute) cs splitFuncs ( CtEq (Scalar (Arrow _ _ _ t1a (RetType _ t1r))) @@ -288,9 +294,20 @@ rankToShape x = do addRankInfo :: (MonadTypeChecker m) => TyVar -> SubstT m () addRankInfo t = do rs <- asks envRanks - unless (fromMaybe 0 (rs M.!? t) == 0) $ do - new_vars <- gets substNewVars - maybe new_var (const $ pure ()) $ new_vars M.!? t + if (fromMaybe 0 (rs M.!? t) == 0) + then do + old_tyvars <- asks envTyVars + case old_tyvars M.!? t of + -- Probably not needed + -- Just (lvl, TyVarFree) -> + -- -- is anyPrimType right here? + -- modify $ + -- \s -> s {substTyVars = M.insert t (lvl, TyVarPrim anyPrimType) $ substTyVars s} + _ -> do + pure () + else do + new_vars <- gets substNewVars + maybe new_var (const $ pure ()) $ new_vars M.!? t where new_var = do t' <- newTyVar t @@ -336,7 +353,7 @@ updAM rank_map e = args' = fmap ( bimap - (fmap $ second upd) + (fmap $ bimap id upd) (updAM rank_map) ) args From e2505f5acaca043ca4a4004ab20300ca32a350ce Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Fri, 23 Feb 2024 17:43:54 -0800 Subject: [PATCH 115/258] Don't peel frames and distribute instead. --- src/Language/Futhark/TypeChecker/Terms.hs | 23 ++++++-- src/Language/Futhark/TypeChecker/Terms2.hs | 68 +++++++++++++++++----- 2 files changed, 73 insertions(+), 18 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 3a1e9fc663..2107f56fd0 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -27,7 +27,7 @@ import Data.List.NonEmpty qualified as NE import Data.Map.Strict qualified as M import Data.Maybe import Data.Set qualified as S -import Futhark.Util (mapAccumLM, nubOrd, topologicalSort) +import Futhark.Util (debugTraceM, mapAccumLM, nubOrd, topologicalSort) import Futhark.Util.Pretty hiding (space) import Language.Futhark import Language.Futhark.Primitive (intByteSize) @@ -941,13 +941,28 @@ checkApply :: Exp -> AutoMap -> TermTypeM (StructType, StructType, Maybe VName, [VName], AutoMap) -checkApply loc (fname, _) ft@(Scalar (Arrow _ pname _ tp1 tp2)) argexp am = do +checkApply loc fn@(fname, _) ft@(Scalar (Arrow _ pname _ tp1 tp2)) argexp am = do let argtype = typeOf argexp onFailure (CheckingApply fname argexp tp1 argtype) $ do (am_map_shape, argtype_with_frame) <- splitArrayAt (autoMapRank am) <$> normTypeFully argtype (am_rep_shape, tp1_with_frame) <- splitArrayAt (autoRepRank am) <$> normTypeFully tp1 let (am_frame_shape, argtype_automap) = splitArrayAt (autoFrameRank am) argtype_with_frame + debugTraceM $ + unlines + [ "## checkApply", + "## fn", + prettyString fn, + "## ft", + prettyString ft, + "## tp1_with_frame", + prettyString tp1_with_frame, + "## argtype_with_frame", + prettyString argtype_with_frame, + "## am", + show am + ] + unify (mkUsage argexp "use as function argument") tp1_with_frame argtype_with_frame -- Perform substitutions of instantiated variables in the types. @@ -990,8 +1005,8 @@ checkApply loc (fname, _) ft@(Scalar (Arrow _ pname _ tp1 tp2)) argexp am = do let am = AutoMap - { autoMap = am_map_shape, - autoRep = mempty, + { autoRep = mempty, + autoMap = am_map_shape, autoFrame = am_map_shape <> am_frame_shape } diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index b3a3106435..59182007b5 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -628,41 +628,77 @@ checkApply loc fname (fframe, ftype) args = do ) checkApplyOne :: SrcLoc -> (Maybe (QualName VName), Int) -> (Shape Size, Type) -> (Shape Size, Type) -> TermM (Type, AutoMap) -checkApplyOne loc _ (fframe, ftype) (argframe, argtype) = do - (a, b) <- split $ stripFrame fframe ftype +checkApplyOne loc fname (fframe, ftype) (argframe, argtype) = do + (a, b) <- split ftype r <- newSVar loc "R" m <- newSVar loc "M" let unit_info = Info $ Scalar $ Prim Bool r_var = Var (QualName [] r) unit_info mempty m_var = Var (QualName [] m) unit_info mempty lhs = arrayOf (toShape (SVar r)) argtype - rhs = arrayOf (toShape (SVar m) <> (toSComp <$> fframe)) a - + rhs = arrayOf (toShape (SVar m)) a ctAM r m ctEq lhs rhs + debugTraceM $ + unlines $ + [ "## checkApplyOne", + "## fname", + prettyString fname, + "## (fframe, ftype)", + prettyString (fframe, ftype), + "## (argframe, argtype)", + prettyString (argframe, argtype), + "## r", + prettyString r, + "## m", + prettyString m, + "## lhs", + prettyString lhs, + "## rhs", + prettyString rhs, + "## ret", + prettyString $ arrayOf (toShape (SVar m)) b + ] pure - ( arrayOf (toShape (SVar m) <> (toSComp <$> fframe)) b, + ( arrayOf (toShape (SVar m)) b, AutoMap {autoRep = toShape r_var, autoMap = toShape m_var, autoFrame = toShape m_var <> fframe} ) where - stripFrame :: Shape Size -> Type -> Type - stripFrame frame (Array u ds t) = - let mnew_shape = Shape <$> L.stripPrefix (toSComp <$> shapeDims frame) (shapeDims ds) - in case mnew_shape of - Nothing -> Scalar t - Just new_shape -> arrayOfWithAliases u new_shape $ Scalar t - stripFrame _ t = t + -- stripFrame :: Shape Size -> Type -> Type + -- stripFrame frame (Array u ds t) = + -- let mnew_shape = Shape <$> L.stripPrefix (toSComp <$> shapeDims frame) (shapeDims ds) + -- in case mnew_shape of + -- Nothing -> Scalar t + -- Just new_shape -> arrayOfWithAliases u new_shape $ Scalar t + -- stripFrame _ t = t + + isFunType (Scalar Arrow {}) = True + isFunType _ = False -- (fix) toSComp (Var (QualName [] x) _ _) = SVar x toSComp _ = error "" toShape = Shape . pure split (Scalar (Arrow _ _ _ a (RetType _ b))) = pure (a, b `setUniqueness` NoUniqueness) + split (Array u s t) = do + (a, b) <- split $ Scalar t + pure (arrayOf s a, arrayOf s b) split ftype' = do a <- newType loc "arg" NoUniqueness b <- newType loc "res" Nonunique ctEq ftype' $ Scalar $ Arrow NoUniqueness Unnamed Observe a $ RetType [] b pure (a, b `setUniqueness` NoUniqueness) +distribute :: TypeBase dim u -> TypeBase dim u +distribute (Array u s (Arrow _ _ _ ta (RetType rd tr))) = + Scalar $ + Arrow + u + Unnamed + mempty + (arrayOf s ta) + (RetType rd $ distribute (arrayOfWithAliases (uniqueness tr) s tr)) +distribute t = t + checkSlice :: SliceBase NoInfo VName -> TermM [DimIndex] checkSlice = mapM checkDimIndex where @@ -1151,7 +1187,11 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do debugTraceM $ unlines [ "## cts:", - unlines $ map prettyString cts + unlines $ map prettyString cts, + "## body:", + prettyString body', + "## tyvars:", + unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars ] (cts_tyvars', bodys') <- unzip <$> rankAnalysis loc cts tyvars body' @@ -1166,7 +1206,7 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do unlines [ "## constraints:", unlines $ map prettyString cts', - "## tyvars:", + "## tyvars':", unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars', "## solution:", let p (v, t) = prettyNameString v <> " => " <> prettyString t From f25bd9e57cec8fcf49ba48d7e27c83223df2029c Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Fri, 23 Feb 2024 18:56:41 -0800 Subject: [PATCH 116/258] Change objective from `M_i + R_i` to `M_i + max(0, |R_i| - |F_i|).` --- src/Futhark/Solve/LP.hs | 22 ++++++++++++++++- .../Futhark/TypeChecker/Constraints.hs | 6 ++--- src/Language/Futhark/TypeChecker/Rank.hs | 24 ++++++++++++++----- src/Language/Futhark/TypeChecker/Terms.hs | 8 ++++++- src/Language/Futhark/TypeChecker/Terms2.hs | 6 ++--- 5 files changed, 52 insertions(+), 14 deletions(-) diff --git a/src/Futhark/Solve/LP.hs b/src/Futhark/Solve/LP.hs index 044f6efe63..a2224617ea 100644 --- a/src/Futhark/Solve/LP.hs +++ b/src/Futhark/Solve/LP.hs @@ -8,6 +8,8 @@ module Futhark.Solve.LP cval, bin, or, + min, + max, oneIsZero, (~+~), (~-~), @@ -45,7 +47,7 @@ import Futhark.Solve.Matrix (Matrix (..)) import Futhark.Solve.Matrix qualified as M import Futhark.Util.Pretty import Language.Futhark.Pretty -import Prelude hiding (or) +import Prelude hiding (max, min, or) import Prelude qualified -- | A linear program. 'LP c a d' represents the program @@ -203,6 +205,24 @@ linearProgToPulp prog = bigM :: (Num a) => a bigM = 2 ^ 10 +-- max{x, y} = z +max :: (Eq a, Num a, Ord v) => v -> LSum v a -> LSum v a -> LSum v a -> [Constraint v a] +max b x y z = + [ z ~>=~ x, + z ~>=~ y, + z ~<=~ x ~+~ bigM ~*~ var b, + z ~<=~ y ~+~ bigM ~*~ (constant 1 ~-~ var b) + ] + +-- min{x, y} = z +min :: (Eq a, Num a, Ord v) => v -> v -> v -> v -> [Constraint v a] +min b x y z = + [ var z ~<=~ var x, + var z ~<=~ var y, + var z ~>=~ var x ~-~ bigM ~*~ (constant 1 ~-~ var b), + var z ~>=~ var y ~-~ bigM ~*~ var b + ] + oneIsZero :: (Eq a, Num a, Ord v) => (v, v) -> (v, v) -> [Constraint v a] oneIsZero (b1, x1) (b2, x2) = mkC b1 x1 diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 63e2320a47..ad11df4729 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -52,12 +52,12 @@ toType = first (const SDim) data Ct = CtEq Type Type - | CtAM SVar SVar + | CtAM SVar SVar (Shape SComp) deriving (Show) instance Pretty Ct where pretty (CtEq t1 t2) = pretty t1 <+> "~" <+> pretty t2 - pretty (CtAM r m) = prettyName r <+> "=" <+> "•" <+> "∨" <+> prettyName m <+> "=" <+> "•" + pretty (CtAM r m _) = prettyName r <+> "=" <+> "•" <+> "∨" <+> prettyName m <+> "=" <+> "•" type Constraints = [Ct] @@ -190,7 +190,7 @@ solveCt :: Ct -> SolveM () solveCt ct = case ct of CtEq t1 t2 -> solveCt' (t1, t2) - CtAM _ _ -> pure () -- Good vibes only. + CtAM _ _ _ -> pure () -- Good vibes only. where bad = throwError $ "Unsolvable: " <> prettyText ct solveCt' (t1, t2) = do diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 7fffe64107..c890ddaca5 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -73,7 +73,8 @@ instance Distribute Ct where data RankState = RankState { rankBinVars :: Map VName VName, rankCounter :: !Int, - rankConstraints :: [Constraint] + rankConstraints :: [Constraint], + rankObj :: LSum } newtype RankM a = RankM {runRankM :: State RankState a} @@ -106,12 +107,22 @@ addConstraints cs = addConstraint :: Constraint -> RankM () addConstraint = addConstraints . pure +addObj :: SVar -> RankM () +addObj sv = + modify $ \s -> s {rankObj = rankObj s ~+~ var sv} + addCt :: Ct -> RankM () addCt (CtEq t1 t2) = addConstraint $ rank t1 ~==~ rank t2 -addCt (CtAM r m) = do +addCt (CtAM r m f) = do b_r <- binVar r b_m <- binVar m + b_max <- VName "b_max" <$> incCounter + tr <- VName ("T_" <> baseName r) <$> incCounter + addConstraints $ [bin b_max, var b_max ~<=~ var tr] addConstraints $ oneIsZero (b_r, r) (b_m, m) + addConstraints $ LP.max b_max (constant 0) (rank r ~-~ rank f) (var tr) + addObj m + addObj tr addTyVarInfo :: TyVar -> (Int, TyVarInfo) -> RankM () addTyVarInfo _ (_, TyVarFree) = pure () @@ -126,9 +137,9 @@ mkLinearProg :: [Ct] -> TyVars -> LinearProg mkLinearProg cs tyVars = LP.LinearProg { optType = Minimize, - objective = - let shape_vars = M.keys $ rankBinVars finalState - in foldr (\sv s -> var sv ~+~ s) (constant 0) shape_vars, + objective = rankObj finalState, + -- let shape_vars = M.keys $ rankBinVars finalState + -- in foldr (\sv s -> var sv ~+~ s) (constant 0) shape_vars, constraints = rankConstraints finalState } where @@ -136,7 +147,8 @@ mkLinearProg cs tyVars = RankState { rankBinVars = mempty, rankCounter = 0, - rankConstraints = mempty + rankConstraints = mempty, + rankObj = constant 0 } buildLP = do mapM_ addCt cs diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 2107f56fd0..8ad8b2a366 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -1620,7 +1620,13 @@ checkFunDef (fname, retdecl, tparams, params, body, loc) = do case (maybe_tysubstss, bodys') of ([], _) -> error "impossible" ([maybe_tysubsts], [body']) -> doChecks (maybe_tysubsts, params', retdecl', body') - _ -> typeError loc mempty "Rank ILP is ambiguous" + (substs, bodies') -> + typeError loc mempty $ + stack $ + [ "Rank ILP is ambiguous.", + "Choices:" + ] + ++ map pretty bodies' where -- TODO: Print out the possibilities. (And also potentially eliminate --- some of the possibilities to disambiguate). diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 59182007b5..5bf91a47fe 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -283,8 +283,8 @@ ctEq t1 t2 = t1' = t1 `setUniqueness` NoUniqueness t2' = t2 `setUniqueness` NoUniqueness -ctAM :: SVar -> SVar -> TermM () -ctAM r m = addCt $ CtAM r m +ctAM :: SVar -> SVar -> Shape SComp -> TermM () +ctAM r m f = addCt $ CtAM r m f localScope :: (TermScope -> TermScope) -> TermM a -> TermM a localScope f = local $ \tenv -> tenv {termScope = f $ termScope tenv} @@ -637,7 +637,7 @@ checkApplyOne loc fname (fframe, ftype) (argframe, argtype) = do m_var = Var (QualName [] m) unit_info mempty lhs = arrayOf (toShape (SVar r)) argtype rhs = arrayOf (toShape (SVar m)) a - ctAM r m + ctAM r m $ fmap toSComp (toShape m_var <> fframe) ctEq lhs rhs debugTraceM $ unlines $ From 93fa76e46bbaa2798a4ad9aa6de8d4bd438618a2 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Fri, 23 Feb 2024 19:04:38 -0800 Subject: [PATCH 117/258] Add some tests. --- tests/automap/ambiguous0.fut | 1 + tests/automap/bool1.fut | 6 +++ tests/automap/equality1.fut | 23 ++++++++++ tests/automap/lambda.fut | 6 +++ tests/automap/map0.fut | 8 ++++ tests/automap/mri-q-qr.fut | 2 + tests/automap/mri-q.fut | 40 +++++++++++++++++ tests/automap/operator1.fut | 9 ++++ tests/automap/optionpricing.fut | 78 +++++++++++++++++++++++++++++++++ tests/automap/pagerank.fut | 18 ++++++++ tests/automap/project.fut | 9 ++++ tests/automap/projsec1.fut | 9 ++++ tests/automap/same_typevar.fut | 16 +++++++ tests/automap/sgemm.fut | 32 ++++++++++++++ tests/automap/simple1.fut | 7 +++ tests/automap/simple2.fut | 8 ++++ tests/automap/simple3.fut | 8 ++++ tests/automap/simple4.fut | 8 ++++ tests/automap/simple5.fut | 6 +++ 19 files changed, 294 insertions(+) create mode 100644 tests/automap/ambiguous0.fut create mode 100644 tests/automap/bool1.fut create mode 100644 tests/automap/equality1.fut create mode 100644 tests/automap/lambda.fut create mode 100644 tests/automap/map0.fut create mode 100644 tests/automap/mri-q-qr.fut create mode 100644 tests/automap/mri-q.fut create mode 100644 tests/automap/operator1.fut create mode 100644 tests/automap/optionpricing.fut create mode 100644 tests/automap/pagerank.fut create mode 100644 tests/automap/project.fut create mode 100644 tests/automap/projsec1.fut create mode 100644 tests/automap/same_typevar.fut create mode 100644 tests/automap/sgemm.fut create mode 100644 tests/automap/simple1.fut create mode 100644 tests/automap/simple2.fut create mode 100644 tests/automap/simple3.fut create mode 100644 tests/automap/simple4.fut create mode 100644 tests/automap/simple5.fut diff --git a/tests/automap/ambiguous0.fut b/tests/automap/ambiguous0.fut new file mode 100644 index 0000000000..58a663bf36 --- /dev/null +++ b/tests/automap/ambiguous0.fut @@ -0,0 +1 @@ +def ambig (xss : [][]i32) = i64.sum (length xss) diff --git a/tests/automap/bool1.fut b/tests/automap/bool1.fut new file mode 100644 index 0000000000..f3fe08213e --- /dev/null +++ b/tests/automap/bool1.fut @@ -0,0 +1,6 @@ +-- == +-- entry: f +-- input { [true, true, false] [false, true, true] } +-- output { [true, true, true] } + +def f [m] (xs: [m]bool) (ys: [m]bool) = xs || ys diff --git a/tests/automap/equality1.fut b/tests/automap/equality1.fut new file mode 100644 index 0000000000..1604c49d93 --- /dev/null +++ b/tests/automap/equality1.fut @@ -0,0 +1,23 @@ +-- == +-- entry: bigger_to_smaller +-- input { [[1,2],[3,4]] [1,2] } +-- output { [true, false] } + +-- == +-- entry: smaller_to_bigger +-- input { [[1,2],[3,4]] [1,2] } +-- output { [true, false] } + +-- == +-- entry: smaller_to_bigger2 +-- input { [[1,2],[3,4]] 1 } +-- output { [[true,false],[false,false]]} + +def bigger_to_smaller [n] (xss : [n][n]i32) (ys: [n]i32) : [n]bool = + xss == ys + +def smaller_to_bigger [n] (xss : [n][n]i32) (ys: [n]i32) : [n]bool = + ys == xss + +def smaller_to_bigger2 [n] (xss : [n][n]i32) (z: i32) : [n][n]bool = + z == xss diff --git a/tests/automap/lambda.fut b/tests/automap/lambda.fut new file mode 100644 index 0000000000..1bb7ed26e3 --- /dev/null +++ b/tests/automap/lambda.fut @@ -0,0 +1,6 @@ +-- == +-- entry: main +-- random input { [10]f32 [10]f32 } + +entry main [n](xs: [n]f32) (ys: [n]f32): [n]f32 = + map2 (*) xs ys diff --git a/tests/automap/map0.fut b/tests/automap/map0.fut new file mode 100644 index 0000000000..a5ab0887ae --- /dev/null +++ b/tests/automap/map0.fut @@ -0,0 +1,8 @@ +-- == +-- entry: main +-- input { [0,1,2,3] } +-- output { [1,2,3,4] } + +def automap 'a [n] 'x (f: a -> x) (as: [n]a): *[n]x = f as + +entry main (x: []i32) = automap (+1) x diff --git a/tests/automap/mri-q-qr.fut b/tests/automap/mri-q-qr.fut new file mode 100644 index 0000000000..8004f7da5d --- /dev/null +++ b/tests/automap/mri-q-qr.fut @@ -0,0 +1,2 @@ +def qr [numX][numK] (expArgs : [numX][numK]f32) (phiMag : [numK]f32) : [numX]f32 = + f32.sum (f32.cos expArgs * phiMag) diff --git a/tests/automap/mri-q.fut b/tests/automap/mri-q.fut new file mode 100644 index 0000000000..eaed14333a --- /dev/null +++ b/tests/automap/mri-q.fut @@ -0,0 +1,40 @@ +-- == +-- entry: main +-- random input { [12]f32 [12]f32 [12]f32 [10]f32 [10]f32 [10]f32 [12]f32 [12]f32 } +-- output { true } + +def main_orig [numK][numX] + (kx: [numK]f32) (ky: [numK]f32) (kz: [numK]f32) + (x: [numX]f32) (y: [numX]f32) (z: [numX]f32) + (phiR: [numK]f32) (phiI: [numK]f32) + : ([numX]f32, [numX]f32) = + let phiMag = map2 (\r i -> r*r + i*i) phiR phiI + let expArgs = map3 (\x_e y_e z_e -> + map (2.0f32*f32.pi*) + (map3 (\kx_e ky_e kz_e -> + kx_e * x_e + ky_e * y_e + kz_e * z_e) + kx ky kz)) + x y z + let qr = map1 (map f32.cos >-> map2 (*) phiMag >-> f32.sum) expArgs + let qi = map1 (map f32.sin >-> map2 (*) phiMag >-> f32.sum) expArgs + in (qr, qi) + +def main_am [numK][numX] + (kx: [numK]f32) (ky: [numK]f32) (kz: [numK]f32) + (x: [numX]f32) (y: [numX]f32) (z: [numX]f32) + (phiR: [numK]f32) (phiI: [numK]f32) + : ([numK]f32, [numX][numK]f32) = + let (phiMag : [numK]f32) = phiR * phiR + phiI * phiI + let (expArgs : [numX][numK]f32) = map3 (\(x_e : f32) (y_e : f32) (z_e : f32) -> + 2.0*f32.pi*(kx*x_e + ky*y_e + kz*z_e)) + x y z + in (phiMag, expArgs) + --let (qr : [numX]f32) = f32.sum (f32.cos expArgs * phiMag) -- [numx]f32 + --let (qi : [numX]f32) = f32.sum (f32.sin expArgs * phiMag) -- let (qi_10408: artificial₁₁₄_10524 ~ [M113_10523]f32) + --in (qr, qi) + +--entry main [numK][numX] +-- (kx: [numK]f32) (ky: [numK]f32) (kz: [numK]f32) +-- (x: [numX]f32) (y: [numX]f32) (z: [numX]f32) +-- (phiR: [numK]f32) (phiI: [numK]f32) = +-- main_orig kx ky kz x y z phiR phiI == main_am kx ky kz x y z phiR phiI diff --git a/tests/automap/operator1.fut b/tests/automap/operator1.fut new file mode 100644 index 0000000000..464a8b79c4 --- /dev/null +++ b/tests/automap/operator1.fut @@ -0,0 +1,9 @@ +-- == +-- entry: main +-- input { [[1,2],[3,4]] [10,20] } +-- output { [[11, 22],[13, 24]] } + +def (+^) [n] (xs: [n]i32) (ys: [n]i32) : [n]i32 = xs + ys + +--entry main [n] (xss : [n][n]i32) (ys: [n]i32) : [n][n]i32 = +-- xss +^ ys diff --git a/tests/automap/optionpricing.fut b/tests/automap/optionpricing.fut new file mode 100644 index 0000000000..c58bc39a0a --- /dev/null +++ b/tests/automap/optionpricing.fut @@ -0,0 +1,78 @@ +-- == +-- entry: sobolIndR +-- random input { [12][10]i32 i32 } +-- output { true } + +-- == +-- entry: sobolRecI +-- random input { [12][10]i32 [12]i32 i32} +-- output { true } + +-- == +-- entry: sobolReci2 +-- random input { [12][10]i32 [12]i32 i32} +-- output { true } + +def grayCode(x: i32): i32 = (x >> 1) ^ x + +def testBit(n: i32, ind: i32): bool = + let t = (1 << ind) in (n & t) == t + +def xorInds [num_bits] (n: i32) (dir_vs: [num_bits]i32): i32 = + let reldv_vals = map (\(dv: i32, i): i32 -> + if testBit(grayCode(n),i32.i64 i) + then dv else 0 + ) (zip (dir_vs) (iota(num_bits)) ) in + reduce (^) 0 (reldv_vals ) + + +def sobolIndI [len] (dir_vs: [len][]i32, n: i32 ): [len]i32 = + map (xorInds(n)) (dir_vs ) + +def index_of_least_significant_0(num_bits: i32, n: i32): i32 = + let (goon,k) = (true,0) in + let (_,k,_) = loop ((goon,k,n)) for i < num_bits do + if(goon) + then if (n & 1) == 1 + then (true, k+1, n>>1) + else (false,k, n ) + else (false,k, n ) + in k + +def recM [len][num_bits] (sob_dirs: [len][num_bits]i32, i: i32 ): [len]i32 = + let bit= index_of_least_significant_0(i32.i64 num_bits,i) in + map (\(row: []i32): i32 -> row[bit]) (sob_dirs ) + +def sobolIndR_orig [m][num_bits] (dir_vs: [m][num_bits]i32) (n: i32): [m]f32 = + let divisor = 2.0 ** f32.i64(num_bits) + let arri = map (xorInds n) dir_vs + in map (\x -> f32.i32(x) / divisor) arri + +def sobolRecI_orig [num_bits][n] (sob_dir_vs: [n][num_bits]i32, prev: [n]i32, x: i32): [n]i32 = + let bit = index_of_least_significant_0(i32.i64 num_bits, x) + in map2 (\vct_row prev -> vct_row[bit] ^ prev) sob_dir_vs prev + +def sobolReci2_orig [n][num_bits] (sob_dirs: [n][num_bits]i32, prev: [n]i32, i: i32): [n]i32= + let col = recM(sob_dirs, i) + in map2 (^) prev col + +def sobolIndR_am [m][num_bits] (dir_vs: [m][num_bits]i32) (n: i32): [m]f32 = + let divisor = 2.0 ** f32.i64(num_bits) + let arri = xorInds n dir_vs + in f32.i32 arri / divisor + +def sobolRecI_am [num_bits][n] (sob_dir_vs: [n][num_bits]i32, prev: [n]i32, x: i32): [n]i32 = + let bit = index_of_least_significant_0(i32.i64 num_bits, x) + in sob_dir_vs[:,bit] ^ prev + +def sobolReci2_am [n][num_bits] (sob_dirs: [n][num_bits]i32, prev: [n]i32, i: i32): [n]i32= + prev ^ recM(sob_dirs, i) + +entry sobolIndR [m][num_bits] (dir_vs: [m][num_bits]i32) (n: i32): []bool = + sobolIndR_orig dir_vs n == sobolIndR_am dir_vs n + +entry sobolRecI [num_bits][n] (sob_dir_vs: [n][num_bits]i32) (prev: [n]i32) (x: i32): []bool = + sobolRecI_orig (sob_dir_vs, prev, x) == sobolRecI_am (sob_dir_vs, prev, x) + +entry sobolReci2 [n][num_bits] (sob_dirs: [n][num_bits]i32) (prev: [n]i32) (i: i32): []bool = + sobolReci2_orig (sob_dirs, prev, i) == sobolReci2_am (sob_dirs, prev, i) diff --git a/tests/automap/pagerank.fut b/tests/automap/pagerank.fut new file mode 100644 index 0000000000..c444932de5 --- /dev/null +++ b/tests/automap/pagerank.fut @@ -0,0 +1,18 @@ +-- == +-- entry: calculate_dangling_ranks +-- random input { [12]f32 [12]i32} +-- output { true } + +def calculate_dangling_ranks_orig [n] (ranks: [n]f32) (sizes: [n]i32): *[]f32 = + let zipped = zip sizes ranks + let weights = map (\(size, rank) -> if size == 0 then rank else 0f32) zipped + let total = f32.sum weights / f32.i64 n + in map (+total) ranks + +def calculate_dangling_ranks_am [n] (ranks: [n]f32) (sizes: [n]i32): *[]f32 = + let weights = f32.bool (sizes == 0) * ranks + let total = f32.sum weights / f32.i64 n + in ranks + total + +entry calculate_dangling_ranks [n] (ranks: [n]f32) (sizes: [n]i32): []bool = + calculate_dangling_ranks_orig ranks sizes == calculate_dangling_ranks_am ranks sizes diff --git a/tests/automap/project.fut b/tests/automap/project.fut new file mode 100644 index 0000000000..2902d0565a --- /dev/null +++ b/tests/automap/project.fut @@ -0,0 +1,9 @@ +-- == +-- entry: main +-- input { [1,2,3] [4,5,6] } +-- output { [1,2,3,4,5,6] } + +entry main [n] (xs: [n]i32) (ys: [n]i32) : []i32 = + let xsys = zip xs ys + in xsys.0 ++ xsys.1 + diff --git a/tests/automap/projsec1.fut b/tests/automap/projsec1.fut new file mode 100644 index 0000000000..485c977bc5 --- /dev/null +++ b/tests/automap/projsec1.fut @@ -0,0 +1,9 @@ +-- == +-- entry: main +-- input { [1,2,3] [4,5,6] } +-- output { [1,2,3,4,5,6] } + +entry main [n] (xs: [n]i32) (ys: [n]i32) : []i32 = + let xsys = zip xs ys + in (.0) xsys ++ (.1) xsys + diff --git a/tests/automap/same_typevar.fut b/tests/automap/same_typevar.fut new file mode 100644 index 0000000000..260a00b785 --- /dev/null +++ b/tests/automap/same_typevar.fut @@ -0,0 +1,16 @@ +-- == +-- tags { no_wasm } +-- entry: big_to_small +-- no_wasm compiled input { [[1,2],[3,4]] [1,2] 3 } + +-- == +-- entry: small_to_big +-- no_wasm compiled input { [[1,2],[3,4]] [1,2] 3 } + +def f 'a (x: a) (y: a) (z: a) = (x, y, z) + +entry big_to_small [n] (xss : [n][n]i32) (ys: [n]i32) (z: i32) : [n][n](i32,i32,i32) = + f xss ys z + +entry small_to_big [n] (xss : [n][n]i32) (ys: [n]i32) (z: i32) : [n][n](i32,i32,i32) = + f z ys xss diff --git a/tests/automap/sgemm.fut b/tests/automap/sgemm.fut new file mode 100644 index 0000000000..56dc08eb7e --- /dev/null +++ b/tests/automap/sgemm.fut @@ -0,0 +1,32 @@ +-- == +-- entry: main +-- random input { [5][10]f32 [10][3]f32 [5][3]f32 f32 f32 } +-- output { true } + +def mult_orig [n][m][p] (xss: [n][m]f32, yss: [m][p]f32): [n][p]f32 = + let dotprod xs ys = f32.sum (map2 (*) xs ys) + in map (\xs -> map (dotprod xs) (transpose yss)) xss + +def add [n][m] (xss: [n][m]f32, yss: [n][m]f32): [n][m]f32 = + map2 (map2 (+)) xss yss + +def scale [n][m] (xss: [n][m]f32, a: f32): [n][m]f32 = + map (map1 (*a)) xss + +def main_orig [n][m][p] (ass: [n][m]f32) (bss: [m][p]f32) (css: [n][p]f32) + (alpha: f32) (beta: f32) + : [n][p]f32 = + add(scale(css,beta), scale(mult_orig(ass,bss), alpha)) + + +def mult_am [n][m][p] (xss: [n][m]f32, yss: [m][p]f32): [n][p]f32 = + f32.sum ((transpose (replicate p xss)) * (replicate n (transpose yss))) + +def main_am [n][m][p] (ass: [n][m]f32) (bss: [m][p]f32) (css: [n][p]f32) + (alpha: f32) (beta: f32) + : [n][p]f32 = + css*beta + mult_am(ass,bss)*alpha + +entry main [n][m][p] (ass: [n][m]f32) (bss: [m][p]f32) (css: [n][p]f32) + (alpha: f32) (beta: f32) = + main_orig ass bss css alpha beta == main_am ass bss css alpha beta diff --git a/tests/automap/simple1.fut b/tests/automap/simple1.fut new file mode 100644 index 0000000000..f8833bb3b6 --- /dev/null +++ b/tests/automap/simple1.fut @@ -0,0 +1,7 @@ +-- == +-- entry: main +-- input { [1,2] 10 } +-- output { [11, 12] } + +entry main [n] (xs: [n]i32) (y : i32) : [n]i32 = + xs + y diff --git a/tests/automap/simple2.fut b/tests/automap/simple2.fut new file mode 100644 index 0000000000..ac57abcbe0 --- /dev/null +++ b/tests/automap/simple2.fut @@ -0,0 +1,8 @@ +-- == +-- entry: main +-- input { [[1,2],[3,4]] [1,1] } +-- output { [[2,3],[4,5]] } + +entry main [n] (xss : [n][n]i32) (ys: [n]i32) : [n][n]i32 = + xss + ys + diff --git a/tests/automap/simple3.fut b/tests/automap/simple3.fut new file mode 100644 index 0000000000..adc60bd43f --- /dev/null +++ b/tests/automap/simple3.fut @@ -0,0 +1,8 @@ +-- == +-- entry: main +-- input { [[1,2],[3,4]] [1,1] } +-- output { [[2,3],[4,5]] } + +entry main [n] (xss : [n][n]i32) (ys: [n]i32) : [n][n]i32 = + ys + xss + diff --git a/tests/automap/simple4.fut b/tests/automap/simple4.fut new file mode 100644 index 0000000000..d94bbe4a6b --- /dev/null +++ b/tests/automap/simple4.fut @@ -0,0 +1,8 @@ +-- == +-- entry: main +-- input { 3 [1,1] [[1,2],[3,4]] } +-- output { [[5,6],[7,8]] } + +entry main [n] (x : i32) (ys: [n]i32) (zss : [n][n]i32) : [n][n]i32 = + x + ys + zss + diff --git a/tests/automap/simple5.fut b/tests/automap/simple5.fut new file mode 100644 index 0000000000..46610e6567 --- /dev/null +++ b/tests/automap/simple5.fut @@ -0,0 +1,6 @@ +-- == +-- input { [1,2,3] 4 } +-- output { [5, 6, 7] } + +entry main [n] (xs: [n]i32) (y : i32) : [n]i32 = + (\x y -> x + y) xs y From b3d3c423e03222c8a488b77f366e7c5a090ad87a Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Fri, 23 Feb 2024 19:40:25 -0800 Subject: [PATCH 118/258] Renaming hack to not count these when looking for new solutions. --- src/Language/Futhark/TypeChecker/Rank.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index c890ddaca5..186f8d0fb3 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -116,7 +116,7 @@ addCt (CtEq t1 t2) = addConstraint $ rank t1 ~==~ rank t2 addCt (CtAM r m f) = do b_r <- binVar r b_m <- binVar m - b_max <- VName "b_max" <$> incCounter + b_max <- VName "c_max" <$> incCounter tr <- VName ("T_" <> baseName r) <$> incCounter addConstraints $ [bin b_max, var b_max ~<=~ var tr] addConstraints $ oneIsZero (b_r, r) (b_m, m) From 3cdce28d1a63497ccfba1ea82dd983593ebc603f Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Fri, 23 Feb 2024 22:23:50 -0800 Subject: [PATCH 119/258] Add leetcode test. --- tests/automap/leetcode.fut | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 tests/automap/leetcode.fut diff --git a/tests/automap/leetcode.fut b/tests/automap/leetcode.fut new file mode 100644 index 0000000000..43a50cb2b8 --- /dev/null +++ b/tests/automap/leetcode.fut @@ -0,0 +1,4 @@ +def outerprod f x y = map (f >-> flip map y) x +def bidd A = outerprod (==) (indices A) (indices A) +def xmat A = bidd A || reverse (bidd A) +def check_matrix (A : [][]i32) = xmat A == (A != 0) |> flatten |> and From 6616a07928260326e1029e8759552dfce4f54655 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sat, 24 Feb 2024 21:17:24 +0100 Subject: [PATCH 120/258] Easier to read with some linebreaks. --- src/Language/Futhark/Pretty.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Language/Futhark/Pretty.hs b/src/Language/Futhark/Pretty.hs index 8519b22060..3318fe11e4 100644 --- a/src/Language/Futhark/Pretty.hs +++ b/src/Language/Futhark/Pretty.hs @@ -239,7 +239,7 @@ prettyAppExp p (BinOp (bop, _) _ (x, xi) (y, yi) _) = (Just (_, xam), Just (_, yam)) | isEnvVarAtLeast "FUTHARK_COMPILER_DEBUGGING" 1 -> -- fix - parens (prettyBinOp p bop x y <+> "Δ" <+> pretty xam <+> "Δ" <+> pretty yam) + parens $ align $ prettyBinOp p bop x y "Δ" <+> pretty xam "Δ" <+> pretty yam _ -> prettyBinOp p bop x y prettyAppExp _ (Match e cs _) = "match" <+> pretty e (stack . map pretty) (NE.toList cs) prettyAppExp _ (Loop sizeparams pat initexp form loopbody _) = From 2de3f1eb7537530bc123897aa17c3eb154936d34 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 25 Feb 2024 10:49:45 +0100 Subject: [PATCH 121/258] Remove map intrinsic. AUTOMAP all the way! --- prelude/soacs.fut | 2 +- prelude/zip.fut | 18 ++++++------------ src/Futhark/Internalise/Exps.hs | 9 +-------- src/Language/Futhark/Interpreter.hs | 16 ---------------- src/Language/Futhark/Prop.hs | 10 ---------- 5 files changed, 8 insertions(+), 47 deletions(-) diff --git a/prelude/soacs.fut b/prelude/soacs.fut index 310fad5421..9cda4d2e69 100644 --- a/prelude/soacs.fut +++ b/prelude/soacs.fut @@ -48,7 +48,7 @@ import "zip" -- -- **Span:** *O(S(f))* def map 'a [n] 'x (f: a -> x) (as: [n]a): *[n]x = - intrinsics.map f as + f as -- | Apply the given function to each element of a single array. -- diff --git a/prelude/zip.fut b/prelude/zip.fut index 1171820307..48816fe97a 100644 --- a/prelude/zip.fut +++ b/prelude/zip.fut @@ -7,12 +7,6 @@ -- The main reason this module exists is that we need it to define -- SOACs like `map2`. --- We need a map to define some of the zip variants, but this file is --- depended upon by soacs.fut. So we just define a quick-and-dirty --- internal one here that uses the intrinsic version. -local def internal_map 'a [n] 'x (f: a -> x) (as: [n]a): *[n]x = - intrinsics.map f as - -- | Construct an array of pairs from two arrays. def zip [n] 'a 'b (as: [n]a) (bs: [n]b): *[n](a,b) = intrinsics.zip as bs @@ -23,15 +17,15 @@ def zip2 [n] 'a 'b (as: [n]a) (bs: [n]b): *[n](a,b) = -- | As `zip2`@term, but with one more array. def zip3 [n] 'a 'b 'c (as: [n]a) (bs: [n]b) (cs: [n]c): *[n](a,b,c) = - internal_map (\(a,(b,c)) -> (a,b,c)) (zip as (zip2 bs cs)) + (\(a,(b,c)) -> (a,b,c)) (zip as (zip2 bs cs)) -- | As `zip3`@term, but with one more array. def zip4 [n] 'a 'b 'c 'd (as: [n]a) (bs: [n]b) (cs: [n]c) (ds: [n]d): *[n](a,b,c,d) = - internal_map (\(a,(b,c,d)) -> (a,b,c,d)) (zip as (zip3 bs cs ds)) + (\(a,(b,c,d)) -> (a,b,c,d)) (zip as (zip3 bs cs ds)) -- | As `zip4`@term, but with one more array. def zip5 [n] 'a 'b 'c 'd 'e (as: [n]a) (bs: [n]b) (cs: [n]c) (ds: [n]d) (es: [n]e): *[n](a,b,c,d,e) = - internal_map (\(a,(b,c,d,e)) -> (a,b,c,d,e)) (zip as (zip4 bs cs ds es)) + (\(a,(b,c,d,e)) -> (a,b,c,d,e)) (zip as (zip4 bs cs ds es)) -- | Turn an array of pairs into two arrays. def unzip [n] 'a 'b (xs: [n](a,b)): ([n]a, [n]b) = @@ -43,18 +37,18 @@ def unzip2 [n] 'a 'b (xs: [n](a,b)): ([n]a, [n]b) = -- | As `unzip2`@term, but with one more array. def unzip3 [n] 'a 'b 'c (xs: [n](a,b,c)): ([n]a, [n]b, [n]c) = - let (as, bcs) = unzip (internal_map (\(a,b,c) -> (a,(b,c))) xs) + let (as, bcs) = unzip ((\(a,b,c) -> (a,(b,c))) xs) let (bs, cs) = unzip bcs in (as, bs, cs) -- | As `unzip3`@term, but with one more array. def unzip4 [n] 'a 'b 'c 'd (xs: [n](a,b,c,d)): ([n]a, [n]b, [n]c, [n]d) = - let (as, bs, cds) = unzip3 (internal_map (\(a,b,c,d) -> (a,b,(c,d))) xs) + let (as, bs, cds) = unzip3 ((\(a,b,c,d) -> (a,b,(c,d))) xs) let (cs, ds) = unzip cds in (as, bs, cs, ds) -- | As `unzip4`@term, but with one more array. def unzip5 [n] 'a 'b 'c 'd 'e (xs: [n](a,b,c,d,e)): ([n]a, [n]b, [n]c, [n]d, [n]e) = - let (as, bs, cs, des) = unzip4 (internal_map (\(a,b,c,d,e) -> (a,b,c,(d,e))) xs) + let (as, bs, cs, des) = unzip4 ((\(a,b,c,d,e) -> (a,b,c,(d,e))) xs) let (ds, es) = unzip des in (as, bs, cs, ds, es) diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index 5f933e659c..fab4677f2f 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -1024,7 +1024,6 @@ withAutoMap ams arg_desc res_t args_e innerM = do where stripAutoMapDims i am = am {autoMap = E.Shape $ drop i $ E.shapeDims $ autoMap am} - autoMapRank = E.shapeRank . autoMap max_am = maximumBy (\x y -> E.shapeRank x `compare` E.shapeRank y) $ fmap autoMap ams inner_t = E.stripArray (E.shapeRank max_am) res_t ds = map autoMapRank ams @@ -1701,7 +1700,7 @@ findFuncall :: E.AppExp -> (Function, [(E.Exp, Maybe VName, AutoMap)]) findFuncall (E.Apply f args _) | E.Var fname _ _ <- f = (FunctionName fname, map onArg $ NE.toList args) - | E.Hole (Info t) loc <- f = + | E.Hole (Info _) loc <- f = (FunctionHole loc, map onArg $ NE.toList args) where onArg (Info (argext, am), e) = (e, argext, am) @@ -1859,12 +1858,6 @@ isIntrinsicFunction qname args loc = do fmap pure $ letSubExp desc $ I.BasicOp $ I.ConvOp conv x' handleOps _ _ = Nothing - handleSOACs [lam, arr] "map" = Just $ \desc -> do - arr' <- internaliseExpToVars "map_arr" arr - arr_ts <- mapM lookupType arr' - lam' <- internaliseLambdaCoerce lam $ map rowType arr_ts - let w = arraysSize 0 arr_ts - letTupExp' desc $ I.Op $ I.Screma w arr' (I.mapSOAC lam') handleSOACs [k, lam, arr] "partition" = do k' <- fromIntegral <$> fromInt32 k Just $ \_desc -> do diff --git a/src/Language/Futhark/Interpreter.hs b/src/Language/Futhark/Interpreter.hs index 527c734cbf..37358dcda1 100644 --- a/src/Language/Futhark/Interpreter.hs +++ b/src/Language/Futhark/Interpreter.hs @@ -1577,22 +1577,6 @@ initialCtx = Just $ fun2 stream def s | "reduce_stream" `isPrefixOf` s = Just $ fun3 $ \_ f arg -> stream f arg - def "map" = Just $ - TermPoly Nothing $ \t eval' -> do - t' <- evalType eval' mempty t - pure $ ValueFun $ \f -> pure . ValueFun $ \xs -> - case unfoldFunType t' of - ([_, _], ret_t) - | Just rowshape <- typeRowShape ret_t -> - toArray' rowshape <$> mapM (apply noLoc mempty f) (snd $ fromArray xs) - | otherwise -> - error $ "Bad return type: " <> prettyString ret_t - _ -> - error $ - "Invalid arguments to map intrinsic:\n" - ++ unlines [prettyString t, show f, show xs] - where - typeRowShape = sequenceA . structTypeShape . stripArray 1 def s | "reduce" `isPrefixOf` s = Just $ fun3 $ \f ne xs -> foldM (apply2 noLoc mempty f) ne $ snd $ fromArray xs diff --git a/src/Language/Futhark/Prop.hs b/src/Language/Futhark/Prop.hs index 6c69e1ef3d..280b531286 100644 --- a/src/Language/Futhark/Prop.hs +++ b/src/Language/Futhark/Prop.hs @@ -840,16 +840,6 @@ intrinsics = $ array_a Unique $ shape [m, k, l] ), - ( "map", - IntrinsicPolyFun - [tp_a, tp_b, sp_n] - [ Scalar (t_a mempty) `arr` Scalar (t_b Nonunique), - array_a Observe $ shape [n] - ] - $ RetType [] - $ array_b Unique - $ shape [n] - ), ( "reduce", IntrinsicPolyFun [tp_a, sp_n] From d24831df9f64058670f2a648003694537a6b79a8 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Sun, 25 Feb 2024 07:13:27 -0800 Subject: [PATCH 122/258] Fix AUTOMAP shapes for the replicate case. --- src/Language/Futhark/TypeChecker/Terms.hs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 8ad8b2a366..a661507772 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -946,7 +946,10 @@ checkApply loc fn@(fname, _) ft@(Scalar (Arrow _ pname _ tp1 tp2)) argexp am = d onFailure (CheckingApply fname argexp tp1 argtype) $ do (am_map_shape, argtype_with_frame) <- splitArrayAt (autoMapRank am) <$> normTypeFully argtype (am_rep_shape, tp1_with_frame) <- splitArrayAt (autoRepRank am) <$> normTypeFully tp1 - let (am_frame_shape, argtype_automap) = splitArrayAt (autoFrameRank am) argtype_with_frame + (am_frame_shape, argtype_automap) <- + if autoMapRank am == 0 + then splitArrayAt (autoFrameRank am) <$> normTypeFully tp1 + else pure $ splitArrayAt (autoFrameRank am) argtype_with_frame debugTraceM $ unlines @@ -1005,7 +1008,7 @@ checkApply loc fn@(fname, _) ft@(Scalar (Arrow _ pname _ tp1 tp2)) argexp am = d let am = AutoMap - { autoRep = mempty, + { autoRep = am_rep_shape, autoMap = am_map_shape, autoFrame = am_map_shape <> am_frame_shape } From 60db26b21d1a82c7ab92b3ebd3b92a35cb39c7bb Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Sun, 25 Feb 2024 07:17:34 -0800 Subject: [PATCH 123/258] Remove confusing/wrong name. --- src/Language/Futhark/TypeChecker/Terms.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index a661507772..67caf7a3c7 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -946,7 +946,7 @@ checkApply loc fn@(fname, _) ft@(Scalar (Arrow _ pname _ tp1 tp2)) argexp am = d onFailure (CheckingApply fname argexp tp1 argtype) $ do (am_map_shape, argtype_with_frame) <- splitArrayAt (autoMapRank am) <$> normTypeFully argtype (am_rep_shape, tp1_with_frame) <- splitArrayAt (autoRepRank am) <$> normTypeFully tp1 - (am_frame_shape, argtype_automap) <- + (am_frame_shape, _) <- if autoMapRank am == 0 then splitArrayAt (autoFrameRank am) <$> normTypeFully tp1 else pure $ splitArrayAt (autoFrameRank am) argtype_with_frame From 54882ab6b20fbfaf679c57e615fe39044e7610f9 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Sun, 25 Feb 2024 07:27:02 -0800 Subject: [PATCH 124/258] Better frame computation + clarifying notes. --- src/Language/Futhark/TypeChecker/Terms.hs | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 67caf7a3c7..e79fa461a0 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -944,12 +944,24 @@ checkApply :: checkApply loc fn@(fname, _) ft@(Scalar (Arrow _ pname _ tp1 tp2)) argexp am = do let argtype = typeOf argexp onFailure (CheckingApply fname argexp tp1 argtype) $ do + -- argtype = arg_frame argtype' + -- tp1 = f_frame tp1' + -- + -- Rep case: + -- R arg_frame argtype' = f_frame tp1' + -- ==> R = (autoRepRank am)-length prefix of tp1 + -- ==> frame = f_frame = (autoFrameRank am)-length prefix of tp1 + -- + -- Map case: + -- arg_frame argtype' = M f_frame tp1' + -- ==> M = (autoMapRank am)-length prefix of argtype + -- ==> frame = M f_frame = (autoFrameRank am)-length prefix of argtype (am_map_shape, argtype_with_frame) <- splitArrayAt (autoMapRank am) <$> normTypeFully argtype (am_rep_shape, tp1_with_frame) <- splitArrayAt (autoRepRank am) <$> normTypeFully tp1 (am_frame_shape, _) <- if autoMapRank am == 0 then splitArrayAt (autoFrameRank am) <$> normTypeFully tp1 - else pure $ splitArrayAt (autoFrameRank am) argtype_with_frame + else splitArrayAt (autoFrameRank am) <$> normTypeFully argtype debugTraceM $ unlines @@ -1010,7 +1022,7 @@ checkApply loc fn@(fname, _) ft@(Scalar (Arrow _ pname _ tp1 tp2)) argexp am = d AutoMap { autoRep = am_rep_shape, autoMap = am_map_shape, - autoFrame = am_map_shape <> am_frame_shape + autoFrame = am_frame_shape } pure (tp1, distributeFrame (autoMap am) tp2'', argext, ext, am) From c04ade2456e19ae97be412612bbc2399db5034d9 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 25 Feb 2024 17:54:35 +0100 Subject: [PATCH 125/258] Bump Nix and cabal. --- cabal.project | 2 +- nix/sources.json | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cabal.project b/cabal.project index f5a57ba11c..d8f795a28e 100644 --- a/cabal.project +++ b/cabal.project @@ -1,5 +1,5 @@ packages: futhark.cabal -index-state: 2024-01-24T22:19:37Z +index-state: 2024-02-25T13:57:21Z package futhark ghc-options: -j -fwrite-ide-info -hiedir=.hie diff --git a/nix/sources.json b/nix/sources.json index 1a95fd6f85..a60e35d52a 100644 --- a/nix/sources.json +++ b/nix/sources.json @@ -17,10 +17,10 @@ "homepage": "", "owner": "NixOS", "repo": "nixpkgs", - "rev": "2bcbada7a108ef5584abda1e36c42109d1f0d374", - "sha256": "12n79sl0nkp3b25ifdz9i8d9046g6dqz8g2jghg8d3836yjih7qj", + "rev": "efeff60fd4a0bc4f639a217a723f9e11df3f5e20", + "sha256": "09gxq604v7r9sl5qgp37n6414z2jivdjipwyrhka0d4rdhdbm31m", "type": "tarball", - "url": "https://github.com/NixOS/nixpkgs/archive/2bcbada7a108ef5584abda1e36c42109d1f0d374.tar.gz", + "url": "https://github.com/NixOS/nixpkgs/archive/efeff60fd4a0bc4f639a217a723f9e11df3f5e20.tar.gz", "url_template": "https://github.com///archive/.tar.gz" } } From 6c91e26e392c6cadaa338d8ae37f8937a4ab486b Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 25 Feb 2024 23:21:34 +0100 Subject: [PATCH 126/258] Work on supporting AUTOMAP in interpreter. --- src/Language/Futhark/Interpreter.hs | 117 +++++++++++++++++---- src/Language/Futhark/Interpreter/Values.hs | 20 +++- 2 files changed, 116 insertions(+), 21 deletions(-) diff --git a/src/Language/Futhark/Interpreter.hs b/src/Language/Futhark/Interpreter.hs index 37358dcda1..39c278c4a7 100644 --- a/src/Language/Futhark/Interpreter.hs +++ b/src/Language/Futhark/Interpreter.hs @@ -414,6 +414,11 @@ fromArray :: Value -> (ValueShape, [Value]) fromArray (ValueArray shape as) = (shape, elems as) fromArray v = error $ "Expected array value, but found: " <> show v +fromArrayR :: Int -> Value -> [Value] +fromArrayR 0 v = [v] +fromArrayR 1 v = snd $ fromArray v +fromArrayR n v = concatMap (fromArrayR (n - 1)) $ snd $ fromArray v + apply :: SrcLoc -> Env -> Value -> Value -> EvalM Value apply loc env (ValueFun f) v = stacking loc env (f v) apply _ _ f _ = error $ "Cannot apply non-function: " <> show f @@ -423,6 +428,54 @@ apply2 loc env f x y = stacking loc env $ do f' <- apply noLoc mempty f x apply noLoc mempty f' y +data AutoMapArg + = -- | Map function across argument of this shape. + AutoMapMap [Int64] + | -- | Replicate argument to array of this shape. + AutoMapRep [Int64] + | AutoMapNone + deriving (Eq, Ord, Show) + +applyAM :: + SrcLoc -> + Env -> + (Value, StructType) -> + AutoMapArg -> + Value -> + EvalM Value +applyAM loc env (ValueArray _ xs, ft) AutoMapNone v = do + t' <- evalType (eval env) mempty ft + undefined +applyAM loc env (f, _) AutoMapNone v = + apply loc env f v +applyAM loc env (f, _) (AutoMapMap []) v = + apply loc env f v +applyAM loc env (f, _) (AutoMapRep []) v = + apply loc env f v +applyAM loc env (f, _) (AutoMapRep shape) v = + apply noLoc mempty f $ repArray shape v +-- The next case essentially implements the "map" primitive. +applyAM loc env (f, ft) (AutoMapMap shape) v = do + t' <- evalType (eval env) mempty ft + let rank = length shape + vs = fromArrayR rank v + case t' of + Scalar (Arrow _ _ _ _ (RetType _ ret_t@(Scalar Arrow {}))) + | Just rowshape <- sequenceA $ structTypeShape $ toStruct ret_t -> do + fs <- mapM (apply noLoc mempty f) vs + pure $ ValueFun $ \v' -> + toArrayR shape rowshape + <$> zipWithM (apply loc env) fs (fromArrayR rank v') + Scalar (Arrow _ _ _ _ (RetType _ ret_t)) + | Just rowshape <- sequenceA $ structTypeShape $ toStruct ret_t -> + toArrayR shape rowshape <$> mapM (apply noLoc mempty f) vs + | otherwise -> + error $ "Bad return type: " <> prettyString ret_t + _ -> + error $ + "Invalid automap arguments:\n" + ++ unlines [prettyString ft, show f, show v] + matchPat :: Env -> Pat (TypeBase Size u) -> Value -> EvalM Env matchPat env p v = do m <- runMaybeT $ patternMatch env p v @@ -752,13 +805,21 @@ evalFunctionBinding env tparams ps ret fbody = do returned env (retType ret) retext =<< evalFunction env' missing_sizes ps fbody (retType ret) -evalArg :: Env -> Exp -> Maybe VName -> EvalM Value -evalArg env e ext = do +evalArg :: Env -> Exp -> Maybe VName -> AutoMap -> EvalM (Value, AutoMapArg) +evalArg env e ext am = do v <- eval env e case ext of Just ext' -> putExtSize ext' v _ -> pure () - pure v + let evalShape = mapM (fmap asInt64 . eval env) . shapeDims + am' <- + if not $ null $ autoMap am + then AutoMapMap <$> evalShape (autoMap am) + else + if not $ null $ autoRep am + then AutoMapRep <$> evalShape (autoRep am) + else pure AutoMapNone + pure (v, am') returned :: Env -> TypeBase Size als -> [VName] -> Value -> EvalM Value returned _ _ [] v = pure v @@ -828,22 +889,31 @@ evalAppExp env (LetPat sizes p e body _) = do evalAppExp env (LetFun f (tparams, ps, _, Info ret, fbody) body _) = do binding <- evalFunctionBinding env tparams ps ret fbody eval (env {envTerm = M.insert f binding $ envTerm env}) body -evalAppExp env (BinOp (op, _) op_t (x, Info (xext, xam)) (y, Info (yext, yam)) loc) - | baseString (qualLeaf op) == "&&" = do +evalAppExp env (BinOp (op, _) (Info op_t) (x, Info (xext, xam)) (y, Info (yext, yam)) loc) + | baseString (qualLeaf op) == "&&", + noAutoMap = do x' <- asBool <$> eval env x if x' then eval env y else pure $ ValuePrim $ BoolValue False - | baseString (qualLeaf op) == "||" = do + | baseString (qualLeaf op) == "||", + noAutoMap = do x' <- asBool <$> eval env x if x' then pure $ ValuePrim $ BoolValue True else eval env y | otherwise = do - x' <- evalArg env x xext - y' <- evalArg env y yext - op' <- eval env $ Var op op_t loc - apply2 loc env op' x' y' + (x', xam') <- evalArg env x xext xam + (y', yam') <- evalArg env y yext yam + op' <- evalTermVar env op op_t + op'' <- applyAM loc env (op', op_t) xam' x' + applyAM loc env (op'', op_ret) yam' y' + where + op_ret = case op_t of + Scalar (Arrow _ _ _ _ (RetType _ t)) -> + toStruct t + _ -> error $ "Nonsensical binop type: " <> prettyString op_t + noAutoMap = xam == mempty && yam == mempty evalAppExp env (If cond e1 e2 _) = do cond' <- asBool <$> eval env cond if cond' then eval env e1 else eval env e2 @@ -853,9 +923,11 @@ evalAppExp env (Apply f args loc) = do -- type of the functions. args' <- reverse <$> mapM evalArg' (reverse $ NE.toList args) f' <- eval env f - foldM (apply loc env) f' args' + foldM apply' f' args' where - evalArg' (Info (ext, _), x) = evalArg env x ext + ft = typeOf f + apply' f' (v', am') = applyAM loc env (f', ft) am' v' + evalArg' (Info (ext, am), x) = evalArg env x ext am evalAppExp env (Index e is loc) = do is' <- mapM (evalDimIndex env) is arr <- eval env e @@ -1047,16 +1119,21 @@ eval env (Lambda ps body _ (Info (RetType _ rt)) _) = evalFunction env [] ps body rt eval env (OpSection qv (Info t) _) = evalTermVar env qv $ toStruct t -eval env (OpSectionLeft qv _ e (Info (_, _, argext, _), _) (Info (RetType _ t), _) loc) = do - v <- evalArg env e argext - f <- evalTermVar env qv (toStruct t) - apply loc env f v -eval env (OpSectionRight qv _ e (Info _, Info (_, _, argext, _)) (Info (RetType _ t)) loc) = do - y <- evalArg env e argext +eval env (OpSectionLeft qv _ e (Info (_, _, argext, am), _) (Info (RetType _ t), _) loc) = do + (v, am') <- evalArg env e argext am + f <- evalTermVar env qv t' + applyAM loc env (f, t') am' v + where + t' = toStruct t +eval env (OpSectionRight qv _ e (Info _, Info (_, _, argext, am)) (Info (RetType _ t)) loc) = do + (y, am') <- evalArg env e argext am pure $ ValueFun $ \x -> do - f <- evalTermVar env qv $ toStruct t - apply2 loc env f x y + f <- evalTermVar env qv t' + f' <- apply loc env f x + applyAM loc env (f', t') am' y + where + t' = toStruct t eval env (IndexSection is _ loc) = do is' <- mapM (evalDimIndex env) is pure $ ValueFun $ evalIndex loc env is' diff --git a/src/Language/Futhark/Interpreter/Values.hs b/src/Language/Futhark/Interpreter/Values.hs index 40f0a8b287..b3fb36ac8c 100644 --- a/src/Language/Futhark/Interpreter/Values.hs +++ b/src/Language/Futhark/Interpreter/Values.hs @@ -20,7 +20,9 @@ module Language.Futhark.Interpreter.Values prettyEmptyArray, toArray, toArray', + toArrayR, toTuple, + repArray, -- * Conversion fromDataValue, @@ -28,7 +30,7 @@ module Language.Futhark.Interpreter.Values where import Data.Array -import Data.List (genericLength) +import Data.List (genericLength, genericReplicate) import Data.Map qualified as M import Data.Maybe import Data.Monoid hiding (Sum) @@ -206,6 +208,15 @@ toArray' rowshape vs = ValueArray shape (listArray (0, length vs - 1) vs) where shape = ShapeDim (genericLength vs) rowshape +-- | Produce multidimensional array from a flat list of values. +toArrayR :: [Int64] -> ValueShape -> [Value m] -> Value m +toArrayR [] _ = error "toArrayR: empty shape" +toArrayR [_] elemshape = toArray' elemshape +toArrayR (n : ns) elemshape = + toArray (foldr ShapeDim elemshape (n : ns)) + . map (toArrayR ns elemshape) + . chunk (fromIntegral (product ns)) + arrayLength :: (Integral int) => Array Int (Value m) -> int arrayLength = fromIntegral . (+ 1) . snd . bounds @@ -237,6 +248,13 @@ fromDataValueWith f shape vector where shape' = SVec.tail shape +repArray :: [Int64] -> Value m -> Value m +repArray [] v = v +repArray (n : ns) v = + toArray' (valueShape v') (genericReplicate n v') + where + v' = repArray ns v + -- | Convert a Futhark value in the externally observable data format -- to an interpreter value. fromDataValue :: V.Value -> Value m From 9919fec8dfb981695b2d7dca080735202e2ed9da Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 26 Feb 2024 15:45:57 +0100 Subject: [PATCH 127/258] We must also touch automaps here. --- src/Language/Futhark/Traversals.hs | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/src/Language/Futhark/Traversals.hs b/src/Language/Futhark/Traversals.hs index fc20935c24..94b440b2ff 100644 --- a/src/Language/Futhark/Traversals.hs +++ b/src/Language/Futhark/Traversals.hs @@ -60,6 +60,13 @@ class ASTMappable x where -- into subexpressions. The mapping is done left-to-right. astMap :: (Monad m) => ASTMapper m -> x -> m x +mapOnAutoMap :: (Monad m) => ASTMapper m -> AutoMap -> m AutoMap +mapOnAutoMap tv (AutoMap r m f) = + AutoMap + <$> traverse (mapOnExp tv) r + <*> traverse (mapOnExp tv) m + <*> traverse (mapOnExp tv) f + instance ASTMappable (AppExpBase Info VName) where astMap tv (Range start next end loc) = Range @@ -73,7 +80,7 @@ instance ASTMappable (AppExpBase Info VName) where Match <$> mapOnExp tv e <*> astMap tv cases <*> pure loc astMap tv (Apply f args loc) = do f' <- mapOnExp tv f - args' <- traverse (traverse $ mapOnExp tv) args + args' <- traverse onArg args -- Safe to disregard return type because existentials cannot be -- instantiated here, as the return is necessarily a function. pure $ case f' of @@ -81,6 +88,9 @@ instance ASTMappable (AppExpBase Info VName) where Apply f_inner (args_inner <> args') loc _ -> Apply f' args' loc + where + onArg (Info (ext, am), e) = + (,) <$> (Info . (ext,) <$> mapOnAutoMap tv am) <*> mapOnExp tv e astMap tv (LetPat sizes pat e body loc) = LetPat sizes <$> astMap tv pat <*> mapOnExp tv e <*> mapOnExp tv body <*> pure loc astMap tv (LetFun name (tparams, params, ret, t, e) body loc) = @@ -101,13 +111,16 @@ instance ASTMappable (AppExpBase Info VName) where <*> mapOnExp tv vexp <*> mapOnExp tv body <*> pure loc - astMap tv (BinOp (fname, fname_loc) t (x, xext) (y, yext) loc) = + astMap tv (BinOp (fname, fname_loc) t x y loc) = BinOp <$> ((,) <$> mapOnName tv fname <*> pure fname_loc) <*> traverse (mapOnStructType tv) t - <*> ((,) <$> mapOnExp tv x <*> pure xext) - <*> ((,) <$> mapOnExp tv y <*> pure yext) + <*> onArg x + <*> onArg y <*> pure loc + where + onArg (e, Info (ext, am)) = + (,) <$> mapOnExp tv e <*> (Info . (ext,) <$> mapOnAutoMap tv am) astMap tv (Loop sparams mergepat mergeexp form loopbody loc) = Loop sparams <$> astMap tv mergepat From afa09d3e501d266db9bf8b0c213539908ab2283b Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 26 Feb 2024 16:31:24 +0100 Subject: [PATCH 128/258] Add another test program. --- tests/automap/combinations.fut | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 tests/automap/combinations.fut diff --git a/tests/automap/combinations.fut b/tests/automap/combinations.fut new file mode 100644 index 0000000000..5c49fb251f --- /dev/null +++ b/tests/automap/combinations.fut @@ -0,0 +1,33 @@ +-- All the various ways one can imagine automapping a very simple program. + +def plus (x: i32) (y: i32) = x + y + +-- == +-- entry: vecint +-- input { [1,2,3] } output { [3,4,5] } + +entry vecint (x: []i32) = plus x 2 + +-- == +-- entry: vecvec +-- input { [1,2,3] } output { [2,4,6] } + +entry vecvec (x: []i32) = plus x x + +-- == +-- entry: matint +-- input { [[1,2],[3,4]] } output { [[3,4],[5,6]] } + +entry matint (x: [][]i32) = plus x 2 + +-- == +-- entry: matmat +-- input { [[1,2],[3,4]] } output { [[2,4],[6,8]] } + +entry matmat (x: [][]i32) = plus x x + +-- == +-- entry: matvec +-- input { [[1,2],[3,4]] [5,6] } output { [[6,8],[8,10]] } + +entry matvec (x: [][]i32) (y: []i32) = plus x y From 58c8ff33e06a8e91078126ae03763c1d7ddd0859 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Mon, 26 Feb 2024 08:01:39 -0800 Subject: [PATCH 129/258] Distribute frames recursively. --- src/Language/Futhark/TypeChecker/Terms.hs | 15 +++++++++++---- src/Language/Futhark/TypeChecker/Terms2.hs | 11 ----------- 2 files changed, 11 insertions(+), 15 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 2b8b9f6c16..25c6dcb0c8 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -1025,11 +1025,18 @@ checkApply loc fn@(fname, _) ft@(Scalar (Arrow _ pname _ tp1 tp2)) argexp am = d autoFrame = am_frame_shape } - pure (tp1, distributeFrame (autoMap am) tp2'', argext, ext, am) + pure (tp1, distribute (arrayOf (autoMap am) tp2''), argext, ext, am) where - distributeFrame frame (Scalar (Arrow u p d a (RetType ds b))) = - Scalar $ Arrow u p d (arrayOf frame a) (RetType ds (arrayOfWithAliases (uniqueness b) frame b)) - distributeFrame frame t = arrayOf frame t + distribute :: TypeBase dim u -> TypeBase dim u + distribute (Array u s (Arrow _ _ _ ta (RetType rd tr))) = + Scalar $ + Arrow + u + Unnamed + mempty + (arrayOf s ta) + (RetType rd $ distribute (arrayOfWithAliases (uniqueness tr) s tr)) + distribute t = t checkApply _ _ _ _ _ = error "checkApply: array" diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 5bf91a47fe..ccba90ab07 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -688,17 +688,6 @@ checkApplyOne loc fname (fframe, ftype) (argframe, argtype) = do ctEq ftype' $ Scalar $ Arrow NoUniqueness Unnamed Observe a $ RetType [] b pure (a, b `setUniqueness` NoUniqueness) -distribute :: TypeBase dim u -> TypeBase dim u -distribute (Array u s (Arrow _ _ _ ta (RetType rd tr))) = - Scalar $ - Arrow - u - Unnamed - mempty - (arrayOf s ta) - (RetType rd $ distribute (arrayOfWithAliases (uniqueness tr) s tr)) -distribute t = t - checkSlice :: SliceBase NoInfo VName -> TermM [DimIndex] checkSlice = mapM checkDimIndex where From 99fb122442709e4ab11e4d16caab276eb34cb8d9 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 26 Feb 2024 20:14:10 +0100 Subject: [PATCH 130/258] The vindication of Robert. --- src/Language/Futhark/Interpreter.hs | 49 +++++++---------------------- tests/automap/combinations.fut | 5 +++ 2 files changed, 17 insertions(+), 37 deletions(-) diff --git a/src/Language/Futhark/Interpreter.hs b/src/Language/Futhark/Interpreter.hs index 39c278c4a7..506a8b715d 100644 --- a/src/Language/Futhark/Interpreter.hs +++ b/src/Language/Futhark/Interpreter.hs @@ -429,11 +429,7 @@ apply2 loc env f x y = stacking loc env $ do apply noLoc mempty f' y data AutoMapArg - = -- | Map function across argument of this shape. - AutoMapMap [Int64] - | -- | Replicate argument to array of this shape. - AutoMapRep [Int64] - | AutoMapNone + = AutoMapArg [Int64] [Int64] [Int64] deriving (Eq, Ord, Show) applyAM :: @@ -443,38 +439,23 @@ applyAM :: AutoMapArg -> Value -> EvalM Value -applyAM loc env (ValueArray _ xs, ft) AutoMapNone v = do - t' <- evalType (eval env) mempty ft - undefined -applyAM loc env (f, _) AutoMapNone v = - apply loc env f v -applyAM loc env (f, _) (AutoMapMap []) v = - apply loc env f v -applyAM loc env (f, _) (AutoMapRep []) v = +applyAM loc env (f, _) (AutoMapArg [] [] []) v = apply loc env f v -applyAM loc env (f, _) (AutoMapRep shape) v = - apply noLoc mempty f $ repArray shape v --- The next case essentially implements the "map" primitive. -applyAM loc env (f, ft) (AutoMapMap shape) v = do +applyAM loc env (f, ft) am@(AutoMapArg repshape mapshape frame) v = do + let v' = repArray repshape v + f' = repArray mapshape f + rank = length frame + vs = fromArrayR rank v' + fs = fromArrayR rank f' t' <- evalType (eval env) mempty ft - let rank = length shape - vs = fromArrayR rank v case t' of - Scalar (Arrow _ _ _ _ (RetType _ ret_t@(Scalar Arrow {}))) - | Just rowshape <- sequenceA $ structTypeShape $ toStruct ret_t -> do - fs <- mapM (apply noLoc mempty f) vs - pure $ ValueFun $ \v' -> - toArrayR shape rowshape - <$> zipWithM (apply loc env) fs (fromArrayR rank v') Scalar (Arrow _ _ _ _ (RetType _ ret_t)) | Just rowshape <- sequenceA $ structTypeShape $ toStruct ret_t -> - toArrayR shape rowshape <$> mapM (apply noLoc mempty f) vs - | otherwise -> - error $ "Bad return type: " <> prettyString ret_t + toArrayR frame rowshape <$> zipWithM (apply loc env) fs vs _ -> error $ "Invalid automap arguments:\n" - ++ unlines [prettyString ft, show f, show v] + ++ unlines [prettyString ft, show f, show v, show am] matchPat :: Env -> Pat (TypeBase Size u) -> Value -> EvalM Env matchPat env p v = do @@ -806,19 +787,13 @@ evalFunctionBinding env tparams ps ret fbody = do =<< evalFunction env' missing_sizes ps fbody (retType ret) evalArg :: Env -> Exp -> Maybe VName -> AutoMap -> EvalM (Value, AutoMapArg) -evalArg env e ext am = do +evalArg env e ext (AutoMap rshape mshape frame) = do v <- eval env e case ext of Just ext' -> putExtSize ext' v _ -> pure () let evalShape = mapM (fmap asInt64 . eval env) . shapeDims - am' <- - if not $ null $ autoMap am - then AutoMapMap <$> evalShape (autoMap am) - else - if not $ null $ autoRep am - then AutoMapRep <$> evalShape (autoRep am) - else pure AutoMapNone + am' <- AutoMapArg <$> evalShape rshape <*> evalShape mshape <*> evalShape frame pure (v, am') returned :: Env -> TypeBase Size als -> [VName] -> Value -> EvalM Value diff --git a/tests/automap/combinations.fut b/tests/automap/combinations.fut index 5c49fb251f..7d77e85abb 100644 --- a/tests/automap/combinations.fut +++ b/tests/automap/combinations.fut @@ -31,3 +31,8 @@ entry matmat (x: [][]i32) = plus x x -- input { [[1,2],[3,4]] [5,6] } output { [[6,8],[8,10]] } entry matvec (x: [][]i32) (y: []i32) = plus x y + +-- == +-- entry: vecvecvec +-- input { [1,2,3] } output { [3,6,9] } +entry vecvecvec (x: []i32) = (\x y z -> x + y + z) x x x From e4041695f2ae12c2710f2a2eaca5550fd14eafaf Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Mon, 26 Feb 2024 21:13:47 -0800 Subject: [PATCH 131/258] Basic internalization working. --- src/Futhark/Internalise/Exps.hs | 157 +++++++++++++++++++++++++++++++- 1 file changed, 156 insertions(+), 1 deletion(-) diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index fab4677f2f..3c6b7bb99c 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -14,8 +14,10 @@ import Data.List (elemIndex, find, intercalate, intersperse, maximumBy, transpos import Data.List.NonEmpty (NonEmpty (..)) import Data.List.NonEmpty qualified as NE import Data.Map.Strict qualified as M +import Data.Maybe import Data.Set qualified as S import Data.Text qualified as T +import Debug.Trace import Futhark.IR.SOACS as I hiding (stmPat) import Futhark.Internalise.AccurateSizes import Futhark.Internalise.Bindings @@ -882,11 +884,16 @@ internalisePatLit l t = withAutoMap_ :: [AutoMap] -> String -> StructType -> [(E.Exp, Maybe VName)] -> ([[SubExp]] -> InternaliseM [SubExp]) -> InternaliseM [SubExp] withAutoMap_ ams arg_desc res_t args_e innerM = - withAutoMap ams arg_desc res_t args_e $ \args_stms -> do + withAutoMapNew (zip3 args_e ams (repeat arg_desc)) $ \args_stms -> do let (args, stms) = unzip args_stms mapM_ addStms $ reverse stms innerM args +-- withAutoMap ams arg_desc res_t args_e $ \args_stms -> do +-- let (args, stms) = unzip args_stms +-- mapM_ addStms $ reverse stms +-- innerM args + -- | Internalization of 'AutoMap'-annotated applications. -- -- Each application @f x@ has an annotation with @AutoMap R M F@ where @@ -1009,6 +1016,154 @@ withAutoMap_ ams arg_desc res_t args_e innerM = -- -- This process continues until the level is greater than the maximum -- true level of any application, at which we terminate. +type Level = Int + +type ArgNum = Int + +type ArgMap = M.Map Level (M.Map ArgNum AutoMapArg) + +data AutoMapArg = AutoMapArg + { amArgs :: [VName], + amArgStms :: Stms SOACS + } + deriving (Show) + +data AutoMapParam = AutoMapParam + { amParams :: [LParam SOACS], + amParamStms :: Stms SOACS, + amMapDim :: SubExp + } + deriving (Show) + +withAutoMapNew :: [((E.Exp, Maybe VName), AutoMap, String)] -> ([([SubExp], Stms SOACS)] -> InternaliseM [SubExp]) -> InternaliseM [SubExp] +withAutoMapNew args_am func = do + (param_maps, arg_maps) <- + unzip . reverse + <$> mapM buildArgMap (reverse args_am) + let param_map = M.unionsWith (++) $ (fmap . fmap) pure param_maps + arg_map = M.unionsWith (++) $ (fmap . fmap) pure arg_maps + traceM $ + unlines + [ "##param_map", + show param_map, + "##arg_map", + show arg_map + ] + buildMapNest param_map arg_map $ maximum $ M.keys arg_map + where + buildMapNest _ arg_map 0 = + func $ map (\a -> (map I.Var $ amArgs a, amArgStms a)) $ arg_map M.! 0 + buildMapNest param_map arg_map l = + case map amMapDim $ param_map M.! l of + [] -> buildMapNest param_map arg_map (l - 1) + (map_dim : _) -> do + let (params, p_stms) = + unzip $ + map (\p -> (amParams p, amParamStms p)) $ + param_map M.! l + (args, arg_stms) = + unzip $ + map (\a -> (amArgs a, amArgStms a)) $ + arg_map M.! l + letValExp' + "automap" + . Op + . Screma map_dim (concat args) + . mapSOAC + =<< mkLambda + (concat params) + ( do + subExpsRes <$> buildMapNest param_map arg_map (l - 1) + ) + + buildArgMap :: ((E.Exp, Maybe VName), AutoMap, String) -> InternaliseM (M.Map Level AutoMapParam, M.Map Level AutoMapArg) + buildArgMap (arg, am, arg_desc) = do + ses <- internaliseArg arg_desc arg + arg_vnames <- mapM (letExp "" <=< eSubExp) ses + ts <- mapM subExpType ses + (p_map, a_map) <- + foldM (mkArgsAndParams arg_vnames ses ts) (mempty, mempty) $ + reverse [0 .. trueLevel am] + traceM $ + unlines + [ "##truelevel am", + show $ trueLevel am, + "## arg", + prettyString arg, + "## am", + show am + ] + + pure (p_map, a_map) + where + mkArgsAndParams arg_vnames ses ts (p_map, a_map) l + | l == 0 = do + let as = + fromMaybe + arg_vnames + ( ( map I.paramName + . amParams + ) + <$> p_map M.!? 1 + ) + (ses, stms) <- mkBottomArgs as ts + pure $ (p_map, M.insert 0 (AutoMapArg ses stms) a_map) + | l == trueLevel am = do + (ps, p_stms) <- mkParams arg_vnames ts l + d <- outerDim am l + pure + ( M.insert l (AutoMapParam ps p_stms d) p_map, + M.insert l (AutoMapArg arg_vnames mempty) a_map + ) + | l < trueLevel am && l > 0 = do + (ps, p_stms) <- mkParams arg_vnames ts l + d <- outerDim am l + let as = + map I.paramName $ + amParams $ + p_map M.! (l + 1) + pure + ( M.insert l (AutoMapParam ps p_stms d) p_map, + M.insert l (AutoMapArg as mempty) a_map + ) + | otherwise = error "" + + mkParams _ ts level = + collectStms $ + forM ts $ \t -> + newParam ("p_" <> arg_desc) $ argType (level - 1) am t + mkBottomArgs arg_vnames ts = + collectStms $ do + rep_shape <- internaliseShape $ autoRep am `E.shapePrefix` autoFrame am + if I.shapeRank rep_shape > 0 + then concat <$> mapM (letValExp "autorep" . BasicOp . Replicate rep_shape . I.Var) arg_vnames + else pure arg_vnames + + argType level am t = I.stripArray (trueLevel am - level) t + + internaliseShape :: E.Shape Size -> InternaliseM I.Shape + internaliseShape = + fmap I.Shape . mapM (internaliseExp1 "") . E.shapeDims + + trueLevel :: AutoMap -> Int + trueLevel am + | autoMap am == mempty = max 0 $ E.shapeRank (autoFrame am) - E.shapeRank (autoRep am) + | otherwise = E.shapeRank $ autoFrame am + + outerDim :: AutoMap -> Int -> InternaliseM SubExp + outerDim am level = do + traceM $ + unlines + [ "##outerDim", + "##am", + show am, + "##level", + show level, + "## dff", + show (trueLevel am - level) + ] + internaliseExp1 "" $ (!! (trueLevel am - level)) $ E.shapeDims $ autoFrame am + withAutoMap :: [AutoMap] -> String -> StructType -> [(E.Exp, Maybe VName)] -> ([([SubExp], Stms SOACS)] -> InternaliseM [SubExp]) -> InternaliseM [SubExp] withAutoMap ams arg_desc res_t args_e innerM = do (args, stms) <- From 27588dc92ac5a4ddbb14a3fb0667377f7d6e9d4f Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Mon, 26 Feb 2024 22:05:19 -0800 Subject: [PATCH 132/258] Remove some complexity. --- src/Futhark/Internalise/Exps.hs | 261 +++++++------------------------- 1 file changed, 57 insertions(+), 204 deletions(-) diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index 3c6b7bb99c..d38e36493c 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -17,7 +17,6 @@ import Data.Map.Strict qualified as M import Data.Maybe import Data.Set qualified as S import Data.Text qualified as T -import Debug.Trace import Futhark.IR.SOACS as I hiding (stmPat) import Futhark.Internalise.AccurateSizes import Futhark.Internalise.Bindings @@ -354,10 +353,8 @@ internaliseAppExp desc (E.AppRes et ext) e@E.Apply {} = -- created by function applications can be brought into scope. let fname = nameFromString $ prettyString $ baseName $ qualLeaf qfname loc = srclocOf e - arg_desc = nameToString fname ++ "_arg" - args = map (\(a, b, _) -> (a, b)) argsam - ams = map (\(_, _, c) -> c) argsam - res_t = et + (args, ams) = unzip argsam + args_am_desc = zip3 args ams (repeat (nameToString fname ++ "_arg")) -- Some functions are magical (overloaded) and we handle that here. case () of @@ -367,20 +364,20 @@ internaliseAppExp desc (E.AppRes et ext) e@E.Apply {} = -- existential), so we can safely ignore the existential -- dimensions. | Just internalise <- isOverloadedFunction qfname desc loc -> do - withAutoMap_ ams arg_desc res_t args $ \args' -> do - let prepareArg (arg, _, am) arg' = + withAutoMap args_am_desc $ \args' -> do + let prepareArg ((arg, _), am, _) arg' = (E.toStruct $ E.stripArray (E.shapeRank $ autoMap am) (E.typeOf arg), arg') - internalise $ zipWith prepareArg argsam args' + internalise $ zipWith prepareArg args_am_desc args' | Just internalise <- isIntrinsicFunction qfname (map fst args) loc -> internalise desc | baseTag (qualLeaf qfname) <= maxIntrinsicTag, Just (rettype, _) <- M.lookup fname I.builtInFunctions -> - withAutoMap_ ams arg_desc res_t args $ \args' -> do + withAutoMap args_am_desc $ \args' -> do let tag ses = [(se, I.Observe) | se <- ses] let args'' = concatMap tag args' letValExp' desc $ I.Apply fname args'' [(I.Prim rettype, mempty)] (Safe, loc, []) | otherwise -> do - withAutoMap_ ams arg_desc res_t args $ \args' -> + withAutoMap args_am_desc $ \args' -> do funcall desc qfname (concat args') loc internaliseAppExp desc _ (E.LetPat sizes pat e body _) = internalisePat desc sizes pat e $ internaliseExp desc body @@ -882,18 +879,6 @@ internalisePatLit (E.PatLitFloat x) (E.Scalar (E.Prim (E.FloatType ft))) = internalisePatLit l t = error $ "Nonsensical pattern and type: " ++ show (l, t) -withAutoMap_ :: [AutoMap] -> String -> StructType -> [(E.Exp, Maybe VName)] -> ([[SubExp]] -> InternaliseM [SubExp]) -> InternaliseM [SubExp] -withAutoMap_ ams arg_desc res_t args_e innerM = - withAutoMapNew (zip3 args_e ams (repeat arg_desc)) $ \args_stms -> do - let (args, stms) = unzip args_stms - mapM_ addStms $ reverse stms - innerM args - --- withAutoMap ams arg_desc res_t args_e $ \args_stms -> do --- let (args, stms) = unzip args_stms --- mapM_ addStms $ reverse stms --- innerM args - -- | Internalization of 'AutoMap'-annotated applications. -- -- Each application @f x@ has an annotation with @AutoMap R M F@ where @@ -1018,128 +1003,104 @@ withAutoMap_ ams arg_desc res_t args_e innerM = -- true level of any application, at which we terminate. type Level = Int -type ArgNum = Int - -type ArgMap = M.Map Level (M.Map ArgNum AutoMapArg) - data AutoMapArg = AutoMapArg - { amArgs :: [VName], - amArgStms :: Stms SOACS + { amArgs :: [VName] } deriving (Show) data AutoMapParam = AutoMapParam { amParams :: [LParam SOACS], - amParamStms :: Stms SOACS, amMapDim :: SubExp } deriving (Show) -withAutoMapNew :: [((E.Exp, Maybe VName), AutoMap, String)] -> ([([SubExp], Stms SOACS)] -> InternaliseM [SubExp]) -> InternaliseM [SubExp] -withAutoMapNew args_am func = do +withAutoMap :: + [((E.Exp, Maybe VName), AutoMap, String)] -> + ([[SubExp]] -> InternaliseM [SubExp]) -> + InternaliseM [SubExp] +withAutoMap args_am func = do (param_maps, arg_maps) <- unzip . reverse <$> mapM buildArgMap (reverse args_am) - let param_map = M.unionsWith (++) $ (fmap . fmap) pure param_maps - arg_map = M.unionsWith (++) $ (fmap . fmap) pure arg_maps - traceM $ - unlines - [ "##param_map", - show param_map, - "##arg_map", - show arg_map - ] + let param_map = M.unionsWith (<>) $ (fmap . fmap) pure param_maps + arg_map = M.unionsWith (<>) $ (fmap . fmap) pure arg_maps buildMapNest param_map arg_map $ maximum $ M.keys arg_map where buildMapNest _ arg_map 0 = - func $ map (\a -> (map I.Var $ amArgs a, amArgStms a)) $ arg_map M.! 0 + func $ map (map I.Var . amArgs) $ arg_map M.! 0 buildMapNest param_map arg_map l = case map amMapDim $ param_map M.! l of [] -> buildMapNest param_map arg_map (l - 1) (map_dim : _) -> do - let (params, p_stms) = - unzip $ - map (\p -> (amParams p, amParamStms p)) $ - param_map M.! l - (args, arg_stms) = - unzip $ - map (\a -> (amArgs a, amArgStms a)) $ - arg_map M.! l + let params = map amParams $ param_map M.! l + args = map amArgs $ arg_map M.! l letValExp' "automap" . Op . Screma map_dim (concat args) . mapSOAC =<< mkLambda - (concat params) - ( do - subExpsRes <$> buildMapNest param_map arg_map (l - 1) - ) + (concat params) + ( subExpsRes <$> buildMapNest param_map arg_map (l - 1) + ) - buildArgMap :: ((E.Exp, Maybe VName), AutoMap, String) -> InternaliseM (M.Map Level AutoMapParam, M.Map Level AutoMapArg) + buildArgMap :: + ((E.Exp, Maybe VName), AutoMap, String) -> + InternaliseM (M.Map Level AutoMapParam, M.Map Level AutoMapArg) buildArgMap (arg, am, arg_desc) = do ses <- internaliseArg arg_desc arg arg_vnames <- mapM (letExp "" <=< eSubExp) ses ts <- mapM subExpType ses - (p_map, a_map) <- - foldM (mkArgsAndParams arg_vnames ses ts) (mempty, mempty) $ - reverse [0 .. trueLevel am] - traceM $ - unlines - [ "##truelevel am", - show $ trueLevel am, - "## arg", - prettyString arg, - "## am", - show am - ] - - pure (p_map, a_map) + foldM (mkArgsAndParams arg_vnames ses ts) (mempty, mempty) $ + reverse [0 .. trueLevel am] where mkArgsAndParams arg_vnames ses ts (p_map, a_map) l | l == 0 = do let as = - fromMaybe + maybe arg_vnames - ( ( map I.paramName - . amParams - ) - <$> p_map M.!? 1 + ( map I.paramName + . amParams ) - (ses, stms) <- mkBottomArgs as ts - pure $ (p_map, M.insert 0 (AutoMapArg ses stms) a_map) + (p_map M.!? 1) + ses <- mkBottomArgs as ts + pure (p_map, M.insert 0 (AutoMapArg ses) a_map) | l == trueLevel am = do - (ps, p_stms) <- mkParams arg_vnames ts l + ps <- mkParams arg_vnames ts l d <- outerDim am l pure - ( M.insert l (AutoMapParam ps p_stms d) p_map, - M.insert l (AutoMapArg arg_vnames mempty) a_map + ( M.insert l (AutoMapParam ps d) p_map, + M.insert l (AutoMapArg arg_vnames) a_map ) | l < trueLevel am && l > 0 = do - (ps, p_stms) <- mkParams arg_vnames ts l + ps <- mkParams arg_vnames ts l d <- outerDim am l let as = map I.paramName $ amParams $ p_map M.! (l + 1) pure - ( M.insert l (AutoMapParam ps p_stms d) p_map, - M.insert l (AutoMapArg as mempty) a_map + ( M.insert l (AutoMapParam ps d) p_map, + M.insert l (AutoMapArg as) a_map ) | otherwise = error "" mkParams _ ts level = - collectStms $ - forM ts $ \t -> - newParam ("p_" <> arg_desc) $ argType (level - 1) am t - mkBottomArgs arg_vnames ts = - collectStms $ do - rep_shape <- internaliseShape $ autoRep am `E.shapePrefix` autoFrame am - if I.shapeRank rep_shape > 0 - then concat <$> mapM (letValExp "autorep" . BasicOp . Replicate rep_shape . I.Var) arg_vnames - else pure arg_vnames - - argType level am t = I.stripArray (trueLevel am - level) t + forM ts $ \t -> + newParam ("p_" <> arg_desc) $ argType (level - 1) am t + mkBottomArgs arg_vnames ts = do + rep_shape <- internaliseShape $ autoRep am `E.shapePrefix` autoFrame am + if I.shapeRank rep_shape > 0 + then + concat + <$> mapM + ( letValExp "autorep" + . BasicOp + . Replicate rep_shape + . I.Var + ) + arg_vnames + else pure arg_vnames internaliseShape :: E.Shape Size -> InternaliseM I.Shape internaliseShape = @@ -1151,118 +1112,10 @@ withAutoMapNew args_am func = do | otherwise = E.shapeRank $ autoFrame am outerDim :: AutoMap -> Int -> InternaliseM SubExp - outerDim am level = do - traceM $ - unlines - [ "##outerDim", - "##am", - show am, - "##level", - show level, - "## dff", - show (trueLevel am - level) - ] + outerDim am level = internaliseExp1 "" $ (!! (trueLevel am - level)) $ E.shapeDims $ autoFrame am -withAutoMap :: [AutoMap] -> String -> StructType -> [(E.Exp, Maybe VName)] -> ([([SubExp], Stms SOACS)] -> InternaliseM [SubExp]) -> InternaliseM [SubExp] -withAutoMap ams arg_desc res_t args_e innerM = do - (args, stms) <- - foldM - ( \(args, stms) arg -> do - (arg', stms') <- inScopeOf (reverse stms) $ collectStms $ internaliseArg arg_desc arg - pure (arg' : args, stms' : stms) - ) - (mempty, mempty) - (reverse args_e) - argts <- inScopeOf (reverse stms) $ (mapM . mapM) subExpType args - expand args stms argts ams (maximum ds) - where - stripAutoMapDims i am = - am {autoMap = E.Shape $ drop i $ E.shapeDims $ autoMap am} - max_am = maximumBy (\x y -> E.shapeRank x `compare` E.shapeRank y) $ fmap autoMap ams - inner_t = E.stripArray (E.shapeRank max_am) res_t - ds = map autoMapRank ams - mkLambdaParams level (ses, ts, stm, d) - | d == level = - Left - <$> zipWithM - ( \se t -> do - let t' = I.stripArray 1 t - p <- newParam "x" t' - addStms stm - pure ((se, p), t') - ) - ses - ts - | otherwise = pure $ Right $ zip ses ts - - internaliseShape = - fmap I.Shape . mapM (internaliseExp1 "") . E.shapeDims - - addReplicates = - zipWithM - ( \am arg -> do - rep_shape <- - internaliseShape $ - autoRep am `E.shapePrefix` autoFrame am - if I.shapeRank rep_shape > 0 - then concat <$> mapM (letValExp' "autoRep" . BasicOp . Replicate rep_shape) arg - else pure arg - ) - - expand args stms argts ams' level - | level <= 0 = do - args' <- addReplicates ams' args - innerM $ zip args' stms - | otherwise = do - let ds' = map autoMapRank ams' - arg_params <- mapM (mkLambdaParams level) $ zip4 args argts stms ds' - let argts' = map (either (map snd) (map snd)) arg_params - (ams'', stms') = - unzip $ - zipWith - ( \am stm -> - if autoMapRank am == level - then (stripAutoMapDims 1 am, mempty) - else (am, stm) - ) - ams' - stms - args' = map (either (map (I.Var . I.paramName . snd . fst)) (map fst)) arg_params - (map_ses, params) = unzip $ (concatMap . map) fst $ lefts arg_params - - ((ses, ses_ts), lam_stms) <- collectStms $ localScope (scopeOfLParams params) $ do - ses <- expand args' stms' argts' ams'' (level - 1) - ses_ts <- internaliseLambdaReturnType (E.toRes Nonunique inner_t) =<< mapM subExpType ses - pure (ses, ses_ts) - - case map_ses of - [] -> pure mempty - (map_se : _) -> do - outer_shape <- I.takeDims 1 . I.arrayShape <$> subExpType map_se - let I.Shape [outer_shape_se] = outer_shape - map_args <- forM map_ses $ \se -> do - se_t <- subExpType se - se_name <- letExp "map_arg" =<< toExp se - letExp "reshaped" $ - I.BasicOp $ - I.Reshape - I.ReshapeCoerce - (reshapeOuter outer_shape 1 $ I.arrayShape se_t) - se_name - - letValExp' "automap" - . Op - . Screma outer_shape_se map_args - . mapSOAC - =<< mkLambda - params - ( ensureResultShape - (ErrorMsg [ErrorString "AutoMap: unexpected lambda result size"]) - mempty - ses_ts - =<< (addStms lam_stms >> pure (subExpsRes ses)) - ) + argType level am = I.stripArray (trueLevel am - level) generateCond :: E.Pat StructType -> @@ -1851,14 +1704,14 @@ data Function | FunctionHole SrcLoc deriving (Show) -findFuncall :: E.AppExp -> (Function, [(E.Exp, Maybe VName, AutoMap)]) +findFuncall :: E.AppExp -> (Function, [((E.Exp, Maybe VName), AutoMap)]) findFuncall (E.Apply f args _) | E.Var fname _ _ <- f = (FunctionName fname, map onArg $ NE.toList args) | E.Hole (Info _) loc <- f = (FunctionHole loc, map onArg $ NE.toList args) where - onArg (Info (argext, am), e) = (e, argext, am) + onArg (Info (argext, am), e) = ((e, argext), am) findFuncall e = error $ "Invalid function expression in application:\n" ++ prettyString e From 8af019ebd6aea8715bb6095e1f7b27d54a2f3879 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Mon, 26 Feb 2024 22:05:42 -0800 Subject: [PATCH 133/258] Fixes. --- tests/automap/equality1.fut | 10 +++++----- tests/automap/pagerank.fut | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/automap/equality1.fut b/tests/automap/equality1.fut index 1604c49d93..b2a173f30d 100644 --- a/tests/automap/equality1.fut +++ b/tests/automap/equality1.fut @@ -1,23 +1,23 @@ -- == -- entry: bigger_to_smaller -- input { [[1,2],[3,4]] [1,2] } --- output { [true, false] } +-- output { [[true, true], [false, false]] } -- == -- entry: smaller_to_bigger -- input { [[1,2],[3,4]] [1,2] } --- output { [true, false] } +-- output { [[true, true], [false, false]] } -- == -- entry: smaller_to_bigger2 -- input { [[1,2],[3,4]] 1 } -- output { [[true,false],[false,false]]} -def bigger_to_smaller [n] (xss : [n][n]i32) (ys: [n]i32) : [n]bool = +entry bigger_to_smaller [n] (xss : [n][n]i32) (ys: [n]i32) : [n][n]bool = xss == ys -def smaller_to_bigger [n] (xss : [n][n]i32) (ys: [n]i32) : [n]bool = +entry smaller_to_bigger [n] (xss : [n][n]i32) (ys: [n]i32) : [n][n]bool = ys == xss -def smaller_to_bigger2 [n] (xss : [n][n]i32) (z: i32) : [n][n]bool = +entry smaller_to_bigger2 [n] (xss : [n][n]i32) (z: i32) : [n][n]bool = z == xss diff --git a/tests/automap/pagerank.fut b/tests/automap/pagerank.fut index c444932de5..3552990144 100644 --- a/tests/automap/pagerank.fut +++ b/tests/automap/pagerank.fut @@ -14,5 +14,5 @@ def calculate_dangling_ranks_am [n] (ranks: [n]f32) (sizes: [n]i32): *[]f32 = let total = f32.sum weights / f32.i64 n in ranks + total -entry calculate_dangling_ranks [n] (ranks: [n]f32) (sizes: [n]i32): []bool = - calculate_dangling_ranks_orig ranks sizes == calculate_dangling_ranks_am ranks sizes +entry calculate_dangling_ranks [n] (ranks: [n]f32) (sizes: [n]i32): bool = + and (calculate_dangling_ranks_orig ranks sizes == calculate_dangling_ranks_am ranks sizes) From 321b5245c1bbbcc824153869adf3e2a8a260503f Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Mon, 26 Feb 2024 22:33:36 -0800 Subject: [PATCH 134/258] Looks like we actually do need some reshaping. --- src/Futhark/Internalise/Exps.hs | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index d38e36493c..bdef5c071e 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -1034,15 +1034,21 @@ withAutoMap args_am func = do (map_dim : _) -> do let params = map amParams $ param_map M.! l args = map amArgs $ arg_map M.! l + + reshaped_args <- + forM (concat args) $ \argvn -> + letExp "reshaped" $ + shapeCoerce [map_dim] argvn + letValExp' "automap" . Op - . Screma map_dim (concat args) + . Screma map_dim reshaped_args . mapSOAC =<< mkLambda - (concat params) - ( subExpsRes <$> buildMapNest param_map arg_map (l - 1) - ) + (concat params) + ( subExpsRes <$> buildMapNest param_map arg_map (l - 1) + ) buildArgMap :: ((E.Exp, Maybe VName), AutoMap, String) -> From bb037270053e99ce46c8012a72dad862cd372e72 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Mon, 26 Feb 2024 22:41:23 -0800 Subject: [PATCH 135/258] Oops. Fix reshaping. --- src/Futhark/Internalise/Exps.hs | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index bdef5c071e..dd6f16a8a1 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -1036,9 +1036,14 @@ withAutoMap args_am func = do args = map amArgs $ arg_map M.! l reshaped_args <- - forM (concat args) $ \argvn -> + forM (concat args) $ \argvn -> do + arg_t <- subExpType $ I.Var argvn letExp "reshaped" $ - shapeCoerce [map_dim] argvn + I.BasicOp $ + I.Reshape + I.ReshapeCoerce + (reshapeOuter (I.Shape [map_dim]) 1 $ I.arrayShape arg_t) + argvn letValExp' "automap" From 10cfde83bb80d4e8f9187b79f345374091437dc8 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Mon, 26 Feb 2024 22:46:11 -0800 Subject: [PATCH 136/258] Better to do the reshaping here, I think. --- src/Futhark/Internalise/Exps.hs | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index dd6f16a8a1..e1180e274e 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -1035,20 +1035,10 @@ withAutoMap args_am func = do let params = map amParams $ param_map M.! l args = map amArgs $ arg_map M.! l - reshaped_args <- - forM (concat args) $ \argvn -> do - arg_t <- subExpType $ I.Var argvn - letExp "reshaped" $ - I.BasicOp $ - I.Reshape - I.ReshapeCoerce - (reshapeOuter (I.Shape [map_dim]) 1 $ I.arrayShape arg_t) - argvn - letValExp' "automap" . Op - . Screma map_dim reshaped_args + . Screma map_dim (concat args) . mapSOAC =<< mkLambda (concat params) @@ -1079,9 +1069,20 @@ withAutoMap args_am func = do | l == trueLevel am = do ps <- mkParams arg_vnames ts l d <- outerDim am l + + reshaped_args <- + forM arg_vnames $ \argvn -> do + arg_t <- subExpType $ I.Var argvn + letExp "reshaped" $ + I.BasicOp $ + I.Reshape + I.ReshapeCoerce + (reshapeOuter (I.Shape [d]) 1 $ I.arrayShape arg_t) + argvn + pure ( M.insert l (AutoMapParam ps d) p_map, - M.insert l (AutoMapArg arg_vnames) a_map + M.insert l (AutoMapArg reshaped_args) a_map ) | l < trueLevel am && l > 0 = do ps <- mkParams arg_vnames ts l From 5adef9022786be1ffe856ff9221fab9de6846a1d Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 26 Feb 2024 20:46:51 +0100 Subject: [PATCH 137/258] Begin handling overloaded type variables. --- src/Language/Futhark/TypeChecker/Constraints.hs | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index ad11df4729..0725d50418 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -159,7 +159,15 @@ subTyVar v lvl t = do linkTyVar :: VName -> VName -> SolveM () linkTyVar v t = do occursCheck v $ Scalar $ TypeVar NoUniqueness (qualName t) [] + tyvars <- gets solverTyVars modify $ \s -> s {solverTyVars = M.insert v (TyVarLink t) $ solverTyVars s} + tyvars' <- + case (M.lookup v tyvars, M.lookup t tyvars) of + (Just (TyVarUnsol _ info), Just (TyVarUnsol lvl TyVarFree)) -> + pure $ M.insert t (TyVarUnsol lvl info) tyvars + -- TODO: handle more cases. + _ -> pure tyvars + modify $ \s -> s {solverTyVars = M.insert v (TyVarLink t) tyvars'} -- Unify at the root, emitting new equalities that must hold. unify :: Type -> Type -> Maybe [(Type, Type)] @@ -190,7 +198,7 @@ solveCt :: Ct -> SolveM () solveCt ct = case ct of CtEq t1 t2 -> solveCt' (t1, t2) - CtAM _ _ _ -> pure () -- Good vibes only. + CtAM {} -> pure () -- Good vibes only. where bad = throwError $ "Unsolvable: " <> prettyText ct solveCt' (t1, t2) = do From 1d9d25d897133feb0b99bd1562d259736e682b0d Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Mon, 26 Feb 2024 23:27:51 -0800 Subject: [PATCH 138/258] These should be frames. --- src/Futhark/Internalise/Defunctionalise.hs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Futhark/Internalise/Defunctionalise.hs b/src/Futhark/Internalise/Defunctionalise.hs index 98b698aeeb..375fe18577 100644 --- a/src/Futhark/Internalise/Defunctionalise.hs +++ b/src/Futhark/Internalise/Defunctionalise.hs @@ -972,7 +972,7 @@ defuncApplyArg fname_s (f', LambdaSV pat lam_e_t lam_e closure_env) (((argext, a pure ( mkApply fname' [(Nothing, mempty, f'), (argext, am, arg')] callret, - autoMapSV (autoMap am) sv + autoMapSV (autoFrame am) sv -- sv ) -- If 'f' is a dynamic function, we just leave the application in @@ -996,7 +996,7 @@ defuncApplyArg _ (f', DynamicFun _ sv) (((argext, am), arg), argtypes) = do "## ret sv", show $ autoMapSV (autoMap am) sv ] - pure (apply_e, autoMapSV (autoMap am) sv) + pure (apply_e, autoMapSV (autoFrame am) sv) -- pure (apply_e, sv) -- defuncApplyArg fname_s (_, sv) ((_, arg), _) = From 266368a9058e233928b2805ba485c9c4ee01f21a Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Mon, 26 Feb 2024 23:34:46 -0800 Subject: [PATCH 139/258] Revert "Better to do the reshaping here, I think." I was wrong. --- src/Futhark/Internalise/Exps.hs | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index e1180e274e..dd6f16a8a1 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -1035,10 +1035,20 @@ withAutoMap args_am func = do let params = map amParams $ param_map M.! l args = map amArgs $ arg_map M.! l + reshaped_args <- + forM (concat args) $ \argvn -> do + arg_t <- subExpType $ I.Var argvn + letExp "reshaped" $ + I.BasicOp $ + I.Reshape + I.ReshapeCoerce + (reshapeOuter (I.Shape [map_dim]) 1 $ I.arrayShape arg_t) + argvn + letValExp' "automap" . Op - . Screma map_dim (concat args) + . Screma map_dim reshaped_args . mapSOAC =<< mkLambda (concat params) @@ -1069,20 +1079,9 @@ withAutoMap args_am func = do | l == trueLevel am = do ps <- mkParams arg_vnames ts l d <- outerDim am l - - reshaped_args <- - forM arg_vnames $ \argvn -> do - arg_t <- subExpType $ I.Var argvn - letExp "reshaped" $ - I.BasicOp $ - I.Reshape - I.ReshapeCoerce - (reshapeOuter (I.Shape [d]) 1 $ I.arrayShape arg_t) - argvn - pure ( M.insert l (AutoMapParam ps d) p_map, - M.insert l (AutoMapArg reshaped_args) a_map + M.insert l (AutoMapArg arg_vnames) a_map ) | l < trueLevel am && l > 0 = do ps <- mkParams arg_vnames ts l From d190923710fd74e5831f8f26cd4a5d276f27c02f Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Tue, 27 Feb 2024 13:04:36 -0800 Subject: [PATCH 140/258] Strip off automapped shapes from arg static values. Fixes `optionpricing.fut` bug. --- src/Futhark/Internalise/Defunctionalise.hs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/Futhark/Internalise/Defunctionalise.hs b/src/Futhark/Internalise/Defunctionalise.hs index 375fe18577..2f9ea6d787 100644 --- a/src/Futhark/Internalise/Defunctionalise.hs +++ b/src/Futhark/Internalise/Defunctionalise.hs @@ -907,14 +907,18 @@ defuncApplyArg :: DefM (Exp, StaticVal) defuncApplyArg fname_s (f', LambdaSV pat lam_e_t lam_e closure_env) (((argext, am), arg), _) = do (arg', arg_sv) <- defuncExp arg - let env' = alwaysMatchPatSV pat arg_sv + let arg_sv' = + case arg_sv of + (Dynamic ty@(Array {})) -> Dynamic $ stripArray (shapeRank $ autoFrame am) ty + _ -> arg_sv dims = mempty + env' = alwaysMatchPatSV pat arg_sv' (lam_e', sv) <- localNewEnv (env' <> closure_env) $ defuncExp lam_e let closure_pat = buildEnvPat dims closure_env - pat' = updatePat pat arg_sv + pat' = updatePat pat arg_sv' globals <- asks fst From 56c18c05389095a987645dbeaa1feac2ab51b1ea Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Tue, 27 Feb 2024 13:07:15 -0800 Subject: [PATCH 141/258] Prevent loops. --- src/Futhark/Internalise/Defunctionalise.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Futhark/Internalise/Defunctionalise.hs b/src/Futhark/Internalise/Defunctionalise.hs index 2f9ea6d787..2819f444c3 100644 --- a/src/Futhark/Internalise/Defunctionalise.hs +++ b/src/Futhark/Internalise/Defunctionalise.hs @@ -1223,7 +1223,7 @@ matchPatSV (PatConstr c1 _ ps _) (Dynamic (Scalar (Sum fs))) else Nothing | otherwise = error $ "matchPatSV: missing constructor in type: " ++ prettyString c1 -matchPatSV pat (Dynamic t) = matchPatSV pat $ svFromType t +matchPatSV pat (Dynamic t@(Scalar Record {})) = matchPatSV pat $ svFromType t matchPatSV pat (HoleSV t _) = matchPatSV pat $ svFromType $ toParam Observe t matchPatSV pat sv = error $ From 8e1acffa1b28c5c88e7e54d507193a407680ca21 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Wed, 28 Feb 2024 10:49:19 -0800 Subject: [PATCH 142/258] Don't need this. --- src/Futhark/Internalise/Exps.hs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index dd6f16a8a1..21f559c00c 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -1062,10 +1062,10 @@ withAutoMap args_am func = do ses <- internaliseArg arg_desc arg arg_vnames <- mapM (letExp "" <=< eSubExp) ses ts <- mapM subExpType ses - foldM (mkArgsAndParams arg_vnames ses ts) (mempty, mempty) $ + foldM (mkArgsAndParams arg_vnames ts) (mempty, mempty) $ reverse [0 .. trueLevel am] where - mkArgsAndParams arg_vnames ses ts (p_map, a_map) l + mkArgsAndParams arg_vnames ts (p_map, a_map) l | l == 0 = do let as = maybe From 9e5d59a6ac82848cf7883b1da83fc2e899fdc741 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Wed, 28 Feb 2024 11:54:41 -0800 Subject: [PATCH 143/258] Undo AUTOMAP-handling in the later phases of internalization. --- src/Futhark/Internalise/Defunctionalise.hs | 95 +------ src/Futhark/Internalise/Exps.hs | 283 ++------------------- src/Futhark/Internalise/Monomorphise.hs | 28 +- 3 files changed, 43 insertions(+), 363 deletions(-) diff --git a/src/Futhark/Internalise/Defunctionalise.hs b/src/Futhark/Internalise/Defunctionalise.hs index 2819f444c3..82cc845d69 100644 --- a/src/Futhark/Internalise/Defunctionalise.hs +++ b/src/Futhark/Internalise/Defunctionalise.hs @@ -16,7 +16,7 @@ import Data.Maybe import Data.Set qualified as S import Futhark.IR.Pretty () import Futhark.MonadFreshNames -import Futhark.Util (debugTraceM, mapAccumLM, nubOrd) +import Futhark.Util (mapAccumLM, nubOrd) import Language.Futhark import Language.Futhark.Traversals import Language.Futhark.TypeChecker.Types (Subst (..), applySubst) @@ -905,20 +905,16 @@ defuncApplyArg :: (Exp, StaticVal) -> (((Maybe VName, AutoMap), Exp), [ParamType]) -> DefM (Exp, StaticVal) -defuncApplyArg fname_s (f', LambdaSV pat lam_e_t lam_e closure_env) (((argext, am), arg), _) = do +defuncApplyArg fname_s (f', LambdaSV pat lam_e_t lam_e closure_env) (((argext, _), arg), _) = do (arg', arg_sv) <- defuncExp arg - let arg_sv' = - case arg_sv of - (Dynamic ty@(Array {})) -> Dynamic $ stripArray (shapeRank $ autoFrame am) ty - _ -> arg_sv + let env' = alwaysMatchPatSV pat arg_sv dims = mempty - env' = alwaysMatchPatSV pat arg_sv' (lam_e', sv) <- localNewEnv (env' <> closure_env) $ defuncExp lam_e let closure_pat = buildEnvPat dims closure_env - pat' = updatePat pat arg_sv' + pat' = updatePat pat arg_sv globals <- asks fst @@ -959,49 +955,20 @@ defuncApplyArg fname_s (f', LambdaSV pat lam_e_t lam_e closure_env) (((argext, a fname' = Var (qualName fname) (Info fname_t) (srclocOf arg) callret <- unRetType lifted_rettype - debugTraceM $ - unlines - [ "##defuncApplyArg LambdaSV", - "## fname", - fname_s, - "## f'", - prettyString f', - "## arg", - prettyString arg, - "## sv", - show sv, - "## ret sv", - show $ autoMapSV (autoMap am) sv - ] - pure - ( mkApply fname' [(Nothing, mempty, f'), (argext, am, arg')] callret, - autoMapSV (autoFrame am) sv - -- sv + ( mkApply fname' [(Nothing, mempty, f'), (argext, mempty, arg')] callret, + sv ) -- If 'f' is a dynamic function, we just leave the application in -- place, but we update the types since it may be partially -- applied or return a higher-order value. -defuncApplyArg _ (f', DynamicFun _ sv) (((argext, am), arg), argtypes) = do +defuncApplyArg _ (f', DynamicFun _ sv) (((argext, _), arg), argtypes) = do (arg', _) <- defuncExp arg let (argtypes', rettype) = dynamicFunType sv argtypes restype = foldFunType argtypes' (RetType [] rettype) callret = AppRes restype [] - apply_e = mkApply f' [(argext, am, arg')] callret - debugTraceM $ - unlines - [ "##defuncApplyArg DynamicFun", - "## f'", - prettyString f', - "## arg", - prettyString arg, - "## sv", - show sv, - "## ret sv", - show $ autoMapSV (autoMap am) sv - ] - pure (apply_e, autoMapSV (autoFrame am) sv) --- pure (apply_e, sv) + apply_e = mkApply f' [(argext, mempty, arg')] callret + pure (apply_e, sv) -- defuncApplyArg fname_s (_, sv) ((_, arg), _) = error $ @@ -1017,11 +984,6 @@ updateReturn (AppRes ret1 ext1) (AppExp apply (Info (AppRes ret2 ext2))) = AppExp apply $ Info $ AppRes (combineTypeShapes ret1 ret2) (ext1 <> ext2) updateReturn _ e = e -autoMapSV :: Shape Size -> StaticVal -> StaticVal -autoMapSV shape (Dynamic t) = - Dynamic $ arrayOfWithAliases (diet t) shape t -autoMapSV _ sv = sv - defuncApply :: Exp -> NE.NonEmpty ((Maybe VName, AutoMap), Exp) -> AppRes -> SrcLoc -> DefM (Exp, StaticVal) defuncApply f args appres loc = do (f', f_sv) <- defuncApplyFunction f (length args) @@ -1037,39 +999,10 @@ defuncApply f args appres loc = do _ -> do let fname = liftedName 0 f (argtypes, _) = unfoldFunType $ typeOf f - (app, app_sv) <- - fmap (first $ updateReturn appres) $ - foldM (defuncApplyArg fname) (f', f_sv) $ - NE.zip args $ - NE.tails argtypes - - let (p_ts, _) = unfoldFunType $ typeOf f - arg_ts = typeOf . snd <$> args - -- am_dims = zipWith typeShapePrefix (NE.toList arg_ts) p_ts - -- ret_am = maximumBy (\x y -> shapeRank x `compare` shapeRank y) am_dims - ams = NE.toList $ autoMap . snd . fst <$> args - ret_am = maximumBy (\x y -> shapeRank x `compare` shapeRank y) ams - debugTraceM $ - unlines - [ "## defuncApply", - "## f", - prettyString f, - "## args", - prettyString $ snd <$> args, - "## appres", - show appres, - "## app", - prettyString app, - "## app_sv", - show app_sv, - "## f type", - prettyString $ typeOf f, - "## arg types", - prettyString $ typeOf . snd <$> args, - "## ret_am", - prettyString ret_am - ] - pure (app, app_sv) + fmap (first $ updateReturn appres) $ + foldM (defuncApplyArg fname) (f', f_sv) $ + NE.zip args $ + NE.tails argtypes where intrinsicOrHole e' = do -- If the intrinsic is fully applied, then we are done. @@ -1223,7 +1156,7 @@ matchPatSV (PatConstr c1 _ ps _) (Dynamic (Scalar (Sum fs))) else Nothing | otherwise = error $ "matchPatSV: missing constructor in type: " ++ prettyString c1 -matchPatSV pat (Dynamic t@(Scalar Record {})) = matchPatSV pat $ svFromType t +matchPatSV pat (Dynamic t) = matchPatSV pat $ svFromType t matchPatSV pat (HoleSV t _) = matchPatSV pat $ svFromType $ toParam Observe t matchPatSV pat sv = error $ diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index 21f559c00c..b5684552e4 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -8,13 +8,11 @@ module Futhark.Internalise.Exps (transformProg) where import Control.Monad import Control.Monad.Reader import Data.Bifunctor -import Data.Either import Data.Foldable (toList) -import Data.List (elemIndex, find, intercalate, intersperse, maximumBy, transpose, zip4) +import Data.List (elemIndex, find, intercalate, intersperse, transpose) import Data.List.NonEmpty (NonEmpty (..)) import Data.List.NonEmpty qualified as NE import Data.Map.Strict qualified as M -import Data.Maybe import Data.Set qualified as S import Data.Text qualified as T import Futhark.IR.SOACS as I hiding (stmPat) @@ -348,13 +346,12 @@ internaliseAppExp desc (E.AppRes et ext) e@E.Apply {} = let subst = map (,E.ExpSubst (E.sizeFromInteger 0 mempty)) ext et' = E.applySubst (`lookup` subst) et internaliseExp desc (E.Hole (Info et') loc) - (FunctionName qfname, argsam) -> do + (FunctionName qfname, args) -> do -- Argument evaluation is outermost-in so that any existential sizes -- created by function applications can be brought into scope. let fname = nameFromString $ prettyString $ baseName $ qualLeaf qfname loc = srclocOf e - (args, ams) = unzip argsam - args_am_desc = zip3 args ams (repeat (nameToString fname ++ "_arg")) + arg_desc = nameToString fname ++ "_arg" -- Some functions are magical (overloaded) and we handle that here. case () of @@ -364,21 +361,20 @@ internaliseAppExp desc (E.AppRes et ext) e@E.Apply {} = -- existential), so we can safely ignore the existential -- dimensions. | Just internalise <- isOverloadedFunction qfname desc loc -> do - withAutoMap args_am_desc $ \args' -> do - let prepareArg ((arg, _), am, _) arg' = - (E.toStruct $ E.stripArray (E.shapeRank $ autoMap am) (E.typeOf arg), arg') - internalise $ zipWith prepareArg args_am_desc args' + let prepareArg (arg, _) = + (E.toStruct (E.typeOf arg),) <$> internaliseExp "arg" arg + internalise =<< mapM prepareArg args | Just internalise <- isIntrinsicFunction qfname (map fst args) loc -> internalise desc | baseTag (qualLeaf qfname) <= maxIntrinsicTag, - Just (rettype, _) <- M.lookup fname I.builtInFunctions -> - withAutoMap args_am_desc $ \args' -> do - let tag ses = [(se, I.Observe) | se <- ses] - let args'' = concatMap tag args' - letValExp' desc $ I.Apply fname args'' [(I.Prim rettype, mempty)] (Safe, loc, []) + Just (rettype, _) <- M.lookup fname I.builtInFunctions -> do + let tag ses = [(se, I.Observe) | se <- ses] + args' <- reverse <$> mapM (internaliseArg arg_desc) (reverse args) + let args'' = concatMap tag args' + letValExp' desc $ I.Apply fname args'' [(I.Prim rettype, mempty)] (Safe, loc, []) | otherwise -> do - withAutoMap args_am_desc $ \args' -> do - funcall desc qfname (concat args') loc + args' <- concat . reverse <$> mapM (internaliseArg arg_desc) (reverse args) + funcall desc qfname args' loc internaliseAppExp desc _ (E.LetPat sizes pat e body _) = internalisePat desc sizes pat e $ internaliseExp desc body internaliseAppExp _ _ (E.LetFun ofname _ _ _) = @@ -879,255 +875,6 @@ internalisePatLit (E.PatLitFloat x) (E.Scalar (E.Prim (E.FloatType ft))) = internalisePatLit l t = error $ "Nonsensical pattern and type: " ++ show (l, t) --- | Internalization of 'AutoMap'-annotated applications. --- --- Each application @f x@ has an annotation with @AutoMap R M F@ where --- @R, M, F@ are the autorep, automap, and frame shapes, --- respectively. --- --- The application @f x@ will have type @F t@ for some @t@, i.e. @(f --- x) : F t@. The frame @F@ is a prefix of the type of @f x@; namely --- it is the total accumulated shape that is due to implicit maps. --- Another way of thinking about that is that @|F|@ is is the level --- of the automap-nest that @f x@ is in. For example, if @|F| = 2@ --- then we know that @f x@ implicitly stands for --- --- > map (\x' -> map (\x'' -> f x'') x') x --- --- For an application with a non-empty autorep annotation, the frame --- tells about how many dimensions of the replicate can be eliminated. --- For example, @[[1,2],[3,4]] + 5@ will yield the following annotations: --- --- > ([[1,2],[3,4]] +) -- AutoMap {R = mempty, M = [2][2], F = [2][2]} --- > (([[1,2],[3,4]] +) 5) -- AutoMap {R = [2][2], M = mempty, F = [2][2]} --- --- All replicated arguments are pushed down the auto-map nest. Each --- time a replicated argument is pushed down a level of an --- automap-nest, one fewer replicates is needed (i.e., the outermost --- dimension of @R@ can be dropped). Replicated arguments are pushed --- down the nest until either 1) the bottom of the nest is encountered --- or 2) no replicate dimensions remain. For example, in the second --- application above @R@ = @F@, so we can push the replicated argument --- down two levels. Since each level effectively removes a dimension --- of the replicate, no replicates will be required: --- --- > map (\xs -> map (\x -> f x'' 5) xs) [[1,2],[3,4]] --- --- The number of replicates that are actually required is given by --- max(|R| - |F|, 0). --- --- An expression's "true level" is the level at which that expression --- will appear in the automap-nest. The bottom of a mapnest is level 0. --- --- * For annotations with @R = mempty@, the true level is @|F|@. --- * For annotations with @M = mempty@, the true level is @|F| - |R|@. --- --- If @|R| > |F|@ then actual replicates (namely @|R| - |F|@ of them) --- will be required at the bottom of the mapnest. --- --- Note that replicates can only appear at the bottom of a mapnest; any --- expression of the form --- --- > map (\ls x' rs -> e) (replicate x) --- --- can always be written as --- --- > map (\ls rs -> e[x' -> x]) --- --- Let's look at another example. Consider (with exact sizes omitted for brevity) --- --- > f : a -> a -> a -> []a -> [][][]a -> a --- > xss : [][]a --- > ys : []a --- > zsss : [][][]a --- > w : a --- > vss : [][]a --- --- and the application --- --- > f xss ys zsss w vss --- --- which will have the following annotations --- --- > (f xss) -- AutoMap {R = mempty, M = [][], F = [][]} (1) --- > ((f xss) ys) -- AutoMap {R = [], M = mempty, F = [][]} (2) --- > (((f xss) ys) zsss) -- AutoMap {R = mempty, M = [], F = [][][]} (3) --- > ((((f xss) ys) zsss) w) -- AutoMap {R = [][][][], M = mempty, F = [][][]} (4) --- > (((((f xss) ys) zsss) w) vss) -- AutoMap {R = [], M = mempty, F = [][][]} (5) --- --- This will yield the following mapnest. --- --- > map (\zss -> --- > map (\xs zs vs -> --- > map (\x y z v -> f x y z (replicate w) v) xs ys zs v) xss zss vss) zsss --- --- Let's see how we'd construct this mapnest from the annotations. We construct --- the nest bottom-up. We have: --- --- Application | True level --- --------------------------- --- (1) | |[][]| = 2 --- (2) | |[][]| - |[]| = 1 --- (3) | |[][][]| = 3 --- (4) | |[][][]| - |[][][][]| = -1 --- (5) | |[][][]| - |[]| = 2 --- --- We start at level 0. --- * Any argument with a negative true level of @-n@ will be replicated @n@ times; --- the exact shapes can be found by removing the @F@ postfix from @R@, --- i.e. @R = shapes_to_rep_by <> F@. --- * Any argument with a 0 true level will be included. --- * For any argument @arg@ with a positive true level, we construct a new parameter --- whose type is @arg@ with the leading @n@ dimensions (where @n@ is the true level) --- removed. --- --- Following the rules above, @w@ will be replicated once. For the remaining arguments, --- we create new parameters @x : a, y : a, z : a , v : a@. Hence, level 0 becomes --- --- > f x y z (replicate w) v --- --- At level l > 0: --- * There are no replicates. --- * Any argument with l true level will be included verbatim. --- * Any argument with true level > l will have a new parameter constructed for it, --- whose type has the leading @n - l@ dimensions (where @n@ is the true level) removed. --- * We surround the previous level with a map that binds that levels' new parameters --- and is passed the current levels' arguments. --- --- Following the above recipe for level 1, we create parameters --- @xs : []a, zs : []a, vs :[]a@ and obtain --- --- > map (\x y z v -> f x y z (replicate w) v) xs ys zs vs --- --- This process continues until the level is greater than the maximum --- true level of any application, at which we terminate. -type Level = Int - -data AutoMapArg = AutoMapArg - { amArgs :: [VName] - } - deriving (Show) - -data AutoMapParam = AutoMapParam - { amParams :: [LParam SOACS], - amMapDim :: SubExp - } - deriving (Show) - -withAutoMap :: - [((E.Exp, Maybe VName), AutoMap, String)] -> - ([[SubExp]] -> InternaliseM [SubExp]) -> - InternaliseM [SubExp] -withAutoMap args_am func = do - (param_maps, arg_maps) <- - unzip . reverse - <$> mapM buildArgMap (reverse args_am) - let param_map = M.unionsWith (<>) $ (fmap . fmap) pure param_maps - arg_map = M.unionsWith (<>) $ (fmap . fmap) pure arg_maps - buildMapNest param_map arg_map $ maximum $ M.keys arg_map - where - buildMapNest _ arg_map 0 = - func $ map (map I.Var . amArgs) $ arg_map M.! 0 - buildMapNest param_map arg_map l = - case map amMapDim $ param_map M.! l of - [] -> buildMapNest param_map arg_map (l - 1) - (map_dim : _) -> do - let params = map amParams $ param_map M.! l - args = map amArgs $ arg_map M.! l - - reshaped_args <- - forM (concat args) $ \argvn -> do - arg_t <- subExpType $ I.Var argvn - letExp "reshaped" $ - I.BasicOp $ - I.Reshape - I.ReshapeCoerce - (reshapeOuter (I.Shape [map_dim]) 1 $ I.arrayShape arg_t) - argvn - - letValExp' - "automap" - . Op - . Screma map_dim reshaped_args - . mapSOAC - =<< mkLambda - (concat params) - ( subExpsRes <$> buildMapNest param_map arg_map (l - 1) - ) - - buildArgMap :: - ((E.Exp, Maybe VName), AutoMap, String) -> - InternaliseM (M.Map Level AutoMapParam, M.Map Level AutoMapArg) - buildArgMap (arg, am, arg_desc) = do - ses <- internaliseArg arg_desc arg - arg_vnames <- mapM (letExp "" <=< eSubExp) ses - ts <- mapM subExpType ses - foldM (mkArgsAndParams arg_vnames ts) (mempty, mempty) $ - reverse [0 .. trueLevel am] - where - mkArgsAndParams arg_vnames ts (p_map, a_map) l - | l == 0 = do - let as = - maybe - arg_vnames - ( map I.paramName - . amParams - ) - (p_map M.!? 1) - ses <- mkBottomArgs as ts - pure (p_map, M.insert 0 (AutoMapArg ses) a_map) - | l == trueLevel am = do - ps <- mkParams arg_vnames ts l - d <- outerDim am l - pure - ( M.insert l (AutoMapParam ps d) p_map, - M.insert l (AutoMapArg arg_vnames) a_map - ) - | l < trueLevel am && l > 0 = do - ps <- mkParams arg_vnames ts l - d <- outerDim am l - let as = - map I.paramName $ - amParams $ - p_map M.! (l + 1) - pure - ( M.insert l (AutoMapParam ps d) p_map, - M.insert l (AutoMapArg as) a_map - ) - | otherwise = error "" - - mkParams _ ts level = - forM ts $ \t -> - newParam ("p_" <> arg_desc) $ argType (level - 1) am t - mkBottomArgs arg_vnames ts = do - rep_shape <- internaliseShape $ autoRep am `E.shapePrefix` autoFrame am - if I.shapeRank rep_shape > 0 - then - concat - <$> mapM - ( letValExp "autorep" - . BasicOp - . Replicate rep_shape - . I.Var - ) - arg_vnames - else pure arg_vnames - - internaliseShape :: E.Shape Size -> InternaliseM I.Shape - internaliseShape = - fmap I.Shape . mapM (internaliseExp1 "") . E.shapeDims - - trueLevel :: AutoMap -> Int - trueLevel am - | autoMap am == mempty = max 0 $ E.shapeRank (autoFrame am) - E.shapeRank (autoRep am) - | otherwise = E.shapeRank $ autoFrame am - - outerDim :: AutoMap -> Int -> InternaliseM SubExp - outerDim am level = - internaliseExp1 "" $ (!! (trueLevel am - level)) $ E.shapeDims $ autoFrame am - - argType level am = I.stripArray (trueLevel am - level) - generateCond :: E.Pat StructType -> [I.SubExp] -> @@ -1715,14 +1462,14 @@ data Function | FunctionHole SrcLoc deriving (Show) -findFuncall :: E.AppExp -> (Function, [((E.Exp, Maybe VName), AutoMap)]) +findFuncall :: E.AppExp -> (Function, [(E.Exp, Maybe VName)]) findFuncall (E.Apply f args _) | E.Var fname _ _ <- f = (FunctionName fname, map onArg $ NE.toList args) | E.Hole (Info _) loc <- f = (FunctionHole loc, map onArg $ NE.toList args) where - onArg (Info (argext, am), e) = ((e, argext), am) + onArg (Info (argext, _), e) = (e, argext) findFuncall e = error $ "Invalid function expression in application:\n" ++ prettyString e diff --git a/src/Futhark/Internalise/Monomorphise.hs b/src/Futhark/Internalise/Monomorphise.hs index aada3924c0..2617f95b8b 100644 --- a/src/Futhark/Internalise/Monomorphise.hs +++ b/src/Futhark/Internalise/Monomorphise.hs @@ -535,7 +535,7 @@ transformAppExp (Loop sparams pat e1 form body loc) res = do (pat_sizes, pat'') <- sizesForPat pat' res' <- transformAppRes res pure $ AppExp (Loop (sparams' ++ pat_sizes) pat'' e1' form' body' loc) (Info res') -transformAppExp (BinOp (fname, _) (Info t) (e1, Info (d1, am1)) (e2, Info (d2, am2)) loc) res = do +transformAppExp (BinOp (fname, _) (Info t) (e1, Info (d1, _)) (e2, Info (d2, _)) loc) res = do (AppRes ret ext) <- transformAppRes res fname' <- transformFName loc fname (toStruct t) e1' <- transformExp e1 @@ -570,8 +570,8 @@ transformAppExp (BinOp (fname, _) (Info t) (e1, Info (d1, am1)) (e2, Info (d2, a where applyOp ret ext fname' x y = mkApply - (mkApply fname' [(d1, am1, x)] (AppRes ret mempty)) - [(d2, am2, y)] + (mkApply fname' [(d1, mempty, x)] (AppRes ret mempty)) + [(d2, mempty, y)] (AppRes ret ext) makeVarParam arg = do @@ -664,27 +664,27 @@ transformExp (Lambda params e0 decl tp loc) = do transformExp (OpSection qn t loc) = transformExp $ Var qn t loc transformExp (OpSectionLeft fname (Info t) e arg (Info rettype, Info retext) loc) = do - let (Info (xp, xtype, xargext, xam), Info (yp, ytype)) = arg + let (Info (xp, xtype, xargext, _), Info (yp, ytype)) = arg e' <- transformExp e desugarBinOpSection fname (Just e') Nothing t - (xp, xtype, xargext, xam) - (yp, ytype, Nothing, mempty) + (xp, xtype, xargext) + (yp, ytype, Nothing) (rettype, retext) loc transformExp (OpSectionRight fname (Info t) e arg (Info rettype) loc) = do - let (Info (xp, xtype), Info (yp, ytype, yargext, yam)) = arg + let (Info (xp, xtype), Info (yp, ytype, yargext, _)) = arg e' <- transformExp e desugarBinOpSection fname Nothing (Just e') t - (xp, xtype, Nothing, mempty) - (yp, ytype, yargext, yam) + (xp, xtype, Nothing) + (yp, ytype, yargext) (rettype, []) loc transformExp (ProjectSection fields (Info t) loc) = do @@ -735,12 +735,12 @@ desugarBinOpSection :: Maybe Exp -> Maybe Exp -> StructType -> - (PName, ParamType, Maybe VName, AutoMap) -> - (PName, ParamType, Maybe VName, AutoMap) -> + (PName, ParamType, Maybe VName) -> + (PName, ParamType, Maybe VName) -> (ResRetType, [VName]) -> SrcLoc -> MonoM Exp -desugarBinOpSection fname e_left e_right t (xp, xtype, xext, xam) (yp, ytype, yext, yam) (RetType dims rettype, retext) loc = do +desugarBinOpSection fname e_left e_right t (xp, xtype, xext) (yp, ytype, yext) (RetType dims rettype, retext) loc = do t' <- transformType t op <- transformFName loc fname $ toStruct t (v1, wrap_left, e1, p1) <- makeVarParam e_left =<< transformType xtype @@ -748,7 +748,7 @@ desugarBinOpSection fname e_left e_right t (xp, xtype, xext, xam) (yp, ytype, ye let apply_left = mkApply op - [(xext, xam, e1)] + [(xext, mempty, e1)] (AppRes (Scalar $ Arrow mempty yp (diet ytype) (toStruct ytype) (RetType [] $ toRes Nonunique t')) []) onDim (Var d typ _) | Named p <- xp, qualLeaf d == p = Var (qualName v1) typ loc @@ -757,7 +757,7 @@ desugarBinOpSection fname e_left e_right t (xp, xtype, xext, xam) (yp, ytype, ye rettype' = first onDim rettype body <- scoping (S.fromList [v1, v2]) $ - mkApply apply_left [(yext, yam, e2)] + mkApply apply_left [(yext, mempty, e2)] <$> transformAppRes (AppRes (toStruct rettype') retext) rettype'' <- transformRetTypeSizes (S.fromList [v1, v2]) $ RetType dims rettype' pure . wrap_left . wrap_right $ From 8d88c9582154ac4a22c83e268e8aec2f8acd6c45 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Wed, 28 Feb 2024 12:08:44 -0800 Subject: [PATCH 144/258] Undo AUTOMAP-handling in `FullNormalise.hs`. --- src/Futhark/Internalise/FullNormalise.hs | 46 ++++++++---------------- 1 file changed, 15 insertions(+), 31 deletions(-) diff --git a/src/Futhark/Internalise/FullNormalise.hs b/src/Futhark/Internalise/FullNormalise.hs index 17841c9f53..a3dca1f8bd 100644 --- a/src/Futhark/Internalise/FullNormalise.hs +++ b/src/Futhark/Internalise/FullNormalise.hs @@ -203,13 +203,13 @@ getOrdering final (Lambda params body mte ret loc) = do nameExp final $ Lambda params body' mte ret loc getOrdering _ (OpSection qn ty loc) = pure $ Var qn ty loc -getOrdering final (OpSectionLeft op ty e (Info (xp, _, xext, xam), Info (yp, yty)) (Info (RetType dims ret), Info exts) loc) = do +getOrdering final (OpSectionLeft op ty e (Info (xp, _, xext, _), Info (yp, yty)) (Info (RetType dims ret), Info exts) loc) = do x <- getOrdering False e yn <- newNameFromString "y" let y = Var (qualName yn) (Info $ toStruct yty) mempty ret' = applySubst (pSubst x y) ret body = - mkApply (Var op ty mempty) [(xext, xam, x), (Nothing, mempty, y)] $ + mkApply (Var op ty mempty) [(xext, mempty, x), (Nothing, mempty, y)] $ AppRes (toStruct ret') exts nameExp final $ Lambda [Id yn (Info yty) mempty] body Nothing (Info (RetType dims ret')) loc where @@ -217,12 +217,12 @@ getOrdering final (OpSectionLeft op ty e (Info (xp, _, xext, xam), Info (yp, yty | Named p <- xp, p == vn = Just $ ExpSubst x | Named p <- yp, p == vn = Just $ ExpSubst y | otherwise = Nothing -getOrdering final (OpSectionRight op ty e (Info (xp, xty), Info (yp, _, yext, yam)) (Info (RetType dims ret)) loc) = do +getOrdering final (OpSectionRight op ty e (Info (xp, xty), Info (yp, _, yext, _)) (Info (RetType dims ret)) loc) = do xn <- newNameFromString "x" y <- getOrdering False e let x = Var (qualName xn) (Info $ toStruct xty) mempty ret' = applySubst (pSubst x y) ret - body = mkApply (Var op ty mempty) [(Nothing, mempty, x), (yext, yam, y)] $ AppRes (toStruct ret') [] + body = mkApply (Var op ty mempty) [(Nothing, mempty, x), (yext, mempty, y)] $ AppRes (toStruct ret') [] nameExp final $ Lambda [Id xn (Info xty) mempty] body Nothing (Info (RetType dims ret')) loc where pSubst x y vn @@ -298,41 +298,25 @@ getOrdering final (AppExp (Loop sizes pat einit form body loc) resT) = do While e -> While <$> transformBody e body' <- transformBody body nameExp final $ AppExp (Loop sizes pat einit' form' body' loc) resT -getOrdering final (AppExp (BinOp (op, oloc) opT (el, Info (elp, elam)) (er, Info (erp, eram)) loc) (Info resT)) = do +getOrdering final (AppExp (BinOp (op, oloc) opT (el, Info (elp, _)) (er, Info (erp, _)) loc) (Info resT)) = do -- Rewrite short-circuiting boolean operators on scalars to explicit -- if-then-else. Automapped cases are turned into applications of -- intrinsic functions. expr' <- case (isOr, isAnd) of - (True, _) - | elam == mempty, - eram == mempty -> do - el' <- naming "or_lhs" $ getOrdering True el - er' <- naming "or_rhs" $ transformBody er - pure $ AppExp (If el' (Literal (BoolValue True) mempty) er' loc) (Info resT) - | otherwise -> do - el' <- naming "or_lhs" $ getOrdering False el - er' <- naming "or_rhs" $ getOrdering False er - pure $ mkApply orop [(elp, elam, el'), (erp, eram, er')] resT - (_, True) - | elam == mempty, - eram == mempty -> do - el' <- naming "and_lhs" $ getOrdering True el - er' <- naming "and_rhs" $ transformBody er - pure $ AppExp (If el' er' (Literal (BoolValue False) mempty) loc) (Info resT) - | otherwise -> do - el' <- naming "and_lhs" $ getOrdering False el - er' <- naming "and_rhs" $ getOrdering False er - pure $ mkApply andop [(elp, elam, el'), (erp, eram, er')] resT - _ -> do + (True, _) -> do + el' <- naming "or_lhs" $ getOrdering True el + er' <- naming "or_rhs" $ transformBody er + pure $ AppExp (If el' (Literal (BoolValue True) mempty) er' loc) (Info resT) + (_, True) -> do + el' <- naming "and_lhs" $ getOrdering True el + er' <- naming "and_rhs" $ transformBody er + pure $ AppExp (If el' er' (Literal (BoolValue False) mempty) loc) (Info resT) + (False, False) -> do el' <- naming (prettyString op <> "_lhs") $ getOrdering False el er' <- naming (prettyString op <> "_rhs") $ getOrdering False er - pure $ mkApply (Var op opT oloc) [(elp, elam, el'), (erp, eram, er')] resT + pure $ mkApply (Var op opT oloc) [(elp, mempty, el'), (erp, mempty, er')] resT nameExp final expr' where - bool = Scalar $ Prim Bool - opt = foldFunType [bool, bool] $ RetType [] bool - andop = Var (qualName (intrinsicVar "&&")) (Info opt) mempty - orop = Var (qualName (intrinsicVar "||")) (Info opt) mempty isOr = baseName (qualLeaf op) == "||" isAnd = baseName (qualLeaf op) == "&&" getOrdering final (AppExp (LetWith (Ident dest dty dloc) (Ident src sty sloc) slice e body loc) _) = do From de5cf38fa8e36e7145db8d080346158ec3292b9e Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 29 Feb 2024 14:54:36 +0100 Subject: [PATCH 145/258] Use proper type here. --- src/Language/Futhark/TypeChecker/Terms/Loop.hs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms/Loop.hs b/src/Language/Futhark/TypeChecker/Terms/Loop.hs index 7cba8af7e8..51d1c8ceba 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Loop.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Loop.hs @@ -221,14 +221,16 @@ checkLoop checkExp (mergepat, mergeexp, form, loopbody) loc = do case form of For i uboundexp -> do uboundexp' <- checkExp uboundexp - bindingIdent i . bindingPat [] mergepat merge_t $ + it <- expType uboundexp' + let i' = i {identType = Info it} + bindingIdent i' . bindingPat [] mergepat merge_t $ \mergepat' -> incLevel $ do loopbody' <- checkExp loopbody (sparams, mergepat'') <- checkLoopReturnSize mergepat' loopbody' pure ( sparams, mergepat'', - For i uboundexp', + For i' uboundexp', loopbody' ) ForIn xpat e -> do From dcfdf1e5b373e0a8625862499c864e74f57c1161 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 29 Feb 2024 14:56:09 +0100 Subject: [PATCH 146/258] Also update type here. --- src/Language/Futhark/TypeChecker/Terms.hs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 25c6dcb0c8..68547cb024 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -363,8 +363,8 @@ checkExp (RecordLit fs loc) = where checkField (RecordFieldExplicit f e rloc) = RecordFieldExplicit f <$> checkExp e <*> pure rloc - checkField (RecordFieldImplicit name info rloc) = - pure $ RecordFieldImplicit name info rloc + checkField (RecordFieldImplicit name (Info t) rloc) = + RecordFieldImplicit name <$> (Info <$> replaceTyVars rloc t) <*> pure rloc checkExp (ArrayLit all_es _ loc) = -- Construct the result type and unify all elements with it. We -- only create a type variable for empty arrays; otherwise we use From 9c869ab0b9d48477deda0271e0a4589532dca4ff Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 29 Feb 2024 16:38:13 +0100 Subject: [PATCH 147/258] Fix handling of overloaded type variables. --- .../Futhark/TypeChecker/Constraints.hs | 25 +++++++------------ src/Language/Futhark/TypeChecker/Terms2.hs | 25 +++++++++++++------ src/Language/Futhark/TypeChecker/Types.hs | 18 +++++++++++++ 3 files changed, 44 insertions(+), 24 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 0725d50418..10206b0bb8 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -22,6 +22,7 @@ import Data.Set qualified as S import Data.Text qualified as T import Futhark.Util.Pretty import Language.Futhark +import Language.Futhark.TypeChecker.Types (substTyVars) type SVar = VName @@ -99,21 +100,13 @@ newtype SolverState = SolverState {solverTyVars :: M.Map TyVar TyVarSol} initialState :: TyVars -> SolverState initialState tyvars = SolverState $ M.map (uncurry TyVarUnsol) tyvars -substTyVars :: (Monoid u) => M.Map TyVar TyVarSol -> TypeBase SComp u -> TypeBase SComp u -substTyVars m t@(Scalar (TypeVar u (QualName qs v) args)) = +substTyVar :: (Monoid u) => M.Map TyVar TyVarSol -> VName -> Maybe (TypeBase SComp u) +substTyVar m v = case M.lookup v m of - Just (TyVarLink v') -> - substTyVars m $ Scalar $ TypeVar u (QualName qs v') args - Just (TyVarSol _ t') -> second (const mempty) $ substTyVars m t' - Just (TyVarUnsol {}) -> t - Nothing -> t -substTyVars _ (Scalar (Prim pt)) = Scalar $ Prim pt -substTyVars m (Scalar (Record fs)) = Scalar $ Record $ M.map (substTyVars m) fs -substTyVars m (Scalar (Sum cs)) = Scalar $ Sum $ M.map (map $ substTyVars m) cs -substTyVars m (Scalar (Arrow u pname d t1 (RetType ext t2))) = - Scalar $ Arrow u pname d (substTyVars m t1) $ RetType ext $ substTyVars m t2 `setUniqueness` uniqueness t2 -substTyVars m (Array u shape elemt) = - arrayOfWithAliases u shape $ substTyVars m $ Scalar elemt + Just (TyVarLink v') -> substTyVar m v' + Just (TyVarSol _ t') -> Just $ second (const mempty) $ substTyVars (substTyVar m) t' + Just (TyVarUnsol {}) -> Nothing + Nothing -> Nothing -- | A solution maps a type variable to its substitution. This -- substitution is complete, in the sense there are no right-hand @@ -127,7 +120,7 @@ solution s = ) where mkSubst (TyVarSol _lvl t) = - Just $ Right $ first (const ()) $ substTyVars (solverTyVars s) t + Just $ Right $ first (const ()) $ substTyVars (substTyVar (solverTyVars s)) t mkSubst (TyVarLink v') = Just . fromMaybe (Right $ Scalar $ TypeVar mempty (qualName v') []) $ mkSubst =<< M.lookup v' (solverTyVars s) @@ -143,7 +136,7 @@ newtype SolveM a = SolveM {runSolveM :: StateT SolverState (Except T.Text) a} occursCheck :: VName -> Type -> SolveM () occursCheck v tp = do vars <- gets solverTyVars - let tp' = substTyVars vars tp + let tp' = substTyVars (substTyVar vars) tp when (v `S.member` typeVars tp') . throwError . docText $ "Occurs check: cannot instantiate" <+> prettyName v diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index ccba90ab07..f65bbb1811 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -640,7 +640,7 @@ checkApplyOne loc fname (fframe, ftype) (argframe, argtype) = do ctAM r m $ fmap toSComp (toShape m_var <> fframe) ctEq lhs rhs debugTraceM $ - unlines $ + unlines [ "## checkApplyOne", "## fname", prettyString fname, @@ -1123,12 +1123,12 @@ checkExp (Coerce e te NoInfo loc) = do ctEq (expType e') (toType st) pure $ Coerce e' te' (Info (toStruct st)) loc -doDefaults :: +doDefault :: S.Set VName -> VName -> Either [PrimType] (TypeBase () NoUniqueness) -> TermM (TypeBase () NoUniqueness) -doDefaults tyvars_at_toplevel v (Left pts) +doDefault tyvars_at_toplevel v (Left pts) | Signed Int32 `elem` pts = do when (v `S.member` tyvars_at_toplevel) $ warn usage "Defaulting ambiguous type to i32." @@ -1145,7 +1145,18 @@ doDefaults tyvars_at_toplevel v (Left pts) "Add a type annotation to disambiguate the type." where usage = mkUsage NoLoc "overload" -doDefaults _ _ (Right t) = pure t +doDefault _ _ (Right t) = pure t + +-- | Apply defaults on otherwise ambiguous types. This may result in +-- some type variables becoming known, so we have to perform +-- substitutions on the RHS of the substitutions afterwards. +doDefaults :: + S.Set VName -> + M.Map TyVar (Either [PrimType] (TypeBase () NoUniqueness)) -> + TermM (M.Map TyVar (TypeBase () NoUniqueness)) +doDefaults tyvars_at_toplevel substs = do + substs' <- M.traverseWithKey (doDefault tyvars_at_toplevel) substs + pure $ M.map (substTyVars (`M.lookup` substs')) substs' checkValDef :: ( VName, @@ -1187,8 +1198,7 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do solutions <- forM cts_tyvars' $ - bitraverse pure (traverse (M.traverseWithKey (doDefaults mempty))) - . uncurry solve + bitraverse pure (traverse (doDefaults mempty)) . uncurry solve forM_ (zip solutions cts_tyvars') $ \(solution, (cts', tyvars')) -> debugTraceM $ @@ -1213,6 +1223,5 @@ checkSingleExp e = runTermM $ do cts <- gets termConstraints tyvars <- gets termTyVars solution <- - bitraverse pure (traverse (M.traverseWithKey (doDefaults mempty))) $ - solve cts tyvars + bitraverse pure (traverse (doDefaults mempty)) $ solve cts tyvars pure (solution, e') diff --git a/src/Language/Futhark/TypeChecker/Types.hs b/src/Language/Futhark/TypeChecker/Types.hs index 22fa6fc5a1..b89802c5e0 100644 --- a/src/Language/Futhark/TypeChecker/Types.hs +++ b/src/Language/Futhark/TypeChecker/Types.hs @@ -8,6 +8,7 @@ module Language.Futhark.TypeChecker.Types TypeSubs, Substitutable (..), substTypesAny, + substTyVars, -- * Witnesses mustBeExplicitInType, @@ -531,6 +532,23 @@ substTypesAny lookupSubst ot = toAny d = d in first toAny ot' +-- | Substitution without caring about sizes. +substTyVars :: (Monoid u) => (VName -> Maybe (TypeBase d NoUniqueness)) -> TypeBase d u -> TypeBase d u +substTyVars f t@(Scalar (TypeVar u (QualName qs v) args)) = + case f v of + Just t' -> second (const mempty) $ substTyVars f t' + Nothing -> t +substTyVars _ (Scalar (Prim pt)) = Scalar $ Prim pt +substTyVars f (Scalar (Record fs)) = Scalar $ Record $ M.map (substTyVars f) fs +substTyVars f (Scalar (Sum cs)) = Scalar $ Sum $ M.map (map $ substTyVars f) cs +substTyVars f (Scalar (Arrow u pname d t1 (RetType ext t2))) = + Scalar $ + Arrow u pname d (substTyVars f t1) $ + RetType ext $ + substTyVars f t2 `setUniqueness` uniqueness t2 +substTyVars f (Array u shape elemt) = + arrayOfWithAliases u shape $ substTyVars f $ Scalar elemt + -- Note [AnySize] -- -- Consider a program: From 837f3cd57121906ddfaf428b7c66fe57c92ff153 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 29 Feb 2024 17:02:16 +0100 Subject: [PATCH 148/258] Add checkSizeExp to Terms2. --- src/Language/Futhark/TypeChecker/Terms.hs | 5 ++--- src/Language/Futhark/TypeChecker/Terms2.hs | 24 ++++++++++++++++------ 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 68547cb024..e953ae3c00 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -1063,16 +1063,14 @@ checkOneExp e = do -- turn out to be polymorphic, in which case it is unified with i64. checkSizeExp :: ExpBase NoInfo VName -> TypeM Exp checkSizeExp e = do - (maybe_tysubsts, e') <- Terms2.checkSingleExp e + (maybe_tysubsts, e') <- Terms2.checkSizeExp e case maybe_tysubsts of Left err -> typeError e' mempty $ pretty err Right (unconstrained, tysubsts) -> runTermTypeM checkExp tysubsts $ do e'' <- checkExp e' - let t = typeOf e'' when (hasBinding e'') $ typeError (srclocOf e'') mempty . withIndexLink "size-expression-bind" $ "Size expression with binding is forbidden." - unify (mkUsage e'' "Size expression") t (Scalar (Prim (Signed Int64))) normTypeFully e'' -- Verify that all sum type constructors and empty array literals have @@ -1657,6 +1655,7 @@ checkFunDef (fname, retdecl, tparams, params, body, loc) = do case maybe_tysubsts of Left err -> typeError loc mempty $ pretty err Right (unconstrained, tysubsts) -> runTermTypeM checkExp tysubsts $ do + debugTraceM $ unlines [unlines $ map show $ M.toList tysubsts, prettyString body'] let unconstrained_tparams = map (\v -> TypeParamType Unlifted v mempty) unconstrained (tparams', params'', retdecl'', RetType dims rettype', body'') <- checkBinding (fname, retdecl', unconstrained_tparams <> tparams, params', body', loc) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index f65bbb1811..b56f534184 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -38,6 +38,7 @@ module Language.Futhark.TypeChecker.Terms2 ( checkValDef, checkSingleExp, + checkSizeExp, Solution, ) where @@ -442,8 +443,8 @@ patLitMkType (PatLitFloat _) loc = patLitMkType (PatLitPrim v) _ = pure $ Scalar $ Prim $ primValueType v -checkSizeExp :: ExpBase NoInfo VName -> TermM Exp -checkSizeExp e = do +checkSizeExp' :: ExpBase NoInfo VName -> TermM Exp +checkSizeExp' e = do e' <- checkExp e ctEq (expType e') (Scalar (Prim (Signed Int64))) pure e' @@ -496,7 +497,7 @@ checkPat' (RecordPat fs loc) NoneInferred = <$> traverse (`checkPat'` NoneInferred) (M.fromList fs) <*> pure loc checkPat' (PatAscription p t loc) maybe_outer_t = do - (t', _, RetType _ st, _) <- checkTypeExp checkSizeExp t + (t', _, RetType _ st, _) <- checkTypeExp checkSizeExp' t -- Uniqueness kung fu to make the Monoid(mempty) instance give what -- we expect. We should perhaps stop being so implicit. @@ -775,7 +776,7 @@ checkRetDecl :: TermM (Maybe (TypeExp Exp VName)) checkRetDecl _ Nothing = pure Nothing checkRetDecl body (Just te) = do - (te', _, RetType _ st, _) <- checkTypeExp checkSizeExp te + (te', _, RetType _ st, _) <- checkTypeExp checkSizeExp' te ctEq (expType body) (toType st) pure $ Just te' @@ -1114,12 +1115,12 @@ checkExp (AppExp (Loop _ pat arg form body loc) _) = do -- checkExp (Ascript e te loc) = do e' <- checkExp e - (te', _, RetType _ st, _) <- checkTypeExp checkSizeExp te + (te', _, RetType _ st, _) <- checkTypeExp checkSizeExp' te ctEq (expType e') (toType st) pure $ Ascript e' te' loc checkExp (Coerce e te NoInfo loc) = do e' <- checkExp e - (te', _, RetType _ st, _) <- checkTypeExp checkSizeExp te + (te', _, RetType _ st, _) <- checkTypeExp checkSizeExp' te ctEq (expType e') (toType st) pure $ Coerce e' te' (Info (toStruct st)) loc @@ -1225,3 +1226,14 @@ checkSingleExp e = runTermM $ do solution <- bitraverse pure (traverse (doDefaults mempty)) $ solve cts tyvars pure (solution, e') + +-- | Type-check a single size expression in isolation. This expression may +-- turn out to be polymorphic, in which case it is unified with i64. +checkSizeExp :: ExpBase NoInfo VName -> TypeM (Either T.Text ([VName], M.Map TyVar (TypeBase () NoUniqueness)), Exp) +checkSizeExp e = runTermM $ do + e' <- checkSizeExp' e + cts <- gets termConstraints + tyvars <- gets termTyVars + solution <- + bitraverse pure (traverse (doDefaults mempty)) $ solve cts tyvars + pure (solution, e') From 12d21a47e5e6ad63b6784611df89acd6f360f6c1 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 29 Feb 2024 22:47:09 +0100 Subject: [PATCH 149/258] Fixes to updating of types. --- src/Language/Futhark/TypeChecker/Terms.hs | 1 - .../Futhark/TypeChecker/Terms/Monad.hs | 1 + src/Language/Futhark/TypeChecker/Terms/Pat.hs | 30 +++++++++---------- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index e953ae3c00..5eae32ac6b 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -1655,7 +1655,6 @@ checkFunDef (fname, retdecl, tparams, params, body, loc) = do case maybe_tysubsts of Left err -> typeError loc mempty $ pretty err Right (unconstrained, tysubsts) -> runTermTypeM checkExp tysubsts $ do - debugTraceM $ unlines [unlines $ map show $ M.toList tysubsts, prettyString body'] let unconstrained_tparams = map (\v -> TypeParamType Unlifted v mempty) unconstrained (tparams', params'', retdecl'', RetType dims rettype', body'') <- checkBinding (fname, retdecl', unconstrained_tparams <> tparams, params', body', loc) diff --git a/src/Language/Futhark/TypeChecker/Terms/Monad.hs b/src/Language/Futhark/TypeChecker/Terms/Monad.hs index 3e23ceec2e..2153eb2751 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Monad.hs @@ -25,6 +25,7 @@ module Language.Futhark.TypeChecker.Terms.Monad constrain, newArrayType, allDimsFreshInType, + instTyVars, replaceTyVars, updateTypes, Names, diff --git a/src/Language/Futhark/TypeChecker/Terms/Pat.hs b/src/Language/Futhark/TypeChecker/Terms/Pat.hs index 17aef0e1a2..4485c20f4c 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Pat.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Pat.hs @@ -15,6 +15,7 @@ import Data.List (find, isPrefixOf, sort) import Data.Map.Strict qualified as M import Data.Maybe import Data.Set qualified as S +import Futhark.Util import Futhark.Util.Pretty hiding (group, space) import Language.Futhark import Language.Futhark.TypeChecker.Monad hiding (BoundV) @@ -104,9 +105,9 @@ bindingIdent ident = binding [ident] checkPat' :: [(SizeBinder VName, QualName VName)] -> - Pat (TypeBase Size u) -> + Pat ParamType -> Inferred ParamType -> - TermTypeM (Pat (TypeBase Size u)) + TermTypeM (Pat ParamType) checkPat' sizes (PatParens p loc) t = PatParens <$> checkPat' sizes p t <*> pure loc checkPat' sizes (PatAttr attr p loc) t = @@ -115,16 +116,14 @@ checkPat' _ (Id name (Info t) loc) NoneInferred = do t' <- replaceTyVars loc t pure $ Id name (Info t') loc checkPat' _ (Id name (Info t1) loc) (Ascribed t2) = do - t <- replaceTyVars loc t1 - unify (mkUsage loc "id") (toStruct t) (toStruct t2) - pure $ Id name (Info t) loc + t' <- instTyVars loc [] (first (const ()) t1) t2 + pure $ Id name (Info t') loc checkPat' _ (Wildcard (Info t) loc) NoneInferred = do t' <- replaceTyVars loc t pure $ Wildcard (Info t') loc checkPat' _ (Wildcard (Info t1) loc) (Ascribed t2) = do - t <- replaceTyVars loc t1 - unify (mkUsage loc "wildcard") (toStruct t) (toStruct t2) - pure $ Wildcard (Info t) loc + t' <- instTyVars loc [] (first (const ()) t1) t2 + pure $ Wildcard (Info t') loc checkPat' sizes p@(TuplePat ps loc) (Ascribed t) | Just ts <- isTupleRecord t, length ts == length ps = @@ -185,9 +184,9 @@ checkPat' sizes (PatConstr n info ps loc) _ = do checkPat :: [(SizeBinder VName, QualName VName)] -> - Pat (TypeBase Size u) -> + Pat ParamType -> Inferred StructType -> - (Pat (TypeBase Size u) -> TermTypeM a) -> + (Pat ParamType -> TermTypeM a) -> TermTypeM a checkPat sizes p t m = do p' <- @@ -210,14 +209,15 @@ bindingPat :: [SizeBinder VName] -> Pat (TypeBase Size u) -> StructType -> - (Pat (TypeBase Size u) -> TermTypeM a) -> + (Pat ParamType -> TermTypeM a) -> TermTypeM a bindingPat sizes p t m = do substs <- mapM mkSizeSubst sizes - checkPat substs p (Ascribed t) $ \p' -> binding (patIdents (fmap toStruct p')) $ - case filter ((`S.notMember` fvVars (freeInPat p')) . sizeName) sizes of - [] -> m p' - size : _ -> unusedSize size + checkPat substs (fmap (toParam Observe) p) (Ascribed t) $ \p' -> + binding (patIdents (fmap toStruct p')) $ + case filter ((`S.notMember` fvVars (freeInPat p')) . sizeName) sizes of + [] -> m p' + size : _ -> unusedSize size where mkSizeSubst v = do v' <- newID $ baseName $ sizeName v From c6aaf9ca507368589e5ff904668e06483c8e5e83 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 1 Mar 2024 15:36:00 +0100 Subject: [PATCH 150/258] Fix unification of abstract types. --- src/Language/Futhark/TypeChecker/Constraints.hs | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 10206b0bb8..fc4f1fb5f8 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -166,6 +166,14 @@ linkTyVar v t = do unify :: Type -> Type -> Maybe [(Type, Type)] unify (Scalar (Prim pt1)) (Scalar (Prim pt2)) | pt1 == pt2 = Just [] +unify + (Scalar (TypeVar _ (QualName _ v1) targs1)) + (Scalar (TypeVar _ (QualName _ v2) targs2)) + | v1 == v2 = + Just $ mapMaybe f $ zip targs1 targs2 + where + f (TypeArgType t1, TypeArgType t2) = Just (t1, t2) + f _ = Nothing unify (Scalar (Arrow _ _ _ t1a (RetType _ t1r))) (Scalar (Arrow _ _ _ t2a (RetType _ t2r))) = Just [(t1a, t2a), (t1r', t2r')] where @@ -223,13 +231,9 @@ solveCt ct = (Scalar (TypeVar _ (QualName [] v1) []), t2') | Just lvl <- flexible v1 -> subTyVar v1 lvl t2' - | otherwise -> - bad (t1', Scalar (TypeVar _ (QualName [] v2) [])) | Just lvl <- flexible v2 -> subTyVar v2 lvl t1' - | otherwise -> - bad (t1', t2') -> case unify t1' t2' of Nothing -> bad Just eqs -> mapM_ solveCt' eqs From b6d2e1b383b0cf4cafc23db411430c64e2eb10fe Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 1 Mar 2024 15:44:29 +0100 Subject: [PATCH 151/258] Also do the AUTOMAP on size expressions. --- src/Language/Futhark/TypeChecker/Terms2.hs | 83 ++++++++++++---------- 1 file changed, 47 insertions(+), 36 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index b56f534184..cc1f611129 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -1174,47 +1174,48 @@ checkValDef :: [Exp] ) checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do - bindParams tparams params $ \params' -> do - body' <- checkExp body + (params', body', retdecl') <- + bindParams tparams params $ \params' -> do + body' <- checkExp body + retdecl' <- checkRetDecl body' retdecl + pure (params', body', retdecl') - retdecl' <- checkRetDecl body' retdecl + cts <- gets termConstraints - cts <- gets termConstraints + tyvars <- gets termTyVars - tyvars <- gets termTyVars + debugTraceM $ "\n# function " <> prettyNameString fname <> "\n# " <> locStr loc <> "\n" - debugTraceM $ "\n# function " <> prettyNameString fname <> "\n# " <> locStr loc <> "\n" + debugTraceM $ + unlines + [ "## cts:", + unlines $ map prettyString cts, + "## body:", + prettyString body', + "## tyvars:", + unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars + ] + + (cts_tyvars', bodys') <- unzip <$> rankAnalysis loc cts tyvars body' + solutions <- + forM cts_tyvars' $ + bitraverse pure (traverse (doDefaults mempty)) . uncurry solve + + forM_ (zip solutions cts_tyvars') $ \(solution, (cts', tyvars')) -> debugTraceM $ unlines - [ "## cts:", - unlines $ map prettyString cts, - "## body:", - prettyString body', - "## tyvars:", - unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars + [ "## constraints:", + unlines $ map prettyString cts', + "## tyvars':", + unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars', + "## solution:", + let p (v, t) = prettyNameString v <> " => " <> prettyString t + in either T.unpack (unlines . map p . M.toList . snd) solution, + either (const mempty) (unlines . ("## unconstrained:" :) . map prettyNameString . fst) solution ] - (cts_tyvars', bodys') <- unzip <$> rankAnalysis loc cts tyvars body' - - solutions <- - forM cts_tyvars' $ - bitraverse pure (traverse (doDefaults mempty)) . uncurry solve - - forM_ (zip solutions cts_tyvars') $ \(solution, (cts', tyvars')) -> - debugTraceM $ - unlines - [ "## constraints:", - unlines $ map prettyString cts', - "## tyvars':", - unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars', - "## solution:", - let p (v, t) = prettyNameString v <> " => " <> prettyString t - in either T.unpack (unlines . map p . M.toList . snd) solution, - either (const mempty) (unlines . ("## unconstrained:" :) . map prettyNameString . fst) solution - ] - - pure (solutions, params', retdecl', bodys') + pure (solutions, params', retdecl', bodys') checkSingleExp :: ExpBase NoInfo VName -> @@ -1229,11 +1230,21 @@ checkSingleExp e = runTermM $ do -- | Type-check a single size expression in isolation. This expression may -- turn out to be polymorphic, in which case it is unified with i64. -checkSizeExp :: ExpBase NoInfo VName -> TypeM (Either T.Text ([VName], M.Map TyVar (TypeBase () NoUniqueness)), Exp) +checkSizeExp :: + ExpBase NoInfo VName -> + TypeM (Either T.Text ([VName], M.Map TyVar (TypeBase () NoUniqueness)), Exp) checkSizeExp e = runTermM $ do e' <- checkSizeExp' e cts <- gets termConstraints tyvars <- gets termTyVars - solution <- - bitraverse pure (traverse (doDefaults mempty)) $ solve cts tyvars - pure (solution, e') + + (cts_tyvars', es') <- unzip <$> rankAnalysis (srclocOf e) cts tyvars e' + + solutions <- + forM cts_tyvars' $ + bitraverse pure (traverse (doDefaults mempty)) . uncurry solve + + case (solutions, es') of + ([solution], [e'']) -> + pure (solution, e'') + _ -> pure (Left "Ambiguous size expression", e') From 45a5c448578c5a1f03935de811ae400d89b202e8 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 1 Mar 2024 16:45:21 +0100 Subject: [PATCH 152/258] Fix instantiation of parametric abstract types. --- src/Language/Futhark/TypeChecker/Terms/Monad.hs | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms/Monad.hs b/src/Language/Futhark/TypeChecker/Terms/Monad.hs index 2153eb2751..086580745d 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Monad.hs @@ -411,6 +411,14 @@ instTyVars loc names orig_t1 orig_t2 = do (Array _ (Shape (d : ds2)) t2) = arrayOfWithAliases u (Shape [d]) <$> f (arrayOf (Shape ds1) (Scalar t1)) (arrayOf (Shape ds2) (Scalar t2)) + f + (Scalar (TypeVar u v1 targs1)) + (Scalar (TypeVar _ _ targs2)) = + Scalar . TypeVar u v1 <$> zipWithM g targs1 targs2 + where + g (TypeArgType t1) (TypeArgType t2) = + TypeArgType <$> f t1 t2 + g _ targ = pure targ f t1 t2 = do let mkNew = fst <$> lift (allDimsFreshInType (mkUsage loc "instantiation") Nonrigid "dv" t1) @@ -538,11 +546,6 @@ lookupVar loc qn@(QualName qs name) inst_t = do replaceTyVars loc inst_t Just OverloadedF {} -> replaceTyVars loc inst_t - where - instOverloaded argtype pts rt = - ( map (maybe (toStruct argtype) (Scalar . Prim)) pts, - maybe (toStruct argtype) (Scalar . Prim) rt - ) onFailure :: Checking -> TermTypeM a -> TermTypeM a onFailure c = local $ \env -> env {termChecking = Just c} From a3a50945e72882c5de16bc2ed7374f0496f05844 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 1 Mar 2024 17:06:56 +0100 Subject: [PATCH 153/258] Must match. --- src/Language/Futhark/TypeChecker/Terms/Monad.hs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms/Monad.hs b/src/Language/Futhark/TypeChecker/Terms/Monad.hs index 086580745d..018a3e920a 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Monad.hs @@ -60,6 +60,7 @@ import Data.Set qualified as S import Data.Text qualified as T import Futhark.FreshNames hiding (newName) import Futhark.FreshNames qualified +import Futhark.Util import Futhark.Util.Pretty hiding (space) import Language.Futhark import Language.Futhark.Traversals @@ -413,8 +414,9 @@ instTyVars loc names orig_t1 orig_t2 = do <$> f (arrayOf (Shape ds1) (Scalar t1)) (arrayOf (Shape ds2) (Scalar t2)) f (Scalar (TypeVar u v1 targs1)) - (Scalar (TypeVar _ _ targs2)) = - Scalar . TypeVar u v1 <$> zipWithM g targs1 targs2 + (Scalar (TypeVar _ _ targs2)) + | length targs1 == length targs2 = + Scalar . TypeVar u v1 <$> zipWithM g targs1 targs2 where g (TypeArgType t1) (TypeArgType t2) = TypeArgType <$> f t1 t2 @@ -461,7 +463,6 @@ instTypeScheme qn loc tparams scheme_t inferred = do let tp_names = map typeParamName $ filter isTypeParam tparams t' <- instTyVars loc tp_names inferred $ applySubst (`lookup` substs) scheme_t - pure (names, t') lookupQualNameEnv :: QualName VName -> TermTypeM TermScope From 989c75778003feb0a2509d4827d5a3425fae6ff7 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Fri, 1 Mar 2024 17:49:27 -0800 Subject: [PATCH 154/258] Fix `mri-q.fut`. --- tests/automap/mri-q.fut | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/tests/automap/mri-q.fut b/tests/automap/mri-q.fut index eaed14333a..8fe26aded6 100644 --- a/tests/automap/mri-q.fut +++ b/tests/automap/mri-q.fut @@ -23,18 +23,17 @@ def main_am [numK][numX] (kx: [numK]f32) (ky: [numK]f32) (kz: [numK]f32) (x: [numX]f32) (y: [numX]f32) (z: [numX]f32) (phiR: [numK]f32) (phiI: [numK]f32) - : ([numK]f32, [numX][numK]f32) = - let (phiMag : [numK]f32) = phiR * phiR + phiI * phiI - let (expArgs : [numX][numK]f32) = map3 (\(x_e : f32) (y_e : f32) (z_e : f32) -> + : ([numX]f32, [numX]f32) = + let phiMag = phiR * phiR + phiI * phiI + let expArgs = map3 (\x_e y_e z_e -> 2.0*f32.pi*(kx*x_e + ky*y_e + kz*z_e)) x y z - in (phiMag, expArgs) - --let (qr : [numX]f32) = f32.sum (f32.cos expArgs * phiMag) -- [numx]f32 - --let (qi : [numX]f32) = f32.sum (f32.sin expArgs * phiMag) -- let (qi_10408: artificial₁₁₄_10524 ~ [M113_10523]f32) - --in (qr, qi) + let qr = f32.sum (f32.cos expArgs * phiMag) + let qi = f32.sum (f32.sin expArgs * phiMag) + in (qr, qi) ---entry main [numK][numX] --- (kx: [numK]f32) (ky: [numK]f32) (kz: [numK]f32) --- (x: [numX]f32) (y: [numX]f32) (z: [numX]f32) --- (phiR: [numK]f32) (phiI: [numK]f32) = --- main_orig kx ky kz x y z phiR phiI == main_am kx ky kz x y z phiR phiI +entry main [numK][numX] + (kx: [numK]f32) (ky: [numK]f32) (kz: [numK]f32) + (x: [numX]f32) (y: [numX]f32) (z: [numX]f32) + (phiR: [numK]f32) (phiI: [numK]f32) : bool = + main_orig kx ky kz x y z phiR phiI == main_am kx ky kz x y z phiR phiI From 101285a0a4d907225786a55378e94fe4756bb1c2 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Fri, 1 Mar 2024 17:38:11 -0800 Subject: [PATCH 155/258] `debugTraceM` now takes a level. --- src/Futhark/Util.hs | 8 ++++---- src/Language/Futhark/Pretty.hs | 4 ++-- src/Language/Futhark/TypeChecker.hs | 2 +- src/Language/Futhark/TypeChecker/Rank.hs | 6 +++--- src/Language/Futhark/TypeChecker/Terms.hs | 2 +- src/Language/Futhark/TypeChecker/Terms2.hs | 10 +++++----- 6 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/Futhark/Util.hs b/src/Futhark/Util.hs index 6d41b1c7f2..6a97f25e0e 100644 --- a/src/Futhark/Util.hs +++ b/src/Futhark/Util.hs @@ -510,8 +510,8 @@ topologicalSort dep nodes = mapM_ sorting $ mapMaybe (depends_of node) nodes_idx modify $ bimap (node :) (IM.insert i False) --- | 'traceM', but only if @FUTHARK_COMPILER_DEBUGGING@ is set to at least 1. -debugTraceM :: (Monad m) => String -> m () -debugTraceM - | isEnvVarAtLeast "FUTHARK_COMPILER_DEBUGGING" 1 = traceM +-- | 'traceM', but only if @FUTHARK_COMPILER_DEBUGGING@ is set to to the appropriate level. +debugTraceM :: (Monad m) => Int -> String -> m () +debugTraceM level + | isEnvVarAtLeast "FUTHARK_COMPILER_DEBUGGING" level = traceM | otherwise = const $ pure () diff --git a/src/Language/Futhark/Pretty.hs b/src/Language/Futhark/Pretty.hs index 3318fe11e4..8e0b2619d9 100644 --- a/src/Language/Futhark/Pretty.hs +++ b/src/Language/Futhark/Pretty.hs @@ -237,7 +237,7 @@ prettyAppExp :: (Eq vn, IsName vn, Annot f) => Int -> AppExpBase f vn -> Doc a prettyAppExp p (BinOp (bop, _) _ (x, xi) (y, yi) _) = case (unAnnot xi, unAnnot yi) of (Just (_, xam), Just (_, yam)) - | isEnvVarAtLeast "FUTHARK_COMPILER_DEBUGGING" 1 -> + | isEnvVarAtLeast "FUTHARK_COMPILER_DEBUGGING" 3 -> -- fix parens $ align $ prettyBinOp p bop x y "Δ" <+> pretty xam "Δ" <+> pretty yam _ -> prettyBinOp p bop x y @@ -322,7 +322,7 @@ prettyAppExp p (Apply f args _) = prettyArg (i, e) = case unAnnot i of Just (_, am) - | isEnvVarAtLeast "FUTHARK_COMPILER_DEBUGGING" 1 -> + | isEnvVarAtLeast "FUTHARK_COMPILER_DEBUGGING" 3 -> parens (prettyExp 10 e <+> "Δ" <+> pretty am) _ -> prettyExp 10 e diff --git a/src/Language/Futhark/TypeChecker.hs b/src/Language/Futhark/TypeChecker.hs index 65fd30f220..70824ad148 100644 --- a/src/Language/Futhark/TypeChecker.hs +++ b/src/Language/Futhark/TypeChecker.hs @@ -714,7 +714,7 @@ checkValBind vb = do let vb' = ValBind entry' fname maybe_tdecl' (Info rettype) tparams' params' body' doc attrs' loc - debugTraceM $ unlines ["# Inferred:", prettyString vb'] + debugTraceM 3 $ unlines ["# Inferred:", prettyString vb'] pure ( mempty diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 186f8d0fb3..9672e9aa60 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -195,7 +195,7 @@ enumerateRankSols prog = solveRankILP :: (MonadTypeChecker m) => SrcLoc -> LinearProg -> m [Map VName Int] solveRankILP loc prog = do - debugTraceM $ + debugTraceM 3 $ unlines [ "## solveRankILP", prettyString prog @@ -203,9 +203,9 @@ solveRankILP loc prog = do case enumerateRankSols prog of [] -> typeError loc mempty "Rank ILP cannot be solved." rs -> do - debugTraceM "## rank maps" + debugTraceM 3 "## rank maps" forM_ (zip [0 :: Int ..] rs) $ \(i, r) -> - debugTraceM $ + debugTraceM 3 $ unlines $ "\n## rank map " <> prettyString i : map prettyString (M.toList r) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 5eae32ac6b..ec30e59689 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -963,7 +963,7 @@ checkApply loc fn@(fname, _) ft@(Scalar (Arrow _ pname _ tp1 tp2)) argexp am = d then splitArrayAt (autoFrameRank am) <$> normTypeFully tp1 else splitArrayAt (autoFrameRank am) <$> normTypeFully argtype - debugTraceM $ + debugTraceM 3 $ unlines [ "## checkApply", "## fn", diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index cc1f611129..1e8dbcefbb 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -640,8 +640,8 @@ checkApplyOne loc fname (fframe, ftype) (argframe, argtype) = do rhs = arrayOf (toShape (SVar m)) a ctAM r m $ fmap toSComp (toShape m_var <> fframe) ctEq lhs rhs - debugTraceM $ - unlines + debugTraceM 3 $ + unlines $ [ "## checkApplyOne", "## fname", prettyString fname, @@ -1184,9 +1184,9 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do tyvars <- gets termTyVars - debugTraceM $ "\n# function " <> prettyNameString fname <> "\n# " <> locStr loc <> "\n" + debugTraceM 3 $ "\n# function " <> prettyNameString fname <> "\n# " <> locStr loc <> "\n" - debugTraceM $ + debugTraceM 3 $ unlines [ "## cts:", unlines $ map prettyString cts, @@ -1203,7 +1203,7 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do bitraverse pure (traverse (doDefaults mempty)) . uncurry solve forM_ (zip solutions cts_tyvars') $ \(solution, (cts', tyvars')) -> - debugTraceM $ + debugTraceM 3 $ unlines [ "## constraints:", unlines $ map prettyString cts', From 178a7a09e30b262be1b8263ca9295cd9af007d10 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Fri, 1 Mar 2024 10:57:46 -0800 Subject: [PATCH 156/258] Expand AUTOMAP annotations in normalization. --- src/Futhark/Internalise/Exps.hs | 13 + src/Futhark/Internalise/FullNormalise.hs | 344 ++++++++++++++++++++++- 2 files changed, 356 insertions(+), 1 deletion(-) diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index b5684552e4..11b261e425 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -1624,6 +1624,19 @@ isIntrinsicFunction qname args loc = do fmap pure $ letSubExp desc $ I.BasicOp $ I.ConvOp conv x' handleOps _ _ = Nothing + handleSOACs (lam : args) "map" = Just $ \desc -> do + let internaliseVName x = do + es <- map (BasicOp . SubExp) <$> internaliseExp "arg" x + concat <$> mapM (letValExp "arg") es + args' <- concat <$> mapM internaliseVName args + param_ts <- mapM (fmap (I.stripArray 1) . lookupType) args' + map_dim <- (head . I.shapeDims . I.arrayShape) <$> lookupType (head args') + lambda <- internaliseLambdaCoerce lam param_ts + letTupExp' + desc + $ Op + $ Screma map_dim args' + $ mapSOAC lambda handleSOACs [k, lam, arr] "partition" = do k' <- fromIntegral <$> fromInt32 k Just $ \_desc -> do diff --git a/src/Futhark/Internalise/FullNormalise.hs b/src/Futhark/Internalise/FullNormalise.hs index a3dca1f8bd..fe197a842c 100644 --- a/src/Futhark/Internalise/FullNormalise.hs +++ b/src/Futhark/Internalise/FullNormalise.hs @@ -23,11 +23,18 @@ module Futhark.Internalise.FullNormalise (transformProg) where import Control.Monad.Reader import Control.Monad.State import Data.Bifunctor +import Data.Functor.Identity +import Data.List (zip4) import Data.List.NonEmpty qualified as NE import Data.Map qualified as M +import Data.Maybe import Data.Text qualified as T +import Debug.Trace import Futhark.MonadFreshNames +import Futhark.Util.Pretty import Language.Futhark +import Language.Futhark.Pretty +import Language.Futhark.Primitive (intValue) import Language.Futhark.Traversals import Language.Futhark.TypeChecker.Types @@ -347,7 +354,7 @@ getOrdering final (AppExp (Match expr cs loc) resT) = do -- a complete separtion of states. transformBody :: (MonadFreshNames m) => Exp -> m Exp transformBody e = do - (e', pre_eval) <- runOrdering (getOrdering True e) + (e', pre_eval) <- runOrdering . getOrdering True =<< expandAMAnnotations e pure $ foldl f e' pre_eval where appRes = case e of @@ -366,3 +373,338 @@ transformValBind valbind = do transformProg :: (MonadFreshNames m) => [ValBind] -> m [ValBind] transformProg = mapM transformValBind + +--- | Expansion of 'AutoMap'-annotated applications. +--- +--- Each application @f x@ has an annotation with @AutoMap R M F@ where +--- @R, M, F@ are the autorep, automap, and frame shapes, +--- respectively. +--- +--- The application @f x@ will have type @F t@ for some @t@, i.e. @(f +--- x) : F t@. The frame @F@ is a prefix of the type of @f x@; namely +--- it is the total accumulated shape that is due to implicit maps. +--- Another way of thinking about that is that @|F|@ is is the level +--- of the automap-nest that @f x@ is in. For example, if @|F| = 2@ +--- then we know that @f x@ implicitly stands for +--- +--- > map (\x' -> map (\x'' -> f x'') x') x +--- +--- For an application with a non-empty autorep annotation, the frame +--- tells about how many dimensions of the replicate can be eliminated. +--- For example, @[[1,2],[3,4]] + 5@ will yield the following annotations: +--- +--- > ([[1,2],[3,4]] +) -- AutoMap {R = mempty, M = [2][2], F = [2][2]} +--- > (([[1,2],[3,4]] +) 5) -- AutoMap {R = [2][2], M = mempty, F = [2][2]} +--- +--- All replicated arguments are pushed down the auto-map nest. Each +--- time a replicated argument is pushed down a level of an +--- automap-nest, one fewer replicates is needed (i.e., the outermost +--- dimension of @R@ can be dropped). Replicated arguments are pushed +--- down the nest until either 1) the bottom of the nest is encountered +--- or 2) no replicate dimensions remain. For example, in the second +--- application above @R@ = @F@, so we can push the replicated argument +--- down two levels. Since each level effectively removes a dimension +--- of the replicate, no replicates will be required: +--- +--- > map (\xs -> map (\x -> f x'' 5) xs) [[1,2],[3,4]] +--- +--- The number of replicates that are actually required is given by +--- max(|R| - |F|, 0). +--- +--- An expression's "true level" is the level at which that expression +--- will appear in the automap-nest. The bottom of a mapnest is level 0. +--- +--- * For annotations with @R = mempty@, the true level is @|F|@. +--- * For annotations with @M = mempty@, the true level is @|F| - |R|@. +--- +--- If @|R| > |F|@ then actual replicates (namely @|R| - |F|@ of them) +--- will be required at the bottom of the mapnest. +--- +--- Note that replicates can only appear at the bottom of a mapnest; any +--- expression of the form +--- +--- > map (\ls x' rs -> e) (replicate x) +--- +--- can always be written as +--- +--- > map (\ls rs -> e[x' -> x]) +--- +--- Let's look at another example. Consider (with exact sizes omitted for brevity) +--- +--- > f : a -> a -> a -> []a -> [][][]a -> a +--- > xss : [][]a +--- > ys : []a +--- > zsss : [][][]a +--- > w : a +--- > vss : [][]a +--- +--- and the application +--- +--- > f xss ys zsss w vss +--- +--- which will have the following annotations +--- +--- > (f xss) -- AutoMap {R = mempty, M = [][], F = [][]} (1) +--- > ((f xss) ys) -- AutoMap {R = [], M = mempty, F = [][]} (2) +--- > (((f xss) ys) zsss) -- AutoMap {R = mempty, M = [], F = [][][]} (3) +--- > ((((f xss) ys) zsss) w) -- AutoMap {R = [][][][], M = mempty, F = [][][]} (4) +--- > (((((f xss) ys) zsss) w) vss) -- AutoMap {R = [], M = mempty, F = [][][]} (5) +--- +--- This will yield the following mapnest. +--- +--- > map (\zss -> +--- > map (\xs zs vs -> +--- > map (\x y z v -> f x y z (replicate w) v) xs ys zs v) xss zss vss) zsss +--- +--- Let's see how we'd construct this mapnest from the annotations. We construct +--- the nest bottom-up. We have: +--- +--- Application | True level +--- --------------------------- +--- (1) | |[][]| = 2 +--- (2) | |[][]| - |[]| = 1 +--- (3) | |[][][]| = 3 +--- (4) | |[][][]| - |[][][][]| = -1 +--- (5) | |[][][]| - |[]| = 2 +--- +--- We start at level 0. +--- * Any argument with a negative true level of @-n@ will be replicated @n@ times; +--- the exact shapes can be found by removing the @F@ postfix from @R@, +--- i.e. @R = shapes_to_rep_by <> F@. +--- * Any argument with a 0 true level will be included. +--- * For any argument @arg@ with a positive true level, we construct a new parameter +--- whose type is @arg@ with the leading @n@ dimensions (where @n@ is the true level) +--- removed. +--- +--- Following the rules above, @w@ will be replicated once. For the remaining arguments, +--- we create new parameters @x : a, y : a, z : a , v : a@. Hence, level 0 becomes +--- +--- > f x y z (replicate w) v +--- +--- At level l > 0: +--- * There are no replicates. +--- * Any argument with l true level will be included verbatim. +--- * Any argument with true level > l will have a new parameter constructed for it, +--- whose type has the leading @n - l@ dimensions (where @n@ is the true level) removed. +--- * We surround the previous level with a map that binds that levels' new parameters +--- and is passed the current levels' arguments. +--- +--- Following the above recipe for level 1, we create parameters +--- @xs : []a, zs : []a, vs :[]a@ and obtain +--- +--- > map (\x y z v -> f x y z (replicate w) v) xs ys zs vs +--- +--- This process continues until the level is greater than the maximum +--- true level of any application, at which we terminate. + +-- | Expands 'AutoMap' annotations into explicit @map@s and @replicates@. +expandAMAnnotations :: (MonadFreshNames m) => Exp -> m Exp +expandAMAnnotations e = do + case e of + (AppExp (Apply f args loc) (Info res)) -> do + let ((exts, ams), arg_es) = first unzip $ unzip $ map (first unInfo) $ NE.toList args + f' <- expandAMAnnotations f + arg_es' <- mapM expandAMAnnotations arg_es + let diets = funDiets $ typeOf f + withMapNest loc (zip4 exts ams arg_es diets) $ \args' -> do + inner_f <- setNewType f' $ innerFType (typeOf f') ams + let (_, ret) = unfoldFunType $ typeOf inner_f + + -- when (any (/= mempty) ams) $ + -- traceM $ + -- unlines $ + -- [ "##f'", + -- prettyString $ typeOf f', + -- "##inner_f", + -- prettyString $ typeOf inner_f, + -- "##e", + -- prettyString e, + -- "##ams", + -- show ams + -- ] + pure $ + mkApply inner_f (zip3 exts (repeat mempty) args') $ + res {appResType = snd $ unfoldFunType $ typeOf inner_f} + (AppExp (BinOp op (Info t) (x, Info (xext, xam)) (y, Info (yext, yam)) loc) (Info res)) -> do + x' <- expandAMAnnotations x + y' <- expandAMAnnotations y + withMapNest loc [(xext, xam, x', Observe), (yext, yam, y', Observe)] $ \[x'', y''] -> + pure $ + AppExp + ( BinOp + op + (Info t) + (x'', Info (xext, mempty)) + (y'', Info (yext, mempty)) + loc + ) + (Info res {appResType = stripArray (shapeRank $ autoFrame yam) (appResType res)}) + _ -> astMap identityMapper {mapOnExp = expandAMAnnotations} e + where + setNewType e t = astMap identityMapper {mapOnStructType = const $ pure t} e + + funDiets :: TypeBase dim as -> [Diet] + funDiets (Scalar (Arrow _ _ d _ (RetType _ t2))) = d : funDiets t2 + funDiets _ = [] + + dropDims :: Int -> TypeBase dim as -> TypeBase dim as + dropDims n (Scalar (Arrow u p diet t1 (RetType ds t2))) = + Scalar (Arrow u p diet (stripArray n t1) (RetType ds (dropDims n t2))) + dropDims n t = stripArray n t + + innerFType :: TypeBase dim as -> [AutoMap] -> TypeBase dim as + innerFType (Scalar (Arrow u p diet t1 (RetType ds t2))) ams = + Scalar $ Arrow u p diet t1 $ RetType ds $ innerFType' t2 ams + where + innerFType' t [] = t + innerFType' (Scalar (Arrow u p diet t1 (RetType ds t2))) (am : ams) = + Scalar $ Arrow u p diet (dropDims (shapeRank (autoMap am)) t1) $ RetType ds $ innerFType' t2 ams + innerFType' t [am] = dropDims (shapeRank (autoMap am)) t + innerFType' _ _ = error "" + innerFType _ _ = error "" + +type Level = Int + +data AutoMapArg = AutoMapArg + { amArg :: Exp + } + deriving (Show) + +data AutoMapParam = AutoMapParam + { amParam :: Pat ParamType, + amMapDim :: Size, + amDiet :: Diet + } + deriving (Show) + +-- | Builds a map-nest based on the 'AutoMap' annotations. +withMapNest :: + forall m. + (MonadFreshNames m) => + SrcLoc -> + [(Maybe VName, AutoMap, Exp, Diet)] -> + ([Exp] -> m Exp) -> + m Exp +withMapNest loc args f = do + (param_map, arg_map) <- + bimap combineMaps combineMaps . unzip <$> mapM buildArgMap args + buildMapNest param_map arg_map $ maximum $ M.keys arg_map + where + combineMaps :: (Ord k) => [M.Map k v] -> M.Map k [v] + combineMaps = M.unionsWith (<>) . (fmap . fmap) pure + + buildMapNest :: + M.Map Level [AutoMapParam] -> + M.Map Level [AutoMapArg] -> + Level -> + m Exp + buildMapNest _ arg_map 0 = + f $ map amArg $ arg_map M.! 0 + buildMapNest param_map arg_map l = + case map amMapDim $ param_map M.! l of + [] -> error "Malformed param map." + (map_dim : _) -> do + let params = map (\p -> (amDiet p, amParam p)) $ param_map M.! l + args = map amArg $ arg_map M.! l + body <- buildMapNest param_map arg_map (l - 1) + pure $ + mkMap map_dim params body args $ + RetType [] $ + arrayOfWithAliases Unique (Shape [map_dim]) (typeOf body) + + buildArgMap :: + (Maybe VName, AutoMap, Exp, Diet) -> + m (M.Map Level AutoMapParam, M.Map Level AutoMapArg) + buildArgMap (ext, am, arg, diet) = + foldM (mkArgsAndParams arg) mempty $ reverse [0 .. trueLevel am] + where + mkArgsAndParams arg (p_map, a_map) l + | l == 0 = do + let arg' = maybe arg (paramToExp . amParam) (p_map M.!? 1) + rarg <- mkReplicateShape (autoRep am `shapePrefix` autoFrame am) arg' + pure (p_map, M.insert 0 (AutoMapArg rarg) a_map) + | l == trueLevel am = do + p <- mkAMParam (typeOf arg) l + let d = outerDim am l + pure + ( M.insert l (AutoMapParam p d diet) p_map, + M.insert l (AutoMapArg arg) a_map + ) + | l < trueLevel am && l > 0 = do + p <- mkAMParam (typeOf arg) l + let d = outerDim am l + let arg' = + paramToExp $ + amParam $ + p_map M.! (l + 1) + pure + ( M.insert l (AutoMapParam p d diet) p_map, + M.insert l (AutoMapArg arg') a_map + ) + | otherwise = error "Impossible." + + mkAMParam t level = + mkParam ("p_" <> show level) $ argType (level - 1) am t + + trueLevel :: AutoMap -> Int + trueLevel am + | autoMap am == mempty = + max 0 $ shapeRank (autoFrame am) - shapeRank (autoRep am) + | otherwise = + shapeRank $ autoFrame am + + outerDim :: AutoMap -> Int -> Size + outerDim am level = + (!! (trueLevel am - level)) $ shapeDims $ autoFrame am + + argType level am = stripArray (trueLevel am - level) + +mkParam :: (MonadFreshNames m) => String -> TypeBase Size u -> m (Pat ParamType) +mkParam desc t = do + x <- newVName desc + pure $ Id x (Info $ toParam Observe t) mempty + +mkReplicateShape :: (MonadFreshNames m) => Shape Size -> Exp -> m Exp +mkReplicateShape s e = foldM (flip mkReplicate) e s + +mkReplicate :: (MonadFreshNames m) => Exp -> Exp -> m Exp +mkReplicate dim e = do + x <- mkParam "x" (Scalar $ Prim $ Unsigned Int64) + pure $ + mkMap dim [(Observe, x)] e [xs] $ + RetType mempty (arrayOfWithAliases Unique (Shape [dim]) (typeOf e)) + where + xs = + AppExp + ( Range + (Literal (UnsignedValue $ intValue Int64 0) mempty) + Nothing + (UpToExclusive dim) + mempty + ) + ( Info $ AppRes (arrayOf (Shape [dim]) (Scalar $ Prim $ Unsigned Int64)) [] + ) + +mkMap :: Exp -> [(Diet, Pat ParamType)] -> Exp -> [Exp] -> ResRetType -> Exp +mkMap dim params body arrs rettype = + mkApply mapN args (AppRes (toStruct $ retType rettype) []) + where + args = map (Nothing,mempty,) $ lambda : arrs + mapt = foldFunType (zipWith toParam (Observe : map fst params) (typeOf lambda : map typeOf arrs)) rettype + mapN = Var (QualName [] $ VName "map" 0) (Info mapt) mempty + lambda = + Lambda + (map snd params) + body + Nothing + ( Info $ + RetType + (retDims rettype) + (typeOf body `setUniqueness` uniqueness (retType rettype)) + ) + mempty + +paramToExp :: Pat ParamType -> Exp +paramToExp (Id vn (Info t) loc) = + Var (QualName [] vn) (Info $ toStruct t) loc +paramToExp p = error $ prettyString p From e971922f16054fe4ac13b7c8b9e67475326d5798 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Sun, 3 Mar 2024 14:13:17 -0800 Subject: [PATCH 157/258] Apostrophes are important, man. --- src/Futhark/Internalise/FullNormalise.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Futhark/Internalise/FullNormalise.hs b/src/Futhark/Internalise/FullNormalise.hs index fe197a842c..c615e49c3d 100644 --- a/src/Futhark/Internalise/FullNormalise.hs +++ b/src/Futhark/Internalise/FullNormalise.hs @@ -506,7 +506,7 @@ expandAMAnnotations e = do f' <- expandAMAnnotations f arg_es' <- mapM expandAMAnnotations arg_es let diets = funDiets $ typeOf f - withMapNest loc (zip4 exts ams arg_es diets) $ \args' -> do + withMapNest loc (zip4 exts ams arg_es' diets) $ \args' -> do inner_f <- setNewType f' $ innerFType (typeOf f') ams let (_, ret) = unfoldFunType $ typeOf inner_f From d62614f44df55571c63a988b9a7175f73f83df03 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Sun, 3 Mar 2024 18:00:58 -0800 Subject: [PATCH 158/258] Fixes/clean-up. --- src/Futhark/Internalise/Exps.hs | 24 +++++++++++++++++------- src/Futhark/Internalise/FullNormalise.hs | 15 +-------------- tests/automap/mri-q.fut | 6 ++++-- tests/automap/optionpricing.fut | 12 ++++++------ tests/automap/sgemm.fut | 2 +- 5 files changed, 29 insertions(+), 30 deletions(-) diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index 11b261e425..c5fe0a8fcd 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -1625,17 +1625,27 @@ isIntrinsicFunction qname args loc = do handleOps _ _ = Nothing handleSOACs (lam : args) "map" = Just $ \desc -> do - let internaliseVName x = do - es <- map (BasicOp . SubExp) <$> internaliseExp "arg" x - concat <$> mapM (letValExp "arg") es - args' <- concat <$> mapM internaliseVName args - param_ts <- mapM (fmap (I.stripArray 1) . lookupType) args' - map_dim <- (head . I.shapeDims . I.arrayShape) <$> lookupType (head args') + arg_ses <- concat <$> mapM (internaliseExp "arg") args + arg_ts <- mapM subExpType arg_ses + let param_ts = map rowType arg_ts + map_dim = head $ I.shapeDims $ I.arrayShape $ head arg_ts + + arg_ses' <- + zipWithM + ( \p a -> + ensureShape "" mempty (arrayOfRow p map_dim) "" a + ) + param_ts + arg_ses + + args_v'' <- mapM (letExp "" . BasicOp . SubExp) arg_ses' + lambda <- internaliseLambdaCoerce lam param_ts + letTupExp' desc $ Op - $ Screma map_dim args' + $ Screma map_dim args_v'' $ mapSOAC lambda handleSOACs [k, lam, arr] "partition" = do k' <- fromIntegral <$> fromInt32 k diff --git a/src/Futhark/Internalise/FullNormalise.hs b/src/Futhark/Internalise/FullNormalise.hs index c615e49c3d..52dc0fb0f5 100644 --- a/src/Futhark/Internalise/FullNormalise.hs +++ b/src/Futhark/Internalise/FullNormalise.hs @@ -509,19 +509,6 @@ expandAMAnnotations e = do withMapNest loc (zip4 exts ams arg_es' diets) $ \args' -> do inner_f <- setNewType f' $ innerFType (typeOf f') ams let (_, ret) = unfoldFunType $ typeOf inner_f - - -- when (any (/= mempty) ams) $ - -- traceM $ - -- unlines $ - -- [ "##f'", - -- prettyString $ typeOf f', - -- "##inner_f", - -- prettyString $ typeOf inner_f, - -- "##e", - -- prettyString e, - -- "##ams", - -- show ams - -- ] pure $ mkApply inner_f (zip3 exts (repeat mempty) args') $ res {appResType = snd $ unfoldFunType $ typeOf inner_f} @@ -610,7 +597,7 @@ withMapNest loc args f = do pure $ mkMap map_dim params body args $ RetType [] $ - arrayOfWithAliases Unique (Shape [map_dim]) (typeOf body) + arrayOfWithAliases Nonunique (Shape [map_dim]) (typeOf body) buildArgMap :: (Maybe VName, AutoMap, Exp, Diet) -> diff --git a/tests/automap/mri-q.fut b/tests/automap/mri-q.fut index 8fe26aded6..3a4648c7b9 100644 --- a/tests/automap/mri-q.fut +++ b/tests/automap/mri-q.fut @@ -35,5 +35,7 @@ def main_am [numK][numX] entry main [numK][numX] (kx: [numK]f32) (ky: [numK]f32) (kz: [numK]f32) (x: [numX]f32) (y: [numX]f32) (z: [numX]f32) - (phiR: [numK]f32) (phiI: [numK]f32) : bool = - main_orig kx ky kz x y z phiR phiI == main_am kx ky kz x y z phiR phiI + (phiR: [numK]f32) (phiI: [numK]f32) = + let (qr, qi) = main_orig kx ky kz x y z phiR phiI + let (qr_am, qi_am) = main_am kx ky kz x y z phiR phiI + in and (map2 (==) qr qr_am && qi == qi_am) diff --git a/tests/automap/optionpricing.fut b/tests/automap/optionpricing.fut index c58bc39a0a..c4c916521f 100644 --- a/tests/automap/optionpricing.fut +++ b/tests/automap/optionpricing.fut @@ -68,11 +68,11 @@ def sobolRecI_am [num_bits][n] (sob_dir_vs: [n][num_bits]i32, prev: [n]i32, x: i def sobolReci2_am [n][num_bits] (sob_dirs: [n][num_bits]i32, prev: [n]i32, i: i32): [n]i32= prev ^ recM(sob_dirs, i) -entry sobolIndR [m][num_bits] (dir_vs: [m][num_bits]i32) (n: i32): []bool = - sobolIndR_orig dir_vs n == sobolIndR_am dir_vs n +entry sobolIndR [m][num_bits] (dir_vs: [m][num_bits]i32) (n: i32): bool = + and (sobolIndR_orig dir_vs n == sobolIndR_am dir_vs n) -entry sobolRecI [num_bits][n] (sob_dir_vs: [n][num_bits]i32) (prev: [n]i32) (x: i32): []bool = - sobolRecI_orig (sob_dir_vs, prev, x) == sobolRecI_am (sob_dir_vs, prev, x) +entry sobolRecI [num_bits][n] (sob_dir_vs: [n][num_bits]i32) (prev: [n]i32) (x: i32): bool = + and (sobolRecI_orig (sob_dir_vs, prev, x) == sobolRecI_am (sob_dir_vs, prev, x)) -entry sobolReci2 [n][num_bits] (sob_dirs: [n][num_bits]i32) (prev: [n]i32) (i: i32): []bool = - sobolReci2_orig (sob_dirs, prev, i) == sobolReci2_am (sob_dirs, prev, i) +entry sobolReci2 [n][num_bits] (sob_dirs: [n][num_bits]i32) (prev: [n]i32) (i: i32): bool = + and (sobolReci2_orig (sob_dirs, prev, i) == sobolReci2_am (sob_dirs, prev, i)) diff --git a/tests/automap/sgemm.fut b/tests/automap/sgemm.fut index 56dc08eb7e..a31ce0188e 100644 --- a/tests/automap/sgemm.fut +++ b/tests/automap/sgemm.fut @@ -29,4 +29,4 @@ def main_am [n][m][p] (ass: [n][m]f32) (bss: [m][p]f32) (css: [n][p]f32) entry main [n][m][p] (ass: [n][m]f32) (bss: [m][p]f32) (css: [n][p]f32) (alpha: f32) (beta: f32) = - main_orig ass bss css alpha beta == main_am ass bss css alpha beta + and (and (main_orig ass bss css alpha beta == main_am ass bss css alpha beta)) From 51c5393d515c8a3ec8f4ab25cfc3ecbd030e471a Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 4 Mar 2024 10:47:33 +0100 Subject: [PATCH 159/258] Tuples must have more than one field. --- src/Language/Futhark/Tuple.hs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/Language/Futhark/Tuple.hs b/src/Language/Futhark/Tuple.hs index a410ae0a5a..63cf7c1188 100644 --- a/src/Language/Futhark/Tuple.hs +++ b/src/Language/Futhark/Tuple.hs @@ -17,7 +17,8 @@ import Language.Futhark.Core (Name, nameFromString, nameToText) areTupleFields :: M.Map Name a -> Maybe [a] areTupleFields fs = let fs' = sortFields fs - in if and $ zipWith (==) (map fst fs') tupleFieldNames + in if length fs' > 1 + && and (zipWith (==) (map fst fs') tupleFieldNames) then Just $ map snd fs' else Nothing From 8260a4da843a25570a9fd463ff2cf49951fe953d Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 4 Mar 2024 13:44:37 +0100 Subject: [PATCH 160/258] Start handling overloaded tyvars. --- .../Futhark/TypeChecker/Constraints.hs | 47 ++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index fc4f1fb5f8..ebac367dfb 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -108,6 +108,16 @@ substTyVar m v = Just (TyVarUnsol {}) -> Nothing Nothing -> Nothing +lookupTyVar :: TyVar -> SolveM (Maybe Type) +lookupTyVar orig = do + tyvars <- gets solverTyVars + let f v = case M.lookup v tyvars of + Nothing -> error $ "Unknown tyvar: " <> prettyNameString v + Just (TyVarSol _ t) -> pure $ Just t + Just (TyVarLink v') -> f v' + Just (TyVarUnsol {}) -> pure Nothing + f orig + -- | A solution maps a type variable to its substitution. This -- substitution is complete, in the sense there are no right-hand -- sides that contain a type variable. @@ -238,11 +248,46 @@ solveCt ct = Nothing -> bad Just eqs -> mapM_ solveCt' eqs +solveTyVar :: (VName, (Int, TyVarInfo)) -> SolveM () +solveTyVar (tv, (_, TyVarFree {})) = pure () +solveTyVar (tv, (_, TyVarPrim pts)) = do + t <- lookupTyVar tv + case t of + Nothing -> pure () + Just t' + | t' `elem` map (Scalar . Prim) pts -> pure () + | otherwise -> + throwError $ + "Type variable " + <> prettyNameText tv + <> " must be one of\n" + <> prettyText pts + <> "\nbut inferred to be\n" + <> prettyText t' +solveTyVar (tv, (_, TyVarRecord fs1)) = do + tv_t <- lookupTyVar tv + case tv_t of + Nothing -> pure () + Just (Scalar (Record fs2)) + | all (`M.member` fs2) (M.keys fs1) -> + forM_ (M.toList $ M.intersectionWith (,) fs1 fs2) $ \(k, (t1, t2)) -> + solveCt $ CtEq t1 t2 + Just tv_t' -> + throwError $ + "Type variable " + <> prettyNameText tv + <> " must be record with fields\n" + <> prettyText (Scalar (Record fs1)) + <> " but inferred to be\n" + <> prettyText tv_t' + solve :: Constraints -> TyVars -> Either T.Text ([VName], Solution) solve constraints tyvars = second solution . runExcept . flip execStateT (initialState tyvars) . runSolveM - $ mapM solveCt constraints + $ do + mapM_ solveCt constraints + mapM_ solveTyVar (M.toList tyvars) {-# NOINLINE solve #-} From 5bab301710d0c33293940ed9064f9600b19a14e3 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 4 Mar 2024 13:45:21 +0100 Subject: [PATCH 161/258] This order is better. --- src/Language/Futhark/TypeChecker.hs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/Language/Futhark/TypeChecker.hs b/src/Language/Futhark/TypeChecker.hs index 70824ad148..a69722ee2b 100644 --- a/src/Language/Futhark/TypeChecker.hs +++ b/src/Language/Futhark/TypeChecker.hs @@ -708,14 +708,14 @@ checkValBind vb = do checkFunDef (fname, maybe_tdecl, tparams, params, body, loc) let entry' = Info (entryPoint params' maybe_tdecl' rettype) <$ entry + vb' = ValBind entry' fname maybe_tdecl' (Info rettype) tparams' params' body' doc attrs' loc + + debugTraceM 3 $ unlines ["# Inferred:", prettyString vb'] + case entry' of Just _ -> checkEntryPoint loc tparams' params' maybe_tdecl' rettype _ -> pure () - let vb' = ValBind entry' fname maybe_tdecl' (Info rettype) tparams' params' body' doc attrs' loc - - debugTraceM 3 $ unlines ["# Inferred:", prettyString vb'] - pure ( mempty { envVtable = From 9e2c1ee9b77ab60538687e2261d6202eb58839f3 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 5 Mar 2024 14:02:50 +0100 Subject: [PATCH 162/258] Refactor to return one list with everything. --- src/Language/Futhark/TypeChecker/Terms.hs | 13 ++++++----- src/Language/Futhark/TypeChecker/Terms2.hs | 26 ++++++++++------------ 2 files changed, 19 insertions(+), 20 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index ec30e59689..cdc25b99d6 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -22,7 +22,7 @@ import Data.Bifunctor import Data.Bitraversable import Data.Char (isAscii) import Data.Either -import Data.List (delete, find, genericLength, partition) +import Data.List (delete, find, genericLength, partition, unzip4) import Data.List.NonEmpty qualified as NE import Data.Map.Strict qualified as M import Data.Maybe @@ -1635,12 +1635,13 @@ checkFunDef :: Exp ) checkFunDef (fname, retdecl, tparams, params, body, loc) = do - (maybe_tysubstss, params', retdecl', bodys') <- + solutions <- Terms2.checkValDef (fname, retdecl, tparams, params, body, loc) - case (maybe_tysubstss, bodys') of - ([], _) -> error "impossible" - ([maybe_tysubsts], [body']) -> doChecks (maybe_tysubsts, params', retdecl', body') - (substs, bodies') -> + case solutions of + [(maybe_tysubsts, params', retdecl', body')] -> + doChecks (maybe_tysubsts, params', retdecl', body') + ls -> do + let (_, _, _, bodies') = unzip4 ls typeError loc mempty $ stack $ [ "Rank ILP is ambiguous.", diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 1e8dbcefbb..b6953d810f 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -641,7 +641,7 @@ checkApplyOne loc fname (fframe, ftype) (argframe, argtype) = do ctAM r m $ fmap toSComp (toShape m_var <> fframe) ctEq lhs rhs debugTraceM 3 $ - unlines $ + unlines [ "## checkApplyOne", "## fname", prettyString fname, @@ -1168,11 +1168,12 @@ checkValDef :: SrcLoc ) -> TypeM - ( [Either T.Text ([VName], M.Map TyVar (TypeBase () NoUniqueness))], - [Pat ParamType], - Maybe (TypeExp Exp VName), - [Exp] - ) + [ ( Either T.Text ([VName], M.Map TyVar (TypeBase () NoUniqueness)), + [Pat ParamType], + Maybe (TypeExp Exp VName), + Exp + ) + ] checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do (params', body', retdecl') <- bindParams tparams params $ \params' -> do @@ -1196,13 +1197,11 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars ] - (cts_tyvars', bodys') <- unzip <$> rankAnalysis loc cts tyvars body' - - solutions <- - forM cts_tyvars' $ - bitraverse pure (traverse (doDefaults mempty)) . uncurry solve + ranks <- rankAnalysis loc cts tyvars body' - forM_ (zip solutions cts_tyvars') $ \(solution, (cts', tyvars')) -> + forM ranks $ \((cts', tyvars'), body'') -> do + solution <- + bitraverse pure (traverse (doDefaults mempty)) $ solve cts' tyvars' debugTraceM 3 $ unlines [ "## constraints:", @@ -1214,8 +1213,7 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do in either T.unpack (unlines . map p . M.toList . snd) solution, either (const mempty) (unlines . ("## unconstrained:" :) . map prettyNameString . fst) solution ] - - pure (solutions, params', retdecl', bodys') + pure (solution, params', retdecl', body'') checkSingleExp :: ExpBase NoInfo VName -> From a6676efdc7f59f9e3a4848dbf84e6cfbfc360abc Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 5 Mar 2024 14:48:18 +0100 Subject: [PATCH 163/258] Let-generalise in Terms2. --- src/Language/Futhark/TypeChecker/Terms.hs | 82 +++++++++++----------- src/Language/Futhark/TypeChecker/Terms2.hs | 76 ++++++++++++++------ 2 files changed, 96 insertions(+), 62 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index cdc25b99d6..22ecc9ebc0 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -1048,7 +1048,7 @@ checkOneExp e = do (maybe_tysubsts, e') <- Terms2.checkSingleExp e case maybe_tysubsts of Left err -> typeError e' mempty $ pretty err - Right (unconstrained, tysubsts) -> runTermTypeM checkExp tysubsts $ do + Right (generalised, tysubsts) -> runTermTypeM checkExp tysubsts $ do e'' <- checkExp e' let t = typeOf e'' (tparams, _, _) <- @@ -1066,7 +1066,7 @@ checkSizeExp e = do (maybe_tysubsts, e') <- Terms2.checkSizeExp e case maybe_tysubsts of Left err -> typeError e' mempty $ pretty err - Right (unconstrained, tysubsts) -> runTermTypeM checkExp tysubsts $ do + Right (generalised, tysubsts) -> runTermTypeM checkExp tysubsts $ do e'' <- checkExp e' when (hasBinding e'') $ typeError (srclocOf e'') mempty . withIndexLink "size-expression-bind" $ @@ -1492,14 +1492,16 @@ closeOverTypes defname defloc tparams paramts ret substs = do _ -> Nothing pure - ( tparams ++ more_tparams, + ( tparams + ++ more_tparams, injectExt (nubOrd $ retext ++ mapMaybe mkExt (S.toList $ fvVars $ freeInType ret)) ret ) where -- Diet does not matter here. t = foldFunType (map (toParam Observe) paramts) $ RetType [] ret - to_close_over = M.filterWithKey (\k _ -> k `S.member` visible) substs visible = typeVars t <> fvVars (freeInType t) + to_close_over = + M.filterWithKey (\k _ -> k `S.member` visible) substs (produced_sizes, param_sizes) = dimUses t @@ -1655,39 +1657,39 @@ checkFunDef (fname, retdecl, tparams, params, body, loc) = do doChecks (maybe_tysubsts, params', retdecl', body') = case maybe_tysubsts of Left err -> typeError loc mempty $ pretty err - Right (unconstrained, tysubsts) -> runTermTypeM checkExp tysubsts $ do - let unconstrained_tparams = map (\v -> TypeParamType Unlifted v mempty) unconstrained - (tparams', params'', retdecl'', RetType dims rettype', body'') <- - checkBinding (fname, retdecl', unconstrained_tparams <> tparams, params', body', loc) - - -- Since this is a top-level function, we also resolve overloaded - -- types, using either defaults or complaining about ambiguities. - fixOverloadedTypes $ - typeVars rettype' <> foldMap (typeVars . patternType) params'' - - -- Then replace all inferred types in the body and parameters. - body''' <- normTypeFully body'' - params''' <- mapM normTypeFully params'' - retdecl''' <- traverse updateTypes retdecl'' - rettype'' <- normTypeFully rettype' - - -- Check if the function body can actually be evaluated. - causalityCheck body''' - - -- Check for various problems. - mapM_ (mustBeIrrefutable . fmap toStruct) params'' - localChecks body''' - - let ((body'''', updated_ret), errors) = - Consumption.checkValDef - ( fname, - params''', - body''', - RetType dims rettype'', - retdecl''', - loc - ) - - mapM_ throwError errors - - pure (tparams', params''', retdecl''', updated_ret, body'''') + Right (generalised, tysubsts) -> + runTermTypeM checkExp tysubsts $ do + (tparams', params'', retdecl'', RetType dims rettype', body'') <- + checkBinding (fname, retdecl', generalised <> tparams, params', body', loc) + + -- Since this is a top-level function, we also resolve overloaded + -- types, using either defaults or complaining about ambiguities. + fixOverloadedTypes $ + typeVars rettype' <> foldMap (typeVars . patternType) params'' + + -- Then replace all inferred types in the body and parameters. + body''' <- normTypeFully body'' + params''' <- mapM normTypeFully params'' + retdecl''' <- traverse updateTypes retdecl'' + rettype'' <- normTypeFully rettype' + + -- Check if the function body can actually be evaluated. + causalityCheck body''' + + -- Check for various problems. + mapM_ (mustBeIrrefutable . fmap toStruct) params'' + localChecks body''' + + let ((body'''', updated_ret), errors) = + Consumption.checkValDef + ( fname, + params''', + body''', + RetType dims rettype'', + retdecl''', + loc + ) + + mapM_ throwError errors + + pure (tparams', params''', retdecl''', updated_ret, body'''') diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index b6953d810f..d9fc4f82d0 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -49,6 +49,7 @@ import Control.Monad.State import Data.Bifunctor import Data.Bitraversable import Data.Char (isAscii) +import Data.Either (partitionEithers) import Data.List qualified as L import Data.List.NonEmpty qualified as NE import Data.Loc (Loc (NoLoc)) @@ -1125,17 +1126,17 @@ checkExp (Coerce e te NoInfo loc) = do pure $ Coerce e' te' (Info (toStruct st)) loc doDefault :: - S.Set VName -> + [VName] -> VName -> Either [PrimType] (TypeBase () NoUniqueness) -> TermM (TypeBase () NoUniqueness) doDefault tyvars_at_toplevel v (Left pts) | Signed Int32 `elem` pts = do - when (v `S.member` tyvars_at_toplevel) $ + when (v `elem` tyvars_at_toplevel) $ warn usage "Defaulting ambiguous type to i32." pure $ Scalar $ Prim $ Signed Int32 | FloatType Float64 `elem` pts = do - when (v `S.member` tyvars_at_toplevel) $ + when (v `elem` tyvars_at_toplevel) $ warn usage "Defaulting ambiguous type to f64." pure $ Scalar $ Prim $ FloatType Float64 | otherwise = @@ -1152,13 +1153,29 @@ doDefault _ _ (Right t) = pure t -- some type variables becoming known, so we have to perform -- substitutions on the RHS of the substitutions afterwards. doDefaults :: - S.Set VName -> + [VName] -> M.Map TyVar (Either [PrimType] (TypeBase () NoUniqueness)) -> TermM (M.Map TyVar (TypeBase () NoUniqueness)) doDefaults tyvars_at_toplevel substs = do substs' <- M.traverseWithKey (doDefault tyvars_at_toplevel) substs pure $ M.map (substTyVars (`M.lookup` substs')) substs' +generalise :: + StructType -> [VName] -> Solution -> ([TypeParam], [VName]) +generalise fun_t unconstrained solution = + -- Candidates for let-generalisation are those type variables that + -- are used in fun_t. + let visible = foldMap expandTyVars $ typeVars fun_t + onTyVar v + | v `S.member` visible = Left $ TypeParamType Unlifted v mempty + | otherwise = Right v + in partitionEithers $ map onTyVar unconstrained + where + expandTyVars v = + case M.lookup v solution of + Just (Right t) -> foldMap expandTyVars $ typeVars t + _ -> S.singleton v + checkValDef :: ( VName, Maybe (TypeExp (ExpBase NoInfo VName) VName), @@ -1168,7 +1185,7 @@ checkValDef :: SrcLoc ) -> TypeM - [ ( Either T.Text ([VName], M.Map TyVar (TypeBase () NoUniqueness)), + [ ( Either T.Text ([TypeParam], M.Map TyVar (TypeBase () NoUniqueness)), [Pat ParamType], Maybe (TypeExp Exp VName), Exp @@ -1197,23 +1214,38 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars ] - ranks <- rankAnalysis loc cts tyvars body' - - forM ranks $ \((cts', tyvars'), body'') -> do - solution <- - bitraverse pure (traverse (doDefaults mempty)) $ solve cts' tyvars' - debugTraceM 3 $ - unlines - [ "## constraints:", - unlines $ map prettyString cts', - "## tyvars':", - unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars', - "## solution:", - let p (v, t) = prettyNameString v <> " => " <> prettyString t - in either T.unpack (unlines . map p . M.toList . snd) solution, - either (const mempty) (unlines . ("## unconstrained:" :) . map prettyNameString . fst) solution - ] - pure (solution, params', retdecl', body'') + mapM (onRankSolution params' retdecl') =<< rankAnalysis loc cts tyvars body' + where + onRankSolution params' retdecl' ((cts', tyvars'), body'') = do + solution <- + bitraverse pure (onTySolution params' body'') $ solve cts' tyvars' + debugTraceM 3 $ + unlines + [ "## constraints:", + unlines $ map prettyString cts', + "## tyvars':", + unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars', + "## solution:", + let p (v, t) = prettyNameString v <> " => " <> prettyString t + in either T.unpack (unlines . map p . M.toList . snd) solution, + either (const mempty) (unlines . ("## generalised:" :) . map prettyString . fst) solution + ] + pure (solution, params', retdecl', body'') + + onTySolution params' body' (unconstrained, solution) = do + let fun_t = + foldFunType + (map patternType params') + (RetType [] $ toRes Nonunique (typeOf body')) + (generalised, unconstrained') = + generalise fun_t unconstrained solution + solution' <- doDefaults (map typeParamName generalised) solution + pure + ( generalised, + -- See #1552 for why we resolve unconstrained and + -- un-generalised type variables to (). + M.fromList (map (,Scalar (Record mempty)) unconstrained') <> solution' + ) checkSingleExp :: ExpBase NoInfo VName -> From d85bd68e0716ac1a445523753e57c53a27c66aed Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Tue, 5 Mar 2024 09:19:19 -0800 Subject: [PATCH 164/258] Fix return types of partially applied functions in AM nests. --- src/Futhark/Internalise/FullNormalise.hs | 12 ++++++++---- src/Language/Futhark/Prop.hs | 10 ++++++++++ 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/src/Futhark/Internalise/FullNormalise.hs b/src/Futhark/Internalise/FullNormalise.hs index 52dc0fb0f5..9b27a6cf46 100644 --- a/src/Futhark/Internalise/FullNormalise.hs +++ b/src/Futhark/Internalise/FullNormalise.hs @@ -354,7 +354,7 @@ getOrdering final (AppExp (Match expr cs loc) resT) = do -- a complete separtion of states. transformBody :: (MonadFreshNames m) => Exp -> m Exp transformBody e = do - (e', pre_eval) <- runOrdering . getOrdering True =<< expandAMAnnotations e + (e', pre_eval) <- runOrdering $ getOrdering True e pure $ foldl f e' pre_eval where appRes = case e of @@ -368,7 +368,7 @@ transformBody e = do transformValBind :: (MonadFreshNames m) => ValBind -> m ValBind transformValBind valbind = do - body' <- transformBody $ valBindBody valbind + body' <- transformBody <=< expandAMAnnotations $ valBindBody valbind pure $ valbind {valBindBody = body'} transformProg :: (MonadFreshNames m) => [ValBind] -> m [ValBind] @@ -508,10 +508,14 @@ expandAMAnnotations e = do let diets = funDiets $ typeOf f withMapNest loc (zip4 exts ams arg_es' diets) $ \args' -> do inner_f <- setNewType f' $ innerFType (typeOf f') ams - let (_, ret) = unfoldFunType $ typeOf inner_f + let rettype = + case unfoldFunTypeWithRet $ typeOf inner_f of + Nothing -> error "Function type expected." + Just (ptypes, f_ret) -> + foldFunType (drop (length args') ptypes) f_ret pure $ mkApply inner_f (zip3 exts (repeat mempty) args') $ - res {appResType = snd $ unfoldFunType $ typeOf inner_f} + res {appResType = rettype} (AppExp (BinOp op (Info t) (x, Info (xext, xam)) (y, Info (yext, yam)) loc) (Info res)) -> do x' <- expandAMAnnotations x y' <- expandAMAnnotations y diff --git a/src/Language/Futhark/Prop.hs b/src/Language/Futhark/Prop.hs index 280b531286..d45cc3d294 100644 --- a/src/Language/Futhark/Prop.hs +++ b/src/Language/Futhark/Prop.hs @@ -53,6 +53,7 @@ module Language.Futhark.Prop arrayShape, orderZero, unfoldFunType, + unfoldFunTypeWithRet, foldFunType, typeVars, isAccType, @@ -522,6 +523,15 @@ unfoldFunType (Scalar (Arrow _ _ d t1 (RetType _ t2))) = in (second (const d) t1 : ps, r) unfoldFunType t = ([], toStruct t) +-- | Extract the parameter types and 'RetTypeBase' from a function type. +-- If the type is not an arrow type, returns 'Nothing'. +unfoldFunTypeWithRet :: TypeBase dim as -> Maybe ([TypeBase dim Diet], RetTypeBase dim Uniqueness) +unfoldFunTypeWithRet (Scalar (Arrow _ _ d t1 (RetType _ t2@(Scalar Arrow {})))) = do + (ps, r) <- unfoldFunTypeWithRet t2 + pure (second (const d) t1 : ps, r) +unfoldFunTypeWithRet (Scalar (Arrow _ _ d t1 r@RetType {})) = Just ([second (const d) t1], r) +unfoldFunTypeWithRet _ = Nothing + -- | The type scheme of a value binding, comprising the type -- parameters and the actual type. valBindTypeScheme :: ValBindBase Info VName -> ([TypeParamBase VName], StructType) From 43c94793c704e97089b2905eb5d5fa90b2c671d8 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Tue, 5 Mar 2024 09:27:37 -0800 Subject: [PATCH 165/258] Only return a single solution when doing rank analysis. --- src/Language/Futhark/TypeChecker/Rank.hs | 21 ++++++++++++++++++++- src/Language/Futhark/TypeChecker/Terms.hs | 16 ++-------------- src/Language/Futhark/TypeChecker/Terms2.hs | 13 ++++++------- 3 files changed, 28 insertions(+), 22 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 9672e9aa60..aa04ea1401 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -1,4 +1,8 @@ -module Language.Futhark.TypeChecker.Rank (rankAnalysis) where +module Language.Futhark.TypeChecker.Rank + ( rankAnalysis, + rankAnalysis1, + ) +where import Control.Monad.Reader import Control.Monad.State @@ -13,6 +17,7 @@ import Futhark.Solve.GLPK import Futhark.Solve.LP hiding (Constraint, LSum, LinearProg) import Futhark.Solve.LP qualified as LP import Futhark.Util (debugTraceM) +import Futhark.Util.Pretty import Language.Futhark hiding (ScalarType) import Language.Futhark.Traversals import Language.Futhark.TypeChecker.Constraints @@ -211,6 +216,20 @@ solveRankILP loc prog = do : map prettyString (M.toList r) pure rs +rankAnalysis1 :: (MonadTypeChecker m) => SrcLoc -> [Ct] -> TyVars -> Exp -> m (([Ct], TyVars), Exp) +rankAnalysis1 loc cs tyVars body = do + solutions <- rankAnalysis loc cs tyVars body + case solutions of + [sol] -> pure sol + sols -> do + let (_, bodies') = unzip sols + typeError loc mempty $ + stack $ + [ "Rank ILP is ambiguous.", + "Choices:" + ] + ++ map pretty bodies' + rankAnalysis :: (MonadTypeChecker m) => SrcLoc -> [Ct] -> TyVars -> Exp -> m [(([Ct], TyVars), Exp)] rankAnalysis _ [] tyVars body = pure [(([], tyVars), body)] rankAnalysis loc cs tyVars body = do diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 22ecc9ebc0..8f039c0d9b 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -1636,20 +1636,8 @@ checkFunDef :: ResRetType, Exp ) -checkFunDef (fname, retdecl, tparams, params, body, loc) = do - solutions <- - Terms2.checkValDef (fname, retdecl, tparams, params, body, loc) - case solutions of - [(maybe_tysubsts, params', retdecl', body')] -> - doChecks (maybe_tysubsts, params', retdecl', body') - ls -> do - let (_, _, _, bodies') = unzip4 ls - typeError loc mempty $ - stack $ - [ "Rank ILP is ambiguous.", - "Choices:" - ] - ++ map pretty bodies' +checkFunDef (fname, retdecl, tparams, params, body, loc) = + doChecks =<< Terms2.checkValDef (fname, retdecl, tparams, params, body, loc) where -- TODO: Print out the possibilities. (And also potentially eliminate --- some of the possibilities to disambiguate). diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index d9fc4f82d0..da553c4089 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -1185,12 +1185,11 @@ checkValDef :: SrcLoc ) -> TypeM - [ ( Either T.Text ([TypeParam], M.Map TyVar (TypeBase () NoUniqueness)), - [Pat ParamType], - Maybe (TypeExp Exp VName), - Exp - ) - ] + ( Either T.Text ([TypeParam], M.Map TyVar (TypeBase () NoUniqueness)), + [Pat ParamType], + Maybe (TypeExp Exp VName), + Exp + ) checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do (params', body', retdecl') <- bindParams tparams params $ \params' -> do @@ -1214,7 +1213,7 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars ] - mapM (onRankSolution params' retdecl') =<< rankAnalysis loc cts tyvars body' + onRankSolution params' retdecl' =<< rankAnalysis1 loc cts tyvars body' where onRankSolution params' retdecl' ((cts', tyvars'), body'') = do solution <- From 155f2bbb4f0a195685e63ccb78951bbad433910b Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 6 Mar 2024 14:31:27 +0100 Subject: [PATCH 166/258] Strangle some warnings. --- src/Language/Futhark/TypeChecker/Terms.hs | 28 ++++++++--------------- 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 8f039c0d9b..26e0705ca8 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -22,7 +22,7 @@ import Data.Bifunctor import Data.Bitraversable import Data.Char (isAscii) import Data.Either -import Data.List (delete, find, genericLength, partition, unzip4) +import Data.List (delete, find, genericLength, partition) import Data.List.NonEmpty qualified as NE import Data.Map.Strict qualified as M import Data.Maybe @@ -465,7 +465,7 @@ checkExp (Coerce e te _ loc) = do t <- expTypeFully e' t' <- matchDims (const . const pure) t te_t pure $ Coerce e' te' (Info t') loc -checkExp e@(AppExp (Apply fe args loc) _) = do +checkExp (AppExp (Apply fe args loc) _) = do fe' <- checkExp fe let ams = fmap (snd . unInfo . fst) args args' <- mapM (checkExp . snd) args @@ -727,7 +727,7 @@ checkExp (OpSection op (Info op_t) loc) = do checkExp (OpSectionLeft op (Info op_t) e (Info (_, _, _, am), _) _ loc) = do ftype <- lookupVar loc op op_t e' <- checkExp e - (t1, rt, argext, retext, am) <- checkApply loc (Just op, 0) ftype e' am + (t1, rt, argext, retext, am') <- checkApply loc (Just op, 0) ftype e' am case (ftype, rt) of (Scalar (Arrow _ m1 d1 _ _), Scalar (Arrow _ m2 d2 t2 (RetType ds rt2))) -> pure $ @@ -735,8 +735,8 @@ checkExp (OpSectionLeft op (Info op_t) e (Info (_, _, _, am), _) _ loc) = do op (Info ftype) e' - (Info (m1, toParam d1 t1, argext, am), Info (m2, toParam d2 t2)) - (Info $ RetType ds $ arrayOfWithAliases (uniqueness rt2) (autoFrame am) rt2, Info retext) + (Info (m1, toParam d1 t1, argext, am'), Info (m2, toParam d2 t2)) + (Info $ RetType ds $ arrayOfWithAliases (uniqueness rt2) (autoFrame am') rt2, Info retext) loc _ -> typeError loc mempty $ @@ -746,7 +746,7 @@ checkExp (OpSectionRight op (Info op_t) e (_, Info (_, _, _, am)) _ loc) = do e' <- checkExp e case ftype of Scalar (Arrow _ m1 d1 t1 (RetType [] (Scalar (Arrow _ m2 d2 t2 (RetType dims2 ret))))) -> do - (t2', arrow', argext, _, am) <- + (t2', arrow', argext, _, am') <- checkApply loc (Just op, 1) @@ -760,8 +760,8 @@ checkExp (OpSectionRight op (Info op_t) e (_, Info (_, _, _, am)) _ loc) = do op (Info ftype) e' - (Info (m1, toParam d1 t1'), Info (m2, toParam d2 t2', argext, am)) - (Info $ RetType dims2' $ arrayOfWithAliases (uniqueness ret') (autoFrame am) ret') + (Info (m1, toParam d1 t1'), Info (m2, toParam d2 t2', argext, am')) + (Info $ RetType dims2' $ arrayOfWithAliases (uniqueness ret') (autoFrame am') ret') loc _ -> error $ "OpSectionRight: impossible type\n" <> prettyString arrow' _ -> @@ -922,14 +922,6 @@ dimUses = flip execState mempty . traverseDims f where fv = freeInExp e `freeWithout` bound --- | Try to find out how many dimensions of the argument we are --- mapping. Returns the shape mapped and the remaining type. -stripToMatch :: StructType -> StructType -> (Shape Size, StructType) -stripToMatch paramt argt | toStructural paramt == toStructural argt = (mempty, argt) -stripToMatch paramt (Array _ (Shape (d : ds)) argt) = - first (Shape [d] <>) $ stripToMatch paramt $ arrayOf (Shape ds) (Scalar argt) -stripToMatch _ argt = (mempty, argt) - splitArrayAt :: Int -> StructType -> (Shape Size, StructType) splitArrayAt x t = (Shape $ take x $ shapeDims $ arrayShape t, stripArray x t) @@ -1018,14 +1010,14 @@ checkApply loc fn@(fname, _) ft@(Scalar (Arrow _ pname _ tp1 tp2)) argexp am = d in pure (Nothing, applySubst parsubst $ toStruct tp2') _ -> pure (Nothing, toStruct tp2') - let am = + let am' = AutoMap { autoRep = am_rep_shape, autoMap = am_map_shape, autoFrame = am_frame_shape } - pure (tp1, distribute (arrayOf (autoMap am) tp2''), argext, ext, am) + pure (tp1, distribute (arrayOf (autoMap am) tp2''), argext, ext, am') where distribute :: TypeBase dim u -> TypeBase dim u distribute (Array u s (Arrow _ _ _ ta (RetType rd tr))) = From bc063036c1a51d949f29ec9d6a65063b31dd6ab1 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 6 Mar 2024 14:59:19 +0100 Subject: [PATCH 167/258] Fix typo. --- src/Language/Futhark/TypeChecker/Terms.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 26e0705ca8..2c70460821 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -1017,7 +1017,7 @@ checkApply loc fn@(fname, _) ft@(Scalar (Arrow _ pname _ tp1 tp2)) argexp am = d autoFrame = am_frame_shape } - pure (tp1, distribute (arrayOf (autoMap am) tp2''), argext, ext, am') + pure (tp1, distribute (arrayOf (autoMap am') tp2''), argext, ext, am') where distribute :: TypeBase dim u -> TypeBase dim u distribute (Array u s (Arrow _ _ _ ta (RetType rd tr))) = From 07b507d15ee73c960adc605109084089aede8871 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 6 Mar 2024 15:06:21 +0100 Subject: [PATCH 168/258] Remove unneeded things. --- futhark.cabal | 1 - src/Futhark/IR/Syntax/Core.hs | 6 ------ 2 files changed, 7 deletions(-) diff --git a/futhark.cabal b/futhark.cabal index f4fb4062bd..b6f570ff80 100644 --- a/futhark.cabal +++ b/futhark.cabal @@ -495,7 +495,6 @@ library , prettyprinter >= 1.7 , prettyprinter-ansi-terminal >= 1.1 -- remove me later - , process , glpk-hs , silently diff --git a/src/Futhark/IR/Syntax/Core.hs b/src/Futhark/IR/Syntax/Core.hs index 982fadcdec..227c25b23b 100644 --- a/src/Futhark/IR/Syntax/Core.hs +++ b/src/Futhark/IR/Syntax/Core.hs @@ -15,7 +15,6 @@ module Futhark.IR.Syntax.Core ShapeBase (..), Shape, stripDims, - takeDims, Ext (..), ExtSize, ExtShape, @@ -129,11 +128,6 @@ instance Monoid (ShapeBase d) where stripDims :: Int -> ShapeBase d -> ShapeBase d stripDims n (Shape dims) = Shape $ drop n dims --- | @takeDims n shape@ takes the outer @n@ dimensions from --- @shape@. If @shape@ has m <= n dimensions, it returns $shape$. -takeDims :: Int -> ShapeBase d -> ShapeBase d -takeDims n (Shape dims) = Shape $ take n dims - -- | The size of an array as a list of subexpressions. If a variable, -- that variable must be in scope where this array is used. type Shape = ShapeBase SubExp From a49537458d8f36ff0f6704a17f331d8f90e65219 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 6 Mar 2024 15:15:45 +0100 Subject: [PATCH 169/258] Off-by-truth. --- src/Language/Futhark/TypeChecker/Consumption.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Consumption.hs b/src/Language/Futhark/TypeChecker/Consumption.hs index 5ebe2996df..8c92e54d20 100644 --- a/src/Language/Futhark/TypeChecker/Consumption.hs +++ b/src/Language/Futhark/TypeChecker/Consumption.hs @@ -489,7 +489,7 @@ consumeAsNeeded loc pt t = checkArg :: [(Exp, TypeAliases)] -> ParamType -> AutoMap -> Exp -> CheckM (Exp, TypeAliases) checkArg prev p_t am e = do ((e', e_als), e_cons) <- - contain $ if autoRep am == mempty then noAliases e else checkExp e + contain $ if autoRep am /= mempty then noAliases e else checkExp e consumed e_cons let e_t = typeOf e' when (e_cons /= mempty && not (orderZero e_t)) $ From c11e72e88d25f0994e0ee2eafb9b1792f119c142 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 6 Mar 2024 15:46:09 +0100 Subject: [PATCH 170/258] Add bindingParam. --- .../Futhark/TypeChecker/Terms/Loop.hs | 40 +++++++++---------- src/Language/Futhark/TypeChecker/Terms/Pat.hs | 12 +++++- 2 files changed, 30 insertions(+), 22 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms/Loop.hs b/src/Language/Futhark/TypeChecker/Terms/Loop.hs index 51d1c8ceba..e1afdfdb4c 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Loop.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Loop.hs @@ -223,16 +223,15 @@ checkLoop checkExp (mergepat, mergeexp, form, loopbody) loc = do uboundexp' <- checkExp uboundexp it <- expType uboundexp' let i' = i {identType = Info it} - bindingIdent i' . bindingPat [] mergepat merge_t $ - \mergepat' -> incLevel $ do - loopbody' <- checkExp loopbody - (sparams, mergepat'') <- checkLoopReturnSize mergepat' loopbody' - pure - ( sparams, - mergepat'', - For i' uboundexp', - loopbody' - ) + bindingIdent i' . bindingParam mergepat merge_t $ \mergepat' -> incLevel $ do + loopbody' <- checkExp loopbody + (sparams, mergepat'') <- checkLoopReturnSize mergepat' loopbody' + pure + ( sparams, + mergepat'', + For i' uboundexp', + loopbody' + ) ForIn xpat e -> do (arr_t, _) <- newArrayType (mkUsage' (srclocOf e)) "e" 1 e' <- unifies "being iterated in a 'for-in' loop" arr_t =<< checkExp e @@ -241,22 +240,21 @@ checkLoop checkExp (mergepat, mergeexp, form, loopbody) loc = do _ | Just t' <- peelArray 1 t -> bindingPat [] xpat t' $ \xpat' -> - bindingPat [] mergepat merge_t $ - \mergepat' -> incLevel $ do - loopbody' <- checkExp loopbody - (sparams, mergepat'') <- checkLoopReturnSize mergepat' loopbody' - pure - ( sparams, - mergepat'', - ForIn (fmap toStruct xpat') e', - loopbody' - ) + bindingParam mergepat merge_t $ \mergepat' -> incLevel $ do + loopbody' <- checkExp loopbody + (sparams, mergepat'') <- checkLoopReturnSize mergepat' loopbody' + pure + ( sparams, + mergepat'', + ForIn (fmap toStruct xpat') e', + loopbody' + ) | otherwise -> typeError (srclocOf e) mempty $ "Iteratee of a for-in loop must be an array, but expression has type" <+> pretty t While cond -> - bindingPat [] mergepat merge_t $ \mergepat' -> + bindingParam mergepat merge_t $ \mergepat' -> incLevel $ do cond' <- checkExp cond diff --git a/src/Language/Futhark/TypeChecker/Terms/Pat.hs b/src/Language/Futhark/TypeChecker/Terms/Pat.hs index 4485c20f4c..b86dd63616 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Pat.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Pat.hs @@ -2,6 +2,7 @@ module Language.Futhark.TypeChecker.Terms.Pat ( binding, bindingParams, + bindingParam, bindingPat, bindingIdent, bindingSizes, @@ -15,7 +16,6 @@ import Data.List (find, isPrefixOf, sort) import Data.Map.Strict qualified as M import Data.Maybe import Data.Set qualified as S -import Futhark.Util import Futhark.Util.Pretty hiding (group, space) import Language.Futhark import Language.Futhark.TypeChecker.Monad hiding (BoundV) @@ -204,6 +204,16 @@ checkPat sizes p t m = do [] -> m p' +-- | Check and bind a single parameter. +bindingParam :: + Pat ParamType -> + StructType -> + (Pat ParamType -> TermTypeM a) -> + TermTypeM a +bindingParam p t m = do + checkPat mempty p (Ascribed t) $ \p' -> + binding (patIdents (fmap toStruct p')) $ m p' + -- | Check and bind a @let@-pattern. bindingPat :: [SizeBinder VName] -> From 47d24deb79809c8d2d7be7e02bc199764e825a0d Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 6 Mar 2024 16:33:27 +0100 Subject: [PATCH 171/258] Allow touching TypeExps here. --- src/Language/Futhark/Traversals.hs | 42 +++++++++++++++++------------- 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/src/Language/Futhark/Traversals.hs b/src/Language/Futhark/Traversals.hs index 94b440b2ff..cd889944cd 100644 --- a/src/Language/Futhark/Traversals.hs +++ b/src/Language/Futhark/Traversals.hs @@ -27,6 +27,7 @@ module Language.Futhark.Traversals where import Data.Bifunctor +import Data.Bitraversable import Data.List.NonEmpty qualified as NE import Language.Futhark.Syntax @@ -337,31 +338,36 @@ instance ASTMappable (IdentBase Info VName StructType) where astMap tv (Ident name (Info t) loc) = Ident name <$> (Info <$> mapOnStructType tv t) <*> pure loc -traversePat :: (Monad m) => (t1 -> m t2) -> PatBase Info VName t1 -> m (PatBase Info VName t2) -traversePat f (Id name (Info t) loc) = +traversePat :: + (Monad m) => + (t1 -> m t2) -> + (ExpBase Info VName -> m (ExpBase Info VName)) -> + PatBase Info VName t1 -> + m (PatBase Info VName t2) +traversePat f _ (Id name (Info t) loc) = Id name <$> (Info <$> f t) <*> pure loc -traversePat f (TuplePat pats loc) = - TuplePat <$> mapM (traversePat f) pats <*> pure loc -traversePat f (RecordPat fields loc) = - RecordPat <$> mapM (traverse $ traversePat f) fields <*> pure loc -traversePat f (PatParens pat loc) = - PatParens <$> traversePat f pat <*> pure loc -traversePat f (PatAscription pat t loc) = - PatAscription <$> traversePat f pat <*> pure t <*> pure loc -traversePat f (Wildcard (Info t) loc) = +traversePat f g (TuplePat pats loc) = + TuplePat <$> mapM (traversePat f g) pats <*> pure loc +traversePat f g (RecordPat fields loc) = + RecordPat <$> mapM (traverse $ traversePat f g) fields <*> pure loc +traversePat f g (PatParens pat loc) = + PatParens <$> traversePat f g pat <*> pure loc +traversePat f g (PatAscription pat t loc) = + PatAscription <$> traversePat f g pat <*> bitraverse g pure t <*> pure loc +traversePat f _ (Wildcard (Info t) loc) = Wildcard <$> (Info <$> f t) <*> pure loc -traversePat f (PatLit v (Info t) loc) = +traversePat f _ (PatLit v (Info t) loc) = PatLit v <$> (Info <$> f t) <*> pure loc -traversePat f (PatConstr n (Info t) ps loc) = - PatConstr n <$> (Info <$> f t) <*> mapM (traversePat f) ps <*> pure loc -traversePat f (PatAttr attr p loc) = - PatAttr attr <$> traversePat f p <*> pure loc +traversePat f g (PatConstr n (Info t) ps loc) = + PatConstr n <$> (Info <$> f t) <*> mapM (traversePat f g) ps <*> pure loc +traversePat f g (PatAttr attr p loc) = + PatAttr attr <$> traversePat f g p <*> pure loc instance ASTMappable (PatBase Info VName StructType) where - astMap tv = traversePat $ mapOnStructType tv + astMap tv = traversePat (mapOnStructType tv) (mapOnExp tv) instance ASTMappable (PatBase Info VName ParamType) where - astMap tv = traversePat $ mapOnParamType tv + astMap tv = traversePat (mapOnParamType tv) (mapOnExp tv) instance ASTMappable (FieldBase Info VName) where astMap tv (RecordFieldExplicit name e loc) = From b34f3ab8e18eeb961c09acab4d325cf81fcf8cde Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 6 Mar 2024 16:34:57 +0100 Subject: [PATCH 172/258] Also perform AM on expressions in params. We are certainly missing params in expressions, such as in Lambda. --- src/Language/Futhark/TypeChecker/Rank.hs | 39 ++++++++++++---------- src/Language/Futhark/TypeChecker/Terms2.hs | 6 ++-- 2 files changed, 24 insertions(+), 21 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index aa04ea1401..9767c35902 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -123,7 +123,7 @@ addCt (CtAM r m f) = do b_m <- binVar m b_max <- VName "c_max" <$> incCounter tr <- VName ("T_" <> baseName r) <$> incCounter - addConstraints $ [bin b_max, var b_max ~<=~ var tr] + addConstraints [bin b_max, var b_max ~<=~ var tr] addConstraints $ oneIsZero (b_r, r) (b_m, m) addConstraints $ LP.max b_max (constant 0) (rank r ~-~ rank f) (var tr) addObj m @@ -216,13 +216,13 @@ solveRankILP loc prog = do : map prettyString (M.toList r) pure rs -rankAnalysis1 :: (MonadTypeChecker m) => SrcLoc -> [Ct] -> TyVars -> Exp -> m (([Ct], TyVars), Exp) -rankAnalysis1 loc cs tyVars body = do - solutions <- rankAnalysis loc cs tyVars body +rankAnalysis1 :: (MonadTypeChecker m) => SrcLoc -> [Ct] -> TyVars -> [Pat ParamType] -> Exp -> m (([Ct], TyVars), [Pat ParamType], Exp) +rankAnalysis1 loc cs tyVars params body = do + solutions <- rankAnalysis loc cs tyVars params body case solutions of [sol] -> pure sol sols -> do - let (_, bodies') = unzip sols + let (_, _, bodies') = unzip3 sols typeError loc mempty $ stack $ [ "Rank ILP is ambiguous.", @@ -230,13 +230,14 @@ rankAnalysis1 loc cs tyVars body = do ] ++ map pretty bodies' -rankAnalysis :: (MonadTypeChecker m) => SrcLoc -> [Ct] -> TyVars -> Exp -> m [(([Ct], TyVars), Exp)] -rankAnalysis _ [] tyVars body = pure [(([], tyVars), body)] -rankAnalysis loc cs tyVars body = do +rankAnalysis :: (MonadTypeChecker m) => SrcLoc -> [Ct] -> TyVars -> [Pat ParamType] -> Exp -> m [(([Ct], TyVars), [Pat ParamType], Exp)] +rankAnalysis _ [] tyVars params body = pure [(([], tyVars), params, body)] +rankAnalysis loc cs tyVars params body = do rank_maps <- solveRankILP loc (mkLinearProg cs' tyVars) cts_tyvars' <- mapM (substRankInfo cs tyVars) rank_maps let bodys = map (`updAM` body) rank_maps - pure $ zip cts_tyvars' bodys + params' = map ((`map` params) . updAMPat) rank_maps + pure $ zip3 cts_tyvars' params' bodys where cs' = foldMap (splitFuncs . distribute) cs splitFuncs @@ -325,7 +326,7 @@ rankToShape x = do addRankInfo :: (MonadTypeChecker m) => TyVar -> SubstT m () addRankInfo t = do rs <- asks envRanks - if (fromMaybe 0 (rs M.!? t) == 0) + if fromMaybe 0 (rs M.!? t) == 0 then do old_tyvars <- asks envTyVars case old_tyvars M.!? t of @@ -334,7 +335,7 @@ addRankInfo t = do -- -- is anyPrimType right here? -- modify $ -- \s -> s {substTyVars = M.insert t (lvl, TyVarPrim anyPrimType) $ substTyVars s} - _ -> do + _ -> pure () else do new_vars <- gets substNewVars @@ -381,13 +382,7 @@ updAM rank_map e = case e of AppExp (Apply f args loc) res -> let f' = updAM rank_map f - args' = - fmap - ( bimap - (fmap $ bimap id upd) - (updAM rank_map) - ) - args + args' = fmap (bimap (fmap $ second upd) (updAM rank_map)) args in AppExp (Apply f' args' loc) res AppExp (BinOp op t (x, Info (xv, xam)) (y, Info (yv, yam)) loc) res -> AppExp (BinOp op t (updAM rank_map x, Info (xv, upd xam)) (updAM rank_map y, Info (yv, upd yam)) loc) res @@ -403,3 +398,11 @@ updAM rank_map e = identityMapper { mapOnExp = pure . updAM rank_map } + +updAMPat :: M.Map VName Int -> Pat ParamType -> Pat ParamType +updAMPat rank_map p = runIdentity $ astMap m p + where + m = + identityMapper + { mapOnExp = pure . updAM rank_map + } diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index da553c4089..376576a13b 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -1213,9 +1213,9 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars ] - onRankSolution params' retdecl' =<< rankAnalysis1 loc cts tyvars body' + onRankSolution retdecl' =<< rankAnalysis1 loc cts tyvars params' body' where - onRankSolution params' retdecl' ((cts', tyvars'), body'') = do + onRankSolution retdecl' ((cts', tyvars'), params', body'') = do solution <- bitraverse pure (onTySolution params' body'') $ solve cts' tyvars' debugTraceM 3 $ @@ -1267,7 +1267,7 @@ checkSizeExp e = runTermM $ do cts <- gets termConstraints tyvars <- gets termTyVars - (cts_tyvars', es') <- unzip <$> rankAnalysis (srclocOf e) cts tyvars e' + (cts_tyvars', _, es') <- unzip3 <$> rankAnalysis (srclocOf e) cts tyvars [] e' solutions <- forM cts_tyvars' $ From 6305172d9063a9bbde839cd279ecefc57b89b4ad Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Wed, 6 Mar 2024 08:28:59 -0800 Subject: [PATCH 173/258] Remove lingering `PuLP` stuff. --- shell.nix | 3 --- src/Futhark/Solve/LP.hs | 35 ----------------------------------- 2 files changed, 38 deletions(-) diff --git a/shell.nix b/shell.nix index d5199b0c02..a5ddb63ab0 100644 --- a/shell.nix +++ b/shell.nix @@ -52,9 +52,6 @@ pkgs.stdenv.mkDerivation { python3Packages.sphinx python3Packages.sphinxcontrib-bibtex imagemagick # needed for literate tests - # remove (needed for PuLP) - python - cbc glpk ] ++ lib.optionals (stdenv.isLinux) diff --git a/src/Futhark/Solve/LP.hs b/src/Futhark/Solve/LP.hs index a2224617ea..f1b7d18939 100644 --- a/src/Futhark/Solve/LP.hs +++ b/src/Futhark/Solve/LP.hs @@ -28,7 +28,6 @@ module Futhark.Solve.LP (~<=~), (~>=~), rowEchelonLPE, - linearProgToPulp, ) where @@ -168,40 +167,6 @@ instance (Ord v) => Vars (LinearProg v a) v where vars (objective lp) <> foldMap vars (constraints lp) --- For debugging -linearProgToPulp :: (Unbox a, IsName v, Ord v, Pretty a, Eq a, Num a) => LinearProg v a -> String -linearProgToPulp prog = - map rm_subscript $ - unlines - [ "from pulp import *", - "prob = LpProblem('', " <> lptype <> ")", - unlines vars, - unlines $ map (("prob += " <>) . prettyString) $ constraints prog, - "status = prob.solve()", - "print(f'status: {status}')", - unlines res - ] - where - lptype = - case optType prog of - Maximize -> "LpMaximize" - Minimize -> "LpMinimize" - prog_vars = Map.elems $ snd $ linearProgToLP prog - vars = - map - ( \v -> - show (prettyName v) - <> " = " - <> "LpVariable(" - <> "'" - <> show (prettyName v) - <> "_'" - <> ", lowBound = 0, cat = 'Integer')" - ) - prog_vars - res = map (\v -> "print(f'" <> show (prettyName v) <> ": {value(" <> show (prettyName v) <> ")}')") prog_vars - rm_subscript x = fromMaybe x $ lookup x $ zip "₀₁₂₃₄₅₆₇₈₉" "0123456789" - bigM :: (Num a) => a bigM = 2 ^ 10 From c8a4348d5448fa6e19e77044b086477a31b8097a Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 7 Mar 2024 12:49:10 +0100 Subject: [PATCH 174/258] Start adding location info. --- .../Futhark/TypeChecker/Constraints.hs | 88 +++++++++++-------- src/Language/Futhark/TypeChecker/Rank.hs | 10 +-- src/Language/Futhark/TypeChecker/Terms.hs | 6 +- src/Language/Futhark/TypeChecker/Terms2.hs | 23 ++--- 4 files changed, 71 insertions(+), 56 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index ebac367dfb..e78bf993b2 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -1,5 +1,6 @@ module Language.Futhark.TypeChecker.Constraints - ( SVar, + ( Reason (..), + SVar, SComp (..), Type, toType, @@ -16,14 +17,22 @@ where import Control.Monad.Except import Control.Monad.State import Data.Bifunctor +import Data.Loc import Data.Map qualified as M import Data.Maybe import Data.Set qualified as S -import Data.Text qualified as T import Futhark.Util.Pretty import Language.Futhark +import Language.Futhark.TypeChecker.Monad (TypeError (..)) import Language.Futhark.TypeChecker.Types (substTyVars) +-- | The reason for a type constraint. Used to generate type error +-- messages. +newtype Reason = Reason + { reasonLoc :: Loc + } + deriving (Eq, Ord, Show) + type SVar = VName -- | A shape component. `SDim` is a single dimension of unspecified @@ -62,23 +71,31 @@ instance Pretty Ct where type Constraints = [Ct] --- | Information about a type variable. +-- | Information about a type variable. Every type variable is +-- associated with a location, which is the original syntax element +-- that it is the type of. data TyVarInfo = -- | Can be substituted with anything. - TyVarFree + TyVarFree Loc | -- | Can only be substituted with these primitive types. - TyVarPrim [PrimType] + TyVarPrim Loc [PrimType] | -- | Must be a record with these fields. - TyVarRecord (M.Map Name Type) + TyVarRecord Loc (M.Map Name Type) | -- | Must be a sum type with these fields. - TyVarSum (M.Map Name [Type]) + TyVarSum Loc (M.Map Name [Type]) deriving (Show, Eq) instance Pretty TyVarInfo where - pretty TyVarFree = "free" - pretty (TyVarPrim pts) = "∈" <+> pretty pts - pretty (TyVarRecord fs) = pretty $ Scalar $ Record fs - pretty (TyVarSum cs) = pretty $ Scalar $ Sum cs + pretty (TyVarFree _) = "free" + pretty (TyVarPrim _ pts) = "∈" <+> pretty pts + pretty (TyVarRecord _ fs) = pretty $ Scalar $ Record fs + pretty (TyVarSum _ cs) = pretty $ Scalar $ Sum cs + +instance Located TyVarInfo where + locOf (TyVarFree loc) = loc + locOf (TyVarPrim loc _) = loc + locOf (TyVarRecord loc _) = loc + locOf (TyVarSum loc _) = loc type TyVar = VName @@ -134,20 +151,20 @@ solution s = mkSubst (TyVarLink v') = Just . fromMaybe (Right $ Scalar $ TypeVar mempty (qualName v') []) $ mkSubst =<< M.lookup v' (solverTyVars s) - mkSubst (TyVarUnsol _ (TyVarPrim pts)) = Just $ Left pts + mkSubst (TyVarUnsol _ (TyVarPrim _ pts)) = Just $ Left pts mkSubst _ = Nothing - unconstrained (v, TyVarUnsol _ TyVarFree) = Just v + unconstrained (v, TyVarUnsol _ (TyVarFree _)) = Just v unconstrained _ = Nothing -newtype SolveM a = SolveM {runSolveM :: StateT SolverState (Except T.Text) a} - deriving (Functor, Applicative, Monad, MonadState SolverState, MonadError T.Text) +newtype SolveM a = SolveM {runSolveM :: StateT SolverState (Except TypeError) a} + deriving (Functor, Applicative, Monad, MonadState SolverState, MonadError TypeError) occursCheck :: VName -> Type -> SolveM () occursCheck v tp = do vars <- gets solverTyVars let tp' = substTyVars (substTyVar vars) tp - when (v `S.member` typeVars tp') . throwError . docText $ + when (v `S.member` typeVars tp') . throwError . TypeError mempty mempty $ "Occurs check: cannot instantiate" <+> prettyName v <+> "with" @@ -166,7 +183,7 @@ linkTyVar v t = do modify $ \s -> s {solverTyVars = M.insert v (TyVarLink t) $ solverTyVars s} tyvars' <- case (M.lookup v tyvars, M.lookup t tyvars) of - (Just (TyVarUnsol _ info), Just (TyVarUnsol lvl TyVarFree)) -> + (Just (TyVarUnsol _ info), Just (TyVarUnsol lvl (TyVarFree _))) -> pure $ M.insert t (TyVarUnsol lvl info) tyvars -- TODO: handle more cases. _ -> pure tyvars @@ -211,7 +228,7 @@ solveCt ct = CtEq t1 t2 -> solveCt' (t1, t2) CtAM {} -> pure () -- Good vibes only. where - bad = throwError $ "Unsolvable: " <> prettyText ct + bad = throwError $ TypeError mempty mempty $ "Unsolvable:" <+> pretty ct solveCt' (t1, t2) = do tyvars <- gets solverTyVars let flexible v = case M.lookup v tyvars of @@ -249,39 +266,36 @@ solveCt ct = Just eqs -> mapM_ solveCt' eqs solveTyVar :: (VName, (Int, TyVarInfo)) -> SolveM () -solveTyVar (tv, (_, TyVarFree {})) = pure () -solveTyVar (tv, (_, TyVarPrim pts)) = do +solveTyVar (_, (_, TyVarFree {})) = pure () +solveTyVar (tv, (_, TyVarPrim loc pts)) = do t <- lookupTyVar tv case t of Nothing -> pure () Just t' | t' `elem` map (Scalar . Prim) pts -> pure () | otherwise -> - throwError $ - "Type variable " - <> prettyNameText tv - <> " must be one of\n" - <> prettyText pts - <> "\nbut inferred to be\n" - <> prettyText t' -solveTyVar (tv, (_, TyVarRecord fs1)) = do + throwError . TypeError loc mempty $ + "Type must be one of" + indent 2 (pretty pts) + "but inferred to be" + indent 2 (pretty t') +solveTyVar (tv, (_, TyVarRecord loc fs1)) = do tv_t <- lookupTyVar tv case tv_t of Nothing -> pure () Just (Scalar (Record fs2)) | all (`M.member` fs2) (M.keys fs1) -> - forM_ (M.toList $ M.intersectionWith (,) fs1 fs2) $ \(k, (t1, t2)) -> + forM_ (M.toList $ M.intersectionWith (,) fs1 fs2) $ \(_k, (t1, t2)) -> solveCt $ CtEq t1 t2 Just tv_t' -> throwError $ - "Type variable " - <> prettyNameText tv - <> " must be record with fields\n" - <> prettyText (Scalar (Record fs1)) - <> " but inferred to be\n" - <> prettyText tv_t' - -solve :: Constraints -> TyVars -> Either T.Text ([VName], Solution) + TypeError loc mempty $ + "Type must be record with fields" + indent 2 (pretty (Scalar (Record fs1))) + "but inferred to be" + indent 2 (pretty tv_t') + +solve :: Constraints -> TyVars -> Either TypeError ([VName], Solution) solve constraints tyvars = second solution . runExcept diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 9767c35902..e42ea67d52 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -130,12 +130,12 @@ addCt (CtAM r m f) = do addObj tr addTyVarInfo :: TyVar -> (Int, TyVarInfo) -> RankM () -addTyVarInfo _ (_, TyVarFree) = pure () -addTyVarInfo tv (_, TyVarPrim _) = +addTyVarInfo _ (_, TyVarFree _) = pure () +addTyVarInfo tv (_, TyVarPrim {}) = addConstraint $ rank tv ~==~ constant 0 -addTyVarInfo tv (_, TyVarRecord _) = +addTyVarInfo tv (_, TyVarRecord {}) = addConstraint $ rank tv ~==~ constant 0 -addTyVarInfo tv (_, TyVarSum _) = +addTyVarInfo tv (_, TyVarSum {}) = addConstraint $ rank tv ~==~ constant 0 mkLinearProg :: [Ct] -> TyVars -> LinearProg @@ -346,7 +346,7 @@ addRankInfo t = do old_tyvars <- asks envTyVars let info = fromJust $ old_tyvars M.!? t modify $ \s -> s {substTyVars = M.insert t' info $ substTyVars s} - modify $ \s -> s {substTyVars = M.insert t (fst info, TyVarFree) $ substTyVars s} + modify $ \s -> s {substTyVars = M.insert t (fst info, TyVarFree mempty) $ substTyVars s} class SubstRanks a where substRanks :: (MonadTypeChecker m) => a -> SubstT m a diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 2c70460821..c54c0b5805 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -1039,7 +1039,7 @@ checkOneExp :: ExpBase NoInfo VName -> TypeM ([TypeParam], Exp) checkOneExp e = do (maybe_tysubsts, e') <- Terms2.checkSingleExp e case maybe_tysubsts of - Left err -> typeError e' mempty $ pretty err + Left err -> throwError err Right (generalised, tysubsts) -> runTermTypeM checkExp tysubsts $ do e'' <- checkExp e' let t = typeOf e'' @@ -1057,7 +1057,7 @@ checkSizeExp :: ExpBase NoInfo VName -> TypeM Exp checkSizeExp e = do (maybe_tysubsts, e') <- Terms2.checkSizeExp e case maybe_tysubsts of - Left err -> typeError e' mempty $ pretty err + Left err -> throwError err Right (generalised, tysubsts) -> runTermTypeM checkExp tysubsts $ do e'' <- checkExp e' when (hasBinding e'') $ @@ -1636,7 +1636,7 @@ checkFunDef (fname, retdecl, tparams, params, body, loc) = doChecks (maybe_tysubsts, params', retdecl', body') = case maybe_tysubsts of - Left err -> typeError loc mempty $ pretty err + Left err -> throwError err Right (generalised, tysubsts) -> runTermTypeM checkExp tysubsts $ do (tparams', params'', retdecl'', RetType dims rettype', body'') <- diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 376576a13b..d73c6758d9 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -220,8 +220,8 @@ incCounter = do tyVarType :: u -> TyVar -> TypeBase dim u tyVarType u v = Scalar $ TypeVar u (qualName v) [] -newTyVarWith :: (Located loc) => loc -> Name -> TyVarInfo -> TermM TyVar -newTyVarWith _loc desc info = do +newTyVarWith :: Name -> TyVarInfo -> TermM TyVar +newTyVarWith desc info = do i <- incCounter v <- newID $ mkTypeVarName desc i lvl <- curLevel @@ -229,24 +229,25 @@ newTyVarWith _loc desc info = do pure v newTyVar :: (Located loc) => loc -> Name -> TermM TyVar -newTyVar loc desc = newTyVarWith loc desc TyVarFree +newTyVar loc desc = newTyVarWith desc $ TyVarFree $ locOf loc newType :: (Located loc) => loc -> Name -> u -> TermM (TypeBase dim u) newType loc desc u = tyVarType u <$> newTyVar loc desc newTypeWithField :: SrcLoc -> Name -> Name -> Type -> TermM Type newTypeWithField loc desc k t = - tyVarType NoUniqueness <$> newTyVarWith loc desc (TyVarRecord $ M.singleton k t) + tyVarType NoUniqueness + <$> newTyVarWith desc (TyVarRecord (locOf loc) $ M.singleton k t) newTypeWithConstr :: SrcLoc -> Name -> u -> Name -> [TypeBase SComp u] -> TermM (TypeBase d u) newTypeWithConstr loc desc u k ts = - tyVarType u <$> newTyVarWith loc desc (TyVarSum $ M.singleton k ts') + tyVarType u <$> newTyVarWith desc (TyVarSum (locOf loc) $ M.singleton k ts') where ts' = map (`setUniqueness` NoUniqueness) ts newTypeOverloaded :: SrcLoc -> Name -> [PrimType] -> TermM (TypeBase d NoUniqueness) newTypeOverloaded loc name pts = - tyVarType NoUniqueness <$> newTyVarWith loc name (TyVarPrim pts) + tyVarType NoUniqueness <$> newTyVarWith name (TyVarPrim (locOf loc) pts) newSVar :: (Located loc) => loc -> Name -> TermM SVar newSVar _loc desc = do @@ -1185,7 +1186,7 @@ checkValDef :: SrcLoc ) -> TypeM - ( Either T.Text ([TypeParam], M.Map TyVar (TypeBase () NoUniqueness)), + ( Either TypeError ([TypeParam], M.Map TyVar (TypeBase () NoUniqueness)), [Pat ParamType], Maybe (TypeExp Exp VName), Exp @@ -1226,7 +1227,7 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars', "## solution:", let p (v, t) = prettyNameString v <> " => " <> prettyString t - in either T.unpack (unlines . map p . M.toList . snd) solution, + in either (docString . prettyTypeError) (unlines . map p . M.toList . snd) solution, either (const mempty) (unlines . ("## generalised:" :) . map prettyString . fst) solution ] pure (solution, params', retdecl', body'') @@ -1248,7 +1249,7 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do checkSingleExp :: ExpBase NoInfo VName -> - TypeM (Either T.Text ([VName], M.Map TyVar (TypeBase () NoUniqueness)), Exp) + TypeM (Either TypeError ([VName], M.Map TyVar (TypeBase () NoUniqueness)), Exp) checkSingleExp e = runTermM $ do e' <- checkExp e cts <- gets termConstraints @@ -1261,7 +1262,7 @@ checkSingleExp e = runTermM $ do -- turn out to be polymorphic, in which case it is unified with i64. checkSizeExp :: ExpBase NoInfo VName -> - TypeM (Either T.Text ([VName], M.Map TyVar (TypeBase () NoUniqueness)), Exp) + TypeM (Either TypeError ([VName], M.Map TyVar (TypeBase () NoUniqueness)), Exp) checkSizeExp e = runTermM $ do e' <- checkSizeExp' e cts <- gets termConstraints @@ -1276,4 +1277,4 @@ checkSizeExp e = runTermM $ do case (solutions, es') of ([solution], [e'']) -> pure (solution, e'') - _ -> pure (Left "Ambiguous size expression", e') + _ -> pure (Left $ TypeError (locOf e) mempty "Ambiguous size expression", e') From 56c99a89c8d32ccec5dc229519ed58b9444354ac Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Thu, 7 Mar 2024 05:23:31 -0800 Subject: [PATCH 175/258] More location info. --- src/Language/Futhark/TypeChecker/Rank.hs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index e42ea67d52..85e1766e13 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -344,9 +344,9 @@ addRankInfo t = do new_var = do t' <- newTyVar t old_tyvars <- asks envTyVars - let info = fromJust $ old_tyvars M.!? t - modify $ \s -> s {substTyVars = M.insert t' info $ substTyVars s} - modify $ \s -> s {substTyVars = M.insert t (fst info, TyVarFree mempty) $ substTyVars s} + let (level, tvinfo) = fromJust $ old_tyvars M.!? t + modify $ \s -> s {substTyVars = M.insert t' (level, tvinfo) $ substTyVars s} + modify $ \s -> s {substTyVars = M.insert t (level, TyVarFree $ locOf tvinfo) $ substTyVars s} class SubstRanks a where substRanks :: (MonadTypeChecker m) => a -> SubstT m a From 05fe9a0a5673631e672ddc8d380b5aaa51159910 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 7 Mar 2024 14:30:56 +0100 Subject: [PATCH 176/258] Also put locations in constraints. --- .../Futhark/TypeChecker/Constraints.hs | 77 +++++++++++------ src/Language/Futhark/TypeChecker/Rank.hs | 12 +-- src/Language/Futhark/TypeChecker/Terms2.hs | 86 +++++++++---------- 3 files changed, 99 insertions(+), 76 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index e78bf993b2..eddc37691c 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -33,6 +33,9 @@ newtype Reason = Reason } deriving (Eq, Ord, Show) +instance Located Reason where + locOf = reasonLoc + type SVar = VName -- | A shape component. `SDim` is a single dimension of unspecified @@ -61,13 +64,20 @@ toType :: TypeBase Size u -> TypeBase SComp u toType = first (const SDim) data Ct - = CtEq Type Type - | CtAM SVar SVar (Shape SComp) + = CtEq Reason Type Type + | CtAM Reason SVar SVar (Shape SComp) deriving (Show) +ctReason :: Ct -> Reason +ctReason (CtEq r _ _) = r +ctReason (CtAM r _ _ _) = r + +instance Located Ct where + locOf = locOf . ctReason + instance Pretty Ct where - pretty (CtEq t1 t2) = pretty t1 <+> "~" <+> pretty t2 - pretty (CtAM r m _) = prettyName r <+> "=" <+> "•" <+> "∨" <+> prettyName m <+> "=" <+> "•" + pretty (CtEq _ t1 t2) = pretty t1 <+> "~" <+> pretty t2 + pretty (CtAM _ r m _) = prettyName r <+> "=" <+> "•" <+> "∨" <+> prettyName m <+> "=" <+> "•" type Constraints = [Ct] @@ -160,25 +170,25 @@ solution s = newtype SolveM a = SolveM {runSolveM :: StateT SolverState (Except TypeError) a} deriving (Functor, Applicative, Monad, MonadState SolverState, MonadError TypeError) -occursCheck :: VName -> Type -> SolveM () -occursCheck v tp = do +occursCheck :: Reason -> VName -> Type -> SolveM () +occursCheck reason v tp = do vars <- gets solverTyVars let tp' = substTyVars (substTyVar vars) tp - when (v `S.member` typeVars tp') . throwError . TypeError mempty mempty $ + when (v `S.member` typeVars tp') . throwError . TypeError (locOf reason) mempty $ "Occurs check: cannot instantiate" <+> prettyName v <+> "with" <+> pretty tp <> "." -subTyVar :: VName -> Int -> Type -> SolveM () -subTyVar v lvl t = do - occursCheck v t +subTyVar :: Reason -> VName -> Int -> Type -> SolveM () +subTyVar reason v lvl t = do + occursCheck reason v t modify $ \s -> s {solverTyVars = M.insert v (TyVarSol lvl t) $ solverTyVars s} -linkTyVar :: VName -> VName -> SolveM () -linkTyVar v t = do - occursCheck v $ Scalar $ TypeVar NoUniqueness (qualName t) [] +linkTyVar :: Reason -> VName -> VName -> SolveM () +linkTyVar reason v t = do + occursCheck reason v $ Scalar $ TypeVar NoUniqueness (qualName t) [] tyvars <- gets solverTyVars modify $ \s -> s {solverTyVars = M.insert v (TyVarLink t) $ solverTyVars s} tyvars' <- @@ -222,13 +232,18 @@ unify t1 t2 Just [(t1', t2')] unify _ _ = Nothing -solveCt :: Ct -> SolveM () -solveCt ct = - case ct of - CtEq t1 t2 -> solveCt' (t1, t2) - CtAM {} -> pure () -- Good vibes only. +solveEq :: Reason -> Type -> Type -> SolveM () +solveEq reason orig_t1 orig_t2 = do + solveCt' (orig_t1, orig_t2) where - bad = throwError $ TypeError mempty mempty $ "Unsolvable:" <+> pretty ct + cannotUnify = do + tyvars <- gets solverTyVars + throwError . TypeError (locOf reason) mempty $ + "Cannot unify" + indent 2 (pretty (substTyVars (substTyVar tyvars) orig_t1)) + "with" + indent 2 (pretty (substTyVars (substTyVar tyvars) orig_t2)) + solveCt' (t1, t2) = do tyvars <- gets solverTyVars let flexible v = case M.lookup v tyvars of @@ -249,22 +264,28 @@ solveCt ct = | v1 == v2 -> pure () | otherwise -> case (flexible v1, flexible v2) of - (Nothing, Nothing) -> bad - (Just lvl, Nothing) -> subTyVar v1 lvl t2' - (Nothing, Just lvl) -> subTyVar v2 lvl t1' + (Nothing, Nothing) -> cannotUnify + (Just lvl, Nothing) -> subTyVar reason v1 lvl t2' + (Nothing, Just lvl) -> subTyVar reason v2 lvl t1' (Just lvl1, Just lvl2) - | lvl1 <= lvl2 -> linkTyVar v1 v2 - | otherwise -> linkTyVar v2 v1 + | lvl1 <= lvl2 -> linkTyVar reason v1 v2 + | otherwise -> linkTyVar reason v2 v1 (Scalar (TypeVar _ (QualName [] v1) []), t2') | Just lvl <- flexible v1 -> - subTyVar v1 lvl t2' + subTyVar reason v1 lvl t2' (t1', Scalar (TypeVar _ (QualName [] v2) [])) | Just lvl <- flexible v2 -> - subTyVar v2 lvl t1' + subTyVar reason v2 lvl t1' (t1', t2') -> case unify t1' t2' of - Nothing -> bad + Nothing -> cannotUnify Just eqs -> mapM_ solveCt' eqs +solveCt :: Ct -> SolveM () +solveCt ct = + case ct of + CtEq reason t1 t2 -> solveEq reason t1 t2 + CtAM {} -> pure () -- Good vibes only. + solveTyVar :: (VName, (Int, TyVarInfo)) -> SolveM () solveTyVar (_, (_, TyVarFree {})) = pure () solveTyVar (tv, (_, TyVarPrim loc pts)) = do @@ -286,7 +307,7 @@ solveTyVar (tv, (_, TyVarRecord loc fs1)) = do Just (Scalar (Record fs2)) | all (`M.member` fs2) (M.keys fs1) -> forM_ (M.toList $ M.intersectionWith (,) fs1 fs2) $ \(_k, (t1, t2)) -> - solveCt $ CtEq t1 t2 + solveCt $ CtEq (Reason loc) t1 t2 Just tv_t' -> throwError $ TypeError loc mempty $ diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 85e1766e13..9af5623d5e 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -72,7 +72,7 @@ instance Distribute (TypeBase dim u) where distribute t = t instance Distribute Ct where - distribute (CtEq t1 t2) = distribute t1 `CtEq` distribute t2 + distribute (CtEq r t1 t2) = CtEq r (distribute t1) (distribute t2) distribute c = c data RankState = RankState @@ -117,8 +117,8 @@ addObj sv = modify $ \s -> s {rankObj = rankObj s ~+~ var sv} addCt :: Ct -> RankM () -addCt (CtEq t1 t2) = addConstraint $ rank t1 ~==~ rank t2 -addCt (CtAM r m f) = do +addCt (CtEq _ t1 t2) = addConstraint $ rank t1 ~==~ rank t2 +addCt (CtAM _ r m f) = do b_r <- binVar r b_m <- binVar m b_max <- VName "c_max" <$> incCounter @@ -242,10 +242,11 @@ rankAnalysis loc cs tyVars params body = do cs' = foldMap (splitFuncs . distribute) cs splitFuncs ( CtEq + reason (Scalar (Arrow _ _ _ t1a (RetType _ t1r))) (Scalar (Arrow _ _ _ t2a (RetType _ t2r))) ) = - splitFuncs (CtEq t1a t2a) ++ splitFuncs (CtEq t1r' t2r') + splitFuncs (CtEq reason t1a t2a) ++ splitFuncs (CtEq reason t1r' t2r') where t1r' = t1r `setUniqueness` NoUniqueness t2r' = t2r `setUniqueness` NoUniqueness @@ -312,6 +313,7 @@ newTyVar t = do substNewCts = substNewCts s ++ [ CtEq + (Reason mempty) -- FIXME (Scalar (TypeVar mempty (QualName [] t) [])) (arrayOf shape (Scalar (TypeVar mempty (QualName [] t') []))) ] @@ -374,7 +376,7 @@ instance SubstRanks (TypeBase SComp u) where substRanks t = pure t instance SubstRanks Ct where - substRanks (CtEq t1 t2) = CtEq <$> substRanks t1 <*> substRanks t2 + substRanks (CtEq r t1 t2) = CtEq r <$> substRanks t1 <*> substRanks t2 substRanks _ = error "" updAM :: Map VName Int -> Exp -> Exp diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index d73c6758d9..6b1e06fc54 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -267,27 +267,27 @@ asStructType loc (Scalar (Sum cs)) = Scalar . Sum <$> traverse (mapM (asStructType loc)) cs asStructType loc t@(Scalar (TypeVar u _ _)) = do t' <- newType loc "artificial" u - ctEq (toType t') t + ctEq (Reason (locOf loc)) (toType t') t pure t' asStructType loc t@(Array u _ _) = do t' <- newType loc "artificial" u - ctEq (toType t') t + ctEq (Reason (locOf loc)) (toType t') t pure t' addCt :: Ct -> TermM () addCt ct = modify $ \s -> s {termConstraints = ct : termConstraints s} -ctEq :: TypeBase SComp u1 -> TypeBase SComp u2 -> TermM () -ctEq t1 t2 = +ctEq :: Reason -> TypeBase SComp u1 -> TypeBase SComp u2 -> TermM () +ctEq reason t1 t2 = -- As a minor optimisation, do not add constraint if the types are -- equal. - unless (t1' == t2') $ addCt $ CtEq t1' t2' + unless (t1' == t2') $ addCt $ CtEq reason t1' t2' where t1' = t1 `setUniqueness` NoUniqueness t2' = t2 `setUniqueness` NoUniqueness -ctAM :: SVar -> SVar -> Shape SComp -> TermM () -ctAM r m f = addCt $ CtAM r m f +ctAM :: Reason -> SVar -> SVar -> Shape SComp -> TermM () +ctAM reason r m f = addCt $ CtAM reason r m f localScope :: (TermScope -> TermScope) -> TermM a -> TermM a localScope f = local $ \tenv -> tenv {termScope = f $ termScope tenv} @@ -361,11 +361,11 @@ arrayOfRank n = arrayOf $ Shape $ replicate n SDim require :: T.Text -> [PrimType] -> Exp -> TermM Exp require _why [pt] e = do - ctEq (Scalar $ Prim pt) (expType e) + ctEq (Reason (locOf e)) (Scalar $ Prim pt) (expType e) pure e require _why pts e = do t :: Type <- newTypeOverloaded (srclocOf e) "t" pts - ctEq t $ expType e + ctEq (Reason (locOf e)) t $ expType e pure e -- | Instantiate a type scheme with fresh type variables for its type @@ -448,7 +448,7 @@ patLitMkType (PatLitPrim v) _ = checkSizeExp' :: ExpBase NoInfo VName -> TermM Exp checkSizeExp' e = do e' <- checkExp e - ctEq (expType e') (Scalar (Prim (Signed Int64))) + ctEq (Reason (locOf e)) (expType e') (Scalar (Prim (Signed Int64))) pure e' checkPat' :: @@ -477,7 +477,7 @@ checkPat' (TuplePat ps loc) (Ascribed t) <*> pure loc | otherwise = do ps_t <- replicateM (length ps) (newType loc "t" Observe) - ctEq (toType (Scalar (tupleRecord ps_t))) (toType t) + ctEq (Reason (locOf loc)) (toType (Scalar (tupleRecord ps_t))) (toType t) TuplePat <$> zipWithM checkPat' ps (map Ascribed ps_t) <*> pure loc checkPat' (TuplePat ps loc) NoneInferred = TuplePat <$> mapM (`checkPat'` NoneInferred) ps <*> pure loc @@ -487,7 +487,7 @@ checkPat' p@(RecordPat p_fs loc) (Ascribed t) RecordPat . M.toList <$> check t_fs <*> pure loc | otherwise = do p_fs' <- traverse (const $ newType loc "t" NoUniqueness) $ M.fromList p_fs - ctEq (Scalar (Record p_fs')) $ toType t + ctEq (Reason (locOf loc)) (Scalar (Record p_fs')) $ toType t st <- asStructType loc $ Scalar (Record p_fs') checkPat' p $ Ascribed $ toParam Observe st where @@ -507,7 +507,7 @@ checkPat' (PatAscription p t loc) maybe_outer_t = do case maybe_outer_t of Ascribed outer_t -> do - ctEq (toType st') (toType outer_t) + ctEq (Reason (locOf loc)) (toType st') (toType outer_t) PatAscription <$> checkPat' p (Ascribed st') <*> pure t' @@ -519,7 +519,7 @@ checkPat' (PatAscription p t loc) maybe_outer_t = do <*> pure loc checkPat' (PatLit l NoInfo loc) (Ascribed t) = do t' <- patLitMkType l loc - ctEq (toType t') (toType t) + ctEq (Reason (locOf loc)) (toType t') (toType t) pure $ PatLit l (Info t') loc checkPat' (PatLit l NoInfo loc) NoneInferred = do t' <- patLitMkType l loc @@ -542,7 +542,7 @@ checkPat' (PatConstr n NoInfo ps loc) (Ascribed t) = do p_t <- newType (srclocOf p) "t" Observe checkPat' p $ Ascribed p_t t' <- newTypeWithConstr loc "t" Observe n $ map (toType . patternType) ps' - ctEq t' (toType t) + ctEq (Reason (locOf loc)) t' (toType t) t'' <- asStructType loc t' pure $ PatConstr n (Info $ toParam Observe t'') ps' loc checkPat' (PatConstr n NoInfo ps loc) NoneInferred = do @@ -640,8 +640,8 @@ checkApplyOne loc fname (fframe, ftype) (argframe, argtype) = do m_var = Var (QualName [] m) unit_info mempty lhs = arrayOf (toShape (SVar r)) argtype rhs = arrayOf (toShape (SVar m)) a - ctAM r m $ fmap toSComp (toShape m_var <> fframe) - ctEq lhs rhs + ctAM (Reason (locOf loc)) r m $ fmap toSComp (toShape m_var <> fframe) + ctEq (Reason (locOf loc)) lhs rhs debugTraceM 3 $ unlines [ "## checkApplyOne", @@ -688,7 +688,7 @@ checkApplyOne loc fname (fframe, ftype) (argframe, argtype) = do split ftype' = do a <- newType loc "arg" NoUniqueness b <- newType loc "res" Nonunique - ctEq ftype' $ Scalar $ Arrow NoUniqueness Unnamed Observe a $ RetType [] b + ctEq (Reason (locOf loc)) ftype' $ Scalar $ Arrow NoUniqueness Unnamed Observe a $ RetType [] b pure (a, b `setUniqueness` NoUniqueness) checkSlice :: SliceBase NoInfo VName -> TermM [DimIndex] @@ -708,17 +708,17 @@ isSlice DimFix {} = False -- Add constraints saying that the first type has a (potentially -- nested) field containing the second type. mustHaveFields :: SrcLoc -> Type -> [Name] -> Type -> TermM () -mustHaveFields _ t [] ve_t = +mustHaveFields loc t [] ve_t = -- This case is probably never reached. - ctEq t ve_t + ctEq (Reason (locOf loc)) t ve_t mustHaveFields loc t [f] ve_t = do rt :: Type <- newTypeWithField loc "ft" f ve_t - ctEq t rt + ctEq (Reason (locOf loc)) t rt mustHaveFields loc t (f : fs) ve_t = do ft <- newType loc "ft" NoUniqueness rt <- newTypeWithField loc "rt" f ft mustHaveFields loc ft fs ve_t - ctEq t rt + ctEq (Reason (locOf loc)) t rt checkCase :: StructType -> @@ -741,7 +741,7 @@ checkCases mt rest_cs = (c, Just cs) -> do (c', c_t) <- checkCase mt c (cs', cs_t) <- checkCases mt cs - ctEq (toType c_t) (toType cs_t) + ctEq (Reason (locOf c)) (toType c_t) (toType cs_t) pure (NE.cons c' cs', c_t) -- | An unmatched pattern. Used in in the generation of @@ -779,7 +779,7 @@ checkRetDecl :: checkRetDecl _ Nothing = pure Nothing checkRetDecl body (Just te) = do (te', _, RetType _ st, _) <- checkTypeExp checkSizeExp' te - ctEq (expType body) (toType st) + ctEq (Reason (locOf body)) (expType body) (toType st) pure $ Just te' checkExp :: ExpBase NoInfo VName -> TermM (ExpBase Info VName) @@ -830,7 +830,7 @@ checkExp (ArrayLit es _ loc) = do et <- newType loc "et" NoUniqueness es' <- forM es $ \e -> do e' <- checkExp e - ctEq (expType e') (toType et) + ctEq (Reason (locOf loc)) (expType e') (toType et) pure e' let arr_t = arrayOf (Shape [sizeFromInteger (L.genericLength es) loc]) et pure $ ArrayLit es' (Info arr_t) loc @@ -1003,19 +1003,19 @@ checkExp (AppExp (Range start maybe_step end loc) _) = do start' <- require "use in range expression" anyIntType =<< checkExp start let check e = do e' <- checkExp e - ctEq (expType start') (expType e') + ctEq (Reason (locOf e')) (expType start') (expType e') pure e' maybe_step' <- traverse check maybe_step end' <- traverse check end range_t <- newType loc "range" NoUniqueness - ctEq (toType range_t) (arrayOfRank 1 (expType start')) + ctEq (Reason (locOf start')) (toType range_t) (arrayOfRank 1 (expType start')) pure $ AppExp (Range start' maybe_step' end' loc) $ Info $ AppRes range_t [] -- checkExp (Project k e NoInfo loc) = do e' <- checkExp e kt <- newType loc "kt" NoUniqueness t <- newTypeWithField loc "t" k kt - ctEq (expType e') t + ctEq (Reason (locOf e')) (expType e') t kt' <- asStructType loc kt pure $ Project k e' (Info kt') loc -- @@ -1031,8 +1031,8 @@ checkExp (IndexSection slice NoInfo loc) = do index_elem_t <- newType loc "index_elem" NoUniqueness index_res_t <- newType loc "index_res" NoUniqueness let num_slices = length $ filter isSlice slice - ctEq (toType index_arg_t) $ arrayOfRank num_slices index_elem_t - ctEq index_res_t $ arrayOfRank (length slice) index_elem_t + ctEq (Reason (locOf loc)) (toType index_arg_t) $ arrayOfRank num_slices index_elem_t + ctEq (Reason (locOf loc)) index_res_t $ arrayOfRank (length slice) index_elem_t index_res_t' <- asStructType loc index_res_t let ft = Scalar $ Arrow mempty Unnamed Observe index_arg_t $ toResRet Nonunique $ RetType [] index_res_t' pure $ IndexSection slice' (Info ft) loc @@ -1043,8 +1043,8 @@ checkExp (AppExp (Index e slice loc) _) = do index_t <- newType loc "index" NoUniqueness index_elem_t <- newType loc "index_elem" NoUniqueness let num_slices = length $ filter isSlice slice - ctEq (toType index_t) $ arrayOfRank num_slices index_elem_t - ctEq (expType e') $ arrayOfRank (length slice) index_elem_t + ctEq (Reason (locOf loc)) (toType index_t) $ arrayOfRank num_slices index_elem_t + ctEq (Reason (locOf e')) (expType e') $ arrayOfRank (length slice) index_elem_t pure $ AppExp (Index e' slice' loc) (Info $ AppRes index_t []) -- checkExp (Update src slice ve loc) = do @@ -1053,8 +1053,8 @@ checkExp (Update src slice ve loc) = do ve' <- checkExp ve let num_slices = length $ filter isSlice slice update_elem_t <- newType loc "update_elem" NoUniqueness - ctEq (expType src') $ arrayOfRank (length slice) update_elem_t - ctEq (expType ve') $ arrayOfRank num_slices update_elem_t + ctEq (Reason (locOf src')) (expType src') $ arrayOfRank (length slice) update_elem_t + ctEq (Reason (locOf ve')) (expType ve') $ arrayOfRank num_slices update_elem_t pure $ Update src' slice' ve' loc -- checkExp (AppExp (LetWith dest src slice ve body loc) _) = do @@ -1065,8 +1065,8 @@ checkExp (AppExp (LetWith dest src slice ve body loc) _) = do ve' <- checkExp ve let num_slices = length $ filter isSlice slice update_elem_t <- newType loc "update_elem" NoUniqueness - ctEq (toType src_t) $ arrayOfRank (length slice) update_elem_t - ctEq (expType ve') $ arrayOfRank num_slices update_elem_t + ctEq (Reason (locOf loc)) (toType src_t) $ arrayOfRank (length slice) update_elem_t + ctEq (Reason (locOf ve')) (expType ve') $ arrayOfRank num_slices update_elem_t bind [dest'] $ do body' <- checkExp body pure $ AppExp (LetWith dest' src' slice' ve' body' loc) (Info $ AppRes (typeOf body') []) @@ -1076,8 +1076,8 @@ checkExp (AppExp (If e1 e2 e3 loc) _) = do e2' <- checkExp e2 e3' <- checkExp e3 - ctEq (expType e1') (Scalar (Prim Bool)) - ctEq (expType e2') (expType e3') + ctEq (Reason (locOf e1')) (expType e1') (Scalar (Prim Bool)) + ctEq (Reason (locOf loc)) (expType e2') (expType e3') pure $ AppExp (If e1' e2' e3' loc) (Info $ AppRes (typeOf e2') []) -- @@ -1096,17 +1096,17 @@ checkExp (AppExp (Loop _ pat arg form body loc) _) = do let i' = Ident i (Info (typeOf bound')) iloc bind [i'] $ do body' <- checkExp body - ctEq (expType arg') (expType body') + ctEq (Reason (locOf arg')) (expType arg') (expType body') pure (For i' bound', body') While cond -> do cond' <- checkExp cond body' <- checkExp body - ctEq (expType arg') (expType body') + ctEq (Reason (locOf arg')) (expType arg') (expType body') pure (While cond', body') ForIn elemp arr -> do arr' <- checkExp arr elem_t <- newType elemp "elem" NoUniqueness - ctEq (expType arr') $ arrayOfRank 1 (toType elem_t) + ctEq (Reason (locOf arr')) (expType arr') $ arrayOfRank 1 (toType elem_t) bindLetPat elemp elem_t $ \elemp' -> do body' <- checkExp body pure (ForIn (toStruct <$> elemp') arr', body') @@ -1118,12 +1118,12 @@ checkExp (AppExp (Loop _ pat arg form body loc) _) = do checkExp (Ascript e te loc) = do e' <- checkExp e (te', _, RetType _ st, _) <- checkTypeExp checkSizeExp' te - ctEq (expType e') (toType st) + ctEq (Reason (locOf e')) (expType e') (toType st) pure $ Ascript e' te' loc checkExp (Coerce e te NoInfo loc) = do e' <- checkExp e (te', _, RetType _ st, _) <- checkTypeExp checkSizeExp' te - ctEq (expType e') (toType st) + ctEq (Reason (locOf e')) (expType e') (toType st) pure $ Coerce e' te' (Info (toStruct st)) loc doDefault :: From ef20fbe82a1f8df507874b913ecfc17fa7091b43 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Thu, 7 Mar 2024 05:38:00 -0800 Subject: [PATCH 177/258] Easy fix. --- src/Language/Futhark/TypeChecker/Rank.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 9af5623d5e..7443c38a47 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -47,7 +47,7 @@ instance Rank (Shape SComp) where instance Rank ScalarType where rank Prim {} = constant 0 - rank (TypeVar _ (QualName [] v) []) = var v -- FIXME - might not be a type variable. + rank (TypeVar _ (QualName [] v) []) = var v rank (TypeVar {}) = constant 0 rank (Arrow {}) = constant 0 rank (Record {}) = constant 0 From 37cc601699cde242d57a42297c01efbea8b253f9 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Thu, 7 Mar 2024 05:41:20 -0800 Subject: [PATCH 178/258] More FIXME extermination. --- src/Language/Futhark/TypeChecker/Rank.hs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 7443c38a47..5f3775a748 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -307,13 +307,14 @@ newTyVar :: (MonadTypeChecker m) => TyVar -> SubstT m TyVar newTyVar t = do t' <- lift $ newTypeName (baseName t) shape <- rankToShape t + loc <- (locOf . snd . fromJust . (M.!? t)) <$> asks envTyVars modify $ \s -> s { substNewVars = M.insert t t' $ substNewVars s, substNewCts = substNewCts s ++ [ CtEq - (Reason mempty) -- FIXME + (Reason loc) (Scalar (TypeVar mempty (QualName [] t) [])) (arrayOf shape (Scalar (TypeVar mempty (QualName [] t') []))) ] From 75e5be4b5e96ccbe8f6df54d8b622f7510e8475a Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 7 Mar 2024 15:28:24 +0100 Subject: [PATCH 179/258] Proper AUTOMAP for single expressions. --- src/Language/Futhark/TypeChecker/Terms2.hs | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 6b1e06fc54..5252a93de9 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -1249,14 +1249,27 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do checkSingleExp :: ExpBase NoInfo VName -> - TypeM (Either TypeError ([VName], M.Map TyVar (TypeBase () NoUniqueness)), Exp) + TypeM (Either TypeError ([TypeParam], M.Map TyVar (TypeBase () NoUniqueness)), Exp) checkSingleExp e = runTermM $ do e' <- checkExp e cts <- gets termConstraints tyvars <- gets termTyVars - solution <- - bitraverse pure (traverse (doDefaults mempty)) $ solve cts tyvars - pure (solution, e') + ((cts', tyvars'), _, e'') <- rankAnalysis1 (srclocOf e') cts tyvars [] e' + case solve cts' tyvars' of + Left err -> pure (Left err, e'') + Right (unconstrained, solution) -> do + let (generalised, unconstrained') = + generalise (typeOf e'') unconstrained solution + solution' <- doDefaults (map typeParamName generalised) solution + pure + ( Right + ( generalised, + -- See #1552 for why we resolve unconstrained and + -- un-generalised type variables to (). + M.fromList (map (,Scalar (Record mempty)) unconstrained') <> solution' + ), + e'' + ) -- | Type-check a single size expression in isolation. This expression may -- turn out to be polymorphic, in which case it is unified with i64. From 941dec2237f5c3decb344a5cbff88f695b2916d4 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 7 Mar 2024 15:40:39 +0100 Subject: [PATCH 180/258] Has to be written like this. --- tests/automap/mri-q.fut | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/automap/mri-q.fut b/tests/automap/mri-q.fut index 3a4648c7b9..f53b5df7a6 100644 --- a/tests/automap/mri-q.fut +++ b/tests/automap/mri-q.fut @@ -18,7 +18,7 @@ def main_orig [numK][numX] let qr = map1 (map f32.cos >-> map2 (*) phiMag >-> f32.sum) expArgs let qi = map1 (map f32.sin >-> map2 (*) phiMag >-> f32.sum) expArgs in (qr, qi) - + def main_am [numK][numX] (kx: [numK]f32) (ky: [numK]f32) (kz: [numK]f32) (x: [numX]f32) (y: [numX]f32) (z: [numX]f32) @@ -31,11 +31,11 @@ def main_am [numK][numX] let qr = f32.sum (f32.cos expArgs * phiMag) let qi = f32.sum (f32.sin expArgs * phiMag) in (qr, qi) - + entry main [numK][numX] (kx: [numK]f32) (ky: [numK]f32) (kz: [numK]f32) (x: [numX]f32) (y: [numX]f32) (z: [numX]f32) (phiR: [numK]f32) (phiI: [numK]f32) = let (qr, qi) = main_orig kx ky kz x y z phiR phiI let (qr_am, qi_am) = main_am kx ky kz x y z phiR phiI - in and (map2 (==) qr qr_am && qi == qi_am) + in and (map2 (==) qr qr_am && map2 (==) qi qi_am) From 82c8458dc7dda7dbf2f6a4ca171ebcea37ce06d1 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 7 Mar 2024 15:40:47 +0100 Subject: [PATCH 181/258] Reduce duplication. --- src/Language/Futhark/TypeChecker/Terms2.hs | 40 ++++++++++------------ 1 file changed, 19 insertions(+), 21 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 5252a93de9..4226fef858 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -1177,6 +1177,22 @@ generalise fun_t unconstrained solution = Just (Right t) -> foldMap expandTyVars $ typeVars t _ -> S.singleton v +generaliseAndDefaults :: + [VName] -> + Solution -> + StructType -> + TermM ([TypeParam], M.Map VName (TypeBase () NoUniqueness)) +generaliseAndDefaults unconstrained solution t = do + let (generalised, unconstrained') = + generalise t unconstrained solution + solution' <- doDefaults (map typeParamName generalised) solution + pure + ( generalised, + -- See #1552 for why we resolve unconstrained and + -- un-generalised type variables to (). + M.fromList (map (,Scalar (Record mempty)) unconstrained') <> solution' + ) + checkValDef :: ( VName, Maybe (TypeExp (ExpBase NoInfo VName) VName), @@ -1237,15 +1253,7 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do foldFunType (map patternType params') (RetType [] $ toRes Nonunique (typeOf body')) - (generalised, unconstrained') = - generalise fun_t unconstrained solution - solution' <- doDefaults (map typeParamName generalised) solution - pure - ( generalised, - -- See #1552 for why we resolve unconstrained and - -- un-generalised type variables to (). - M.fromList (map (,Scalar (Record mempty)) unconstrained') <> solution' - ) + generaliseAndDefaults unconstrained solution fun_t checkSingleExp :: ExpBase NoInfo VName -> @@ -1258,18 +1266,8 @@ checkSingleExp e = runTermM $ do case solve cts' tyvars' of Left err -> pure (Left err, e'') Right (unconstrained, solution) -> do - let (generalised, unconstrained') = - generalise (typeOf e'') unconstrained solution - solution' <- doDefaults (map typeParamName generalised) solution - pure - ( Right - ( generalised, - -- See #1552 for why we resolve unconstrained and - -- un-generalised type variables to (). - M.fromList (map (,Scalar (Record mempty)) unconstrained') <> solution' - ), - e'' - ) + x <- generaliseAndDefaults unconstrained solution $ typeOf e'' + pure (Right x, e'') -- | Type-check a single size expression in isolation. This expression may -- turn out to be polymorphic, in which case it is unified with i64. From f0bef225284a8e07efcfb55fc81e8434084b9a36 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 11 Mar 2024 11:47:22 +0100 Subject: [PATCH 182/258] Work on sum types. --- .../Futhark/TypeChecker/Constraints.hs | 33 ++++++++++++++----- .../Futhark/TypeChecker/Terms/Monad.hs | 1 + src/Language/Futhark/TypeChecker/Terms/Pat.hs | 12 ++++++- 3 files changed, 37 insertions(+), 9 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index eddc37691c..71765d7e0d 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -289,8 +289,8 @@ solveCt ct = solveTyVar :: (VName, (Int, TyVarInfo)) -> SolveM () solveTyVar (_, (_, TyVarFree {})) = pure () solveTyVar (tv, (_, TyVarPrim loc pts)) = do - t <- lookupTyVar tv - case t of + tv_t <- lookupTyVar tv + case tv_t of Nothing -> pure () Just t' | t' `elem` map (Scalar . Prim) pts -> pure () @@ -309,12 +309,29 @@ solveTyVar (tv, (_, TyVarRecord loc fs1)) = do forM_ (M.toList $ M.intersectionWith (,) fs1 fs2) $ \(_k, (t1, t2)) -> solveCt $ CtEq (Reason loc) t1 t2 Just tv_t' -> - throwError $ - TypeError loc mempty $ - "Type must be record with fields" - indent 2 (pretty (Scalar (Record fs1))) - "but inferred to be" - indent 2 (pretty tv_t') + throwError . TypeError loc mempty $ + "Type must be record with fields" + indent 2 (pretty (Scalar (Record fs1))) + "but inferred to be" + indent 2 (pretty tv_t') +solveTyVar (tv, (_, TyVarSum loc cs1)) = do + tv_t <- lookupTyVar tv + case tv_t of + Nothing -> pure () + Just (Scalar (Sum cs2)) + | all (`M.member` cs2) (M.keys cs1), + cs3 <- M.toList $ M.intersectionWith (,) cs1 cs2, + all (sameLength . snd) cs3 -> + forM_ cs3 $ \(_k, (t1s, t2s)) -> + mapM_ solveCt $ zipWith (CtEq (Reason loc)) t1s t2s + Just tv_t' -> + throwError . TypeError loc mempty $ + "Type must be sum type with constructors" + indent 2 (pretty (Scalar (Sum cs1))) + "but inferred to be" + indent 2 (pretty tv_t') + where + sameLength (x, y) = length x == length y solve :: Constraints -> TyVars -> Either TypeError ([VName], Solution) solve constraints tyvars = diff --git a/src/Language/Futhark/TypeChecker/Terms/Monad.hs b/src/Language/Futhark/TypeChecker/Terms/Monad.hs index 018a3e920a..42c5d53ab2 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Monad.hs @@ -87,6 +87,7 @@ unusedSize p = data Inferred t = NoneInferred | Ascribed t + deriving (Show) instance Functor Inferred where fmap _ NoneInferred = NoneInferred diff --git a/src/Language/Futhark/TypeChecker/Terms/Pat.hs b/src/Language/Futhark/TypeChecker/Terms/Pat.hs index b86dd63616..265882c541 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Pat.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Pat.hs @@ -178,9 +178,19 @@ checkPat' sizes (PatAscription p t loc) maybe_outer_t = do <*> pure loc checkPat' _ (PatLit l info loc) _ = pure $ PatLit l info loc -checkPat' sizes (PatConstr n info ps loc) _ = do +checkPat' sizes (PatConstr n info ps loc) NoneInferred = do ps' <- mapM (\p -> checkPat' sizes p NoneInferred) ps pure $ PatConstr n info ps' loc +checkPat' sizes (PatConstr n info ps loc) (Ascribed (Scalar (Sum cs))) + | Just ts <- M.lookup n cs = do + ps' <- zipWithM (\p t -> checkPat' sizes p (Ascribed t)) ps ts + pure $ PatConstr n info ps' loc +checkPat' _ p t = + error . unlines $ + [ "checkPat': bad case", + prettyString p, + show t + ] checkPat :: [(SizeBinder VName, QualName VName)] -> From cefc91b0071ff6d30bfa0c5725b9617e005b3f51 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 11 Mar 2024 12:59:01 +0100 Subject: [PATCH 183/258] Fix for-in loops. --- src/Language/Futhark/TypeChecker/Terms2.hs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 4226fef858..7cd3f10a1c 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -1096,12 +1096,10 @@ checkExp (AppExp (Loop _ pat arg form body loc) _) = do let i' = Ident i (Info (typeOf bound')) iloc bind [i'] $ do body' <- checkExp body - ctEq (Reason (locOf arg')) (expType arg') (expType body') pure (For i' bound', body') While cond -> do cond' <- checkExp cond body' <- checkExp body - ctEq (Reason (locOf arg')) (expType arg') (expType body') pure (While cond', body') ForIn elemp arr -> do arr' <- checkExp arr @@ -1110,6 +1108,7 @@ checkExp (AppExp (Loop _ pat arg form body loc) _) = do bindLetPat elemp elem_t $ \elemp' -> do body' <- checkExp body pure (ForIn (toStruct <$> elemp') arr', body') + ctEq (Reason (locOf loc)) (expType arg') (expType body') pure $ AppExp (Loop [] pat' arg' form' body' loc) From 80c6d80dc06bef4d0ea57032c46fde262c78417f Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 11 Mar 2024 13:19:04 +0100 Subject: [PATCH 184/258] Detect duplicate fields. --- src/Language/Futhark/TypeChecker/Terms2.hs | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 7cd3f10a1c..13c9ef2cd8 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -59,7 +59,7 @@ import Data.Set qualified as S import Data.Text qualified as T import Futhark.FreshNames qualified as FreshNames import Futhark.MonadFreshNames hiding (newName) -import Futhark.Util (debugTraceM, mapAccumLM) +import Futhark.Util (debugTraceM, mapAccumLM, nubOrd) import Futhark.Util.Pretty import Language.Futhark import Language.Futhark.TypeChecker.Constraints @@ -481,6 +481,16 @@ checkPat' (TuplePat ps loc) (Ascribed t) TuplePat <$> zipWithM checkPat' ps (map Ascribed ps_t) <*> pure loc checkPat' (TuplePat ps loc) NoneInferred = TuplePat <$> mapM (`checkPat'` NoneInferred) ps <*> pure loc +checkPat' p@(RecordPat p_fs loc) _ + | Just (f, fp) <- L.find (("_" `T.isPrefixOf`) . nameToText . fst) p_fs = + typeError fp mempty $ + "Underscore-prefixed fields are not allowed." + "Did you mean" + <> dquotes (pretty (T.drop 1 (nameToText f)) <> "=_") + <> "?" + | nubOrd (map fst p_fs) /= map fst p_fs = + typeError loc mempty $ + "Duplicate fields in record pattern" <+> pretty p <> "." checkPat' p@(RecordPat p_fs loc) (Ascribed t) | Scalar (Record t_fs) <- t, L.sort (map fst p_fs) == L.sort (M.keys t_fs) = From c1fb806f8e51b132d248d5a09cba1213bb52d1ef Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 11 Mar 2024 15:05:27 +0100 Subject: [PATCH 185/258] Handle type arguments here. --- src/Language/Futhark/TypeChecker/Types.hs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Types.hs b/src/Language/Futhark/TypeChecker/Types.hs index b89802c5e0..02126dbfd8 100644 --- a/src/Language/Futhark/TypeChecker/Types.hs +++ b/src/Language/Futhark/TypeChecker/Types.hs @@ -534,10 +534,13 @@ substTypesAny lookupSubst ot = -- | Substitution without caring about sizes. substTyVars :: (Monoid u) => (VName -> Maybe (TypeBase d NoUniqueness)) -> TypeBase d u -> TypeBase d u -substTyVars f t@(Scalar (TypeVar u (QualName qs v) args)) = - case f v of +substTyVars f (Scalar (TypeVar u qn args)) = + case f $ qualLeaf qn of Just t' -> second (const mempty) $ substTyVars f t' - Nothing -> t + Nothing -> Scalar (TypeVar u qn (map onArg args)) + where + onArg (TypeArgType t) = TypeArgType $ substTyVars f t + onArg (TypeArgDim e) = TypeArgDim e substTyVars _ (Scalar (Prim pt)) = Scalar $ Prim pt substTyVars f (Scalar (Record fs)) = Scalar $ Record $ M.map (substTyVars f) fs substTyVars f (Scalar (Sum cs)) = Scalar $ Sum $ M.map (map $ substTyVars f) cs From c4e11a9dcb81384f7a1b28e2923e984717393211 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 11 Mar 2024 15:10:28 +0100 Subject: [PATCH 186/258] let should not be generalised. --- docs/language-reference.rst | 8 +++++--- tests/types/inference22.fut | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/docs/language-reference.rst b/docs/language-reference.rst index 51719d62a3..b9d3b6ff37 100644 --- a/docs/language-reference.rst +++ b/docs/language-reference.rst @@ -994,9 +994,11 @@ Syntactic sugar for ``let a = a with [i] = v in a``. ............................... Bind ``f`` to a function with the given parameters and definition -(``e``) and evaluate ``body``. The function will be treated as -aliasing any free variables in ``e``. The function is not in scope of -itself, and hence cannot be recursive. +(``e``) and evaluate ``body``. The function will be treated as +aliasing any free variables in ``e``. The function is not in scope of +itself, and hence cannot be recursive. While the function can be made +polymorphic by putting in explicit size parameters, it is not +automatically generalised the way top level functions are. ``loop pat = initial for x in a do loopbody`` ............................................. diff --git a/tests/types/inference22.fut b/tests/types/inference22.fut index 4e367db82f..dbe574e411 100644 --- a/tests/types/inference22.fut +++ b/tests/types/inference22.fut @@ -2,5 +2,5 @@ -- == def main (x: i32) (y: bool) = - let f x y = (y,x) + let f 'a 'b (x: a) (y: b) = (y,x) in (f x y, f y x) From 8cb7c7d710176b3e4a3df6df9fc30781a1960d17 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Mon, 11 Mar 2024 07:27:05 -0700 Subject: [PATCH 187/258] Stop erroneously changing the type of automapped functions. --- src/Futhark/Internalise/FullNormalise.hs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/Futhark/Internalise/FullNormalise.hs b/src/Futhark/Internalise/FullNormalise.hs index 9b27a6cf46..789e0d3c85 100644 --- a/src/Futhark/Internalise/FullNormalise.hs +++ b/src/Futhark/Internalise/FullNormalise.hs @@ -507,14 +507,13 @@ expandAMAnnotations e = do arg_es' <- mapM expandAMAnnotations arg_es let diets = funDiets $ typeOf f withMapNest loc (zip4 exts ams arg_es' diets) $ \args' -> do - inner_f <- setNewType f' $ innerFType (typeOf f') ams let rettype = - case unfoldFunTypeWithRet $ typeOf inner_f of + case unfoldFunTypeWithRet $ typeOf f' of Nothing -> error "Function type expected." Just (ptypes, f_ret) -> foldFunType (drop (length args') ptypes) f_ret pure $ - mkApply inner_f (zip3 exts (repeat mempty) args') $ + mkApply f' (zip3 exts (repeat mempty) args') $ res {appResType = rettype} (AppExp (BinOp op (Info t) (x, Info (xext, xam)) (y, Info (yext, yam)) loc) (Info res)) -> do x' <- expandAMAnnotations x From c6fa39eb066536cb55ac95ed965c49a10b85b578 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 11 Mar 2024 16:31:56 +0100 Subject: [PATCH 188/258] Some sum fixes. --- src/Language/Futhark/TypeChecker/Terms.hs | 6 +++--- src/Language/Futhark/TypeChecker/Terms/Pat.hs | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index c54c0b5805..b3afad3954 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -783,10 +783,10 @@ checkExp (AppExp (Loop _ mergepat mergeexp form loopbody loc) _) = do AppExp (Loop sparams mergepat' mergeexp' form' loopbody' loc) (Info appres) -checkExp (Constr name es _ loc) = do - t <- newTypeVar loc "t" +checkExp (Constr name es (Info t) loc) = do + t' <- replaceTyVars loc t es' <- mapM checkExp es - pure $ Constr name es' (Info t) loc + pure $ Constr name es' (Info t') loc checkExp (AppExp (Match e cs loc) _) = do e' <- checkExp e mt <- expType e' diff --git a/src/Language/Futhark/TypeChecker/Terms/Pat.hs b/src/Language/Futhark/TypeChecker/Terms/Pat.hs index 265882c541..2a648ea8ab 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Pat.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Pat.hs @@ -181,10 +181,10 @@ checkPat' _ (PatLit l info loc) _ = checkPat' sizes (PatConstr n info ps loc) NoneInferred = do ps' <- mapM (\p -> checkPat' sizes p NoneInferred) ps pure $ PatConstr n info ps' loc -checkPat' sizes (PatConstr n info ps loc) (Ascribed (Scalar (Sum cs))) +checkPat' sizes (PatConstr n _ ps loc) (Ascribed (Scalar (Sum cs))) | Just ts <- M.lookup n cs = do ps' <- zipWithM (\p t -> checkPat' sizes p (Ascribed t)) ps ts - pure $ PatConstr n info ps' loc + pure $ PatConstr n (Info (Scalar (Sum cs))) ps' loc checkPat' _ p t = error . unlines $ [ "checkPat': bad case", From 4de7ecf60a7cba572412b28dac8eaf5fd2b3c51d Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 11 Mar 2024 16:39:41 +0100 Subject: [PATCH 189/258] Detect more ambiguities. --- src/Language/Futhark/TypeChecker/Constraints.hs | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 71765d7e0d..461b256e51 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -303,7 +303,11 @@ solveTyVar (tv, (_, TyVarPrim loc pts)) = do solveTyVar (tv, (_, TyVarRecord loc fs1)) = do tv_t <- lookupTyVar tv case tv_t of - Nothing -> pure () + Nothing -> + throwError . TypeError loc mempty $ + "Type is ambiguous." + "Must be a record with fields" + indent 2 (pretty (Scalar (Record fs1))) Just (Scalar (Record fs2)) | all (`M.member` fs2) (M.keys fs1) -> forM_ (M.toList $ M.intersectionWith (,) fs1 fs2) $ \(_k, (t1, t2)) -> @@ -317,7 +321,11 @@ solveTyVar (tv, (_, TyVarRecord loc fs1)) = do solveTyVar (tv, (_, TyVarSum loc cs1)) = do tv_t <- lookupTyVar tv case tv_t of - Nothing -> pure () + Nothing -> + throwError . TypeError loc mempty $ + "Type is ambiguous." + "Must be a sum type with constructors" + indent 2 (pretty (Scalar (Sum cs1))) Just (Scalar (Sum cs2)) | all (`M.member` cs2) (M.keys cs1), cs3 <- M.toList $ M.intersectionWith (,) cs1 cs2, From c4610ba6e37ef4eb0ac248f23fa3215b655cb141 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 11 Mar 2024 17:03:17 +0100 Subject: [PATCH 190/258] Add notion of equality type. --- src/Language/Futhark/TypeChecker/Constraints.hs | 6 ++++++ src/Language/Futhark/TypeChecker/Rank.hs | 2 ++ src/Language/Futhark/TypeChecker/Terms2.hs | 3 ++- 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 461b256e51..4a7d30ae98 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -93,6 +93,8 @@ data TyVarInfo TyVarRecord Loc (M.Map Name Type) | -- | Must be a sum type with these fields. TyVarSum Loc (M.Map Name [Type]) + | -- | Must be a type that supports equality. + TyVarEql Loc deriving (Show, Eq) instance Pretty TyVarInfo where @@ -100,12 +102,14 @@ instance Pretty TyVarInfo where pretty (TyVarPrim _ pts) = "∈" <+> pretty pts pretty (TyVarRecord _ fs) = pretty $ Scalar $ Record fs pretty (TyVarSum _ cs) = pretty $ Scalar $ Sum cs + pretty (TyVarEql _) = "equality" instance Located TyVarInfo where locOf (TyVarFree loc) = loc locOf (TyVarPrim loc _) = loc locOf (TyVarRecord loc _) = loc locOf (TyVarSum loc _) = loc + locOf (TyVarEql loc) = loc type TyVar = VName @@ -340,6 +344,8 @@ solveTyVar (tv, (_, TyVarSum loc cs1)) = do indent 2 (pretty tv_t') where sameLength (x, y) = length x == length y +solveTyVar (_, (_, TyVarEql _)) = + pure () solve :: Constraints -> TyVars -> Either TypeError ([VName], Solution) solve constraints tyvars = diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 5f3775a748..3826de135c 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -137,6 +137,8 @@ addTyVarInfo tv (_, TyVarRecord {}) = addConstraint $ rank tv ~==~ constant 0 addTyVarInfo tv (_, TyVarSum {}) = addConstraint $ rank tv ~==~ constant 0 +addTyVarInfo tv (_, TyVarEql {}) = + addConstraint $ rank tv ~==~ constant 0 mkLinearProg :: [Ct] -> TyVars -> LinearProg mkLinearProg cs tyVars = diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 13c9ef2cd8..51bac62776 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -410,7 +410,8 @@ lookupVar loc qn@(QualName qs name) = do outer_env <- asks termOuterEnv asStructType loc $ first (const SDim) $ qualifyTypeVars outer_env tnames qs t' Just EqualityF -> do - argtype <- newTypeOverloaded loc "t" anyPrimType + argtype <- + tyVarType NoUniqueness <$> newTyVarWith "t" (TyVarEql (locOf loc)) pure $ foldFunType [toParam Observe argtype, toParam Observe argtype] $ RetType [] $ Scalar $ Prim Bool Just (OverloadedF ts pts rt) -> do argtype <- newTypeOverloaded loc "t" ts From 521846f08121a09ce9251fc6c1e4e54b47e6798d Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 11 Mar 2024 17:04:17 +0100 Subject: [PATCH 191/258] Remove duplicate comment. --- src/Language/Futhark/Prop.hs | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/Language/Futhark/Prop.hs b/src/Language/Futhark/Prop.hs index d45cc3d294..154fce6e59 100644 --- a/src/Language/Futhark/Prop.hs +++ b/src/Language/Futhark/Prop.hs @@ -1067,8 +1067,6 @@ intrinsics = ++ [Bool] ) ++ - -- This overrides the ! from Primitive. - -- This overrides the ! from Primitive. [ ( "!", IntrinsicOverloadedFun From 300cce6492d1cd9ff173279b0f82bfaeb1baa8f2 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 12 Mar 2024 11:53:47 +0100 Subject: [PATCH 192/258] Fix check. --- src/Language/Futhark/TypeChecker/Terms2.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 51bac62776..94ffe58977 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -489,7 +489,7 @@ checkPat' p@(RecordPat p_fs loc) _ "Did you mean" <> dquotes (pretty (T.drop 1 (nameToText f)) <> "=_") <> "?" - | nubOrd (map fst p_fs) /= map fst p_fs = + | length (nubOrd (map fst p_fs)) /= length (map fst p_fs) = typeError loc mempty $ "Duplicate fields in record pattern" <+> pretty p <> "." checkPat' p@(RecordPat p_fs loc) (Ascribed t) From 6a690521dec2ec33f19c7c867f6de5b3445ace15 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 12 Mar 2024 17:19:36 +0100 Subject: [PATCH 193/258] Avoid artificial type variables in constraints. --- src/Language/Futhark/Prop.hs | 6 +- src/Language/Futhark/TypeChecker/Rank.hs | 30 +- src/Language/Futhark/TypeChecker/Terms2.hs | 408 +++++++++++++-------- 3 files changed, 280 insertions(+), 164 deletions(-) diff --git a/src/Language/Futhark/Prop.hs b/src/Language/Futhark/Prop.hs index 154fce6e59..0507139fdc 100644 --- a/src/Language/Futhark/Prop.hs +++ b/src/Language/Futhark/Prop.hs @@ -497,7 +497,7 @@ typeOf (Attr _ e _) = typeOf e typeOf (AppExp _ (Info res)) = appResType res -- | The type of a function with the given parameters and return type. -funType :: [Pat ParamType] -> ResRetType -> StructType +funType :: [Pat (TypeBase d Diet)] -> RetTypeBase d Uniqueness -> TypeBase d NoUniqueness funType params ret = let RetType _ t = foldr (arrow . patternParam) ret params in toStruct t @@ -507,7 +507,7 @@ funType params ret = -- | @foldFunType ts ret@ creates a function type ('Arrow') that takes -- @ts@ as parameters and returns @ret@. -foldFunType :: [ParamType] -> ResRetType -> StructType +foldFunType :: [TypeBase d Diet] -> RetTypeBase d Uniqueness -> TypeBase d NoUniqueness foldFunType ps ret = let RetType _ t = foldr arrow ret ps in toStruct t @@ -621,7 +621,7 @@ patternStructType = toStruct . patternType -- | When viewed as a function parameter, does this pattern correspond -- to a named parameter of some type? -patternParam :: Pat ParamType -> (PName, Diet, StructType) +patternParam :: Pat (TypeBase d Diet) -> (PName, Diet, TypeBase d NoUniqueness) patternParam (PatParens p _) = patternParam p patternParam (PatAttr _ p _) = diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 3826de135c..fba519b544 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -218,13 +218,13 @@ solveRankILP loc prog = do : map prettyString (M.toList r) pure rs -rankAnalysis1 :: (MonadTypeChecker m) => SrcLoc -> [Ct] -> TyVars -> [Pat ParamType] -> Exp -> m (([Ct], TyVars), [Pat ParamType], Exp) -rankAnalysis1 loc cs tyVars params body = do - solutions <- rankAnalysis loc cs tyVars params body +rankAnalysis1 :: (MonadTypeChecker m) => SrcLoc -> [Ct] -> TyVars -> M.Map TyVar Type -> [Pat ParamType] -> Exp -> m (([Ct], TyVars), M.Map TyVar Type, [Pat ParamType], Exp) +rankAnalysis1 loc cs tyVars artificial params body = do + solutions <- rankAnalysis loc cs tyVars artificial params body case solutions of [sol] -> pure sol sols -> do - let (_, _, bodies') = unzip3 sols + let (_, _, _, bodies') = L.unzip4 sols typeError loc mempty $ stack $ [ "Rank ILP is ambiguous.", @@ -232,14 +232,15 @@ rankAnalysis1 loc cs tyVars params body = do ] ++ map pretty bodies' -rankAnalysis :: (MonadTypeChecker m) => SrcLoc -> [Ct] -> TyVars -> [Pat ParamType] -> Exp -> m [(([Ct], TyVars), [Pat ParamType], Exp)] -rankAnalysis _ [] tyVars params body = pure [(([], tyVars), params, body)] -rankAnalysis loc cs tyVars params body = do +rankAnalysis :: (MonadTypeChecker m) => SrcLoc -> [Ct] -> TyVars -> M.Map TyVar Type -> [Pat ParamType] -> Exp -> m [(([Ct], TyVars), M.Map TyVar Type, [Pat ParamType], Exp)] +rankAnalysis _ [] tyVars artificial params body = pure [(([], tyVars), artificial, params, body)] +rankAnalysis loc cs tyVars artificial params body = do rank_maps <- solveRankILP loc (mkLinearProg cs' tyVars) cts_tyvars' <- mapM (substRankInfo cs tyVars) rank_maps let bodys = map (`updAM` body) rank_maps params' = map ((`map` params) . updAMPat) rank_maps - pure $ zip3 cts_tyvars' params' bodys + artificial' <- mapM (substRankInfoArtificial tyVars artificial) rank_maps + pure $ L.zip4 cts_tyvars' artificial' params' bodys where cs' = foldMap (splitFuncs . distribute) cs splitFuncs @@ -265,6 +266,12 @@ substRankInfo cs tyVars rankmap = do isCtAM (CtAM {}) = True isCtAM _ = False +substRankInfoArtificial :: (MonadTypeChecker m) => TyVars -> M.Map VName Type -> Map VName Int -> m (M.Map VName Type) +substRankInfoArtificial tyvars artificial rankmap = do + (artificial', _, _) <- + runSubstT tyvars rankmap $ traverse substRanks artificial + pure artificial' + runSubstT :: (MonadTypeChecker m) => TyVars -> Map VName Int -> SubstT m a -> m (a, [Ct], TyVars) runSubstT tyVars rankmap (SubstT m) = do let env = @@ -372,6 +379,7 @@ instance SubstRanks (TypeBase SComp u) where ta' <- substRanks ta tr' <- substRanks tr pure $ Scalar (Arrow u p d ta' (RetType retdims tr')) + substRanks (Scalar (Record fs)) = Scalar . Record <$> traverse substRanks fs substRanks (Array u shape t) = do shape' <- substRanks shape t' <- substRanks $ Scalar t @@ -391,15 +399,15 @@ updAM rank_map e = in AppExp (Apply f' args' loc) res AppExp (BinOp op t (x, Info (xv, xam)) (y, Info (yv, yam)) loc) res -> AppExp (BinOp op t (updAM rank_map x, Info (xv, upd xam)) (updAM rank_map y, Info (yv, upd yam)) loc) res - _ -> runIdentity $ astMap m e + _ -> runIdentity $ astMap mapper e where dimToRank (Var (QualName [] x) _ _) = replicate (rank_map M.! x) (TupLit mempty mempty) - dimToRank e = error $ prettyString e + dimToRank e' = error $ prettyString e' shapeToRank = Shape . foldMap dimToRank upd (AutoMap r m f) = AutoMap (shapeToRank r) (shapeToRank m) (shapeToRank f) - m = + mapper = identityMapper { mapOnExp = pure . updAM rank_map } diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 94ffe58977..6839cd3d29 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -79,14 +79,11 @@ instance Functor Inferred where fmap f (Ascribed t) = Ascribed (f t) data ValBinding - = BoundV [TypeParam] StructType + = BoundV [TypeParam] Type | OverloadedF [PrimType] [Maybe PrimType] (Maybe PrimType) | EqualityF deriving (Show) -expType :: Exp -> Type -expType = toType . typeOf - data TermScope = TermScope { scopeVtable :: M.Map VName ValBinding, scopeTypeTable :: M.Map VName TypeBinding, @@ -116,7 +113,9 @@ data TermState = TermState termTyVars :: TyVars, termCounter :: !Int, termWarnings :: Warnings, - termNameSource :: VNameSource + termNameSource :: VNameSource, + -- | Mapping from artificial type variables to the actual types they represent. + termArtificial :: M.Map TyVar Type } newtype TermM a @@ -143,7 +142,7 @@ envToTermScope env = } where vtable = M.map valBinding $ envVtable env - valBinding (TypeM.BoundV tps v) = BoundV tps v + valBinding (TypeM.BoundV tps v) = BoundV tps $ toType v initialTermScope :: TermScope initialTermScope = @@ -169,7 +168,7 @@ initialTermScope = addIntrinsicF (name, IntrinsicPolyFun tvs pts rt) = Just ( name, - BoundV tvs $ foldFunType pts rt + BoundV tvs $ toType $ foldFunType pts rt ) addIntrinsicF (name, IntrinsicEquality) = Just (name, EqualityF) @@ -194,7 +193,8 @@ runTermM (TermM m) = do termTyVars = mempty, termWarnings = mempty, termNameSource = src, - termCounter = 0 + termCounter = 0, + termArtificial = mempty } case runExcept (runStateT (runReaderT m initial_env) initial_state) of Left (ws, e) -> do @@ -254,26 +254,42 @@ newSVar _loc desc = do i <- incCounter newID $ mkTypeVarName desc i -asStructType :: SrcLoc -> TypeBase SComp u -> TermM (TypeBase Size u) -asStructType _ (Scalar (Prim pt)) = pure $ Scalar $ Prim pt -asStructType _ (Scalar (TypeVar u v [])) = pure $ Scalar $ TypeVar u v [] -asStructType loc (Scalar (Arrow u pname d t1 (RetType ext t2))) = do - t1' <- asStructType loc t1 - t2' <- asStructType loc t2 - pure $ Scalar $ Arrow u pname d t1' $ RetType ext t2' -asStructType loc (Scalar (Record fs)) = - Scalar . Record <$> traverse (asStructType loc) fs -asStructType loc (Scalar (Sum cs)) = - Scalar . Sum <$> traverse (mapM (asStructType loc)) cs -asStructType loc t@(Scalar (TypeVar u _ _)) = do - t' <- newType loc "artificial" u - ctEq (Reason (locOf loc)) (toType t') t - pure t' -asStructType loc t@(Array u _ _) = do - t' <- newType loc "artificial" u - ctEq (Reason (locOf loc)) (toType t') t +newArtificial :: u -> TypeBase SComp u -> TermM (TypeBase Size u) +newArtificial u t = do + v <- newID "artificial" + let t' = tyVarType u v + modify $ \s -> s {termArtificial = M.insert v (second (const NoUniqueness) t) $ termArtificial s} pure t' +-- The AST requires annotations to be StructTypes, but the type +-- checker works with Types. This creates artificial type "variables" +-- that allow us to connect the AST annotations with the actual +-- inferred types. The artificial variables should never occur in +-- constraints - they can be substituted away with asType. +asStructType :: TypeBase SComp u -> TermM (TypeBase Size u) +asStructType (Scalar (Prim pt)) = pure $ Scalar $ Prim pt +asStructType (Scalar (TypeVar u v [])) = pure $ Scalar $ TypeVar u v [] +asStructType (Scalar (Arrow u pname d t1 (RetType ext t2))) = do + t1' <- asStructType t1 + t2' <- asStructType t2 + pure $ Scalar $ Arrow u pname d t1' $ RetType ext t2' +asStructType (Scalar (Record fs)) = + Scalar . Record <$> traverse asStructType fs +asStructType (Scalar (Sum cs)) = + Scalar . Sum <$> traverse (mapM asStructType) cs +asStructType t@(Scalar (TypeVar u _ _)) = + newArtificial u t +asStructType t@(Array u _ _) = do + newArtificial u t + +asType :: (Monoid u) => TypeBase Size u -> TermM (TypeBase SComp u) +asType t = do + artificial <- gets termArtificial + pure $ substTyVars (`M.lookup` artificial) (toType t) + +expType :: Exp -> TermM Type +expType = asType . typeOf -- NOTE: Only place you should use typeOf. + addCt :: Ct -> TermM () addCt ct = modify $ \s -> s {termConstraints = ct : termConstraints s} @@ -336,8 +352,10 @@ instance MonadTypeChecker TermM where i <- incCounter newID $ mkTypeVarName name i - bindVal v (TypeM.BoundV tps t) = localScope $ \scope -> - scope {scopeVtable = M.insert v (BoundV tps t) $ scopeVtable scope} + bindVal v (TypeM.BoundV tps t) m = do + t' <- asType t + let f scope = scope {scopeVtable = M.insert v (BoundV tps t') $ scopeVtable scope} + localScope f m lookupType qn = do outer_env <- asks termOuterEnv @@ -361,11 +379,13 @@ arrayOfRank n = arrayOf $ Shape $ replicate n SDim require :: T.Text -> [PrimType] -> Exp -> TermM Exp require _why [pt] e = do - ctEq (Reason (locOf e)) (Scalar $ Prim pt) (expType e) + e_t <- expType e + ctEq (Reason (locOf e)) (Scalar $ Prim pt) e_t pure e require _why pts e = do t :: Type <- newTypeOverloaded (srclocOf e) "t" pts - ctEq (Reason (locOf e)) t $ expType e + e_t <- expType e + ctEq (Reason (locOf e)) t e_t pure e -- | Instantiate a type scheme with fresh type variables for its type @@ -375,18 +395,18 @@ instTypeScheme :: QualName VName -> SrcLoc -> [TypeParam] -> - StructType -> - TermM ([VName], StructType) + Type -> + TermM ([VName], Type) instTypeScheme _qn loc tparams t = do (names, substs) <- fmap (unzip . catMaybes) $ forM tparams $ \tparam -> case tparam of TypeParamType _ v _ -> do v' <- newTyVar loc $ nameFromString $ takeWhile isAscii $ baseString v - pure $ Just (v, (typeParamName tparam, Subst [] $ RetType [] $ tyVarType NoUniqueness v')) + pure $ Just (v, (typeParamName tparam, tyVarType NoUniqueness v')) TypeParamDim {} -> pure Nothing - let t' = applySubst (`lookup` substs) t + let t' = substTyVars (`lookup` substs) t pure (names, t') lookupMod :: QualName VName -> TermM Mod @@ -396,7 +416,7 @@ lookupMod qn@(QualName _ name) = do Nothing -> error $ "lookupMod: " <> show qn Just m -> pure m -lookupVar :: SrcLoc -> QualName VName -> TermM StructType +lookupVar :: SrcLoc -> QualName VName -> TermM Type lookupVar loc qn@(QualName qs name) = do scope <- lookupQualNameEnv qn case M.lookup name $ scopeVtable scope of @@ -407,16 +427,15 @@ lookupVar loc qn@(QualName qs name) = do then pure t else do (tnames, t') <- instTypeScheme qn loc tparams t - outer_env <- asks termOuterEnv - asStructType loc $ first (const SDim) $ qualifyTypeVars outer_env tnames qs t' + -- TODO - qualify type names, like in the old type checker. + pure t' Just EqualityF -> do - argtype <- - tyVarType NoUniqueness <$> newTyVarWith "t" (TyVarEql (locOf loc)) - pure $ foldFunType [toParam Observe argtype, toParam Observe argtype] $ RetType [] $ Scalar $ Prim Bool + argtype <- tyVarType Observe <$> newTyVarWith "t" (TyVarEql (locOf loc)) + pure $ foldFunType [argtype, argtype] $ RetType [] $ Scalar $ Prim Bool Just (OverloadedF ts pts rt) -> do argtype <- newTypeOverloaded loc "t" ts let (pts', rt') = instOverloaded argtype pts rt - pure $ foldFunType (map (toParam Observe) pts') $ RetType [] $ toRes Nonunique rt' + pure $ foldFunType (map (second $ const Observe) pts') $ RetType [] $ second (const Nonunique) rt' where instOverloaded argtype pts rt = ( map (maybe argtype (Scalar . Prim)) pts, @@ -427,13 +446,16 @@ bind :: [Ident StructType] -> TermM a -> TermM a -bind idents = localScope (`bindVars` idents) +bind idents m = do + let names = map identName idents + ts <- mapM (asType . unInfo . identType) idents + localScope (`bindVars` zip names ts) m where bindVars = foldl bindVar - bindVar scope (Ident name (Info tp) _) = + bindVar scope (name, t) = scope - { scopeVtable = M.insert name (BoundV [] tp) $ scopeVtable scope + { scopeVtable = M.insert name (BoundV [] t) $ scopeVtable scope } -- All this complexity is just so we can handle un-suffixed numeric @@ -449,24 +471,27 @@ patLitMkType (PatLitPrim v) _ = checkSizeExp' :: ExpBase NoInfo VName -> TermM Exp checkSizeExp' e = do e' <- checkExp e - ctEq (Reason (locOf e)) (expType e') (Scalar (Prim (Signed Int64))) + e_t <- expType e' + ctEq (Reason (locOf e)) e_t (Scalar (Prim (Signed Int64))) pure e' checkPat' :: PatBase NoInfo VName ParamType -> - Inferred ParamType -> + Inferred (TypeBase SComp Diet) -> TermM (Pat ParamType) checkPat' (PatParens p loc) t = PatParens <$> checkPat' p t <*> pure loc checkPat' (PatAttr attr p loc) t = PatAttr <$> checkAttr attr <*> checkPat' p t <*> pure loc -checkPat' (Id name NoInfo loc) (Ascribed t) = - pure $ Id name (Info t) loc +checkPat' (Id name NoInfo loc) (Ascribed t) = do + t' <- asStructType t + pure $ Id name (Info t') loc checkPat' (Id name NoInfo loc) NoneInferred = do t <- newType loc "t" Observe pure $ Id name (Info t) loc -checkPat' (Wildcard _ loc) (Ascribed t) = - pure $ Wildcard (Info t) loc +checkPat' (Wildcard _ loc) (Ascribed t) = do + t' <- asStructType t + pure $ Wildcard (Info t') loc checkPat' (Wildcard NoInfo loc) NoneInferred = do t <- newType loc "t" Observe pure $ Wildcard (Info t) loc @@ -477,9 +502,9 @@ checkPat' (TuplePat ps loc) (Ascribed t) <$> zipWithM checkPat' ps (map Ascribed ts) <*> pure loc | otherwise = do - ps_t <- replicateM (length ps) (newType loc "t" Observe) - ctEq (Reason (locOf loc)) (toType (Scalar (tupleRecord ps_t))) (toType t) - TuplePat <$> zipWithM checkPat' ps (map Ascribed ps_t) <*> pure loc + ps_tvs <- replicateM (length ps) (newTyVar loc "t") + ctEq (Reason (locOf loc)) (Scalar (tupleRecord $ map (tyVarType NoUniqueness) ps_tvs)) t + TuplePat <$> zipWithM checkPat' ps (map (Ascribed . tyVarType Observe) ps_tvs) <*> pure loc checkPat' (TuplePat ps loc) NoneInferred = TuplePat <$> mapM (`checkPat'` NoneInferred) ps <*> pure loc checkPat' p@(RecordPat p_fs loc) _ @@ -498,9 +523,8 @@ checkPat' p@(RecordPat p_fs loc) (Ascribed t) RecordPat . M.toList <$> check t_fs <*> pure loc | otherwise = do p_fs' <- traverse (const $ newType loc "t" NoUniqueness) $ M.fromList p_fs - ctEq (Reason (locOf loc)) (Scalar (Record p_fs')) $ toType t - st <- asStructType loc $ Scalar (Record p_fs') - checkPat' p $ Ascribed $ toParam Observe st + ctEq (Reason (locOf loc)) (Scalar (Record p_fs')) t + checkPat' p $ Ascribed $ const Observe <$> Scalar (Record p_fs') where check t_fs = traverse (uncurry checkPat') $ @@ -514,11 +538,11 @@ checkPat' (PatAscription p t loc) maybe_outer_t = do -- Uniqueness kung fu to make the Monoid(mempty) instance give what -- we expect. We should perhaps stop being so implicit. - st' <- asStructType loc $ toType $ resToParam st + st' <- asType $ resToParam st case maybe_outer_t of Ascribed outer_t -> do - ctEq (Reason (locOf loc)) (toType st') (toType outer_t) + ctEq (Reason (locOf loc)) st' outer_t PatAscription <$> checkPat' p (Ascribed st') <*> pure t' @@ -530,7 +554,7 @@ checkPat' (PatAscription p t loc) maybe_outer_t = do <*> pure loc checkPat' (PatLit l NoInfo loc) (Ascribed t) = do t' <- patLitMkType l loc - ctEq (Reason (locOf loc)) (toType t') (toType t) + ctEq (Reason (locOf loc)) (toType t') t pure $ PatLit l (Info t') loc checkPat' (PatLit l NoInfo loc) NoneInferred = do t' <- patLitMkType l loc @@ -547,28 +571,29 @@ checkPat' (PatConstr n NoInfo ps loc) (Ascribed (Scalar (Sum cs))) <+> pretty (length ts) <+> "arguments." ps' <- zipWithM checkPat' ps $ map Ascribed ts - pure $ PatConstr n (Info (Scalar (Sum cs))) ps' loc + cs' <- traverse (mapM (asStructType)) cs + pure $ PatConstr n (Info (Scalar (Sum cs'))) ps' loc checkPat' (PatConstr n NoInfo ps loc) (Ascribed t) = do ps' <- forM ps $ \p -> do p_t <- newType (srclocOf p) "t" Observe checkPat' p $ Ascribed p_t t' <- newTypeWithConstr loc "t" Observe n $ map (toType . patternType) ps' - ctEq (Reason (locOf loc)) t' (toType t) - t'' <- asStructType loc t' + ctEq (Reason (locOf loc)) t' t + t'' <- asStructType t' pure $ PatConstr n (Info $ toParam Observe t'') ps' loc checkPat' (PatConstr n NoInfo ps loc) NoneInferred = do ps' <- mapM (`checkPat'` NoneInferred) ps t <- newTypeWithConstr loc "t" Observe n $ map (toType . patternType) ps' - t' <- asStructType loc t + t' <- asStructType t pure $ PatConstr n (Info $ toParam Observe t') ps' loc checkPat :: PatBase NoInfo VName (TypeBase Size u) -> - Inferred StructType -> + Inferred Type -> (Pat ParamType -> TermM a) -> TermM a checkPat p t m = - m =<< checkPat' (fmap (toParam Observe) p) (fmap (toParam Observe) t) + m =<< checkPat' (fmap (toParam Observe) p) (fmap (fmap (const Observe)) t) -- | Bind @let@-bound sizes. This is usually followed by 'bindletPat' -- immediately afterwards. @@ -581,7 +606,7 @@ bindSizes sizes m = bind (map sizeWithType sizes) m bindLetPat :: PatBase NoInfo VName (TypeBase Size u) -> - StructType -> + Type -> (Pat ParamType -> TermM a) -> TermM a bindLetPat p t m = do @@ -628,11 +653,15 @@ bindParams tps orig_ps m = bindTypeParams tps $ do incLevel $ descend [] orig_ps -checkApply :: SrcLoc -> Maybe (QualName VName) -> (Shape Size, Type) -> NE.NonEmpty (Shape Size, Type) -> TermM (StructType, NE.NonEmpty AutoMap) +checkApply :: + SrcLoc -> + Maybe (QualName VName) -> + (Shape Size, Type) -> + NE.NonEmpty (Shape Size, Type) -> + TermM (Type, NE.NonEmpty AutoMap) checkApply loc fname (fframe, ftype) args = do ((_, _, rt), argts) <- mapAccumLM onArg (0, fframe, ftype) args - rt' <- asStructType loc rt - pure (rt', argts) + pure (rt, argts) where onArg (i, f_f, f_t) (argframe, argtype) = do (rt, am) <- checkApplyOne loc (fname, i) (f_f, f_t) (argframe, argtype) @@ -732,18 +761,19 @@ mustHaveFields loc t (f : fs) ve_t = do ctEq (Reason (locOf loc)) t rt checkCase :: - StructType -> + Type -> CaseBase NoInfo VName -> - TermM (CaseBase Info VName, StructType) + TermM (CaseBase Info VName, Type) checkCase mt (CasePat p e loc) = bindLetPat p mt $ \p' -> do e' <- checkExp e - pure (CasePat (fmap toStruct p') e' loc, typeOf e') + e_t <- expType e' + pure (CasePat (fmap toStruct p') e' loc, e_t) checkCases :: - StructType -> + Type -> NE.NonEmpty (CaseBase NoInfo VName) -> - TermM (NE.NonEmpty (CaseBase Info VName), StructType) + TermM (NE.NonEmpty (CaseBase Info VName), Type) checkCases mt rest_cs = case NE.uncons rest_cs of (c, Nothing) -> do @@ -752,7 +782,7 @@ checkCases mt rest_cs = (c, Just cs) -> do (c', c_t) <- checkCase mt c (cs', cs_t) <- checkCases mt cs - ctEq (Reason (locOf c)) (toType c_t) (toType cs_t) + ctEq (Reason (locOf c)) c_t cs_t pure (NE.cons c' cs', c_t) -- | An unmatched pattern. Used in in the generation of @@ -790,16 +820,18 @@ checkRetDecl :: checkRetDecl _ Nothing = pure Nothing checkRetDecl body (Just te) = do (te', _, RetType _ st, _) <- checkTypeExp checkSizeExp' te - ctEq (Reason (locOf body)) (expType body) (toType st) + body_t <- expType body + st' <- asType st + ctEq (Reason (locOf body)) body_t st' pure $ Just te' checkExp :: ExpBase NoInfo VName -> TermM (ExpBase Info VName) -- checkExp (Var qn _ loc) = do - t <- lookupVar loc qn + t <- asStructType =<< lookupVar loc qn pure $ Var qn (Info t) loc checkExp (OpSection op _ loc) = do - ftype <- lookupVar loc op + ftype <- asStructType =<< lookupVar loc op pure $ OpSection op (Info ftype) loc checkExp (Negate arg loc) = do arg' <- require "numeric negation" anyNumberType =<< checkExp arg @@ -841,7 +873,9 @@ checkExp (ArrayLit es _ loc) = do et <- newType loc "et" NoUniqueness es' <- forM es $ \e -> do e' <- checkExp e - ctEq (Reason (locOf loc)) (expType e') (toType et) + e_t <- expType e' + et' <- asType et + ctEq (Reason (locOf loc)) e_t et' pure e' let arr_t = arrayOf (Shape [sizeFromInteger (L.genericLength es) loc]) et pure $ ArrayLit es' (Info arr_t) loc @@ -856,7 +890,8 @@ checkExp (RecordLit fs loc) = errIfAlreadySet (baseName name) rloc t <- lift $ lookupVar rloc $ qualName name modify $ M.insert (baseName name) rloc - pure $ RecordFieldImplicit name (Info t) rloc + t' <- lift $ asStructType t + pure $ RecordFieldImplicit name (Info t') rloc errIfAlreadySet f rloc = do maybe_sloc <- gets $ M.lookup f @@ -880,7 +915,8 @@ checkExp (Assert e1 e2 NoInfo loc) = do -- checkExp (Constr name es NoInfo loc) = do es' <- mapM checkExp es - t <- newTypeWithConstr loc "t" NoUniqueness name $ map expType es' + es_ts <- mapM expType es' + t <- newTypeWithConstr loc "t" NoUniqueness name es_ts pure $ Constr name es' (Info t) loc -- checkExp (AppExp (Apply fe args loc) NoInfo) = do @@ -891,13 +927,15 @@ checkExp (AppExp (Apply fe args loc) NoInfo) = do args ( \(_, arg) -> do arg' <- checkExp arg - pure (arg', (frameOf arg', expType arg')) + arg_t <- expType arg' + pure (arg', (frameOf arg', arg_t)) ) - (rt, ams) <- checkApply loc fname (frameOf fe', expType fe') argts' + fe_t <- expType fe' + (rt, ams) <- checkApply loc fname (frameOf fe', fe_t) argts' + rt' <- asStructType rt pure $ AppExp (Apply fe' (NE.zipWith (\am arg -> (Info (Nothing, am), arg)) ams args') loc) $ - Info $ - AppRes rt [] + Info (AppRes rt' []) where fname = case fe of @@ -906,59 +944,71 @@ checkExp (AppExp (Apply fe args loc) NoInfo) = do checkExp (AppExp (BinOp (op, oploc) NoInfo (e1, _) (e2, _) loc) NoInfo) = do ftype <- lookupVar oploc op e1' <- checkExp e1 + e1_t <- expType e1' e2' <- checkExp e2 + e2_t <- expType e2' + (rt, ams) <- checkApply loc (Just op) - (mempty, toType ftype) - ((frameOf e1', toType $ typeOf e1') NE.:| [(frameOf e2', toType $ typeOf e2')]) + (mempty, ftype) + ((frameOf e1', e1_t) NE.:| [(frameOf e2', e2_t)]) + rt' <- asStructType rt let (am1 NE.:| [am2]) = ams + ftype' <- asStructType ftype pure $ AppExp - (BinOp (op, oploc) (Info ftype) (e1', Info (Nothing, am1)) (e2', Info (Nothing, am2)) loc) - (Info (AppRes rt [])) + (BinOp (op, oploc) (Info ftype') (e1', Info (Nothing, am1)) (e2', Info (Nothing, am2)) loc) + (Info (AppRes rt' [])) -- checkExp (OpSectionLeft op _ e _ _ loc) = do optype <- lookupVar loc op e' <- checkExp e + e_t <- expType e' t2 <- newType loc "t" NoUniqueness - t2' <- asStructType loc t2 - let t1 = typeOf e' - f1 = frameOf e' - (rt, ams) <- checkApply loc (Just op) (mempty, toType optype) ((f1, toType t1) NE.:| [(mempty, t2)]) + t2' <- asStructType t2 + let f1 = frameOf e' + (rt, ams) <- checkApply loc (Just op) (mempty, optype) ((f1, e_t) NE.:| [(mempty, t2)]) + rt' <- asStructType rt let (am1 NE.:| _) = ams + t1 <- asStructType e_t + optype' <- asStructType optype pure $ OpSectionLeft op - (Info optype) + (Info optype') e' ( Info (Unnamed, toParam Observe t1, Nothing, am1), Info (Unnamed, toParam Observe t2') ) - (Info (RetType [] (rt `setUniqueness` Nonunique)), Info []) + (Info (RetType [] (rt' `setUniqueness` Nonunique)), Info []) loc checkExp (OpSectionRight op _ e _ NoInfo loc) = do optype <- lookupVar loc op e' <- checkExp e + e_t <- expType e' t1 <- newType loc "t" NoUniqueness - t1' <- asStructType loc t1 - let t2 = typeOf e' - f2 = frameOf e' - (rt, ams) <- checkApply loc (Just op) (mempty, toType optype) ((mempty, t1) NE.:| [(f2, toType t2)]) + t1' <- asStructType t1 + let f2 = frameOf e' + (rt, ams) <- checkApply loc (Just op) (mempty, optype) ((mempty, t1) NE.:| [(f2, e_t)]) + rt' <- asStructType rt let (_ NE.:| [am2]) = ams + t2 <- asStructType e_t + + optype' <- asStructType optype pure $ OpSectionRight op - (Info optype) + (Info optype') e' -- Dummy types. ( Info (Unnamed, toParam Observe t1'), Info (Unnamed, toParam Observe t2, Nothing, am2) ) - (Info $ RetType [] (rt `setUniqueness` Nonunique)) + (Info $ RetType [] (rt' `setUniqueness` Nonunique)) loc -- checkExp (ProjectSection fields NoInfo loc) = do @@ -971,70 +1021,91 @@ checkExp (ProjectSection fields NoInfo loc) = do checkExp (Lambda params body retdecl NoInfo loc) = do bindParams [] params $ \params' -> do body' <- checkExp body + body_t <- expType body' + + body_t' <- asStructType body_t retdecl' <- checkRetDecl body' retdecl - let ret = RetType [] $ toRes Nonunique $ typeOf body' + let ret = RetType [] $ toRes Nonunique body_t' pure $ Lambda params' body' retdecl' (Info ret) loc -- checkExp (AppExp (LetPat sizes pat e body loc) _) = do e' <- checkExp e + e_t <- expType e' - bindSizes sizes . incLevel . bindLetPat pat (typeOf e') $ \pat' -> do + bindSizes sizes . incLevel . bindLetPat pat e_t $ \pat' -> do body' <- incLevel $ checkExp body + body_t <- expType body' + + body_t' <- asStructType body_t pure $ AppExp (LetPat sizes (fmap toStruct pat') e' body' loc) - (Info $ AppRes (typeOf body') []) + (Info $ AppRes body_t' []) -- checkExp (AppExp (LetFun name (tparams, params, retdecl, NoInfo, e) body loc) _) = do (tparams', params', retdecl', rettype, e') <- bindParams tparams params $ \params' -> do e' <- checkExp e - let ret = RetType [] $ toRes Nonunique $ typeOf e' + e_t <- expType e' + let ret = fmap (const Nonunique) e_t retdecl' <- checkRetDecl e' retdecl pure (tparams, params', retdecl', ret, e') - let entry = BoundV tparams' $ funType params' rettype + params'' <- mapM (traverse asType) params' + + let entry = BoundV tparams' $ funType params'' $ RetType [] rettype bindF scope = scope { scopeVtable = M.insert name entry $ scopeVtable scope } body' <- localScope bindF $ checkExp body + body_t <- expType body' + body_t' <- asStructType body_t + rettype' <- asStructType rettype pure $ AppExp ( LetFun name - (tparams', params', retdecl', Info rettype, e') + (tparams', params', retdecl', Info (RetType [] rettype'), e') body' loc ) - (Info $ AppRes (typeOf body') []) + (Info $ AppRes body_t' []) -- checkExp (AppExp (Range start maybe_step end loc) _) = do start' <- require "use in range expression" anyIntType =<< checkExp start let check e = do e' <- checkExp e - ctEq (Reason (locOf e')) (expType start') (expType e') + start_t <- expType start' + e_t <- expType e' + ctEq (Reason (locOf e')) start_t e_t pure e' maybe_step' <- traverse check maybe_step end' <- traverse check end range_t <- newType loc "range" NoUniqueness - ctEq (Reason (locOf start')) (toType range_t) (arrayOfRank 1 (expType start')) + range_t' <- asType range_t + start_t <- expType start' + ctEq (Reason (locOf start')) range_t' (arrayOfRank 1 start_t) pure $ AppExp (Range start' maybe_step' end' loc) $ Info $ AppRes range_t [] -- checkExp (Project k e NoInfo loc) = do e' <- checkExp e kt <- newType loc "kt" NoUniqueness t <- newTypeWithField loc "t" k kt - ctEq (Reason (locOf e')) (expType e') t - kt' <- asStructType loc kt + e_t <- expType e' + ctEq (Reason (locOf e')) e_t t + kt' <- asStructType kt pure $ Project k e' (Info kt') loc -- checkExp (RecordUpdate src fields ve NoInfo loc) = do src' <- checkExp src + src_t <- expType src' ve' <- checkExp ve - mustHaveFields loc (expType src') fields (expType ve') - pure $ RecordUpdate src' fields ve' (Info (typeOf src')) loc + ve_t <- expType ve' + mustHaveFields loc src_t fields ve_t + src_t' <- asStructType src_t + pure $ RecordUpdate src' fields ve' (Info src_t') loc -- checkExp (IndexSection slice NoInfo loc) = do slice' <- checkSlice slice @@ -1044,67 +1115,84 @@ checkExp (IndexSection slice NoInfo loc) = do let num_slices = length $ filter isSlice slice ctEq (Reason (locOf loc)) (toType index_arg_t) $ arrayOfRank num_slices index_elem_t ctEq (Reason (locOf loc)) index_res_t $ arrayOfRank (length slice) index_elem_t - index_res_t' <- asStructType loc index_res_t + index_res_t' <- asStructType index_res_t let ft = Scalar $ Arrow mempty Unnamed Observe index_arg_t $ toResRet Nonunique $ RetType [] index_res_t' pure $ IndexSection slice' (Info ft) loc -- checkExp (AppExp (Index e slice loc) _) = do e' <- checkExp e + e_t <- expType e' slice' <- checkSlice slice - index_t <- newType loc "index" NoUniqueness + index_tv <- newTyVar loc "index" index_elem_t <- newType loc "index_elem" NoUniqueness let num_slices = length $ filter isSlice slice - ctEq (Reason (locOf loc)) (toType index_t) $ arrayOfRank num_slices index_elem_t - ctEq (Reason (locOf e')) (expType e') $ arrayOfRank (length slice) index_elem_t - pure $ AppExp (Index e' slice' loc) (Info $ AppRes index_t []) + ctEq (Reason (locOf loc)) (tyVarType NoUniqueness index_tv) $ arrayOfRank num_slices index_elem_t + ctEq (Reason (locOf e')) e_t $ arrayOfRank (length slice) index_elem_t + pure $ AppExp (Index e' slice' loc) (Info $ AppRes (tyVarType NoUniqueness index_tv) []) -- checkExp (Update src slice ve loc) = do src' <- checkExp src + src_t <- expType src' slice' <- checkSlice slice ve' <- checkExp ve + ve_t <- expType ve' let num_slices = length $ filter isSlice slice update_elem_t <- newType loc "update_elem" NoUniqueness - ctEq (Reason (locOf src')) (expType src') $ arrayOfRank (length slice) update_elem_t - ctEq (Reason (locOf ve')) (expType ve') $ arrayOfRank num_slices update_elem_t + ctEq (Reason (locOf src')) src_t $ arrayOfRank (length slice) update_elem_t + ctEq (Reason (locOf ve')) ve_t $ arrayOfRank num_slices update_elem_t pure $ Update src' slice' ve' loc -- checkExp (AppExp (LetWith dest src slice ve body loc) _) = do src_t <- lookupVar (srclocOf src) $ qualName $ identName src - let src' = src {identType = Info src_t} - dest' = dest {identType = Info src_t} + src_t' <- asStructType src_t + let src' = src {identType = Info src_t'} + dest' = dest {identType = Info src_t'} slice' <- checkSlice slice ve' <- checkExp ve + ve_t <- expType ve' let num_slices = length $ filter isSlice slice update_elem_t <- newType loc "update_elem" NoUniqueness - ctEq (Reason (locOf loc)) (toType src_t) $ arrayOfRank (length slice) update_elem_t - ctEq (Reason (locOf ve')) (expType ve') $ arrayOfRank num_slices update_elem_t + ctEq (Reason (locOf loc)) src_t $ arrayOfRank (length slice) update_elem_t + ctEq (Reason (locOf ve')) ve_t $ arrayOfRank num_slices update_elem_t bind [dest'] $ do body' <- checkExp body - pure $ AppExp (LetWith dest' src' slice' ve' body' loc) (Info $ AppRes (typeOf body') []) + body_t <- expType body' + body_t' <- asStructType body_t + pure $ AppExp (LetWith dest' src' slice' ve' body' loc) (Info $ AppRes body_t' []) -- checkExp (AppExp (If e1 e2 e3 loc) _) = do e1' <- checkExp e1 + e1_t <- expType e1' e2' <- checkExp e2 + e2_t <- expType e2' e3' <- checkExp e3 + e3_t <- expType e3' - ctEq (Reason (locOf e1')) (expType e1') (Scalar (Prim Bool)) - ctEq (Reason (locOf loc)) (expType e2') (expType e3') + ctEq (Reason (locOf e1')) e1_t (Scalar (Prim Bool)) + ctEq (Reason (locOf loc)) e2_t e3_t - pure $ AppExp (If e1' e2' e3' loc) (Info $ AppRes (typeOf e2') []) + e2_t' <- asStructType e2_t + pure $ AppExp (If e1' e2' e3' loc) (Info $ AppRes e2_t' []) -- checkExp (AppExp (Match e cs loc) _) = do e' <- checkExp e - (cs', t) <- checkCases (typeOf e') cs - pure $ AppExp (Match e' cs' loc) (Info $ AppRes t []) + e_t <- expType e' + + (cs', t) <- checkCases e_t cs + t' <- asStructType t + pure $ AppExp (Match e' cs' loc) (Info $ AppRes t' []) -- checkExp (AppExp (Loop _ pat arg form body loc) _) = do arg' <- checkExp arg - bindLetPat pat (typeOf arg') $ \pat' -> do + arg_t <- expType arg' + bindLetPat pat arg_t $ \pat' -> do (form', body') <- case form of For (Ident i _ iloc) bound -> do bound' <- require "loop bound" anyIntType =<< checkExp bound - let i' = Ident i (Info (typeOf bound')) iloc + bound_t <- expType bound' + bound_t' <- asStructType bound_t + let i' = Ident i (Info bound_t') iloc bind [i'] $ do body' <- checkExp body pure (For i' bound', body') @@ -1115,11 +1203,14 @@ checkExp (AppExp (Loop _ pat arg form body loc) _) = do ForIn elemp arr -> do arr' <- checkExp arr elem_t <- newType elemp "elem" NoUniqueness - ctEq (Reason (locOf arr')) (expType arr') $ arrayOfRank 1 (toType elem_t) - bindLetPat elemp elem_t $ \elemp' -> do + arr_t <- expType arr' + elem_t' <- asType elem_t + ctEq (Reason (locOf arr')) arr_t $ arrayOfRank 1 elem_t' + bindLetPat elemp elem_t' $ \elemp' -> do body' <- checkExp body pure (ForIn (toStruct <$> elemp') arr', body') - ctEq (Reason (locOf loc)) (expType arg') (expType body') + body_t <- expType body' + ctEq (Reason (locOf loc)) arg_t body_t pure $ AppExp (Loop [] pat' arg' form' body' loc) @@ -1128,12 +1219,16 @@ checkExp (AppExp (Loop _ pat arg form body loc) _) = do checkExp (Ascript e te loc) = do e' <- checkExp e (te', _, RetType _ st, _) <- checkTypeExp checkSizeExp' te - ctEq (Reason (locOf e')) (expType e') (toType st) + e_t <- expType e' + st' <- asType st + ctEq (Reason (locOf e')) e_t st' pure $ Ascript e' te' loc checkExp (Coerce e te NoInfo loc) = do e' <- checkExp e (te', _, RetType _ st, _) <- checkTypeExp checkSizeExp' te - ctEq (Reason (locOf e')) (expType e') (toType st) + e_t <- expType e' + st' <- asType st + ctEq (Reason (locOf e')) e_t st' pure $ Coerce e' te' (Info (toStruct st)) loc doDefault :: @@ -1172,7 +1267,7 @@ doDefaults tyvars_at_toplevel substs = do pure $ M.map (substTyVars (`M.lookup` substs')) substs' generalise :: - StructType -> [VName] -> Solution -> ([TypeParam], [VName]) + TypeBase () NoUniqueness -> [VName] -> Solution -> ([TypeParam], [VName]) generalise fun_t unconstrained solution = -- Candidates for let-generalisation are those type variables that -- are used in fun_t. @@ -1190,7 +1285,7 @@ generalise fun_t unconstrained solution = generaliseAndDefaults :: [VName] -> Solution -> - StructType -> + TypeBase () NoUniqueness -> TermM ([TypeParam], M.Map VName (TypeBase () NoUniqueness)) generaliseAndDefaults unconstrained solution t = do let (generalised, unconstrained') = @@ -1225,8 +1320,8 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do pure (params', body', retdecl') cts <- gets termConstraints - tyvars <- gets termTyVars + artificial <- gets termArtificial debugTraceM 3 $ "\n# function " <> prettyNameString fname <> "\n# " <> locStr loc <> "\n" @@ -1237,14 +1332,20 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do "## body:", prettyString body', "## tyvars:", - unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars + unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars, + "## artificial:", + unlines $ map (\(v, t) -> prettyNameString v <> " => " <> prettyString t) (M.toList artificial) ] - onRankSolution retdecl' =<< rankAnalysis1 loc cts tyvars params' body' + onRankSolution retdecl' + =<< rankAnalysis1 loc cts tyvars artificial params' body' where - onRankSolution retdecl' ((cts', tyvars'), params', body'') = do + onRankSolution retdecl' ((cts', tyvars'), artificial, params', body'') = do solution <- - bitraverse pure (onTySolution params' body'') $ solve cts' tyvars' + bitraverse + pure + (fmap (second (onArtificial artificial)) . onTySolution params' body'') + $ solve cts' tyvars' debugTraceM 3 $ unlines [ "## constraints:", @@ -1259,12 +1360,16 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do pure (solution, params', retdecl', body'') onTySolution params' body' (unconstrained, solution) = do + body_t <- expType body' let fun_t = foldFunType - (map patternType params') - (RetType [] $ toRes Nonunique (typeOf body')) + (map (first (const ()) . patternType) params') + (RetType [] $ bimap (const ()) (const Nonunique) body_t) generaliseAndDefaults unconstrained solution fun_t + onArtificial artificial solution = + M.map (substTyVars (`M.lookup` solution) . first (const ())) artificial <> solution + checkSingleExp :: ExpBase NoInfo VName -> TypeM (Either TypeError ([TypeParam], M.Map TyVar (TypeBase () NoUniqueness)), Exp) @@ -1272,11 +1377,13 @@ checkSingleExp e = runTermM $ do e' <- checkExp e cts <- gets termConstraints tyvars <- gets termTyVars - ((cts', tyvars'), _, e'') <- rankAnalysis1 (srclocOf e') cts tyvars [] e' + artificial <- gets termArtificial + ((cts', tyvars'), _, _, e'') <- rankAnalysis1 (srclocOf e') cts tyvars artificial [] e' case solve cts' tyvars' of Left err -> pure (Left err, e'') Right (unconstrained, solution) -> do - x <- generaliseAndDefaults unconstrained solution $ typeOf e'' + e_t <- expType e'' + x <- generaliseAndDefaults unconstrained solution $ first (const ()) e_t pure (Right x, e'') -- | Type-check a single size expression in isolation. This expression may @@ -1288,8 +1395,9 @@ checkSizeExp e = runTermM $ do e' <- checkSizeExp' e cts <- gets termConstraints tyvars <- gets termTyVars + artificial <- gets termArtificial - (cts_tyvars', _, es') <- unzip3 <$> rankAnalysis (srclocOf e) cts tyvars [] e' + (cts_tyvars', _, _, es') <- L.unzip4 <$> rankAnalysis (srclocOf e) cts tyvars artificial [] e' solutions <- forM cts_tyvars' $ From 2fbe6c22b6565074cf9e36dc1de0c5c0b5f1e13f Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 12 Mar 2024 17:30:20 +0100 Subject: [PATCH 194/258] Better to do this in same pass. --- src/Language/Futhark/TypeChecker/Rank.hs | 53 ++++++++++++++-------- src/Language/Futhark/TypeChecker/Terms2.hs | 15 +++--- 2 files changed, 42 insertions(+), 26 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index fba519b544..c74fad677c 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -218,13 +218,21 @@ solveRankILP loc prog = do : map prettyString (M.toList r) pure rs -rankAnalysis1 :: (MonadTypeChecker m) => SrcLoc -> [Ct] -> TyVars -> M.Map TyVar Type -> [Pat ParamType] -> Exp -> m (([Ct], TyVars), M.Map TyVar Type, [Pat ParamType], Exp) +rankAnalysis1 :: + (MonadTypeChecker m) => + SrcLoc -> + [Ct] -> + TyVars -> + M.Map TyVar Type -> + [Pat ParamType] -> + Exp -> + m (([Ct], M.Map TyVar Type, TyVars), [Pat ParamType], Exp) rankAnalysis1 loc cs tyVars artificial params body = do solutions <- rankAnalysis loc cs tyVars artificial params body case solutions of [sol] -> pure sol sols -> do - let (_, _, _, bodies') = L.unzip4 sols + let (_, _, bodies') = unzip3 sols typeError loc mempty $ stack $ [ "Rank ILP is ambiguous.", @@ -232,15 +240,23 @@ rankAnalysis1 loc cs tyVars artificial params body = do ] ++ map pretty bodies' -rankAnalysis :: (MonadTypeChecker m) => SrcLoc -> [Ct] -> TyVars -> M.Map TyVar Type -> [Pat ParamType] -> Exp -> m [(([Ct], TyVars), M.Map TyVar Type, [Pat ParamType], Exp)] -rankAnalysis _ [] tyVars artificial params body = pure [(([], tyVars), artificial, params, body)] +rankAnalysis :: + (MonadTypeChecker m) => + SrcLoc -> + [Ct] -> + TyVars -> + M.Map TyVar Type -> + [Pat ParamType] -> + Exp -> + m [(([Ct], M.Map TyVar Type, TyVars), [Pat ParamType], Exp)] +rankAnalysis _ [] tyVars artificial params body = + pure [(([], artificial, tyVars), params, body)] rankAnalysis loc cs tyVars artificial params body = do rank_maps <- solveRankILP loc (mkLinearProg cs' tyVars) - cts_tyvars' <- mapM (substRankInfo cs tyVars) rank_maps + cts_tyvars' <- mapM (substRankInfo cs artificial tyVars) rank_maps let bodys = map (`updAM` body) rank_maps params' = map ((`map` params) . updAMPat) rank_maps - artificial' <- mapM (substRankInfoArtificial tyVars artificial) rank_maps - pure $ L.zip4 cts_tyvars' artificial' params' bodys + pure $ zip3 cts_tyvars' params' bodys where cs' = foldMap (splitFuncs . distribute) cs splitFuncs @@ -255,23 +271,22 @@ rankAnalysis loc cs tyVars artificial params body = do t2r' = t2r `setUniqueness` NoUniqueness splitFuncs c = [c] -substRankInfo :: (MonadTypeChecker m) => [Ct] -> TyVars -> Map VName Int -> m ([Ct], TyVars) -substRankInfo cs tyVars rankmap = do - (cs', new_cs, new_tyVars) <- +substRankInfo :: + (MonadTypeChecker m) => + [Ct] -> + M.Map VName Type -> + TyVars -> + Map VName Int -> + m ([Ct], M.Map VName Type, TyVars) +substRankInfo cs artificial tyVars rankmap = do + ((cs', artificial'), new_cs, new_tyVars) <- runSubstT tyVars rankmap $ - substRanks $ - filter (not . isCtAM) cs - pure (cs' <> new_cs, new_tyVars <> tyVars) + (,) <$> substRanks (filter (not . isCtAM) cs) <*> traverse substRanks artificial + pure (cs' <> new_cs, artificial', new_tyVars <> tyVars) where isCtAM (CtAM {}) = True isCtAM _ = False -substRankInfoArtificial :: (MonadTypeChecker m) => TyVars -> M.Map VName Type -> Map VName Int -> m (M.Map VName Type) -substRankInfoArtificial tyvars artificial rankmap = do - (artificial', _, _) <- - runSubstT tyvars rankmap $ traverse substRanks artificial - pure artificial' - runSubstT :: (MonadTypeChecker m) => TyVars -> Map VName Int -> SubstT m a -> m (a, [Ct], TyVars) runSubstT tyVars rankmap (SubstT m) = do let env = diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 6839cd3d29..371f139f99 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -524,7 +524,7 @@ checkPat' p@(RecordPat p_fs loc) (Ascribed t) | otherwise = do p_fs' <- traverse (const $ newType loc "t" NoUniqueness) $ M.fromList p_fs ctEq (Reason (locOf loc)) (Scalar (Record p_fs')) t - checkPat' p $ Ascribed $ const Observe <$> Scalar (Record p_fs') + checkPat' p $ Ascribed $ Observe <$ Scalar (Record p_fs') where check t_fs = traverse (uncurry checkPat') $ @@ -571,7 +571,7 @@ checkPat' (PatConstr n NoInfo ps loc) (Ascribed (Scalar (Sum cs))) <+> pretty (length ts) <+> "arguments." ps' <- zipWithM checkPat' ps $ map Ascribed ts - cs' <- traverse (mapM (asStructType)) cs + cs' <- traverse (mapM asStructType) cs pure $ PatConstr n (Info (Scalar (Sum cs'))) ps' loc checkPat' (PatConstr n NoInfo ps loc) (Ascribed t) = do ps' <- forM ps $ \p -> do @@ -1340,7 +1340,7 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do onRankSolution retdecl' =<< rankAnalysis1 loc cts tyvars artificial params' body' where - onRankSolution retdecl' ((cts', tyvars'), artificial, params', body'') = do + onRankSolution retdecl' ((cts', artificial, tyvars'), params', body'') = do solution <- bitraverse pure @@ -1378,7 +1378,8 @@ checkSingleExp e = runTermM $ do cts <- gets termConstraints tyvars <- gets termTyVars artificial <- gets termArtificial - ((cts', tyvars'), _, _, e'') <- rankAnalysis1 (srclocOf e') cts tyvars artificial [] e' + ((cts', artificial', tyvars'), _, e'') <- + rankAnalysis1 (srclocOf e') cts tyvars artificial [] e' case solve cts' tyvars' of Left err -> pure (Left err, e'') Right (unconstrained, solution) -> do @@ -1397,11 +1398,11 @@ checkSizeExp e = runTermM $ do tyvars <- gets termTyVars artificial <- gets termArtificial - (cts_tyvars', _, _, es') <- L.unzip4 <$> rankAnalysis (srclocOf e) cts tyvars artificial [] e' + (cts_tyvars', _, es') <- unzip3 <$> rankAnalysis (srclocOf e) cts tyvars artificial [] e' solutions <- - forM cts_tyvars' $ - bitraverse pure (traverse (doDefaults mempty)) . uncurry solve + forM cts_tyvars' $ \(cts', artificial', tyvars') -> + bitraverse pure (traverse (doDefaults mempty)) $ solve cts' tyvars' case (solutions, es') of ([solution], [e'']) -> From f16ee6aa371192a50adce27fe5f795e37b2131dc Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 12 Mar 2024 17:37:08 +0100 Subject: [PATCH 195/258] More cleanup. --- src/Language/Futhark/TypeChecker/Terms2.hs | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 371f139f99..5917860e60 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -1014,8 +1014,8 @@ checkExp (OpSectionRight op _ e _ NoInfo loc) = do checkExp (ProjectSection fields NoInfo loc) = do a <- newType loc "a" NoUniqueness b <- newType loc "b" NoUniqueness - mustHaveFields loc (toType a) fields (toType b) - let ft = Scalar $ Arrow mempty Unnamed Observe a $ RetType [] $ b `setUniqueness` Nonunique + mustHaveFields loc a fields b + ft <- asStructType $ Scalar $ Arrow mempty Unnamed Observe a $ RetType [] $ b `setUniqueness` Nonunique pure $ ProjectSection fields (Info ft) loc -- checkExp (Lambda params body retdecl NoInfo loc) = do @@ -1047,9 +1047,8 @@ checkExp (AppExp (LetFun name (tparams, params, retdecl, NoInfo, e) body loc) _) bindParams tparams params $ \params' -> do e' <- checkExp e e_t <- expType e' - let ret = fmap (const Nonunique) e_t retdecl' <- checkRetDecl e' retdecl - pure (tparams, params', retdecl', ret, e') + pure (tparams, params', retdecl', fmap (const Nonunique) e_t, e') params'' <- mapM (traverse asType) params' @@ -1113,10 +1112,9 @@ checkExp (IndexSection slice NoInfo loc) = do index_elem_t <- newType loc "index_elem" NoUniqueness index_res_t <- newType loc "index_res" NoUniqueness let num_slices = length $ filter isSlice slice - ctEq (Reason (locOf loc)) (toType index_arg_t) $ arrayOfRank num_slices index_elem_t + ctEq (Reason (locOf loc)) index_arg_t $ arrayOfRank num_slices index_elem_t ctEq (Reason (locOf loc)) index_res_t $ arrayOfRank (length slice) index_elem_t - index_res_t' <- asStructType index_res_t - let ft = Scalar $ Arrow mempty Unnamed Observe index_arg_t $ toResRet Nonunique $ RetType [] index_res_t' + ft <- asStructType $ Scalar $ Arrow mempty Unnamed Observe index_arg_t $ second (const Nonunique) $ RetType [] index_res_t pure $ IndexSection slice' (Info ft) loc -- checkExp (AppExp (Index e slice loc) _) = do From 74be1732e22138842e86847a9a8f5b527e65318a Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 12 Mar 2024 17:38:23 +0100 Subject: [PATCH 196/258] Consistency. --- src/Language/Futhark/TypeChecker/Terms2.hs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 5917860e60..aceabbbe22 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -888,10 +888,9 @@ checkExp (RecordLit fs loc) = RecordFieldExplicit f <$> lift (checkExp e) <*> pure rloc checkField (RecordFieldImplicit name NoInfo rloc) = do errIfAlreadySet (baseName name) rloc - t <- lift $ lookupVar rloc $ qualName name + t <- lift $ asStructType =<< lookupVar rloc (qualName name) modify $ M.insert (baseName name) rloc - t' <- lift $ asStructType t - pure $ RecordFieldImplicit name (Info t') rloc + pure $ RecordFieldImplicit name (Info t) rloc errIfAlreadySet f rloc = do maybe_sloc <- gets $ M.lookup f From 64f6a86174e995e58552a1e01ef6efed73b5d276 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Tue, 12 Mar 2024 19:41:27 -0700 Subject: [PATCH 197/258] Fix crashing when LP objective is a constant. This fix is jank(ish). --- src/Futhark/Solve/GLPK.hs | 10 +++++++--- src/Futhark/Solve/LP.hs | 4 ++++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/Futhark/Solve/GLPK.hs b/src/Futhark/Solve/GLPK.hs index b2d340d683..b4b0b4602b 100644 --- a/src/Futhark/Solve/GLPK.hs +++ b/src/Futhark/Solve/GLPK.hs @@ -48,8 +48,12 @@ glpk lp = do res glpk' :: (Ord v, Real a) => F.LinearProg v a -> IO (Maybe (Int, M.Map v Int)) -glpk' lp = do - (_, mres) <- glpSolveVars opts $ linearProgToGLPK lp - pure $ (\(opt, vs) -> (truncate opt, fmap truncate vs)) <$> mres +glpk' lp + | F.isConstant (F.objective lp) -- FIXME + = + pure $ pure (0, M.fromList $ map (,0) $ S.toList $ F.vars lp) + | otherwise = do + (_, mres) <- glpSolveVars opts $ linearProgToGLPK lp + pure $ (\(opt, vs) -> (truncate opt, fmap truncate vs)) <$> mres where opts = mipDefaults {msgLev = MsgAll} diff --git a/src/Futhark/Solve/LP.hs b/src/Futhark/Solve/LP.hs index f1b7d18939..47804b738b 100644 --- a/src/Futhark/Solve/LP.hs +++ b/src/Futhark/Solve/LP.hs @@ -28,6 +28,7 @@ module Futhark.Solve.LP (~<=~), (~>=~), rowEchelonLPE, + isConstant, ) where @@ -109,6 +110,9 @@ instance (IsName v, Pretty a, Eq a, Num a) => Pretty (LSum v a) where ) $ Map.toList m +isConstant :: (Ord v) => LSum v a -> Bool +isConstant (LSum m) = Map.keysSet m `S.isSubsetOf` S.singleton Nothing + instance Functor (LSum v) where fmap f (LSum m) = LSum $ fmap f m From 5666c2b20ade746507e5bdbb83e5e333019d37b7 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Tue, 12 Mar 2024 19:58:39 -0700 Subject: [PATCH 198/258] Fix tuples/records in rank analysis. --- src/Language/Futhark/TypeChecker/Rank.hs | 94 ++++++++++++++++-------- 1 file changed, 64 insertions(+), 30 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index c74fad677c..2f2bbcda6b 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -57,23 +57,63 @@ instance Rank Type where rank (Scalar t) = rank t rank (Array _ shape t) = rank shape ~+~ rank t -class Distribute a where - distribute :: a -> a - -instance Distribute (TypeBase dim u) where - distribute (Array u s (Arrow _ _ _ ta (RetType rd tr))) = - Scalar $ - Arrow - u - Unnamed - mempty - (arrayOf s ta) - (RetType rd $ distribute (arrayOfWithAliases Nonunique s tr)) - distribute t = t - -instance Distribute Ct where - distribute (CtEq r t1 t2) = CtEq r (distribute t1) (distribute t2) - distribute c = c +distribAndSplitArrows :: Ct -> [Ct] +distribAndSplitArrows (CtEq r t1 t2) = + splitArrows $ CtEq r (distribute t1) (distribute t2) + where + distribute :: TypeBase dim as -> TypeBase dim as + distribute (Array u s (Arrow _ _ _ ta (RetType rd tr))) = + Scalar $ + Arrow + u + Unnamed + mempty + (arrayOf s ta) + (RetType rd $ distribute $ arrayOfWithAliases Nonunique s tr) + distribute t = t + + splitArrows + ( CtEq + reason + (Scalar (Arrow _ _ _ t1a (RetType _ t1r))) + (Scalar (Arrow _ _ _ t2a (RetType _ t2r))) + ) = + splitArrows (CtEq reason t1a t2a) ++ splitArrows (CtEq reason t1r' t2r') + where + t1r' = t1r `setUniqueness` NoUniqueness + t2r' = t2r `setUniqueness` NoUniqueness + splitArrows c = [c] +distribAndSplitArrows ct = [ct] + +distribAndSplitCnstrs :: Ct -> [Ct] +distribAndSplitCnstrs ct@(CtEq r t1 t2) = + ct : splitCnstrs (CtEq r (distribute1 t1) (distribute1 t2)) + where + distribute1 :: TypeBase dim as -> TypeBase dim as + distribute1 (Array u s (Record ts1)) = + Scalar $ Record $ fmap (arrayOfWithAliases u s) ts1 + distribute1 t = t + + splitCnstrs (CtEq reason (Scalar (Record ts1)) (Scalar (Record ts2))) = + concat $ zipWith (\x y -> distribAndSplitCnstrs $ CtEq reason x y) (M.elems ts1) (M.elems ts2) + splitCnstrs c = [] +distribAndSplitCnstrs ct = [ct] + +distributeOverCnstrs :: Ct -> [Ct] +distributeOverCnstrs ct@(CtEq r t1 t2) = + [ct, CtEq r t1' t2'] + where + -- case (t1', t2') of + -- (Nothing, Nothing) -> [ct] + -- _ -> [ct, CtEq r (fromMaybe t1 t1') (fromMaybe t2 t2')] + + distribute :: TypeBase dim as -> TypeBase dim as + distribute (Array u s (Record ts1)) = + Scalar $ Record $ fmap (distribute . arrayOfWithAliases u s) ts1 + distribute t = t + t1' = distribute t1 + t2' = distribute t2 +distributeOverCnstrs c = [c] data RankState = RankState { rankBinVars :: Map VName VName, @@ -258,18 +298,9 @@ rankAnalysis loc cs tyVars artificial params body = do params' = map ((`map` params) . updAMPat) rank_maps pure $ zip3 cts_tyvars' params' bodys where - cs' = foldMap (splitFuncs . distribute) cs - splitFuncs - ( CtEq - reason - (Scalar (Arrow _ _ _ t1a (RetType _ t1r))) - (Scalar (Arrow _ _ _ t2a (RetType _ t2r))) - ) = - splitFuncs (CtEq reason t1a t2a) ++ splitFuncs (CtEq reason t1r' t2r') - where - t1r' = t1r `setUniqueness` NoUniqueness - t2r' = t2r `setUniqueness` NoUniqueness - splitFuncs c = [c] + cs' = + foldMap distribAndSplitCnstrs $ + foldMap distribAndSplitArrows cs substRankInfo :: (MonadTypeChecker m) => @@ -331,7 +362,7 @@ newTyVar :: (MonadTypeChecker m) => TyVar -> SubstT m TyVar newTyVar t = do t' <- lift $ newTypeName (baseName t) shape <- rankToShape t - loc <- (locOf . snd . fromJust . (M.!? t)) <$> asks envTyVars + loc <- asks ((locOf . snd . fromJust . (M.!? t)) . envTyVars) modify $ \s -> s { substNewVars = M.insert t t' $ substNewVars s, @@ -399,6 +430,9 @@ instance SubstRanks (TypeBase SComp u) where shape' <- substRanks shape t' <- substRanks $ Scalar t pure $ arrayOfWithAliases u shape' t' + substRanks (Scalar (Record fs)) = do + fs' <- mapM substRanks fs + pure $ Scalar $ Record fs' substRanks t = pure t instance SubstRanks Ct where From 7039dcc5603a1e820ddbcefda8cbc2074457eea2 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Tue, 12 Mar 2024 21:05:59 -0700 Subject: [PATCH 199/258] Delete this. --- src/Language/Futhark/TypeChecker/Rank.hs | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 2f2bbcda6b..939b1dfb9f 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -99,22 +99,6 @@ distribAndSplitCnstrs ct@(CtEq r t1 t2) = splitCnstrs c = [] distribAndSplitCnstrs ct = [ct] -distributeOverCnstrs :: Ct -> [Ct] -distributeOverCnstrs ct@(CtEq r t1 t2) = - [ct, CtEq r t1' t2'] - where - -- case (t1', t2') of - -- (Nothing, Nothing) -> [ct] - -- _ -> [ct, CtEq r (fromMaybe t1 t1') (fromMaybe t2 t2')] - - distribute :: TypeBase dim as -> TypeBase dim as - distribute (Array u s (Record ts1)) = - Scalar $ Record $ fmap (distribute . arrayOfWithAliases u s) ts1 - distribute t = t - t1' = distribute t1 - t2' = distribute t2 -distributeOverCnstrs c = [c] - data RankState = RankState { rankBinVars :: Map VName VName, rankCounter :: !Int, From e5b4f10018589e2d47dce1fb5d3e3b918e950140 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Tue, 12 Mar 2024 21:45:45 -0700 Subject: [PATCH 200/258] Add sum type support. --- src/Language/Futhark/TypeChecker/Rank.hs | 41 +++++++++++++++++++----- 1 file changed, 33 insertions(+), 8 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 939b1dfb9f..42c440dd26 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -92,11 +92,16 @@ distribAndSplitCnstrs ct@(CtEq r t1 t2) = distribute1 :: TypeBase dim as -> TypeBase dim as distribute1 (Array u s (Record ts1)) = Scalar $ Record $ fmap (arrayOfWithAliases u s) ts1 + distribute1 (Array u s (Sum cs)) = + Scalar $ Sum $ (fmap . fmap) (arrayOfWithAliases u s) cs distribute1 t = t + -- FIXME. Should check for key set equality here. splitCnstrs (CtEq reason (Scalar (Record ts1)) (Scalar (Record ts2))) = concat $ zipWith (\x y -> distribAndSplitCnstrs $ CtEq reason x y) (M.elems ts1) (M.elems ts2) - splitCnstrs c = [] + splitCnstrs (CtEq reason (Scalar (Sum cs1)) (Scalar (Sum cs2))) = + concat $ concat $ (zipWith . zipWith) (\x y -> distribAndSplitCnstrs $ CtEq reason x y) (M.elems cs1) (M.elems cs2) + splitCnstrs _ = [] distribAndSplitCnstrs ct = [ct] data RankState = RankState @@ -276,6 +281,14 @@ rankAnalysis :: rankAnalysis _ [] tyVars artificial params body = pure [(([], artificial, tyVars), params, body)] rankAnalysis loc cs tyVars artificial params body = do + debugTraceM 3 $ + unlines $ + [ "##rankAnalysis", + "cs:", + unlines $ map prettyString cs, + "cs':", + unlines $ map prettyString cs' + ] rank_maps <- solveRankILP loc (mkLinearProg cs' tyVars) cts_tyvars' <- mapM (substRankInfo cs artificial tyVars) rank_maps let bodys = map (`updAM` body) rank_maps @@ -294,10 +307,10 @@ substRankInfo :: Map VName Int -> m ([Ct], M.Map VName Type, TyVars) substRankInfo cs artificial tyVars rankmap = do - ((cs', artificial'), new_cs, new_tyVars) <- + ((cs', artificial', tyVars'), new_cs, new_tyVars) <- runSubstT tyVars rankmap $ - (,) <$> substRanks (filter (not . isCtAM) cs) <*> traverse substRanks artificial - pure (cs' <> new_cs, artificial', new_tyVars <> tyVars) + (,,) <$> substRanks (filter (not . isCtAM) cs) <*> traverse substRanks artificial <*> traverse substRanks tyVars + pure (cs' <> new_cs, artificial', new_tyVars <> tyVars') where isCtAM (CtAM {}) = True isCtAM _ = False @@ -409,20 +422,32 @@ instance SubstRanks (TypeBase SComp u) where ta' <- substRanks ta tr' <- substRanks tr pure $ Scalar (Arrow u p d ta' (RetType retdims tr')) - substRanks (Scalar (Record fs)) = Scalar . Record <$> traverse substRanks fs + substRanks (Scalar (Record fs)) = + Scalar . Record <$> traverse substRanks fs + substRanks (Scalar (Sum cs)) = + Scalar . Sum <$> (traverse . traverse) substRanks cs substRanks (Array u shape t) = do shape' <- substRanks shape t' <- substRanks $ Scalar t pure $ arrayOfWithAliases u shape' t' - substRanks (Scalar (Record fs)) = do - fs' <- mapM substRanks fs - pure $ Scalar $ Record fs' substRanks t = pure t instance SubstRanks Ct where substRanks (CtEq r t1 t2) = CtEq r <$> substRanks t1 <*> substRanks t2 substRanks _ = error "" +instance SubstRanks TyVarInfo where + substRanks tv@TyVarFree {} = pure tv + substRanks tv@TyVarPrim {} = pure tv + substRanks (TyVarRecord loc fs) = + TyVarRecord loc <$> traverse substRanks fs + substRanks (TyVarSum loc cs) = + TyVarSum loc <$> (traverse . traverse) substRanks cs + substRanks tv@TyVarEql {} = pure tv + +instance SubstRanks (Int, TyVarInfo) where + substRanks (lvl, tv) = (lvl,) <$> substRanks tv + updAM :: Map VName Int -> Exp -> Exp updAM rank_map e = case e of From 62ddb1d194a6f0a54c2c7a633b98740c63d9cfd9 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 14 Mar 2024 12:47:47 +0100 Subject: [PATCH 201/258] Respect return type annotations. --- src/Language/Futhark/TypeChecker/Terms2.hs | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index aceabbbe22..cfa88069ca 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -816,14 +816,14 @@ instance Pretty (Unmatched (Pat StructType)) where checkRetDecl :: Exp -> Maybe (TypeExp (ExpBase NoInfo VName) VName) -> - TermM (Maybe (TypeExp Exp VName)) -checkRetDecl _ Nothing = pure Nothing + TermM (Type, Maybe (TypeExp Exp VName)) +checkRetDecl body Nothing = (,Nothing) <$> expType body checkRetDecl body (Just te) = do (te', _, RetType _ st, _) <- checkTypeExp checkSizeExp' te body_t <- expType body st' <- asType st ctEq (Reason (locOf body)) body_t st' - pure $ Just te' + pure (second (const NoUniqueness) st', Just te') checkExp :: ExpBase NoInfo VName -> TermM (ExpBase Info VName) -- @@ -1020,10 +1020,9 @@ checkExp (ProjectSection fields NoInfo loc) = do checkExp (Lambda params body retdecl NoInfo loc) = do bindParams [] params $ \params' -> do body' <- checkExp body - body_t <- expType body' + (body_t, retdecl') <- checkRetDecl body' retdecl body_t' <- asStructType body_t - retdecl' <- checkRetDecl body' retdecl let ret = RetType [] $ toRes Nonunique body_t' pure $ Lambda params' body' retdecl' (Info ret) loc -- @@ -1045,8 +1044,7 @@ checkExp (AppExp (LetFun name (tparams, params, retdecl, NoInfo, e) body loc) _) (tparams', params', retdecl', rettype, e') <- bindParams tparams params $ \params' -> do e' <- checkExp e - e_t <- expType e' - retdecl' <- checkRetDecl e' retdecl + (e_t, retdecl') <- checkRetDecl e' retdecl pure (tparams, params', retdecl', fmap (const Nonunique) e_t, e') params'' <- mapM (traverse asType) params' @@ -1313,7 +1311,7 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do (params', body', retdecl') <- bindParams tparams params $ \params' -> do body' <- checkExp body - retdecl' <- checkRetDecl body' retdecl + (_, retdecl') <- checkRetDecl body' retdecl pure (params', body', retdecl') cts <- gets termConstraints From fc5f05dc2f33b6beb8934c8cdfd684e5654b7787 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 14 Mar 2024 15:58:23 +0100 Subject: [PATCH 202/258] Lovely code. --- tests/automap/mri-q.fut | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/automap/mri-q.fut b/tests/automap/mri-q.fut index f53b5df7a6..270e18195a 100644 --- a/tests/automap/mri-q.fut +++ b/tests/automap/mri-q.fut @@ -38,4 +38,4 @@ entry main [numK][numX] (phiR: [numK]f32) (phiI: [numK]f32) = let (qr, qi) = main_orig kx ky kz x y z phiR phiI let (qr_am, qi_am) = main_am kx ky kz x y z phiR phiI - in and (map2 (==) qr qr_am && map2 (==) qi qi_am) + in and (qr == qr_am && qi == qi_am) From 6f3e32d4ccdb77c84bddfee32f42b743f523835c Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 15 Mar 2024 11:03:33 +0100 Subject: [PATCH 203/258] Update type annotation here. --- src/Language/Futhark/TypeChecker/Terms/Pat.hs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms/Pat.hs b/src/Language/Futhark/TypeChecker/Terms/Pat.hs index 2a648ea8ab..e365cd97fc 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Pat.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Pat.hs @@ -176,8 +176,9 @@ checkPat' sizes (PatAscription p t loc) maybe_outer_t = do <$> checkPat' sizes p (Ascribed (resToParam st)) <*> pure t' <*> pure loc -checkPat' _ (PatLit l info loc) _ = - pure $ PatLit l info loc +checkPat' _ (PatLit l (Info t) loc) _ = do + t' <- replaceTyVars loc t + pure $ PatLit l (Info t') loc checkPat' sizes (PatConstr n info ps loc) NoneInferred = do ps' <- mapM (\p -> checkPat' sizes p NoneInferred) ps pure $ PatConstr n info ps' loc From bf64e471a8e5cbe36138c8deacb77d9679cfa70d Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 15 Mar 2024 11:53:24 +0100 Subject: [PATCH 204/258] Handle equality case too. --- src/Language/Futhark/TypeChecker/Constraints.hs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 4a7d30ae98..8c6c3ca3ef 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -197,7 +197,9 @@ linkTyVar reason v t = do modify $ \s -> s {solverTyVars = M.insert v (TyVarLink t) $ solverTyVars s} tyvars' <- case (M.lookup v tyvars, M.lookup t tyvars) of - (Just (TyVarUnsol _ info), Just (TyVarUnsol lvl (TyVarFree _))) -> + (Just (TyVarUnsol _ info), Just (TyVarUnsol lvl TyVarFree {})) -> + pure $ M.insert t (TyVarUnsol lvl info) tyvars + (Just (TyVarUnsol _ info@TyVarPrim {}), Just (TyVarUnsol lvl TyVarEql {})) -> pure $ M.insert t (TyVarUnsol lvl info) tyvars -- TODO: handle more cases. _ -> pure tyvars From 5941f88bf68ea4aa0f2cd15c5729ad116881689b Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 15 Mar 2024 17:05:16 +0100 Subject: [PATCH 205/258] Do not impose inferred type on polymorphic functions. This is necessary to properly handle polymorphic higher order functions that are passed functions with existential return sizes. --- .../Futhark/TypeChecker/Terms/Monad.hs | 27 ++++++++++--------- src/Language/Futhark/TypeChecker/Unify.hs | 4 +-- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms/Monad.hs b/src/Language/Futhark/TypeChecker/Terms/Monad.hs index 42c5d53ab2..b8371cbfe8 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Monad.hs @@ -440,31 +440,32 @@ instTyVars loc names orig_t1 orig_t2 = do evalStateT (f orig_t1 orig_t2) mempty --- | Instantiate a type scheme with fresh size variables for its size --- parameters. Replaces type parameters with their known --- instantiations. Returns the names of the fresh size variables and --- the instantiated type. +-- | Instantiate a type scheme with fresh variables for its size and +-- type parameters. Returns the names of the fresh size and type +-- variables and the instantiated type. instTypeScheme :: QualName VName -> SrcLoc -> [TypeParam] -> StructType -> - TypeBase () NoUniqueness -> TermTypeM ([VName], StructType) -instTypeScheme qn loc tparams scheme_t inferred = do - (names, substs) <- fmap (unzip . catMaybes) . forM tparams $ \tparam -> do +instTypeScheme qn loc tparams scheme_t = do + (names, substs) <- fmap unzip . forM tparams $ \tparam -> do case tparam of - TypeParamType {} -> pure Nothing + TypeParamType l v _ -> do + i <- incCounter + v' <- newID $ mkTypeVarName (baseName v) i + constrain v' . NoConstraint l . mkUsage loc . docText $ + "instantiated type parameter of " <> dquotes (pretty qn) + pure (v', (v, Subst [] $ RetType [] $ Scalar $ TypeVar mempty (qualName v') [])) TypeParamDim v _ -> do i <- incCounter v' <- newID $ mkTypeVarName (baseName v) i constrain v' . Size Nothing . mkUsage loc . docText $ "instantiated size parameter of " <> dquotes (pretty qn) - pure $ Just (v', (v, ExpSubst $ sizeFromName (qualName v') loc)) + pure (v', (v, ExpSubst $ sizeFromName (qualName v') loc)) - let tp_names = map typeParamName $ filter isTypeParam tparams - t' <- instTyVars loc tp_names inferred $ applySubst (`lookup` substs) scheme_t - pure (names, t') + pure (names, applySubst (`lookup` substs) scheme_t) lookupQualNameEnv :: QualName VName -> TermTypeM TermScope lookupQualNameEnv (QualName [q] _) @@ -541,7 +542,7 @@ lookupVar loc qn@(QualName qs name) inst_t = do if null tparams && null qs then pure bound_t else do - (tnames, t) <- instTypeScheme qn loc tparams bound_t $ first (const ()) inst_t + (tnames, t) <- instTypeScheme qn loc tparams bound_t outer_env <- asks termOuterEnv pure $ qualifyTypeVars outer_env tnames qs t Just EqualityF -> diff --git a/src/Language/Futhark/TypeChecker/Unify.hs b/src/Language/Futhark/TypeChecker/Unify.hs index 4493b02b2d..259f62ed0e 100644 --- a/src/Language/Futhark/TypeChecker/Unify.hs +++ b/src/Language/Futhark/TypeChecker/Unify.hs @@ -230,7 +230,7 @@ prettySource ctx loc RigidCoerce = <+> pretty (locStrRel ctx loc) <> "." prettySource _ _ RigidUnify = - "is an artificial size invented during unification of functions with anonymous sizes." + textwrap "is an artificial size invented during unification of functions with anonymous sizes." prettySource ctx loc (RigidCond t1 t2) = "is unknown due to conditional expression at " <> pretty (locStrRel ctx loc) @@ -514,7 +514,7 @@ unifySizes usage bcs bound nonrigid e1 (Var v2 _ _) not (anyBound bound e1) || (qualLeaf v2 `elem` bound) = linkVarToDim usage bcs (qualLeaf v2) lvl2 e1 unifySizes usage bcs _ _ e1 e2 = do - notes <- (<>) <$> dimNotes usage e2 <*> dimNotes usage e2 + notes <- (<>) <$> dimNotes usage e1 <*> dimNotes usage e2 unifyError usage notes bcs $ "Sizes" <+> dquotes (pretty e1) From 64637d06eadd31ff898eeb428e78a2385d76a8a8 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 18 Mar 2024 09:59:22 +0100 Subject: [PATCH 206/258] Fix implicit record fields. --- src/Language/Futhark/TypeChecker/Terms.hs | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index b3afad3954..081e9c0702 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -343,6 +343,9 @@ unscopeType tloc unscoped = sizeFree tloc $ find (`elem` unscoped) . fvVars . freeInExp checkExp :: Exp -> TermTypeM Exp +checkExp (Var qn (Info t) loc) = do + t' <- lookupVar loc qn t + pure $ Var qn (Info t') loc checkExp (Literal val loc) = pure $ Literal val loc checkExp (Hole (Info t) loc) = do @@ -363,8 +366,9 @@ checkExp (RecordLit fs loc) = where checkField (RecordFieldExplicit f e rloc) = RecordFieldExplicit f <$> checkExp e <*> pure rloc - checkField (RecordFieldImplicit name (Info t) rloc) = - RecordFieldImplicit name <$> (Info <$> replaceTyVars rloc t) <*> pure rloc + checkField (RecordFieldImplicit name (Info t) rloc) = do + t' <- lookupVar rloc (qualName name) t + pure $ RecordFieldImplicit name (Info t') rloc checkExp (ArrayLit all_es _ loc) = -- Construct the result type and unify all elements with it. We -- only create a type variable for empty arrays; otherwise we use @@ -540,9 +544,6 @@ checkExp (QualParens (modname, modnameloc) e loc) = do ModFun {} -> typeError loc mempty . withIndexLink "module-is-parametric" $ "Module" <+> pretty modname <+> " is a parametric module." -checkExp (Var qn (Info t) loc) = do - t' <- lookupVar loc qn t - pure $ Var qn (Info t') loc checkExp (Negate arg loc) = do arg' <- checkExp arg pure $ Negate arg' loc From 616a6122a69ab2940817bf48009840c580f00bc1 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 18 Mar 2024 11:42:19 +0100 Subject: [PATCH 207/258] Fix type checking of project sections. --- src/Language/Futhark/Prop.hs | 9 +++++++++ src/Language/Futhark/TypeChecker/Terms.hs | 5 +++++ 2 files changed, 14 insertions(+) diff --git a/src/Language/Futhark/Prop.hs b/src/Language/Futhark/Prop.hs index 0507139fdc..4a339ee124 100644 --- a/src/Language/Futhark/Prop.hs +++ b/src/Language/Futhark/Prop.hs @@ -57,6 +57,7 @@ module Language.Futhark.Prop foldFunType, typeVars, isAccType, + recordField, -- * Operations on types peelArray, @@ -251,6 +252,14 @@ diet (Array d _ _) = d diet (Scalar (TypeVar d _ _)) = d diet (Scalar (Sum cs)) = foldl max Observe $ foldMap (map diet) cs +-- | Look up this record field if it exists. +recordField :: [Name] -> TypeBase dim u -> Maybe (TypeBase dim u) +recordField [] t = Just t +recordField (f : fs) (Scalar (Record fts)) + | Just ft <- M.lookup f fts = + recordField fs ft +recordField _ _ = Nothing + -- | Convert any type to one that has rank information, no alias -- information, and no embedded names. toStructural :: diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 081e9c0702..75dd2947ea 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -770,6 +770,11 @@ checkExp (OpSectionRight op (Info op_t) e (_, Info (_, _, _, am)) _ loc) = do "Operator section with invalid operator of type" <+> pretty ftype checkExp (ProjectSection fields (Info t) loc) = do t' <- replaceTyVars loc t + case t' of + Scalar (Arrow _ _ _ t'' (RetType _ rt)) + | Just ft <- recordField fields t'' -> + unify (mkUsage loc "result of projection") ft $ toStruct rt + _ -> error $ "checkExp ProjectSection: " <> show t' pure $ ProjectSection fields (Info t') loc checkExp (IndexSection slice _ loc) = do slice' <- checkSlice slice From ac0472881efd8a3251a6b859b767337f17e60d70 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 27 Mar 2024 13:31:34 +0100 Subject: [PATCH 208/258] Rework type constraint solving. --- .../Futhark/TypeChecker/Constraints.hs | 291 ++++++++++++++---- 1 file changed, 232 insertions(+), 59 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 8c6c3ca3ef..e6759a9ff7 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -17,6 +17,7 @@ where import Control.Monad.Except import Control.Monad.State import Data.Bifunctor +import Data.List qualified as L import Data.Loc import Data.Map qualified as M import Data.Maybe @@ -139,14 +140,14 @@ substTyVar m v = Just (TyVarUnsol {}) -> Nothing Nothing -> Nothing -lookupTyVar :: TyVar -> SolveM (Maybe Type) +lookupTyVar :: TyVar -> SolveM (Int, Either TyVarInfo Type) lookupTyVar orig = do tyvars <- gets solverTyVars let f v = case M.lookup v tyvars of Nothing -> error $ "Unknown tyvar: " <> prettyNameString v - Just (TyVarSol _ t) -> pure $ Just t + Just (TyVarSol lvl t) -> pure (lvl, Right t) Just (TyVarLink v') -> f v' - Just (TyVarUnsol {}) -> pure Nothing + Just (TyVarUnsol lvl info) -> pure (lvl, Left info) f orig -- | A solution maps a type variable to its substitution. This @@ -185,25 +186,229 @@ occursCheck reason v tp = do <+> pretty tp <> "." +unifySharedConstructors :: + Reason -> + M.Map Name [Type] -> + M.Map Name [Type] -> + SolveM () +unifySharedConstructors reason cs1 cs2 = + forM_ (M.toList $ M.intersectionWith (,) cs1 cs2) $ \(c, (ts1, ts2)) -> + if length ts1 == length ts2 + then zipWithM (solveEq reason) ts1 ts2 + else + throwError . TypeError (locOf reason) mempty $ + "Cannot unify type with constructor" + indent 2 (pretty (Sum (M.singleton c ts1))) + "with type of constructor" + indent 2 (pretty (Sum (M.singleton c ts2))) + "because they differ in arity." + +unifySharedFields :: + Reason -> + M.Map Name Type -> + M.Map Name Type -> + SolveM () +unifySharedFields reason fs1 fs2 = + forM_ (M.toList $ M.intersectionWith (,) fs1 fs2) $ \(_f, (ts1, ts2)) -> + solveEq reason ts1 ts2 + +mustSupportEql :: Reason -> Type -> SolveM () +mustSupportEql reason t = pure () + +-- Precondition: 'v' is currently flexible. subTyVar :: Reason -> VName -> Int -> Type -> SolveM () subTyVar reason v lvl t = do occursCheck reason v t - modify $ \s -> s {solverTyVars = M.insert v (TyVarSol lvl t) $ solverTyVars s} - + v_info <- gets $ M.lookup v . solverTyVars + case (v_info, t) of + (Just (TyVarUnsol _ TyVarFree {}), _) -> + pure () + ( Just (TyVarUnsol _ (TyVarPrim _ v_pts)), + _ + ) -> + if t `elem` map (Scalar . Prim) v_pts + then pure () + else + throwError . TypeError (locOf reason) mempty $ + "Cannot unify type that must be one of" + indent 2 (pretty v_pts) + "with" + indent 2 (pretty t) + ( Just (TyVarUnsol _ (TyVarSum _ cs1)), + Scalar (Sum cs2) + ) -> + if all (`elem` M.keys cs2) (M.keys cs1) + then unifySharedConstructors reason cs1 cs2 + else + throwError . TypeError (locOf reason) mempty $ + "Cannot unify type with constructors" + indent 2 (pretty (Sum cs1)) + "with type" + indent 2 (pretty (Sum cs2)) + ( Just (TyVarUnsol _ (TyVarSum _ cs1)), + _ + ) -> + throwError . TypeError (locOf reason) mempty $ + "Cannot unify type with constructors" + indent 2 (pretty (Sum cs1)) + "with type" + indent 2 (pretty t) + ( Just (TyVarUnsol _ (TyVarRecord _ fs1)), + Scalar (Record fs2) + ) -> + if all (`elem` M.keys fs2) (M.keys fs1) + then unifySharedFields reason fs1 fs2 + else + throwError . TypeError (locOf reason) mempty $ + "Cannot unify record type with fields" + indent 2 (pretty (Record fs1)) + "with record type" + indent 2 (pretty (Record fs2)) + ( Just (TyVarUnsol _ (TyVarRecord _ fs1)), + _ + ) -> + throwError . TypeError (locOf reason) mempty $ + "Cannot unify record type with fields" + indent 2 (pretty (Record fs1)) + "with type" + indent 2 (pretty t) + (Just (TyVarUnsol _ (TyVarEql _)), _) -> + mustSupportEql reason t + -- + -- Internal error cases + (Just TyVarSol {}, _) -> + error $ "Type variable already solved: " <> prettyNameString v + (Just TyVarLink {}, _) -> + error $ "Type variable already linked: " <> prettyNameString v + (Nothing, _) -> + error $ "linkTyVar: Nothing v: " <> prettyNameString v + + setInfo v (TyVarSol lvl t) + +setInfo :: TyVar -> TyVarSol -> SolveM () +setInfo v info = modify $ \s -> s {solverTyVars = M.insert v info $ solverTyVars s} + +-- Precondition: 'v' is currently flexible and 't' has no solution. linkTyVar :: Reason -> VName -> VName -> SolveM () linkTyVar reason v t = do occursCheck reason v $ Scalar $ TypeVar NoUniqueness (qualName t) [] - tyvars <- gets solverTyVars - modify $ \s -> s {solverTyVars = M.insert v (TyVarLink t) $ solverTyVars s} - tyvars' <- - case (M.lookup v tyvars, M.lookup t tyvars) of - (Just (TyVarUnsol _ info), Just (TyVarUnsol lvl TyVarFree {})) -> - pure $ M.insert t (TyVarUnsol lvl info) tyvars - (Just (TyVarUnsol _ info@TyVarPrim {}), Just (TyVarUnsol lvl TyVarEql {})) -> - pure $ M.insert t (TyVarUnsol lvl info) tyvars - -- TODO: handle more cases. - _ -> pure tyvars - modify $ \s -> s {solverTyVars = M.insert v (TyVarLink t) tyvars'} + v_info <- gets $ M.lookup v . solverTyVars + (lvl, t') <- lookupTyVar t + case (v_info, t') of + -- When either is completely unconstrained. + (Just (TyVarUnsol _ TyVarFree {}), _) -> + pure () + ( Just (TyVarUnsol _ info), + Left (TyVarFree {}) + ) -> + setInfo t (TyVarUnsol lvl info) + -- + -- TyVarPrim cases + ( Just (TyVarUnsol _ info@TyVarPrim {}), + Left TyVarEql {} + ) -> + setInfo t (TyVarUnsol lvl info) + ( Just (TyVarUnsol _ (TyVarPrim _ v_pts)), + Left (TyVarPrim t_loc t_pts) + ) -> + let pts = L.intersect v_pts t_pts + in if null pts + then + throwError . TypeError (locOf reason) mempty $ + "Cannot unify type that must be one of" + indent 2 (pretty v_pts) + "with type that must be one of" + indent 2 (pretty t_pts) + else setInfo t (TyVarUnsol lvl (TyVarPrim t_loc pts)) + ( Just (TyVarUnsol _ (TyVarPrim _ v_pts)), + Left TyVarRecord {} + ) -> + throwError . TypeError (locOf reason) mempty $ + "Cannot unify type that must be one of" + indent 2 (pretty v_pts) + "with type that must be record." + ( Just (TyVarUnsol _ (TyVarPrim _ v_pts)), + Left TyVarSum {} + ) -> + throwError . TypeError (locOf reason) mempty $ + "Cannot unify type that must be one of" + indent 2 (pretty v_pts) + "with type that must be sum." + -- + -- TyVarSum cases + ( Just (TyVarUnsol _ (TyVarSum _ cs1)), + Left (TyVarSum loc cs2) + ) -> do + unifySharedConstructors reason cs1 cs2 + let cs3 = cs1 <> cs2 + setInfo t (TyVarUnsol lvl (TyVarSum loc cs3)) + ( Just (TyVarUnsol _ TyVarSum {}), + Left (TyVarPrim _ pts) + ) -> + throwError . TypeError (locOf reason) mempty $ + "A sum type cannot be one of" + indent 2 (pretty pts) + ( Just (TyVarUnsol _ (TyVarSum _ cs1)), + Left (TyVarRecord _ fs) + ) -> + throwError . TypeError (locOf reason) mempty $ + "Cannot unify type with constructors" + indent 2 (pretty (Sum cs1)) + "with type" + indent 2 (pretty (Scalar (Record fs))) + ( Just (TyVarUnsol _ (TyVarSum _ cs1)), + Left (TyVarEql _) + ) -> + mapM_ (mapM_ (mustSupportEql reason)) cs1 + -- + -- TyVarRecord cases + ( Just (TyVarUnsol _ (TyVarRecord _ fs1)), + Left (TyVarRecord loc fs2) + ) -> do + unifySharedFields reason fs1 fs2 + let fs3 = fs1 <> fs2 + setInfo t (TyVarUnsol lvl (TyVarRecord loc fs3)) + ( Just (TyVarUnsol _ TyVarRecord {}), + Left (TyVarPrim _ pts) + ) -> + throwError . TypeError (locOf reason) mempty $ + "A record type cannot be one of" + indent 2 (pretty pts) + ( Just (TyVarUnsol _ (TyVarRecord _ fs1)), + Left (TyVarSum _ cs) + ) -> + throwError . TypeError (locOf reason) mempty $ + "Cannot unify record type" + indent 2 (pretty (Record fs1)) + "with type" + indent 2 (pretty (Scalar (Sum cs))) + ( Just (TyVarUnsol _ (TyVarRecord _ fs1)), + Left (TyVarEql _) + ) -> + mapM_ (mustSupportEql reason) fs1 + -- + -- TyVarEql cases + (Just (TyVarUnsol _ (TyVarEql _)), Left TyVarPrim {}) -> + pure () + (Just (TyVarUnsol _ (TyVarEql _)), Left TyVarEql {}) -> + pure () + (Just (TyVarUnsol _ (TyVarEql _)), Left (TyVarRecord _ fs)) -> + mustSupportEql reason $ Scalar $ Record fs + (Just (TyVarUnsol _ (TyVarEql _)), Left (TyVarSum _ cs)) -> + mustSupportEql reason $ Scalar $ Sum cs + -- + -- Internal error cases + (Just TyVarSol {}, _) -> + error $ "Type variable already solved: " <> prettyNameString v + (Just TyVarLink {}, _) -> + error $ "Type variable already linked: " <> prettyNameString v + (Nothing, _) -> + error $ "linkTyVar: Nothing v: " <> prettyNameString v + (_, Right t'') -> + error $ "linkTyVar: rhs " <> prettyNameString t <> " is solved as " <> prettyString t'' + + -- Finally insert the actual link. + setInfo v (TyVarLink t) -- Unify at the root, emitting new equalities that must hold. unify :: Type -> Type -> Maybe [(Type, Type)] @@ -293,60 +498,28 @@ solveCt ct = CtAM {} -> pure () -- Good vibes only. solveTyVar :: (VName, (Int, TyVarInfo)) -> SolveM () -solveTyVar (_, (_, TyVarFree {})) = pure () -solveTyVar (tv, (_, TyVarPrim loc pts)) = do - tv_t <- lookupTyVar tv - case tv_t of - Nothing -> pure () - Just t' - | t' `elem` map (Scalar . Prim) pts -> pure () - | otherwise -> - throwError . TypeError loc mempty $ - "Type must be one of" - indent 2 (pretty pts) - "but inferred to be" - indent 2 (pretty t') solveTyVar (tv, (_, TyVarRecord loc fs1)) = do - tv_t <- lookupTyVar tv + (_, tv_t) <- lookupTyVar tv case tv_t of - Nothing -> + Left _ -> throwError . TypeError loc mempty $ - "Type is ambiguous." + "Type" + <+> prettyName tv + <+> "is ambiguous." "Must be a record with fields" indent 2 (pretty (Scalar (Record fs1))) - Just (Scalar (Record fs2)) - | all (`M.member` fs2) (M.keys fs1) -> - forM_ (M.toList $ M.intersectionWith (,) fs1 fs2) $ \(_k, (t1, t2)) -> - solveCt $ CtEq (Reason loc) t1 t2 - Just tv_t' -> - throwError . TypeError loc mempty $ - "Type must be record with fields" - indent 2 (pretty (Scalar (Record fs1))) - "but inferred to be" - indent 2 (pretty tv_t') + Right _ -> + pure () solveTyVar (tv, (_, TyVarSum loc cs1)) = do - tv_t <- lookupTyVar tv + (_, tv_t) <- lookupTyVar tv case tv_t of - Nothing -> + Left _ -> throwError . TypeError loc mempty $ "Type is ambiguous." "Must be a sum type with constructors" indent 2 (pretty (Scalar (Sum cs1))) - Just (Scalar (Sum cs2)) - | all (`M.member` cs2) (M.keys cs1), - cs3 <- M.toList $ M.intersectionWith (,) cs1 cs2, - all (sameLength . snd) cs3 -> - forM_ cs3 $ \(_k, (t1s, t2s)) -> - mapM_ solveCt $ zipWith (CtEq (Reason loc)) t1s t2s - Just tv_t' -> - throwError . TypeError loc mempty $ - "Type must be sum type with constructors" - indent 2 (pretty (Scalar (Sum cs1))) - "but inferred to be" - indent 2 (pretty tv_t') - where - sameLength (x, y) = length x == length y -solveTyVar (_, (_, TyVarEql _)) = + Right _ -> pure () +solveTyVar (_, _) = pure () solve :: Constraints -> TyVars -> Either TypeError ([VName], Solution) From 2bdca0efcde1e906e86857d10619082bc96b99e3 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sat, 30 Mar 2024 14:15:21 +0100 Subject: [PATCH 209/258] Substitute dependent sizes when expanding automaps. --- src/Futhark/Internalise/Defunctionalise.hs | 6 +++--- src/Futhark/Internalise/FullNormalise.hs | 10 ++++++++-- src/Language/Futhark/Interpreter.hs | 2 +- src/Language/Futhark/Prop.hs | 17 ++++++++++------- src/Language/Futhark/TypeChecker.hs | 2 +- src/Language/Futhark/TypeChecker/Consumption.hs | 2 +- src/Language/Futhark/TypeChecker/Types.hs | 2 +- 7 files changed, 25 insertions(+), 16 deletions(-) diff --git a/src/Futhark/Internalise/Defunctionalise.hs b/src/Futhark/Internalise/Defunctionalise.hs index 82cc845d69..783c374409 100644 --- a/src/Futhark/Internalise/Defunctionalise.hs +++ b/src/Futhark/Internalise/Defunctionalise.hs @@ -851,7 +851,7 @@ unRetType (RetType ext t) = do defuncApplyFunction :: Exp -> Int -> DefM (Exp, StaticVal) defuncApplyFunction e@(Var qn (Info t) loc) num_args = do - let (argtypes, rettype) = unfoldFunType t + let (argtypes, rettype) = first (map snd) $ unfoldFunType t sv <- lookupVar (toStruct t) (qualLeaf qn) case sv of @@ -1001,8 +1001,8 @@ defuncApply f args appres loc = do (argtypes, _) = unfoldFunType $ typeOf f fmap (first $ updateReturn appres) $ foldM (defuncApplyArg fname) (f', f_sv) $ - NE.zip args $ - NE.tails argtypes + NE.zip args . NE.tails . map snd $ + argtypes where intrinsicOrHole e' = do -- If the intrinsic is fully applied, then we are done. diff --git a/src/Futhark/Internalise/FullNormalise.hs b/src/Futhark/Internalise/FullNormalise.hs index 789e0d3c85..2fe87f6244 100644 --- a/src/Futhark/Internalise/FullNormalise.hs +++ b/src/Futhark/Internalise/FullNormalise.hs @@ -499,7 +499,7 @@ transformProg = mapM transformValBind -- | Expands 'AutoMap' annotations into explicit @map@s and @replicates@. expandAMAnnotations :: (MonadFreshNames m) => Exp -> m Exp -expandAMAnnotations e = do +expandAMAnnotations e = case e of (AppExp (Apply f args loc) (Info res)) -> do let ((exts, ams), arg_es) = first unzip $ unzip $ map (first unInfo) $ NE.toList args @@ -511,7 +511,9 @@ expandAMAnnotations e = do case unfoldFunTypeWithRet $ typeOf f' of Nothing -> error "Function type expected." Just (ptypes, f_ret) -> - foldFunType (drop (length args') ptypes) f_ret + let parsubsts = mapMaybe parSub $ zip ptypes args' + in applySubst (`lookup` parsubsts) $ + foldFunType (drop (length args') $ map snd ptypes) f_ret pure $ mkApply f' (zip3 exts (repeat mempty) args') $ res {appResType = rettype} @@ -531,6 +533,10 @@ expandAMAnnotations e = do (Info res {appResType = stripArray (shapeRank $ autoFrame yam) (appResType res)}) _ -> astMap identityMapper {mapOnExp = expandAMAnnotations} e where + parSub ((Named v, Scalar (Prim (Signed Int64))), arg) = + Just (v, ExpSubst arg) + parSub _ = Nothing + setNewType e t = astMap identityMapper {mapOnStructType = const $ pure t} e funDiets :: TypeBase dim as -> [Diet] diff --git a/src/Language/Futhark/Interpreter.hs b/src/Language/Futhark/Interpreter.hs index 506a8b715d..2ee02502c9 100644 --- a/src/Language/Futhark/Interpreter.hs +++ b/src/Language/Futhark/Interpreter.hs @@ -2031,7 +2031,7 @@ checkEntryArgs entry args entry_t "Got input of types" indent 2 (stack (map pretty args_ts)) where - (param_ts, _) = unfoldFunType entry_t + param_ts = map snd $ fst $ unfoldFunType entry_t args_ts = map (valueStructType . valueType) args expected | null param_ts = diff --git a/src/Language/Futhark/Prop.hs b/src/Language/Futhark/Prop.hs index aa284e0122..6cff2d91a4 100644 --- a/src/Language/Futhark/Prop.hs +++ b/src/Language/Futhark/Prop.hs @@ -526,19 +526,22 @@ foldFunType ps ret = -- | Extract the parameter types and return type from a type. -- If the type is not an arrow type, the list of parameter types is empty. -unfoldFunType :: TypeBase dim as -> ([TypeBase dim Diet], TypeBase dim NoUniqueness) -unfoldFunType (Scalar (Arrow _ _ d t1 (RetType _ t2))) = +unfoldFunType :: TypeBase dim as -> ([(PName, TypeBase dim Diet)], TypeBase dim NoUniqueness) +unfoldFunType (Scalar (Arrow _ p d t1 (RetType _ t2))) = let (ps, r) = unfoldFunType t2 - in (second (const d) t1 : ps, r) + in ((p, second (const d) t1) : ps, r) unfoldFunType t = ([], toStruct t) -- | Extract the parameter types and 'RetTypeBase' from a function type. -- If the type is not an arrow type, returns 'Nothing'. -unfoldFunTypeWithRet :: TypeBase dim as -> Maybe ([TypeBase dim Diet], RetTypeBase dim Uniqueness) -unfoldFunTypeWithRet (Scalar (Arrow _ _ d t1 (RetType _ t2@(Scalar Arrow {})))) = do +unfoldFunTypeWithRet :: + TypeBase dim as -> + Maybe ([(PName, TypeBase dim Diet)], RetTypeBase dim Uniqueness) +unfoldFunTypeWithRet (Scalar (Arrow _ p d t1 (RetType _ t2@(Scalar Arrow {})))) = do (ps, r) <- unfoldFunTypeWithRet t2 - pure (second (const d) t1 : ps, r) -unfoldFunTypeWithRet (Scalar (Arrow _ _ d t1 r@RetType {})) = Just ([second (const d) t1], r) + pure ((p, second (const d) t1) : ps, r) +unfoldFunTypeWithRet (Scalar (Arrow _ p d t1 r@RetType {})) = + Just ([(p, second (const d) t1)], r) unfoldFunTypeWithRet _ = Nothing -- | The type scheme of a value binding, comprising the type diff --git a/src/Language/Futhark/TypeChecker.hs b/src/Language/Futhark/TypeChecker.hs index a69722ee2b..f87e280330 100644 --- a/src/Language/Futhark/TypeChecker.hs +++ b/src/Language/Futhark/TypeChecker.hs @@ -690,7 +690,7 @@ checkEntryPoint loc tparams params maybe_tdecl rettype where (RetType _ rettype_t) = rettype (rettype_params, rettype') = unfoldFunType rettype_t - param_ts = map patternType params ++ rettype_params + param_ts = map patternType params ++ map snd rettype_params checkValBind :: ValBindBase NoInfo Name -> TypeM (Env, ValBind) checkValBind vb = do diff --git a/src/Language/Futhark/TypeChecker/Consumption.hs b/src/Language/Futhark/TypeChecker/Consumption.hs index 8c92e54d20..2f194f9c85 100644 --- a/src/Language/Futhark/TypeChecker/Consumption.hs +++ b/src/Language/Futhark/TypeChecker/Consumption.hs @@ -810,7 +810,7 @@ checkExp (AppExp (LetFun fname (typarams, params, te, Info (RetType ext ret), fu -- checkExp (AppExp (BinOp (op, oploc) opt (x, xp) (y, yp) loc) appres) = do op_als <- observeVar (locOf oploc) (qualLeaf op) (unInfo opt) - let at1 : at2 : _ = fst $ unfoldFunType op_als + let (_, at1) : (_, at2) : _ = fst $ unfoldFunType op_als (x', x_als) <- checkArg [] at1 mempty x (y', y_als) <- checkArg [(x', x_als)] at2 mempty y res_als <- checkFuncall loc (Just op) op_als [(mempty, x_als), (mempty, y_als)] diff --git a/src/Language/Futhark/TypeChecker/Types.hs b/src/Language/Futhark/TypeChecker/Types.hs index 02126dbfd8..3070e273ca 100644 --- a/src/Language/Futhark/TypeChecker/Types.hs +++ b/src/Language/Futhark/TypeChecker/Types.hs @@ -60,7 +60,7 @@ mustBeExplicitInBinding :: StructType -> S.Set VName mustBeExplicitInBinding bind_t = let (ts, ret) = unfoldFunType bind_t alsoRet = M.unionWith (&&) $ M.fromList $ map (,True) (S.toList (fvVars (freeInType ret))) - in S.fromList $ M.keys $ M.filter id $ alsoRet $ foldl' onType mempty $ map toStruct ts + in S.fromList $ M.keys $ M.filter id $ alsoRet $ foldl' onType mempty $ map (toStruct . snd) ts where onType uses t = uses <> mustBeExplicitAux t -- Left-biased union. From 06d7232fb10884f1936e85d3904d04573ba3fd80 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sat, 30 Mar 2024 14:31:43 +0100 Subject: [PATCH 210/258] Workaround for wrong return type handling. --- src/Futhark/Internalise/FullNormalise.hs | 36 +++++++++++++----------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/src/Futhark/Internalise/FullNormalise.hs b/src/Futhark/Internalise/FullNormalise.hs index 2fe87f6244..3afecad4bb 100644 --- a/src/Futhark/Internalise/FullNormalise.hs +++ b/src/Futhark/Internalise/FullNormalise.hs @@ -501,22 +501,26 @@ transformProg = mapM transformValBind expandAMAnnotations :: (MonadFreshNames m) => Exp -> m Exp expandAMAnnotations e = case e of - (AppExp (Apply f args loc) (Info res)) -> do - let ((exts, ams), arg_es) = first unzip $ unzip $ map (first unInfo) $ NE.toList args - f' <- expandAMAnnotations f - arg_es' <- mapM expandAMAnnotations arg_es - let diets = funDiets $ typeOf f - withMapNest loc (zip4 exts ams arg_es' diets) $ \args' -> do - let rettype = - case unfoldFunTypeWithRet $ typeOf f' of - Nothing -> error "Function type expected." - Just (ptypes, f_ret) -> - let parsubsts = mapMaybe parSub $ zip ptypes args' - in applySubst (`lookup` parsubsts) $ - foldFunType (drop (length args') $ map snd ptypes) f_ret - pure $ - mkApply f' (zip3 exts (repeat mempty) args') $ - res {appResType = rettype} + (AppExp (Apply f args loc) (Info res)) + | ((exts, ams), arg_es) <- + first unzip $ unzip $ map (first unInfo) $ NE.toList args, + any (/= mempty) ams -> do + f' <- expandAMAnnotations f + arg_es' <- mapM expandAMAnnotations arg_es + let diets = funDiets $ typeOf f + withMapNest loc (zip4 exts ams arg_es' diets) $ \args' -> do + let rettype = + case unfoldFunTypeWithRet $ typeOf f' of + Nothing -> error "Function type expected." + Just (ptypes, f_ret) -> + let parsubsts = mapMaybe parSub $ zip ptypes args' + in applySubst (`lookup` parsubsts) $ + foldFunType (drop (length args') $ map snd ptypes) f_ret + when (appResExt res /= []) $ + error "expandAMAnnotations: cannot handle existential yet." + pure $ + mkApply f' (zip3 exts (repeat mempty) args') $ + res {appResType = rettype} (AppExp (BinOp op (Info t) (x, Info (xext, xam)) (y, Info (yext, yam)) loc) (Info res)) -> do x' <- expandAMAnnotations x y' <- expandAMAnnotations y From a1628f3c120987e781da824aa2ddafd20c560701 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sat, 30 Mar 2024 14:31:50 +0100 Subject: [PATCH 211/258] Handle special case. --- src/Language/Futhark/Interpreter/Values.hs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/Language/Futhark/Interpreter/Values.hs b/src/Language/Futhark/Interpreter/Values.hs index b3fb36ac8c..de3bd2468d 100644 --- a/src/Language/Futhark/Interpreter/Values.hs +++ b/src/Language/Futhark/Interpreter/Values.hs @@ -1,3 +1,5 @@ +{-# LANGUAGE LambdaCase #-} + -- | The value representation used in the interpreter. -- -- Kept simple and free of unnecessary operational details (in @@ -210,7 +212,9 @@ toArray' rowshape vs = ValueArray shape (listArray (0, length vs - 1) vs) -- | Produce multidimensional array from a flat list of values. toArrayR :: [Int64] -> ValueShape -> [Value m] -> Value m -toArrayR [] _ = error "toArrayR: empty shape" +toArrayR [] _ = \case + [v] -> v + _ -> error "toArrayR: empty shape" toArrayR [_] elemshape = toArray' elemshape toArrayR (n : ns) elemshape = toArray (foldr ShapeDim elemshape (n : ns)) From 6ce8ed3258e42618784f97d8f11272cb5fa07ef0 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 2 Apr 2024 12:34:05 +0200 Subject: [PATCH 212/258] Must expand here. --- src/Language/Futhark/Interpreter.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Language/Futhark/Interpreter.hs b/src/Language/Futhark/Interpreter.hs index 2ee02502c9..06acadc6e6 100644 --- a/src/Language/Futhark/Interpreter.hs +++ b/src/Language/Futhark/Interpreter.hs @@ -900,7 +900,7 @@ evalAppExp env (Apply f args loc) = do f' <- eval env f foldM apply' f' args' where - ft = typeOf f + ft = expandType env $ typeOf f apply' f' (v', am') = applyAM loc env (f', ft) am' v' evalArg' (Info (ext, am), x) = evalArg env x ext am evalAppExp env (Index e is loc) = do From 05311de15be790bccd10fea029dc35f82d501f9a Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 1 May 2024 15:11:40 +0200 Subject: [PATCH 213/258] Strangle warnings. --- src/Futhark/Internalise/Exps.hs | 4 +- src/Futhark/Internalise/FullNormalise.hs | 54 +++---- src/Futhark/Solve/BranchAndBound.hs | 3 +- src/Futhark/Solve/GLPK.hs | 5 +- src/Futhark/Solve/LP.hs | 129 +++++++---------- src/Futhark/Solve/Matrix.hs | 18 +-- src/Futhark/Solve/Simplex.hs | 135 +++++++++--------- .../Futhark/TypeChecker/Constraints.hs | 2 +- src/Language/Futhark/TypeChecker/Terms.hs | 4 +- .../Futhark/TypeChecker/Terms/Monad.hs | 3 - src/Language/Futhark/TypeChecker/Terms2.hs | 20 +-- 11 files changed, 153 insertions(+), 224 deletions(-) diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index 1002b1bfa8..1669ecd5e7 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -1583,7 +1583,7 @@ isIntrinsicFunction :: [E.Exp] -> SrcLoc -> Maybe (String -> InternaliseM [SubExp]) -isIntrinsicFunction qname args loc = do +isIntrinsicFunction qname all_args loc = do guard $ baseTag (qualLeaf qname) <= maxIntrinsicTag let handlers = [ handleSign, @@ -1593,7 +1593,7 @@ isIntrinsicFunction qname args loc = do handleAD, handleRest ] - msum [h args $ baseString $ qualLeaf qname | h <- handlers] + msum [h all_args $ baseString $ qualLeaf qname | h <- handlers] where handleSign [x] "sign_i8" = Just $ toSigned I.Int8 x handleSign [x] "sign_i16" = Just $ toSigned I.Int16 x diff --git a/src/Futhark/Internalise/FullNormalise.hs b/src/Futhark/Internalise/FullNormalise.hs index 97f8c58513..06ddb255f0 100644 --- a/src/Futhark/Internalise/FullNormalise.hs +++ b/src/Futhark/Internalise/FullNormalise.hs @@ -23,17 +23,14 @@ module Futhark.Internalise.FullNormalise (transformProg) where import Control.Monad.Reader import Control.Monad.State import Data.Bifunctor -import Data.Functor.Identity import Data.List (zip4) import Data.List.NonEmpty qualified as NE import Data.Map qualified as M import Data.Maybe import Data.Text qualified as T -import Debug.Trace import Futhark.MonadFreshNames import Futhark.Util.Pretty import Language.Futhark -import Language.Futhark.Pretty import Language.Futhark.Primitive (intValue) import Language.Futhark.Traversals import Language.Futhark.TypeChecker.Types @@ -502,14 +499,14 @@ transformProg = mapM transformValBind expandAMAnnotations :: (MonadFreshNames m) => Exp -> m Exp expandAMAnnotations e = case e of - (AppExp (Apply f args loc) (Info res)) + (AppExp (Apply f args _) (Info res)) | ((exts, ams), arg_es) <- first unzip $ unzip $ map (first unInfo) $ NE.toList args, any (/= mempty) ams -> do f' <- expandAMAnnotations f arg_es' <- mapM expandAMAnnotations arg_es let diets = funDiets $ typeOf f - withMapNest loc (zip4 exts ams arg_es' diets) $ \args' -> do + withMapNest (zip4 exts ams arg_es' diets) $ \args' -> do let rettype = case unfoldFunTypeWithRet $ typeOf f' of Nothing -> error "Function type expected." @@ -525,7 +522,7 @@ expandAMAnnotations e = (AppExp (BinOp op (Info t) (x, Info (xext, xam)) (y, Info (yext, yam)) loc) (Info res)) -> do x' <- expandAMAnnotations x y' <- expandAMAnnotations y - withMapNest loc [(xext, xam, x', Observe), (yext, yam, y', Observe)] $ \[x'', y''] -> + withMapNest [(xext, xam, x', Observe), (yext, yam, y', Observe)] $ \[x'', y''] -> pure $ AppExp ( BinOp @@ -542,31 +539,13 @@ expandAMAnnotations e = Just (v, ExpSubst arg) parSub _ = Nothing - setNewType e t = astMap identityMapper {mapOnStructType = const $ pure t} e - funDiets :: TypeBase dim as -> [Diet] funDiets (Scalar (Arrow _ _ d _ (RetType _ t2))) = d : funDiets t2 funDiets _ = [] - dropDims :: Int -> TypeBase dim as -> TypeBase dim as - dropDims n (Scalar (Arrow u p diet t1 (RetType ds t2))) = - Scalar (Arrow u p diet (stripArray n t1) (RetType ds (dropDims n t2))) - dropDims n t = stripArray n t - - innerFType :: TypeBase dim as -> [AutoMap] -> TypeBase dim as - innerFType (Scalar (Arrow u p diet t1 (RetType ds t2))) ams = - Scalar $ Arrow u p diet t1 $ RetType ds $ innerFType' t2 ams - where - innerFType' t [] = t - innerFType' (Scalar (Arrow u p diet t1 (RetType ds t2))) (am : ams) = - Scalar $ Arrow u p diet (dropDims (shapeRank (autoMap am)) t1) $ RetType ds $ innerFType' t2 ams - innerFType' t [am] = dropDims (shapeRank (autoMap am)) t - innerFType' _ _ = error "" - innerFType _ _ = error "" - type Level = Int -data AutoMapArg = AutoMapArg +newtype AutoMapArg = AutoMapArg { amArg :: Exp } deriving (Show) @@ -582,13 +561,12 @@ data AutoMapParam = AutoMapParam withMapNest :: forall m. (MonadFreshNames m) => - SrcLoc -> [(Maybe VName, AutoMap, Exp, Diet)] -> ([Exp] -> m Exp) -> m Exp -withMapNest loc args f = do +withMapNest nest_args f = do (param_map, arg_map) <- - bimap combineMaps combineMaps . unzip <$> mapM buildArgMap args + bimap combineMaps combineMaps . unzip <$> mapM buildArgMap nest_args buildMapNest param_map arg_map $ maximum $ M.keys arg_map where combineMaps :: (Ord k) => [M.Map k v] -> M.Map k [v] @@ -609,17 +587,17 @@ withMapNest loc args f = do args = map amArg $ arg_map M.! l body <- buildMapNest param_map arg_map (l - 1) pure $ - mkMap map_dim params body args $ + mkMap params body args $ RetType [] $ arrayOfWithAliases Nonunique (Shape [map_dim]) (typeOf body) buildArgMap :: (Maybe VName, AutoMap, Exp, Diet) -> m (M.Map Level AutoMapParam, M.Map Level AutoMapArg) - buildArgMap (ext, am, arg, diet) = - foldM (mkArgsAndParams arg) mempty $ reverse [0 .. trueLevel am] + buildArgMap (_ext, am, arg, arg_diet) = + foldM mkArgsAndParams mempty $ reverse [0 .. trueLevel am] where - mkArgsAndParams arg (p_map, a_map) l + mkArgsAndParams (p_map, a_map) l | l == 0 = do let arg' = maybe arg (paramToExp . amParam) (p_map M.!? 1) rarg <- mkReplicateShape (autoRep am `shapePrefix` autoFrame am) arg' @@ -628,7 +606,7 @@ withMapNest loc args f = do p <- mkAMParam (typeOf arg) l let d = outerDim am l pure - ( M.insert l (AutoMapParam p d diet) p_map, + ( M.insert l (AutoMapParam p d arg_diet) p_map, M.insert l (AutoMapArg arg) a_map ) | l < trueLevel am && l > 0 = do @@ -639,7 +617,7 @@ withMapNest loc args f = do amParam $ p_map M.! (l + 1) pure - ( M.insert l (AutoMapParam p d diet) p_map, + ( M.insert l (AutoMapParam p d arg_diet) p_map, M.insert l (AutoMapArg arg') a_map ) | otherwise = error "Impossible." @@ -672,13 +650,13 @@ mkReplicate :: (MonadFreshNames m) => Exp -> Exp -> m Exp mkReplicate dim e = do x <- mkParam "x" (Scalar $ Prim $ Unsigned Int64) pure $ - mkMap dim [(Observe, x)] e [xs] $ + mkMap [(Observe, x)] e [xs] $ RetType mempty (arrayOfWithAliases Unique (Shape [dim]) (typeOf e)) where xs = AppExp ( Range - (Literal (UnsignedValue $ intValue Int64 0) mempty) + (Literal (UnsignedValue $ intValue Int64 (0 :: Int)) mempty) Nothing (UpToExclusive dim) mempty @@ -686,8 +664,8 @@ mkReplicate dim e = do ( Info $ AppRes (arrayOf (Shape [dim]) (Scalar $ Prim $ Unsigned Int64)) [] ) -mkMap :: Exp -> [(Diet, Pat ParamType)] -> Exp -> [Exp] -> ResRetType -> Exp -mkMap dim params body arrs rettype = +mkMap :: [(Diet, Pat ParamType)] -> Exp -> [Exp] -> ResRetType -> Exp +mkMap params body arrs rettype = mkApply mapN args (AppRes (toStruct $ retType rettype) []) where args = map (Nothing,mempty,) $ lambda : arrs diff --git a/src/Futhark/Solve/BranchAndBound.hs b/src/Futhark/Solve/BranchAndBound.hs index 846ae4a59a..258757113b 100644 --- a/src/Futhark/Solve/BranchAndBound.hs +++ b/src/Futhark/Solve/BranchAndBound.hs @@ -5,7 +5,6 @@ import Data.Maybe import Data.Set qualified as S import Data.Vector.Unboxed (Unbox, Vector) import Data.Vector.Unboxed qualified as V -import Debug.Trace import Futhark.Solve.LP (LP (..)) import Futhark.Solve.Matrix import Futhark.Solve.Simplex @@ -54,7 +53,7 @@ branchAndBound prob@(LP _ a d) = (zopt,) <$> mopt -- TODO: use isInt x = x == round x -- requires a better 'rowEchelon' implementation for matrices - isInt x = (abs (fromIntegral (round x) - x)) <= 10 ^^ (-10) + isInt x = abs (fromIntegral (round x :: Int) - x) <= 10 ^^ ((-10) :: Int) mkProblem = M.foldrWithKey ( \idx bound acc -> addBound acc idx bound diff --git a/src/Futhark/Solve/GLPK.hs b/src/Futhark/Solve/GLPK.hs index b4b0b4602b..5c8f40fcd8 100644 --- a/src/Futhark/Solve/GLPK.hs +++ b/src/Futhark/Solve/GLPK.hs @@ -1,6 +1,7 @@ module Futhark.Solve.GLPK (glpk) where import Control.Monad +import Data.Bifunctor import Data.LinearProgram import Data.Map qualified as M import Data.Maybe @@ -8,7 +9,7 @@ import Data.Set qualified as S import Futhark.Solve.LP qualified as F import System.IO.Silently -linearProgToGLPK :: (Ord v, Eq a, Num a) => F.LinearProg v a -> (LP v a) +linearProgToGLPK :: (Ord v, Num a) => F.LinearProg v a -> LP v a linearProgToGLPK prog = LP { direction = cOptType $ F.optType prog, @@ -54,6 +55,6 @@ glpk' lp pure $ pure (0, M.fromList $ map (,0) $ S.toList $ F.vars lp) | otherwise = do (_, mres) <- glpSolveVars opts $ linearProgToGLPK lp - pure $ (\(opt, vs) -> (truncate opt, fmap truncate vs)) <$> mres + pure $ bimap truncate (fmap truncate) <$> mres where opts = mipDefaults {msgLev = MsgAll} diff --git a/src/Futhark/Solve/LP.hs b/src/Futhark/Solve/LP.hs index 47804b738b..5011ece9fb 100644 --- a/src/Futhark/Solve/LP.hs +++ b/src/Futhark/Solve/LP.hs @@ -32,23 +32,18 @@ module Futhark.Solve.LP ) where -import Control.Monad.LPMonad -import Data.Char (isAscii) -import Data.List qualified as L import Data.Map (Map) -import Data.Map qualified as Map +import Data.Map qualified as M import Data.Maybe import Data.Set (Set) import Data.Set qualified as S import Data.Vector.Unboxed (Unbox, Vector) import Data.Vector.Unboxed qualified as V -import Debug.Trace import Futhark.Solve.Matrix (Matrix (..)) -import Futhark.Solve.Matrix qualified as M +import Futhark.Solve.Matrix qualified as Matrix import Futhark.Util.Pretty import Language.Futhark.Pretty import Prelude hiding (max, min, or) -import Prelude qualified -- | A linear program. 'LP c a d' represents the program -- @@ -79,24 +74,25 @@ data LPE a = LPE } deriving (Eq, Show) -rowEchelonLPE :: (Show a, Unbox a, Fractional a, Ord a) => LPE a -> LPE a +rowEchelonLPE :: (Unbox a, Fractional a, Ord a) => LPE a -> LPE a rowEchelonLPE (LPE c a d) = - LPE c (M.sliceCols (V.generate (ncols a) id) ad) (M.getCol (ncols a) ad) + LPE c (Matrix.sliceCols (V.generate (ncols a) id) ad) (Matrix.getCol (ncols a) ad) where ad = - M.filterRows (V.any (Prelude./= 0)) $ - (M.rowEchelon $ a M.<|> M.fromColVector d) + Matrix.filterRows + (V.any (Prelude./= 0)) + (Matrix.rowEchelon $ a Matrix.<|> Matrix.fromColVector d) -- | Converts an 'LP' into an equivalent 'LPE' by introducing slack -- variables. -convert :: (Show a, Num a, Unbox a) => LP a -> LPE a +convert :: (Num a, Unbox a) => LP a -> LPE a convert (LP c a d) = LPE c' a' d where - a' = a M.<|> M.diagonal (V.replicate (M.nrows a) 1) - c' = c V.++ V.replicate (M.nrows a) 0 + a' = a Matrix.<|> Matrix.diagonal (V.replicate (Matrix.nrows a) 1) + c' = c V.++ V.replicate (Matrix.nrows a) 0 -- | Linear sum of variables. -newtype LSum v a = LSum {lsum :: (Map (Maybe v) a)} +newtype LSum v a = LSum {lsum :: Map (Maybe v) a} deriving (Show, Eq) instance (IsName v, Pretty a, Eq a, Num a) => Pretty (LSum v a) where @@ -108,10 +104,10 @@ instance (IsName v, Pretty a, Eq a, Num a) => Pretty (LSum v a) where Nothing -> pretty a Just k' -> (if a == 1 then mempty else pretty a <> "*") <> prettyName k' ) - $ Map.toList m + $ M.toList m isConstant :: (Ord v) => LSum v a -> Bool -isConstant (LSum m) = Map.keysSet m `S.isSubsetOf` S.singleton Nothing +isConstant (LSum m) = M.keysSet m `S.isSubsetOf` S.singleton Nothing instance Functor (LSum v) where fmap f (LSum m) = LSum $ fmap f m @@ -120,7 +116,7 @@ class Vars a v where vars :: a -> Set v instance (Ord v) => Vars (LSum v a) v where - vars = S.fromList . catMaybes . Map.keys . lsum + vars = S.fromList . catMaybes . M.keys . lsum -- | Type of constraint data CType = Equal | LessEq @@ -159,7 +155,7 @@ data LinearProg v a = LinearProg instance (IsName v, Pretty a, Eq a, Num a) => Pretty (LinearProg v a) where pretty (LinearProg opt obj cs) = - vcat $ + vcat [ pretty opt, indent 2 $ pretty obj, "subject to", @@ -172,10 +168,10 @@ instance (Ord v) => Vars (LinearProg v a) v where <> foldMap vars (constraints lp) bigM :: (Num a) => a -bigM = 2 ^ 10 +bigM = 2 ^ (10 :: Int) -- max{x, y} = z -max :: (Eq a, Num a, Ord v) => v -> LSum v a -> LSum v a -> LSum v a -> [Constraint v a] +max :: (Num a, Ord v) => v -> LSum v a -> LSum v a -> LSum v a -> [Constraint v a] max b x y z = [ z ~>=~ x, z ~>=~ y, @@ -184,7 +180,7 @@ max b x y z = ] -- min{x, y} = z -min :: (Eq a, Num a, Ord v) => v -> v -> v -> v -> [Constraint v a] +min :: (Num a, Ord v) => v -> v -> v -> v -> [Constraint v a] min b x y z = [ var z ~<=~ var x, var z ~<=~ var y, @@ -192,7 +188,7 @@ min b x y z = var z ~>=~ var y ~-~ bigM ~*~ var b ] -oneIsZero :: (Eq a, Num a, Ord v) => (v, v) -> (v, v) -> [Constraint v a] +oneIsZero :: (Num a, Ord v) => (v, v) -> (v, v) -> [Constraint v a] oneIsZero (b1, x1) (b2, x2) = mkC b1 x1 <> mkC b2 x2 @@ -202,7 +198,7 @@ oneIsZero (b1, x1) (b2, x2) = [ var x ~<=~ bigM ~*~ var b ] -or :: (Eq a, Num a, Ord v) => v -> v -> Constraint v a -> Constraint v a -> [Constraint v a] +or :: (Num a, Ord v) => v -> v -> Constraint v a -> Constraint v a -> [Constraint v a] or b1 b2 c1 c2 = mkC b1 c1 <> mkC b2 c2 @@ -216,94 +212,89 @@ or b1 b2 c1 c2 = [ l ~<=~ r ~+~ bigM ~*~ (constant 1 ~-~ var b) ] -bin :: (Num a, Ord v) => v -> Constraint v a +bin :: (Num a) => v -> Constraint v a bin v = Constraint LessEq (var v) (constant 1) -(~==~) :: (Num a, Ord v) => LSum v a -> LSum v a -> Constraint v a +(~==~) :: LSum v a -> LSum v a -> Constraint v a l ~==~ r = Constraint Equal l r infix 4 ~==~ -(~<=~) :: (Num a, Ord v) => LSum v a -> LSum v a -> Constraint v a +(~<=~) :: LSum v a -> LSum v a -> Constraint v a l ~<=~ r = Constraint LessEq l r infix 4 ~<=~ -(~>=~) :: (Num a, Ord v) => LSum v a -> LSum v a -> Constraint v a +(~>=~) :: (Num a) => LSum v a -> LSum v a -> Constraint v a l ~>=~ r = Constraint LessEq (neg l) (neg r) infix 4 ~>=~ normalize :: (Eq a, Num a) => LSum v a -> LSum v a -normalize = LSum . Map.filter (/= 0) . lsum +normalize = LSum . M.filter (/= 0) . lsum var :: (Num a) => v -> LSum v a -var v = LSum $ Map.singleton (Just v) (fromInteger 1) +var v = LSum $ M.singleton (Just v) 1 constant :: a -> LSum v a -constant = LSum . Map.singleton Nothing +constant = LSum . M.singleton Nothing cval :: (Num a, Ord v) => LSum v a -> a cval = (! Nothing) -(~+~) :: (Eq a, Num a, Ord v) => LSum v a -> LSum v a -> LSum v a --- (LSum x) ~+~ (LSum y) = normalize $ LSum $ Map.unionWith (+) x y -(LSum x) ~+~ (LSum y) = LSum $ Map.unionWith (+) x y +(~+~) :: (Ord v, Num a) => LSum v a -> LSum v a -> LSum v a +(LSum x) ~+~ (LSum y) = LSum $ M.unionWith (+) x y infixl 6 ~+~ -(~-~) :: (Eq a, Num a, Ord v) => LSum v a -> LSum v a -> LSum v a -x ~-~ y = x ~+~ (neg y) +(~-~) :: (Ord v, Num a) => LSum v a -> LSum v a -> LSum v a +x ~-~ y = x ~+~ neg y infixl 6 ~-~ -(~*~) :: (Eq a, Num a, Ord v) => a -> LSum v a -> LSum v a --- a ~*~ s = normalize $ fmap (a *) s +(~*~) :: (Num a) => a -> LSum v a -> LSum v a a ~*~ s = fmap (a *) s infixl 7 ~*~ (!) :: (Num a, Ord v) => LSum v a -> Maybe v -> a -(LSum m) ! v = - case m Map.!? v of - Nothing -> 0 - Just a -> a +(LSum m) ! v = fromMaybe 0 (m M.!? v) -neg :: (Num a, Ord v) => LSum v a -> LSum v a +neg :: (Num a) => LSum v a -> LSum v a neg (LSum x) = LSum $ fmap negate x -- | Converts a linear program given with a list of constraints -- into the standard form. linearProgToLP :: forall v a. - (Unbox a, Num a, Ord v, Eq a) => + (Unbox a, Num a, Ord v) => LinearProg v a -> (LP a, Map Int v) linearProgToLP (LinearProg otype obj cs) = - (LP c a d, idxMap) + let c = mkRow $ convertObj otype obj + a = Matrix.fromVectors $ map (mkRow . fst) cs' + d = V.fromList $ map snd cs' + in (LP c a d, idxMap) where cs' = foldMap (convertEqCType . splitConstraint) cs idxMap = - Map.fromList $ + M.fromList $ zip [0 ..] $ catMaybes $ - Map.keys $ + M.keys $ mconcat $ map (lsum . fst) cs' - mkRow s = V.generate (Map.size idxMap) $ \i -> s ! Just (idxMap Map.! i) - c = mkRow $ convertObj otype obj - a = M.fromVectors $ map (mkRow . fst) cs' - d = V.fromList $ map snd cs' + mkRow s = V.generate (M.size idxMap) $ \i -> s ! Just (idxMap M.! i) + + convertEqCType :: (CType, LSum v a, a) -> [(LSum v a, a)] + convertEqCType (Equal, s, a) = [(s, a), (neg s, negate a)] + convertEqCType (LessEq, s, a) = [(s, a)] splitConstraint :: Constraint v a -> (CType, LSum v a, a) splitConstraint (Constraint ctype l r) = let c = negate $ cval (l ~-~ r) in (ctype, l ~-~ r ~-~ constant c, c) - convertEqCType :: (CType, LSum v a, a) -> [(LSum v a, a)] - convertEqCType (Equal, s, a) = [(s, a), (neg s, negate a)] - convertEqCType (LessEq, s, a) = [(s, a)] - convertObj :: OptType -> LSum v a -> LSum v a convertObj Maximize s = s convertObj Minimize s = neg s @@ -312,24 +303,24 @@ linearProgToLP (LinearProg otype obj cs) = -- into the equational form. Assumes no <= constraints. linearProgToLPE :: forall v a. - (Unbox a, Num a, Ord v, Eq a) => + (Unbox a, Num a, Ord v) => LinearProg v a -> (LPE a, Map Int v) linearProgToLPE (LinearProg otype obj cs) = - (LPE c a d, idxMap) + let c = mkRow $ convertObj otype obj + a = Matrix.fromVectors $ map (mkRow . fst) cs' + d = V.fromList $ map snd cs' + in (LPE c a d, idxMap) where cs' = map (checkOnlyEqType . splitConstraint) cs idxMap = - Map.fromList $ + M.fromList $ zip [0 ..] $ catMaybes $ - Map.keys $ + M.keys $ mconcat $ map (lsum . fst) cs' - mkRow s = V.generate (Map.size idxMap) $ \i -> s ! Just (idxMap Map.! i) - c = mkRow $ convertObj otype obj - a = M.fromVectors $ map (mkRow . fst) cs' - d = V.fromList $ map snd cs' + mkRow s = V.generate (M.size idxMap) $ \i -> s ! Just (idxMap M.! i) splitConstraint :: Constraint v a -> (CType, LSum v a, a) splitConstraint (Constraint ctype l r) = @@ -343,15 +334,3 @@ linearProgToLPE (LinearProg otype obj cs) = convertObj :: OptType -> LSum v a -> LSum v a convertObj Maximize s = s convertObj Minimize s = neg s - -test1 :: LPE Double -test1 = - LPE - { pc = V.fromList [5.5, 2.1], - pA = - M.fromLists - [ [-1, 1], - [8, 2] - ], - pd = V.fromList [2, 17] - } diff --git a/src/Futhark/Solve/Matrix.hs b/src/Futhark/Solve/Matrix.hs index ae3bdf6b7c..39ec16a39e 100644 --- a/src/Futhark/Solve/Matrix.hs +++ b/src/Futhark/Solve/Matrix.hs @@ -36,7 +36,6 @@ module Futhark.Solve.Matrix where import Data.List qualified as L -import Data.Map (Map) import Data.Map qualified as M import Data.Vector.Unboxed (Unbox, Vector) import Data.Vector.Unboxed qualified as V @@ -84,7 +83,7 @@ fromVectors :: (Unbox a) => [Vector a] -> Matrix a fromVectors [] = empty fromVectors vs = Matrix - { elems = V.concat $ vs, + { elems = V.concat vs, nrows = length vs, ncols = V.length $ head vs } @@ -263,18 +262,7 @@ update :: (Unbox a) => Matrix a -> Vector ((Int, Int), a) -> Matrix a update m upds = generate ( \i j -> - case (M.fromList $ V.toList upds) M.!? (i, j) of - Nothing -> m ! (i, j) - Just x -> x - ) - (nrows m) - (ncols m) - -update_ :: (Unbox a) => Matrix a -> Map (Int, Int) a -> Matrix a -update_ m upds = - generate - ( \i j -> - case upds M.!? (i, j) of + case M.fromList (V.toList upds) M.!? (i, j) of Nothing -> m ! (i, j) Just x -> x ) @@ -282,7 +270,7 @@ update_ m upds = (ncols m) -- This version doesn't maintain integrality of the entries. -rowEchelon :: (Num a, Fractional a, Unbox a, Ord a) => Matrix a -> Matrix a +rowEchelon :: (Fractional a, Unbox a, Ord a) => Matrix a -> Matrix a rowEchelon = rowEchelon' 0 0 where rowEchelon' h k m@(Matrix _ nr nc) diff --git a/src/Futhark/Solve/Simplex.hs b/src/Futhark/Solve/Simplex.hs index e01c7ce566..362b300038 100644 --- a/src/Futhark/Solve/Simplex.hs +++ b/src/Futhark/Solve/Simplex.hs @@ -12,7 +12,6 @@ import Data.Map.Strict qualified as M import Data.Maybe import Data.Vector.Unboxed (Unbox, Vector) import Data.Vector.Unboxed qualified as V -import Debug.Trace import Futhark.Solve.LP (LP (..), LPE (..), LinearProg (..), convert, linearProgToLPE, rowEchelonLPE) import Futhark.Solve.Matrix @@ -30,35 +29,35 @@ import Futhark.Solve.Matrix -- equal to @z'@. -- | Computes @r@ as given in the tableau above. -comp_r :: +compR :: (Num a, Unbox a) => LPE a -> Matrix a -> Vector Int -> Vector Int -> Vector a -comp_r (LPE c a _) invA_B b n = +compR (LPE c a _) invA_B b n = c @ n .-. c @ b .* invA_B .* a @ n --- | @comp_q_enter prob invA_B b n enter@ computes the @enter@th +-- | @compQEnter prob invA_B b n enter@ computes the @enter@th -- column of @q@. -comp_q_enter :: +compQEnter :: (Num a, Unbox a) => LPE a -> Matrix a -> Int -> Vector a -comp_q_enter (LPE _ a _) invA_B enter = +compQEnter (LPE _ a _) invA_B enter = V.map negate $ invA_B *. getCol enter a -- | Computes the objective given an inversion of @a@ and a basis. -comp_z :: +compZ :: (Num a, Unbox a) => LPE a -> Matrix a -> Vector Int -> a -comp_z (LPE c _ d) invA_B b = +compZ (LPE c _ d) invA_B b = c @ b .* invA_B <.> d -- | Constructs an auxiliary equational linear program to compute the @@ -75,6 +74,57 @@ mkAux (LPE _ a d) = (LPE c_aux a_aux d_aux, b_aux, n_aux) b_aux = V.generate (nrows a) (+ ncols a) n_aux = V.generate (ncols a) id +fixDegenerateBasis :: + (Unbox a, Ord a, Fractional a, Show a) => + LPE a -> + Int -> + LPE a -> + (Matrix a, Vector a, Vector Int, Vector Int) -> + (LPE a, Matrix a, Vector a, Vector Int, Vector Int) +fixDegenerateBasis og_prob col prob (invA_B, p, b, n) + | Just exit_idx <- mexit_idx, + V.null (elim_row exit_idx) = + let prob' = + prob + { pA = deleteRow exit_idx (pA prob), + pd = + V.map snd $ + V.filter ((/= exit_idx) . fst) $ + V.imap (curry id) $ + pd prob + } + invA_B' = deleteRow exit_idx $ deleteCol exit_idx invA_B + p' = + V.map snd $ + V.filter ((/= exit_idx) . fst) $ + V.imap (curry id) p + b' = + V.map snd $ + V.filter ((/= exit_idx) . fst) $ + V.imap (curry id) b + in fixDegenerateBasis og_prob col prob' (invA_B', p', b', n) + | Just exit_idx <- mexit_idx, + (enter, _) <- V.head (elim_row exit_idx) = + let enter_idx = fromJust $ V.findIndex (== enter) n + exit = b V.! exit_idx + in fixDegenerateBasis og_prob col prob $ + pivot prob (invA_B, p, b, n) (enter_idx, enter) (exit_idx, exit) + | otherwise = + let prob' = + prob + { pc = pc og_prob, + pA = sliceCols (V.generate col id) $ pA prob, + pd = V.map abs $ pd og_prob + } + in (prob', invA_B, p, V.filter (< col) b, V.filter (< col) n) + where + mexit_idx = + fst <$> V.filter ((>= col) . snd) (V.imap (curry id) b) V.!? 0 + elim_row exit_idx = + V.filter ((/= 0) . snd) $ + V.map (\j -> (j, compQEnter prob invA_B j V.! exit_idx)) $ + V.generate col id + -- | Finds an initial feasible basis for an equational linear program. -- Returns 'Nothing' if the LP has no solution. Inverts some -- equations by multiplying by -1 so it also returns a modified (but @@ -85,64 +135,13 @@ findBasis :: Maybe (LPE a, Matrix a, Vector a, Vector Int, Vector Int) findBasis prob = do (invA_B, p, b, n) <- step p_aux (invA_B_aux, d_aux, b_aux, n_aux) - if comp_z p_aux invA_B b == 0 + if compZ p_aux invA_B b == 0 then Just $ fixDegenerateBasis prob (ncols $ pA prob) p_aux (invA_B, p, b, n) else Nothing where (p_aux@(LPE _ _ d_aux), b_aux, n_aux) = mkAux prob invA_B_aux = identity $ V.length b_aux - fixDegenerateBasis :: - (Unbox a, Ord a, Fractional a, Show a) => - LPE a -> - Int -> - LPE a -> - (Matrix a, Vector a, Vector Int, Vector Int) -> - (LPE a, Matrix a, Vector a, Vector Int, Vector Int) - fixDegenerateBasis og_prob col prob (invA_B, p, b, n) - | Just exit_idx <- mexit_idx, - V.null (elim_row exit_idx) = - let prob' = - prob - { pA = deleteRow exit_idx (pA prob), - pd = - V.map snd $ - V.filter ((/= exit_idx) . fst) $ - V.imap (curry id) $ - pd prob - } - invA_B' = deleteRow exit_idx $ deleteCol exit_idx invA_B - p' = - V.map snd $ - V.filter ((/= exit_idx) . fst) $ - V.imap (curry id) p - b' = - V.map snd $ - V.filter ((/= exit_idx) . fst) $ - V.imap (curry id) b - in fixDegenerateBasis og_prob col prob' (invA_B', p', b', n) - | Just exit_idx <- mexit_idx, - (enter, _) <- V.head (elim_row exit_idx) = - let enter_idx = fromJust $ V.findIndex (== enter) n - exit = b V.! exit_idx - in fixDegenerateBasis og_prob col prob $ - pivot prob (invA_B, p, b, n) (enter_idx, enter) (exit_idx, exit) - | otherwise = - let prob' = - prob - { pc = pc og_prob, - pA = sliceCols (V.generate col id) $ pA prob, - pd = V.map abs $ pd og_prob - } - in (prob', invA_B, p, V.filter (< col) b, V.filter (< col) n) - where - mexit_idx = - fst <$> ((V.filter ((>= col) . snd) (V.imap (curry id) b)) V.!? 0) - elim_row exit_idx = - V.filter ((/= 0) . snd) $ - V.map (\j -> (j, comp_q_enter prob invA_B j V.! exit_idx)) $ - V.generate col id - -- | Solves an equational linear program. Returns 'Nothing' if the -- program is infeasible or unbounded. Otherwise returns the optimal -- value and the solution. @@ -151,10 +150,9 @@ simplex :: LPE a -> Maybe (a, Vector a) simplex lpe = do - let ech_lpe = rowEchelonLPE lpe - res@(lpe', invA_B, p, b, n) <- findBasis $ rowEchelonLPE lpe + (lpe', invA_B, p, b, n) <- findBasis $ rowEchelonLPE lpe (invA_B', p', b', n') <- step lpe' (invA_B, p, b, n) - let z = comp_z lpe' invA_B' b' + let z = compZ lpe' invA_B' b' sol = V.map snd $ V.fromList $ @@ -180,12 +178,12 @@ simplexProg :: Maybe (a, Map v a) simplexProg prog = do (z, sol) <- simplex lpe - pure $ (z, M.fromList $ map (\(i, x) -> (idxMap M.! i, x)) $ zip [0 ..] $ V.toList sol) + pure (z, M.fromList $ zipWith (\i x -> (idxMap M.! i, x)) [0 ..] $ V.toList sol) where (lpe, idxMap) = linearProgToLPE prog pivot :: - (Unbox a, Ord a, Fractional a, Show a) => + (Unbox a, Fractional a) => LPE a -> (Matrix a, Vector a, Vector Int, Vector Int) -> (Int, Int) -> @@ -194,7 +192,7 @@ pivot :: pivot prob (invA_B, p, b, n) (enter_idx, enter) (exit_idx, exit) = (invA_B', p', b', n') where - q_enter = comp_q_enter prob invA_B enter + q_enter = compQEnter prob invA_B enter b' = b V.// [(exit_idx, enter)] n' = n V.// [(enter_idx, exit)] e_inv_vec = @@ -216,7 +214,7 @@ step :: step prob (invA_B, p, b, n) | Just enter_idx <- menter_idx = let enter = n V.! enter_idx - q_enter = comp_q_enter prob invA_B enter + q_enter = compQEnter prob invA_B enter pq = V.map (\(i, p_', q_) -> (i, -(p_' / q_))) $ V.filter (\(_, _, q_) -> q_ < 0) $ @@ -233,6 +231,5 @@ step prob (invA_B, p, b, n) in step prob $ pivot prob (invA_B, p, b, n) (enter_idx, enter) (exit_idx, exit) | otherwise = Just (invA_B, p, b, n) where - r = comp_r prob invA_B b n + r = compR prob invA_B b n menter_idx = V.findIndex (> 0) r - b_zero = V.filter (\(v, i) -> v == 0 && (not $ V.null (V.filter (< i) n))) $ V.zip p b diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index e6759a9ff7..2bee66d8a6 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -213,7 +213,7 @@ unifySharedFields reason fs1 fs2 = solveEq reason ts1 ts2 mustSupportEql :: Reason -> Type -> SolveM () -mustSupportEql reason t = pure () +mustSupportEql _reason _t = pure () -- Precondition: 'v' is currently flexible. subTyVar :: Reason -> VName -> Int -> Type -> SolveM () diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 75dd2947ea..73a7315356 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -1046,7 +1046,7 @@ checkOneExp e = do (maybe_tysubsts, e') <- Terms2.checkSingleExp e case maybe_tysubsts of Left err -> throwError err - Right (generalised, tysubsts) -> runTermTypeM checkExp tysubsts $ do + Right (_generalised, tysubsts) -> runTermTypeM checkExp tysubsts $ do e'' <- checkExp e' let t = typeOf e'' (tparams, _, _) <- @@ -1064,7 +1064,7 @@ checkSizeExp e = do (maybe_tysubsts, e') <- Terms2.checkSizeExp e case maybe_tysubsts of Left err -> throwError err - Right (generalised, tysubsts) -> runTermTypeM checkExp tysubsts $ do + Right (_generalised, tysubsts) -> runTermTypeM checkExp tysubsts $ do e'' <- checkExp e' when (hasBinding e'') $ typeError (srclocOf e'') mempty . withIndexLink "size-expression-bind" $ diff --git a/src/Language/Futhark/TypeChecker/Terms/Monad.hs b/src/Language/Futhark/TypeChecker/Terms/Monad.hs index b8371cbfe8..fc1733df5e 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Monad.hs @@ -60,7 +60,6 @@ import Data.Set qualified as S import Data.Text qualified as T import Futhark.FreshNames hiding (newName) import Futhark.FreshNames qualified -import Futhark.Util import Futhark.Util.Pretty hiding (space) import Language.Futhark import Language.Futhark.Traversals @@ -533,8 +532,6 @@ instance MonadTypeChecker TermTypeM where lookupVar :: SrcLoc -> QualName VName -> StructType -> TermTypeM StructType lookupVar loc qn@(QualName qs name) inst_t = do scope <- lookupQualNameEnv qn - let usage = mkUsage loc $ docText $ "use of " <> dquotes (pretty qn) - case M.lookup name $ scopeVtable scope of Nothing -> error $ "lookupVar: " <> show qn diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index cfa88069ca..4a65aedf91 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -249,7 +249,7 @@ newTypeOverloaded :: SrcLoc -> Name -> [PrimType] -> TermM (TypeBase d NoUniquen newTypeOverloaded loc name pts = tyVarType NoUniqueness <$> newTyVarWith name (TyVarPrim (locOf loc) pts) -newSVar :: (Located loc) => loc -> Name -> TermM SVar +newSVar :: loc -> Name -> TermM SVar newSVar _loc desc = do i <- incCounter newID $ mkTypeVarName desc i @@ -426,7 +426,7 @@ lookupVar loc qn@(QualName qs name) = do if null tparams && null qs then pure t else do - (tnames, t') <- instTypeScheme qn loc tparams t + (_tnames, t') <- instTypeScheme qn loc tparams t -- TODO - qualify type names, like in the old type checker. pure t' Just EqualityF -> do @@ -707,22 +707,12 @@ checkApplyOne loc fname (fframe, ftype) (argframe, argtype) = do AutoMap {autoRep = toShape r_var, autoMap = toShape m_var, autoFrame = toShape m_var <> fframe} ) where - -- stripFrame :: Shape Size -> Type -> Type - -- stripFrame frame (Array u ds t) = - -- let mnew_shape = Shape <$> L.stripPrefix (toSComp <$> shapeDims frame) (shapeDims ds) - -- in case mnew_shape of - -- Nothing -> Scalar t - -- Just new_shape -> arrayOfWithAliases u new_shape $ Scalar t - -- stripFrame _ t = t - - isFunType (Scalar Arrow {}) = True - isFunType _ = False -- (fix) toSComp (Var (QualName [] x) _ _) = SVar x toSComp _ = error "" toShape = Shape . pure split (Scalar (Arrow _ _ _ a (RetType _ b))) = pure (a, b `setUniqueness` NoUniqueness) - split (Array u s t) = do + split (Array _u s t) = do (a, b) <- split $ Scalar t pure (arrayOf s a, arrayOf s b) split ftype' = do @@ -1373,7 +1363,7 @@ checkSingleExp e = runTermM $ do cts <- gets termConstraints tyvars <- gets termTyVars artificial <- gets termArtificial - ((cts', artificial', tyvars'), _, e'') <- + ((cts', _artificial', tyvars'), _, e'') <- rankAnalysis1 (srclocOf e') cts tyvars artificial [] e' case solve cts' tyvars' of Left err -> pure (Left err, e'') @@ -1396,7 +1386,7 @@ checkSizeExp e = runTermM $ do (cts_tyvars', _, es') <- unzip3 <$> rankAnalysis (srclocOf e) cts tyvars artificial [] e' solutions <- - forM cts_tyvars' $ \(cts', artificial', tyvars') -> + forM cts_tyvars' $ \(cts', _artificial', tyvars') -> bitraverse pure (traverse (doDefaults mempty)) $ solve cts' tyvars' case (solutions, es') of From 4cbdb8de325b126666b2c0197078742f85051412 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 1 May 2024 16:05:31 +0200 Subject: [PATCH 214/258] Warning-free tests. --- .../Futhark/Solve/BranchAndBoundTests.hs | 57 +++++++++---------- unittests/Futhark/Solve/SimplexTests.hs | 48 ++++++---------- 2 files changed, 45 insertions(+), 60 deletions(-) diff --git a/unittests/Futhark/Solve/BranchAndBoundTests.hs b/unittests/Futhark/Solve/BranchAndBoundTests.hs index ed7e04c715..b7e1bfe027 100644 --- a/unittests/Futhark/Solve/BranchAndBoundTests.hs +++ b/unittests/Futhark/Solve/BranchAndBoundTests.hs @@ -1,3 +1,5 @@ +{-# OPTIONS_GHC -fno-warn-type-defaults #-} + module Futhark.Solve.BranchAndBoundTests ( tests, ) @@ -10,7 +12,6 @@ import Futhark.Solve.Matrix qualified as M import Test.Tasty import Test.Tasty.HUnit import Prelude hiding (or) -import Prelude qualified tests :: TestTree tests = @@ -68,31 +69,28 @@ tests = case branchAndBound lp of Nothing -> False Just (z, sol) -> - and - [ z `approxEq` (11.8 :: Double), - and $ zipWith (==) (V.toList sol) [1, 3] - ], - testCase "5" $ - let prog = - LinearProg - { optType = Maximize, - objective = var "x1" ~+~ var "x2", - constraints = - [ var "x1" ~<=~ constant 10, - var "x2" ~<=~ constant 5 - ] - <> oneIsZero ("b1", "x1") ("b2", "x2") - } - (lp, idxmap) = linearProgToLP prog - lpe = convert lp - in assertBool - (unlines [show $ branchAndBound lp]) - $ case branchAndBound lp of - Nothing -> False - Just (z, sol) -> - and - [ z `approxEq` (10 :: Double) - ], + (z `approxEq` (11.8 :: Double)) + && and (zipWith (==) (V.toList sol) [1, 3]), + -- testCase "5" $ + -- let prog = + -- LinearProg + -- { optType = Maximize, + -- objective = var "x1" ~+~ var "x2", + -- constraints = + -- [ var "x1" ~<=~ constant 10, + -- var "x2" ~<=~ constant 5 + -- ] + -- <> oneIsZero ("b1", "x1") ("b2", "x2") + -- } + -- (lp, _idxmap) = linearProgToLP prog + -- in assertBool + -- (unlines [show $ branchAndBound lp]) + -- $ case branchAndBound lp of + -- Nothing -> False + -- Just (z, _sol) -> + -- and + -- [ z `approxEq` (10 :: Double) + -- ], -- testCase "6" $ -- let prog = -- LinearProg @@ -130,17 +128,16 @@ tests = var "b_R2" ~+~ var "b_M3" ~<=~ constant 1 ] } - (lp, idxmap) = linearProgToLP prog - lpe = convert lp + (lp, _idxmap) = linearProgToLP prog in assertBool (unlines [show $ branchAndBound lp]) $ case branchAndBound lp of Nothing -> False - Just (z, sol) -> + Just (z, _sol) -> and [ z `approxEq` (0 :: Double) ] ] approxEq :: (Fractional a, Ord a) => a -> a -> Bool -approxEq x1 x2 = (abs $ x1 - x2) < 10 ^^ (-10 :: Int) +approxEq x1 x2 = abs (x1 - x2) < 10 ^^ (-10 :: Int) diff --git a/unittests/Futhark/Solve/SimplexTests.hs b/unittests/Futhark/Solve/SimplexTests.hs index 1a52203d12..c29bd10a93 100644 --- a/unittests/Futhark/Solve/SimplexTests.hs +++ b/unittests/Futhark/Solve/SimplexTests.hs @@ -1,17 +1,17 @@ +{-# OPTIONS_GHC -fno-warn-type-defaults #-} + module Futhark.Solve.SimplexTests ( tests, ) where import Data.Vector.Unboxed qualified as V -import Debug.Trace import Futhark.Solve.LP import Futhark.Solve.Matrix qualified as M import Futhark.Solve.Simplex import Test.Tasty import Test.Tasty.HUnit import Prelude hiding (or) -import Prelude qualified tests :: TestTree tests = @@ -69,10 +69,8 @@ tests = case simplexLP lp of Nothing -> False Just (z, sol) -> - and - [ z `approxEq` (14.08 :: Double), - and $ zipWith approxEq (V.toList sol) [1.3, 3.3] - ], + (z `approxEq` (14.08 :: Double)) + && and (zipWith approxEq (V.toList sol) [1.3, 3.3]), testCase "5" $ let lp = LP @@ -88,10 +86,8 @@ tests = case simplexLP lp of Nothing -> False Just (z, sol) -> - and - [ z `approxEq` (0 :: Double), - and $ zipWith approxEq (V.toList sol) [0] - ], + (z `approxEq` (0 :: Double)) + && and (zipWith approxEq (V.toList sol) [0]), testCase "6" $ let lp = LP @@ -107,10 +103,8 @@ tests = case simplexLP lp of Nothing -> False Just (z, sol) -> - and - [ z `approxEq` (5 :: Double), - and $ zipWith approxEq (V.toList sol) [5] - ], + z `approxEq` (5 :: Double) + && and (zipWith approxEq (V.toList sol) [5]), testCase "7" $ let prog = LinearProg @@ -121,17 +115,14 @@ tests = var "b1" ~+~ var "b2" ~<=~ constant 1 ] } - (lp, idxmap) = linearProgToLP prog - lpe = convert lp + (lp, _idxmap) = linearProgToLP prog in assertBool (unlines [show $ simplexLP lp]) $ case simplexLP lp of Nothing -> False Just (z, sol) -> - and - [ z `approxEq` (10 :: Double), - and $ zipWith (==) (V.toList sol) [1, 0, 10] - ], + (z `approxEq` (10 :: Double)) + && and (zipWith (==) (V.toList sol) [1, 0, 10]), testCase "8" $ let prog = LinearProg @@ -143,13 +134,12 @@ tests = ] <> oneIsZero ("b1", "x1") ("b2", "x2") } - (lp, idxmap) = linearProgToLP prog - lpe = convert lp + (lp, _idxmap) = linearProgToLP prog in assertBool (unlines [show $ simplexLP lp]) $ case simplexLP lp of Nothing -> False - Just (z, sol) -> + Just (z, _sol) -> and [ z `approxEq` (15 :: Double) ], @@ -192,13 +182,12 @@ tests = var "b_R2" ~+~ var "b_M3" ~<=~ constant 1 ] } - (lp, idxmap) = linearProgToLP prog - lpe = convert lp + (lp, _idxmap) = linearProgToLP prog in assertBool (unlines [show $ simplexLP lp]) $ case simplexLP lp of Nothing -> False - Just (z, sol) -> + Just (z, _sol) -> and [ z `approxEq` (0 :: Double) ], @@ -217,17 +206,16 @@ tests = var "0b_R" ~+~ var "1b_M" ~<=~ constant 1 ] } - (lp, idxmap) = linearProgToLP prog - lpe = convert lp + (lp, _idxmap) = linearProgToLP prog in assertBool (unlines [show $ simplexLP lp]) $ case simplexLP lp of Nothing -> False - Just (z, sol) -> + Just (z, _sol) -> and [ z `approxEq` (0 :: Double) ] ] approxEq :: (Fractional a, Ord a) => a -> a -> Bool -approxEq x1 x2 = (abs $ x1 - x2) < 10 ^^ (-10 :: Int) +approxEq x1 x2 = abs (x1 - x2) < 10 ^^ (-10 :: Int) From a4946a133b80b015cc6ed0f7e8df40e977d44e89 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 1 May 2024 16:11:16 +0200 Subject: [PATCH 215/258] Link against static glpk. --- default.nix | 1 + 1 file changed, 1 insertion(+) diff --git a/default.nix b/default.nix index 17321c27d6..b3f82aac07 100644 --- a/default.nix +++ b/default.nix @@ -75,6 +75,7 @@ let "--extra-lib-dirs=${pkgs.gmp6.override { withStatic = true; }}/lib" "--extra-lib-dirs=${pkgs.zlib.static}/lib" "--extra-lib-dirs=${pkgs.libffi.overrideAttrs (old: { dontDisableStatic = true; })}/lib" + "--extra-lib-dirs=${pkgs.glpk.overrideAttrs (old: { configureFlags = ["--enable-static"] ++ old.configureFlags;})}/lib" ]; preBuild = '' From 9cf8ab8b68ee1e68e8f938a0055507ff47ce9848 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 1 May 2024 16:20:04 +0200 Subject: [PATCH 216/258] This is cleaner. --- default.nix | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/default.nix b/default.nix index b3f82aac07..46d1fdcf3c 100644 --- a/default.nix +++ b/default.nix @@ -75,7 +75,7 @@ let "--extra-lib-dirs=${pkgs.gmp6.override { withStatic = true; }}/lib" "--extra-lib-dirs=${pkgs.zlib.static}/lib" "--extra-lib-dirs=${pkgs.libffi.overrideAttrs (old: { dontDisableStatic = true; })}/lib" - "--extra-lib-dirs=${pkgs.glpk.overrideAttrs (old: { configureFlags = ["--enable-static"] ++ old.configureFlags;})}/lib" + "--extra-lib-dirs=${pkgs.glpk.overrideAttrs (old: { dontDisableStatic = true; })}/lib" ]; preBuild = '' From 323922fdf50554daa7565b94224c31872b683e43 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 6 Jun 2024 18:28:34 +0200 Subject: [PATCH 217/258] Fix typo. --- default.nix | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/default.nix b/default.nix index 026f95049e..4d27d336c9 100644 --- a/default.nix +++ b/default.nix @@ -38,7 +38,7 @@ let haskellPackagesNew.callPackage ./nix/zlib.nix {zlib=pkgs.zlib;}; gasp = - haskellPackagesNew.callPackage ./nix/.nix {}; + haskellPackagesNew.callPackage ./nix/gasp.nix {}; glpk-hs = haskellPackagesNew.callPackage ./nix/glpk-hs.nix {}; From 8a808e037afcbf61a221d96220ea584bdafcbaa3 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sat, 8 Jun 2024 20:15:44 +0200 Subject: [PATCH 218/258] let should not be generalised. --- tests/types/inference5.fut | 7 ------- 1 file changed, 7 deletions(-) delete mode 100644 tests/types/inference5.fut diff --git a/tests/types/inference5.fut b/tests/types/inference5.fut deleted file mode 100644 index 900704f21a..0000000000 --- a/tests/types/inference5.fut +++ /dev/null @@ -1,7 +0,0 @@ --- Inference for a local function. --- == --- input { 2 } output { 4 } - -def main x = - let apply f x = f x - in apply (apply (i32.+) x) x From e71f281db0622fd3f82fe1f9183282cc0c3a5258 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sat, 8 Jun 2024 22:33:19 +0200 Subject: [PATCH 219/258] Minor refactoring. --- .../Futhark/TypeChecker/Constraints.hs | 150 ++++++++++-------- 1 file changed, 88 insertions(+), 62 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 240f8ca765..f6a8e2be8e 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -115,30 +115,39 @@ instance Located TyVarInfo where type TyVar = VName --- | If a VName is not in this map, it is assumed to be rigid. The --- integer is the level. -type TyVars = M.Map TyVar (Int, TyVarInfo) +-- | The level at which a type variable is bound. Higher means +-- deeper. We can only unify a type variable at level @i@ with a type +-- @t@ if all type names that occur in @t@ are at most at level @i@. +type Level = Int + +-- | If a VName is not in this map, it is assumed to be rigid. +type TyVars = M.Map TyVar (Level, TyVarInfo) data TyVarSol = -- | Has been substituted with this. - TyVarSol Int Type - | -- | Replaced by this other type variable. - TyVarLink VName + TyVarSol Level Type | -- | Not substituted yet; has this constraint. - TyVarUnsol Int TyVarInfo + TyVarUnsol Level TyVarInfo deriving (Show) -newtype SolverState = SolverState {solverTyVars :: M.Map TyVar TyVarSol} +tyVarSolLevel :: TyVarSol -> Level +tyVarSolLevel (TyVarSol lvl _) = lvl +tyVarSolLevel (TyVarUnsol lvl _) = lvl + +newtype SolverState = SolverState + { -- | Left means linked to this other type variable. + solverTyVars :: M.Map TyVar (Either VName TyVarSol) + } initialState :: TyVars -> SolverState -initialState tyvars = SolverState $ M.map (uncurry TyVarUnsol) tyvars +initialState tyvars = SolverState $ M.map (Right . uncurry TyVarUnsol) tyvars -substTyVar :: (Monoid u) => M.Map TyVar TyVarSol -> VName -> Maybe (TypeBase SComp u) +substTyVar :: (Monoid u) => M.Map TyVar (Either VName TyVarSol) -> VName -> Maybe (TypeBase SComp u) substTyVar m v = case M.lookup v m of - Just (TyVarLink v') -> substTyVar m v' - Just (TyVarSol _ t') -> Just $ second (const mempty) $ substTyVars (substTyVar m) t' - Just (TyVarUnsol {}) -> Nothing + Just (Left v') -> substTyVar m v' + Just (Right (TyVarSol _ t')) -> Just $ second (const mempty) $ substTyVars (substTyVar m) t' + Just (Right (TyVarUnsol {})) -> Nothing Nothing -> Nothing lookupTyVar :: TyVar -> SolveM (Int, Either TyVarInfo Type) @@ -146,9 +155,9 @@ lookupTyVar orig = do tyvars <- gets solverTyVars let f v = case M.lookup v tyvars of Nothing -> error $ "Unknown tyvar: " <> prettyNameString v - Just (TyVarSol lvl t) -> pure (lvl, Right t) - Just (TyVarLink v') -> f v' - Just (TyVarUnsol lvl info) -> pure (lvl, Left info) + Just (Left v') -> f v' + Just (Right (TyVarSol lvl t)) -> pure (lvl, Right t) + Just (Right (TyVarUnsol lvl info)) -> pure (lvl, Left info) f orig -- | A solution maps a type variable to its substitution. This @@ -162,15 +171,15 @@ solution s = M.mapMaybe mkSubst $ solverTyVars s ) where - mkSubst (TyVarSol _lvl t) = + mkSubst (Right (TyVarSol _lvl t)) = Just $ Right $ first (const ()) $ substTyVars (substTyVar (solverTyVars s)) t - mkSubst (TyVarLink v') = + mkSubst (Left v') = Just . fromMaybe (Right $ Scalar $ TypeVar mempty (qualName v') []) $ mkSubst =<< M.lookup v' (solverTyVars s) - mkSubst (TyVarUnsol _ (TyVarPrim _ pts)) = Just $ Left pts + mkSubst (Right (TyVarUnsol _ (TyVarPrim _ pts))) = Just $ Left pts mkSubst _ = Nothing - unconstrained (v, TyVarUnsol _ (TyVarFree _)) = Just v + unconstrained (v, Right (TyVarUnsol _ (TyVarFree _))) = Just v unconstrained _ = Nothing newtype SolveM a = SolveM {runSolveM :: StateT SolverState (Except TypeError) a} @@ -216,15 +225,27 @@ unifySharedFields reason fs1 fs2 = mustSupportEql :: Reason -> Type -> SolveM () mustSupportEql _reason _t = pure () +scopeViolation :: Reason -> VName -> Type -> SolveM a +scopeViolation reason v tp = + throwError . TypeError (locOf reason) mempty $ + "Cannot unify type" + indent 2 (pretty tp) + "with" + <+> dquotes (prettyName v) + <+> "(scope violation)." + "This is because" + <+> dquotes (prettyName v) + <+> "is rigidly bound in a deeper scope." + -- Precondition: 'v' is currently flexible. subTyVar :: Reason -> VName -> Int -> Type -> SolveM () subTyVar reason v lvl t = do occursCheck reason v t v_info <- gets $ M.lookup v . solverTyVars case (v_info, t) of - (Just (TyVarUnsol _ TyVarFree {}), _) -> + (Just (Right (TyVarUnsol _ TyVarFree {})), _) -> pure () - ( Just (TyVarUnsol _ (TyVarPrim _ v_pts)), + ( Just (Right (TyVarUnsol _ (TyVarPrim _ v_pts))), _ ) -> if t `elem` map (Scalar . Prim) v_pts @@ -235,7 +256,7 @@ subTyVar reason v lvl t = do indent 2 (pretty v_pts) "with" indent 2 (pretty t) - ( Just (TyVarUnsol _ (TyVarSum _ cs1)), + ( Just (Right (TyVarUnsol _ (TyVarSum _ cs1))), Scalar (Sum cs2) ) -> if all (`elem` M.keys cs2) (M.keys cs1) @@ -246,7 +267,7 @@ subTyVar reason v lvl t = do indent 2 (pretty (Sum cs1)) "with type" indent 2 (pretty (Sum cs2)) - ( Just (TyVarUnsol _ (TyVarSum _ cs1)), + ( Just (Right (TyVarUnsol _ (TyVarSum _ cs1))), _ ) -> throwError . TypeError (locOf reason) mempty $ @@ -254,7 +275,7 @@ subTyVar reason v lvl t = do indent 2 (pretty (Sum cs1)) "with type" indent 2 (pretty t) - ( Just (TyVarUnsol _ (TyVarRecord _ fs1)), + ( Just (Right (TyVarUnsol _ (TyVarRecord _ fs1))), Scalar (Record fs2) ) -> if all (`elem` M.keys fs2) (M.keys fs1) @@ -265,7 +286,7 @@ subTyVar reason v lvl t = do indent 2 (pretty (Record fs1)) "with record type" indent 2 (pretty (Record fs2)) - ( Just (TyVarUnsol _ (TyVarRecord _ fs1)), + ( Just (Right (TyVarUnsol _ (TyVarRecord _ fs1))), _ ) -> throwError . TypeError (locOf reason) mempty $ @@ -273,43 +294,48 @@ subTyVar reason v lvl t = do indent 2 (pretty (Record fs1)) "with type" indent 2 (pretty t) - (Just (TyVarUnsol _ (TyVarEql _)), _) -> + (Just (Right (TyVarUnsol _ (TyVarEql _))), _) -> mustSupportEql reason t -- -- Internal error cases - (Just TyVarSol {}, _) -> + (Just (Right TyVarSol {}), _) -> error $ "Type variable already solved: " <> prettyNameString v - (Just TyVarLink {}, _) -> + (Just Left {}, _) -> error $ "Type variable already linked: " <> prettyNameString v (Nothing, _) -> error $ "linkTyVar: Nothing v: " <> prettyNameString v setInfo v (TyVarSol lvl t) +setLink :: TyVar -> VName -> SolveM () +setLink v info = modify $ \s -> s {solverTyVars = M.insert v (Left info) $ solverTyVars s} + setInfo :: TyVar -> TyVarSol -> SolveM () -setInfo v info = modify $ \s -> s {solverTyVars = M.insert v info $ solverTyVars s} +setInfo v info = modify $ \s -> s {solverTyVars = M.insert v (Right info) $ solverTyVars s} -- Precondition: 'v' is currently flexible and 't' has no solution. linkTyVar :: Reason -> VName -> VName -> SolveM () linkTyVar reason v t = do - occursCheck reason v $ Scalar $ TypeVar NoUniqueness (qualName t) [] - v_info <- gets $ M.lookup v . solverTyVars + v_info <- gets $ either alreadyLinked id . fromMaybe unknown . M.lookup v . solverTyVars (lvl, t') <- lookupTyVar t + let tp = Scalar $ TypeVar NoUniqueness (qualName t) [] + occursCheck reason v tp + case (v_info, t') of -- When either is completely unconstrained. - (Just (TyVarUnsol _ TyVarFree {}), _) -> + (TyVarUnsol _ TyVarFree {}, _) -> pure () - ( Just (TyVarUnsol _ info), + ( TyVarUnsol _ info, Left (TyVarFree {}) ) -> setInfo t (TyVarUnsol lvl info) -- -- TyVarPrim cases - ( Just (TyVarUnsol _ info@TyVarPrim {}), + ( TyVarUnsol _ info@TyVarPrim {}, Left TyVarEql {} ) -> setInfo t (TyVarUnsol lvl info) - ( Just (TyVarUnsol _ (TyVarPrim _ v_pts)), + ( TyVarUnsol _ (TyVarPrim _ v_pts), Left (TyVarPrim t_loc t_pts) ) -> let pts = L.intersect v_pts t_pts @@ -321,14 +347,14 @@ linkTyVar reason v t = do "with type that must be one of" indent 2 (pretty t_pts) else setInfo t (TyVarUnsol lvl (TyVarPrim t_loc pts)) - ( Just (TyVarUnsol _ (TyVarPrim _ v_pts)), + ( TyVarUnsol _ (TyVarPrim _ v_pts), Left TyVarRecord {} ) -> throwError . TypeError (locOf reason) mempty $ "Cannot unify type that must be one of" indent 2 (pretty v_pts) "with type that must be record." - ( Just (TyVarUnsol _ (TyVarPrim _ v_pts)), + ( TyVarUnsol _ (TyVarPrim _ v_pts), Left TyVarSum {} ) -> throwError . TypeError (locOf reason) mempty $ @@ -337,19 +363,19 @@ linkTyVar reason v t = do "with type that must be sum." -- -- TyVarSum cases - ( Just (TyVarUnsol _ (TyVarSum _ cs1)), + ( TyVarUnsol _ (TyVarSum _ cs1), Left (TyVarSum loc cs2) ) -> do unifySharedConstructors reason cs1 cs2 let cs3 = cs1 <> cs2 setInfo t (TyVarUnsol lvl (TyVarSum loc cs3)) - ( Just (TyVarUnsol _ TyVarSum {}), + ( TyVarUnsol _ TyVarSum {}, Left (TyVarPrim _ pts) ) -> throwError . TypeError (locOf reason) mempty $ "A sum type cannot be one of" indent 2 (pretty pts) - ( Just (TyVarUnsol _ (TyVarSum _ cs1)), + ( TyVarUnsol _ (TyVarSum _ cs1), Left (TyVarRecord _ fs) ) -> throwError . TypeError (locOf reason) mempty $ @@ -357,25 +383,25 @@ linkTyVar reason v t = do indent 2 (pretty (Sum cs1)) "with type" indent 2 (pretty (Scalar (Record fs))) - ( Just (TyVarUnsol _ (TyVarSum _ cs1)), + ( TyVarUnsol _ (TyVarSum _ cs1), Left (TyVarEql _) ) -> mapM_ (mapM_ (mustSupportEql reason)) cs1 -- -- TyVarRecord cases - ( Just (TyVarUnsol _ (TyVarRecord _ fs1)), + ( TyVarUnsol _ (TyVarRecord _ fs1), Left (TyVarRecord loc fs2) ) -> do unifySharedFields reason fs1 fs2 let fs3 = fs1 <> fs2 setInfo t (TyVarUnsol lvl (TyVarRecord loc fs3)) - ( Just (TyVarUnsol _ TyVarRecord {}), + ( TyVarUnsol _ TyVarRecord {}, Left (TyVarPrim _ pts) ) -> throwError . TypeError (locOf reason) mempty $ "A record type cannot be one of" indent 2 (pretty pts) - ( Just (TyVarUnsol _ (TyVarRecord _ fs1)), + ( TyVarUnsol _ (TyVarRecord _ fs1), Left (TyVarSum _ cs) ) -> throwError . TypeError (locOf reason) mempty $ @@ -383,33 +409,33 @@ linkTyVar reason v t = do indent 2 (pretty (Record fs1)) "with type" indent 2 (pretty (Scalar (Sum cs))) - ( Just (TyVarUnsol _ (TyVarRecord _ fs1)), + ( TyVarUnsol _ (TyVarRecord _ fs1), Left (TyVarEql _) ) -> mapM_ (mustSupportEql reason) fs1 -- -- TyVarEql cases - (Just (TyVarUnsol _ (TyVarEql _)), Left TyVarPrim {}) -> + (TyVarUnsol _ (TyVarEql _), Left TyVarPrim {}) -> pure () - (Just (TyVarUnsol _ (TyVarEql _)), Left TyVarEql {}) -> + (TyVarUnsol _ (TyVarEql _), Left TyVarEql {}) -> pure () - (Just (TyVarUnsol _ (TyVarEql _)), Left (TyVarRecord _ fs)) -> + (TyVarUnsol _ (TyVarEql _), Left (TyVarRecord _ fs)) -> mustSupportEql reason $ Scalar $ Record fs - (Just (TyVarUnsol _ (TyVarEql _)), Left (TyVarSum _ cs)) -> + (TyVarUnsol _ (TyVarEql _), Left (TyVarSum _ cs)) -> mustSupportEql reason $ Scalar $ Sum cs -- -- Internal error cases - (Just TyVarSol {}, _) -> - error $ "Type variable already solved: " <> prettyNameString v - (Just TyVarLink {}, _) -> - error $ "Type variable already linked: " <> prettyNameString v - (Nothing, _) -> - error $ "linkTyVar: Nothing v: " <> prettyNameString v + (TyVarSol {}, _) -> + alreadySolved (_, Right t'') -> error $ "linkTyVar: rhs " <> prettyNameString t <> " is solved as " <> prettyString t'' -- Finally insert the actual link. - setInfo v (TyVarLink t) + setLink v t + where + unknown = error $ "linkTyVar: Nothing v: " <> prettyNameString v + alreadyLinked = error $ "Type variable already linked: " <> prettyNameString v + alreadySolved = error $ "Type variable already solved: " <> prettyNameString v -- Unify at the root, emitting new equalities that must hold. unify :: Type -> Type -> Maybe [(Type, Type)] @@ -459,14 +485,14 @@ solveEq reason orig_t1 orig_t2 = do solveCt' (t1, t2) = do tyvars <- gets solverTyVars let flexible v = case M.lookup v tyvars of - Just (TyVarLink v') -> flexible v' - Just (TyVarUnsol lvl _) -> Just lvl - Just (TyVarSol _ _) -> Nothing + Just (Left v') -> flexible v' + Just (Right (TyVarUnsol lvl _)) -> Just lvl + Just (Right (TyVarSol _ _)) -> Nothing Nothing -> Nothing sub t@(Scalar (TypeVar u (QualName [] v) [])) = case M.lookup v tyvars of - Just (TyVarLink v') -> sub $ Scalar (TypeVar u (QualName [] v') []) - Just (TyVarSol _ t') -> sub t' + Just (Left v') -> sub $ Scalar (TypeVar u (QualName [] v') []) + Just (Right (TyVarSol _ t')) -> sub t' _ -> t sub t = t case (sub t1, sub t2) of From 0d7e6614f6699b6cce0b9fcb791749bbe16b657a Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 3 Jul 2024 10:05:49 +0200 Subject: [PATCH 220/258] Crudely strangle warnings. --- src/Language/Futhark/TypeChecker/Constraints.hs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index f6a8e2be8e..24521c0327 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -11,6 +11,9 @@ module Language.Futhark.TypeChecker.Constraints TyVars, Solution, solve, + -- To hide warnings + tyVarSolLevel, + scopeViolation, ) where From f126abea1519a017d0da1ecf83a7944bcbb8119f Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 5 Jul 2024 16:18:40 +0200 Subject: [PATCH 221/258] Track explicit type parameters in constraint solver. I cannot figure out whether this is hacky or OK, but we need it to handle level checks correctly. --- .../Futhark/TypeChecker/Constraints.hs | 39 ++++++++++++++----- src/Language/Futhark/TypeChecker/Terms2.hs | 32 +++++++++++---- 2 files changed, 54 insertions(+), 17 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 24521c0327..3eb1f9d010 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -9,6 +9,7 @@ module Language.Futhark.TypeChecker.Constraints TyVarInfo (..), TyVar, TyVars, + TyParams, Solution, solve, -- To hide warnings @@ -86,8 +87,8 @@ instance Pretty Ct where type Constraints = [Ct] --- | Information about a type variable. Every type variable is --- associated with a location, which is the original syntax element +-- | Information about a flexible type variable. Every type variable +-- is associated with a location, which is the original syntax element -- that it is the type of. data TyVarInfo = -- | Can be substituted with anything. @@ -123,18 +124,26 @@ type TyVar = VName -- @t@ if all type names that occur in @t@ are at most at level @i@. type Level = Int --- | If a VName is not in this map, it is assumed to be rigid. +-- | If a VName is not in this map, it should be in the 'TyParams' - +-- the exception is abstract types, which are just missing (and +-- assumed to have smallest possible level). type TyVars = M.Map TyVar (Level, TyVarInfo) +-- | Explicit type parameters. +type TyParams = M.Map TyVar (Level, Loc) + data TyVarSol = -- | Has been substituted with this. TyVarSol Level Type + | -- | Is an explicit type parameter in the source program. + TyVarParam Level Loc | -- | Not substituted yet; has this constraint. TyVarUnsol Level TyVarInfo deriving (Show) tyVarSolLevel :: TyVarSol -> Level tyVarSolLevel (TyVarSol lvl _) = lvl +tyVarSolLevel (TyVarParam lvl _) = lvl tyVarSolLevel (TyVarUnsol lvl _) = lvl newtype SolverState = SolverState @@ -142,14 +151,18 @@ newtype SolverState = SolverState solverTyVars :: M.Map TyVar (Either VName TyVarSol) } -initialState :: TyVars -> SolverState -initialState tyvars = SolverState $ M.map (Right . uncurry TyVarUnsol) tyvars +initialState :: TyParams -> TyVars -> SolverState +initialState typarams tyvars = SolverState $ M.map g typarams <> M.map f tyvars + where + f (lvl, info) = Right $ TyVarUnsol lvl info + g (lvl, loc) = Right $ TyVarParam lvl loc substTyVar :: (Monoid u) => M.Map TyVar (Either VName TyVarSol) -> VName -> Maybe (TypeBase SComp u) substTyVar m v = case M.lookup v m of Just (Left v') -> substTyVar m v' Just (Right (TyVarSol _ t')) -> Just $ second (const mempty) $ substTyVars (substTyVar m) t' + Just (Right TyVarParam {}) -> Nothing Just (Right (TyVarUnsol {})) -> Nothing Nothing -> Nothing @@ -160,6 +173,8 @@ lookupTyVar orig = do Nothing -> error $ "Unknown tyvar: " <> prettyNameString v Just (Left v') -> f v' Just (Right (TyVarSol lvl t)) -> pure (lvl, Right t) + Just (Right (TyVarParam lvl _)) -> + pure (lvl, Right $ Scalar $ TypeVar mempty (qualName orig) []) Just (Right (TyVarUnsol lvl info)) -> pure (lvl, Left info) f orig @@ -303,6 +318,8 @@ subTyVar reason v lvl t = do -- Internal error cases (Just (Right TyVarSol {}), _) -> error $ "Type variable already solved: " <> prettyNameString v + (Just (Right TyVarParam {}), _) -> + error $ "Cannot substitute type parameter: " <> prettyNameString v (Just Left {}, _) -> error $ "Type variable already linked: " <> prettyNameString v (Nothing, _) -> @@ -430,6 +447,8 @@ linkTyVar reason v t = do -- Internal error cases (TyVarSol {}, _) -> alreadySolved + (TyVarParam {}, _) -> + isParam (_, Right t'') -> error $ "linkTyVar: rhs " <> prettyNameString t <> " is solved as " <> prettyString t'' @@ -439,6 +458,7 @@ linkTyVar reason v t = do unknown = error $ "linkTyVar: Nothing v: " <> prettyNameString v alreadyLinked = error $ "Type variable already linked: " <> prettyNameString v alreadySolved = error $ "Type variable already solved: " <> prettyNameString v + isParam = error $ "Type name is a type parameter: " <> prettyNameString v -- Unify at the root, emitting new equalities that must hold. unify :: Type -> Type -> Maybe [(Type, Type)] @@ -490,7 +510,8 @@ solveEq reason orig_t1 orig_t2 = do let flexible v = case M.lookup v tyvars of Just (Left v') -> flexible v' Just (Right (TyVarUnsol lvl _)) -> Just lvl - Just (Right (TyVarSol _ _)) -> Nothing + Just (Right TyVarSol {}) -> Nothing + Just (Right TyVarParam {}) -> Nothing Nothing -> Nothing sub t@(Scalar (TypeVar u (QualName [] v) [])) = case M.lookup v tyvars of @@ -552,11 +573,11 @@ solveTyVar (tv, (_, TyVarSum loc cs1)) = do solveTyVar (_, _) = pure () -solve :: Constraints -> TyVars -> Either TypeError ([VName], Solution) -solve constraints tyvars = +solve :: Constraints -> TyParams -> TyVars -> Either TypeError ([VName], Solution) +solve constraints typarams tyvars = second solution . runExcept - . flip execStateT (initialState tyvars) + . flip execStateT (initialState typarams tyvars) . runSolveM $ do mapM_ solveCt constraints diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 56f36c7b9e..cb3082908a 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -112,6 +112,7 @@ data TermEnv = TermEnv data TermState = TermState { termConstraints :: Constraints, termTyVars :: TyVars, + termTyParams :: TyParams, termCounter :: !Int, termWarnings :: Warnings, termNameSource :: VNameSource, @@ -192,6 +193,7 @@ runTermM (TermM m) = do TermState { termConstraints = mempty, termTyVars = mempty, + termTyParams = mempty, termWarnings = mempty, termNameSource = src, termCounter = 0, @@ -632,14 +634,25 @@ bindTypes tbinds = localScope extend } bindTypeParams :: [TypeParam] -> TermM a -> TermM a -bindTypeParams tparams = - bind (mapMaybe typeParamIdent tparams) - . bindTypes (mapMaybe typeParamType tparams) +bindTypeParams tparams m = + bind idents . bindTypes types $ do + lvl <- curLevel + modify $ \s -> + s + { termTyParams = + termTyParams s + <> M.fromList (mapMaybe (typeParam lvl) tparams) + } + m where + idents = mapMaybe typeParamIdent tparams + types = mapMaybe typeParamType tparams typeParamType (TypeParamType l v _) = Just (v, TypeAbbr l [] $ RetType [] $ Scalar (TypeVar mempty (qualName v) [])) typeParamType TypeParamDim {} = Nothing + typeParam lvl (TypeParamType _ v loc) = Just (v, (lvl, locOf loc)) + typeParam _ _ = Nothing bindParams :: [TypeParam] -> @@ -1311,6 +1324,7 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do cts <- gets termConstraints tyvars <- gets termTyVars + typarams <- gets termTyParams artificial <- gets termArtificial debugTraceM 3 $ "\n# function " <> prettyNameString fname <> "\n# " <> locStr loc <> "\n" @@ -1327,15 +1341,15 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do unlines $ map (\(v, t) -> prettyNameString v <> " => " <> prettyString t) (M.toList artificial) ] - onRankSolution retdecl' + onRankSolution retdecl' typarams =<< rankAnalysis1 loc cts tyvars artificial params' body' where - onRankSolution retdecl' ((cts', artificial, tyvars'), params', body'') = do + onRankSolution retdecl' typarams ((cts', artificial, tyvars'), params', body'') = do solution <- bitraverse pure (fmap (second (onArtificial artificial)) . onTySolution params' body'') - $ solve cts' tyvars' + $ solve cts' typarams tyvars' debugTraceM 3 $ unlines [ "## constraints:", @@ -1367,10 +1381,11 @@ checkSingleExp e = runTermM $ do e' <- checkExp e cts <- gets termConstraints tyvars <- gets termTyVars + typarams <- gets termTyParams artificial <- gets termArtificial ((cts', _artificial', tyvars'), _, e'') <- rankAnalysis1 (srclocOf e') cts tyvars artificial [] e' - case solve cts' tyvars' of + case solve cts' typarams tyvars' of Left err -> pure (Left err, e'') Right (unconstrained, solution) -> do e_t <- expType e'' @@ -1386,13 +1401,14 @@ checkSizeExp e = runTermM $ do e' <- checkSizeExp' e cts <- gets termConstraints tyvars <- gets termTyVars + typarams <- gets termTyParams artificial <- gets termArtificial (cts_tyvars', _, es') <- unzip3 <$> rankAnalysis (srclocOf e) cts tyvars artificial [] e' solutions <- forM cts_tyvars' $ \(cts', _artificial', tyvars') -> - bitraverse pure (traverse (doDefaults mempty)) $ solve cts' tyvars' + bitraverse pure (traverse (doDefaults mempty)) $ solve cts' typarams tyvars' case (solutions, es') of ([solution], [e'']) -> From 1ace2df74e8032d8b712b07ade141d71bdd7df1c Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 11 Jul 2024 14:44:49 +0200 Subject: [PATCH 222/258] A bit more work. --- .../Futhark/TypeChecker/Constraints.hs | 69 ++++++++++--------- 1 file changed, 37 insertions(+), 32 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 3eb1f9d010..af5a5f2cfd 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -12,9 +12,6 @@ module Language.Futhark.TypeChecker.Constraints TyParams, Solution, solve, - -- To hide warnings - tyVarSolLevel, - scopeViolation, ) where @@ -135,17 +132,12 @@ type TyParams = M.Map TyVar (Level, Loc) data TyVarSol = -- | Has been substituted with this. TyVarSol Level Type - | -- | Is an explicit type parameter in the source program. + | -- | Is an explicit (rigid) type parameter in the source program. TyVarParam Level Loc | -- | Not substituted yet; has this constraint. TyVarUnsol Level TyVarInfo deriving (Show) -tyVarSolLevel :: TyVarSol -> Level -tyVarSolLevel (TyVarSol lvl _) = lvl -tyVarSolLevel (TyVarParam lvl _) = lvl -tyVarSolLevel (TyVarUnsol lvl _) = lvl - newtype SolverState = SolverState { -- | Left means linked to this other type variable. solverTyVars :: M.Map TyVar (Either VName TyVarSol) @@ -178,6 +170,12 @@ lookupTyVar orig = do Just (Right (TyVarUnsol lvl info)) -> pure (lvl, Left info) f orig +setLink :: TyVar -> VName -> SolveM () +setLink v info = modify $ \s -> s {solverTyVars = M.insert v (Left info) $ solverTyVars s} + +setInfo :: TyVar -> TyVarSol -> SolveM () +setInfo v info = modify $ \s -> s {solverTyVars = M.insert v (Right info) $ solverTyVars s} + -- | A solution maps a type variable to its substitution. This -- substitution is complete, in the sense there are no right-hand -- sides that contain a type variable. @@ -243,23 +241,38 @@ unifySharedFields reason fs1 fs2 = mustSupportEql :: Reason -> Type -> SolveM () mustSupportEql _reason _t = pure () -scopeViolation :: Reason -> VName -> Type -> SolveM a -scopeViolation reason v tp = +scopeViolation :: Reason -> VName -> Type -> VName -> SolveM a +scopeViolation reason v1 ty v2 = throwError . TypeError (locOf reason) mempty $ "Cannot unify type" - indent 2 (pretty tp) + indent 2 (pretty ty) "with" - <+> dquotes (prettyName v) + <+> dquotes (prettyName v1) <+> "(scope violation)." "This is because" - <+> dquotes (prettyName v) + <+> dquotes (prettyName v2) <+> "is rigidly bound in a deeper scope." +scopeCheck :: Reason -> TyVar -> Int -> Type -> SolveM () +scopeCheck reason v v_lvl ty = do + mapM_ check $ typeVars ty + where + check ty_v = do + ty_v_info <- gets $ M.lookup ty_v . solverTyVars + case ty_v_info of + Just (Right (TyVarParam ty_v_lvl _)) + | ty_v_lvl > v_lvl -> scopeViolation reason v ty ty_v + Just (Right (TyVarUnsol ty_v_lvl info)) + | ty_v_lvl /= v_lvl -> + setInfo ty_v $ TyVarUnsol v_lvl info + _ -> pure () + -- Precondition: 'v' is currently flexible. subTyVar :: Reason -> VName -> Int -> Type -> SolveM () -subTyVar reason v lvl t = do +subTyVar reason v v_lvl t = do occursCheck reason v t v_info <- gets $ M.lookup v . solverTyVars + scopeCheck reason v v_lvl t case (v_info, t) of (Just (Right (TyVarUnsol _ TyVarFree {})), _) -> pure () @@ -323,23 +336,15 @@ subTyVar reason v lvl t = do (Just Left {}, _) -> error $ "Type variable already linked: " <> prettyNameString v (Nothing, _) -> - error $ "linkTyVar: Nothing v: " <> prettyNameString v - - setInfo v (TyVarSol lvl t) + error $ "subTyVar: Nothing v: " <> prettyNameString v -setLink :: TyVar -> VName -> SolveM () -setLink v info = modify $ \s -> s {solverTyVars = M.insert v (Left info) $ solverTyVars s} - -setInfo :: TyVar -> TyVarSol -> SolveM () -setInfo v info = modify $ \s -> s {solverTyVars = M.insert v (Right info) $ solverTyVars s} + setInfo v (TyVarSol v_lvl t) --- Precondition: 'v' is currently flexible and 't' has no solution. -linkTyVar :: Reason -> VName -> VName -> SolveM () -linkTyVar reason v t = do +-- Precondition: 'v' and 't' are both currently flexible. +unionTyVars :: Reason -> VName -> VName -> SolveM () +unionTyVars reason v t = do v_info <- gets $ either alreadyLinked id . fromMaybe unknown . M.lookup v . solverTyVars (lvl, t') <- lookupTyVar t - let tp = Scalar $ TypeVar NoUniqueness (qualName t) [] - occursCheck reason v tp case (v_info, t') of -- When either is completely unconstrained. @@ -450,12 +455,12 @@ linkTyVar reason v t = do (TyVarParam {}, _) -> isParam (_, Right t'') -> - error $ "linkTyVar: rhs " <> prettyNameString t <> " is solved as " <> prettyString t'' + error $ "unionTyVars: rhs " <> prettyNameString t <> " is solved as " <> prettyString t'' -- Finally insert the actual link. setLink v t where - unknown = error $ "linkTyVar: Nothing v: " <> prettyNameString v + unknown = error $ "unionTyVars: Nothing v: " <> prettyNameString v alreadyLinked = error $ "Type variable already linked: " <> prettyNameString v alreadySolved = error $ "Type variable already solved: " <> prettyNameString v isParam = error $ "Type name is a type parameter: " <> prettyNameString v @@ -530,8 +535,8 @@ solveEq reason orig_t1 orig_t2 = do (Just lvl, Nothing) -> subTyVar reason v1 lvl t2' (Nothing, Just lvl) -> subTyVar reason v2 lvl t1' (Just lvl1, Just lvl2) - | lvl1 <= lvl2 -> linkTyVar reason v1 v2 - | otherwise -> linkTyVar reason v2 v1 + | lvl1 <= lvl2 -> unionTyVars reason v1 v2 + | otherwise -> unionTyVars reason v2 v1 (Scalar (TypeVar _ (QualName [] v1) []), t2') | Just lvl <- flexible v1 -> subTyVar reason v1 lvl t2' From 5522eb85c0aa1c42e1b38a10d793ee7004a655bf Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 11 Jul 2024 18:13:30 +0200 Subject: [PATCH 223/258] Check for equality. --- src/Language/Futhark/TypeChecker/Constraints.hs | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index af5a5f2cfd..b9a59e90bb 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -575,6 +575,17 @@ solveTyVar (tv, (_, TyVarSum loc cs1)) = do "Must be a sum type with constructors" indent 2 (pretty (Scalar (Sum cs1))) Right _ -> pure () +solveTyVar (tv, (_, TyVarEql loc)) = do + (_, tv_t) <- lookupTyVar tv + case tv_t of + Left _ -> pure () + Right ty + | orderZero ty -> pure () + | otherwise -> + throwError . TypeError loc mempty $ + "Type" + indent 2 (align (pretty ty)) + "does not support equality (may contain function)." solveTyVar (_, _) = pure () From 43744e5e4f6c0e47bf4faa2135e1149cb5454ff1 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 11 Jul 2024 18:19:21 +0200 Subject: [PATCH 224/258] Detect ambiguous equality type. --- src/Language/Futhark/TypeChecker/Constraints.hs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index b9a59e90bb..be19eccec9 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -578,7 +578,10 @@ solveTyVar (tv, (_, TyVarSum loc cs1)) = do solveTyVar (tv, (_, TyVarEql loc)) = do (_, tv_t) <- lookupTyVar tv case tv_t of - Left _ -> pure () + Left _ -> + throwError . TypeError loc mempty $ + "Type is ambiguous (must be equality type)" + "Add a type annotation to disambiguate the type." Right ty | orderZero ty -> pure () | otherwise -> From 7329e8bbf7400b0ba7dbe5f4301d4fc1e2814ee3 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 11 Jul 2024 18:22:37 +0200 Subject: [PATCH 225/258] Abstraction. --- .../Futhark/TypeChecker/Constraints.hs | 50 ++++++++++--------- 1 file changed, 27 insertions(+), 23 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index be19eccec9..c96fa9f442 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -26,7 +26,7 @@ import Data.Maybe import Data.Set qualified as S import Futhark.Util.Pretty import Language.Futhark -import Language.Futhark.TypeChecker.Monad (TypeError (..)) +import Language.Futhark.TypeChecker.Monad (Notes, TypeError (..)) import Language.Futhark.TypeChecker.Types (substTyVars) -- | The reason for a type constraint. Used to generate type error @@ -201,11 +201,15 @@ solution s = newtype SolveM a = SolveM {runSolveM :: StateT SolverState (Except TypeError) a} deriving (Functor, Applicative, Monad, MonadState SolverState, MonadError TypeError) +typeError :: Loc -> Notes -> Doc () -> SolveM () +typeError loc notes msg = + throwError $ TypeError loc notes msg + occursCheck :: Reason -> VName -> Type -> SolveM () occursCheck reason v tp = do vars <- gets solverTyVars let tp' = substTyVars (substTyVar vars) tp - when (v `S.member` typeVars tp') . throwError . TypeError (locOf reason) mempty $ + when (v `S.member` typeVars tp') . typeError (locOf reason) mempty $ "Occurs check: cannot instantiate" <+> prettyName v <+> "with" @@ -220,9 +224,9 @@ unifySharedConstructors :: unifySharedConstructors reason cs1 cs2 = forM_ (M.toList $ M.intersectionWith (,) cs1 cs2) $ \(c, (ts1, ts2)) -> if length ts1 == length ts2 - then zipWithM (solveEq reason) ts1 ts2 + then zipWithM_ (solveEq reason) ts1 ts2 else - throwError . TypeError (locOf reason) mempty $ + typeError (locOf reason) mempty $ "Cannot unify type with constructor" indent 2 (pretty (Sum (M.singleton c ts1))) "with type of constructor" @@ -241,9 +245,9 @@ unifySharedFields reason fs1 fs2 = mustSupportEql :: Reason -> Type -> SolveM () mustSupportEql _reason _t = pure () -scopeViolation :: Reason -> VName -> Type -> VName -> SolveM a +scopeViolation :: Reason -> VName -> Type -> VName -> SolveM () scopeViolation reason v1 ty v2 = - throwError . TypeError (locOf reason) mempty $ + typeError (locOf reason) mempty $ "Cannot unify type" indent 2 (pretty ty) "with" @@ -282,7 +286,7 @@ subTyVar reason v v_lvl t = do if t `elem` map (Scalar . Prim) v_pts then pure () else - throwError . TypeError (locOf reason) mempty $ + typeError (locOf reason) mempty $ "Cannot unify type that must be one of" indent 2 (pretty v_pts) "with" @@ -293,7 +297,7 @@ subTyVar reason v v_lvl t = do if all (`elem` M.keys cs2) (M.keys cs1) then unifySharedConstructors reason cs1 cs2 else - throwError . TypeError (locOf reason) mempty $ + typeError (locOf reason) mempty $ "Cannot unify type with constructors" indent 2 (pretty (Sum cs1)) "with type" @@ -301,7 +305,7 @@ subTyVar reason v v_lvl t = do ( Just (Right (TyVarUnsol _ (TyVarSum _ cs1))), _ ) -> - throwError . TypeError (locOf reason) mempty $ + typeError (locOf reason) mempty $ "Cannot unify type with constructors" indent 2 (pretty (Sum cs1)) "with type" @@ -312,7 +316,7 @@ subTyVar reason v v_lvl t = do if all (`elem` M.keys fs2) (M.keys fs1) then unifySharedFields reason fs1 fs2 else - throwError . TypeError (locOf reason) mempty $ + typeError (locOf reason) mempty $ "Cannot unify record type with fields" indent 2 (pretty (Record fs1)) "with record type" @@ -320,7 +324,7 @@ subTyVar reason v v_lvl t = do ( Just (Right (TyVarUnsol _ (TyVarRecord _ fs1))), _ ) -> - throwError . TypeError (locOf reason) mempty $ + typeError (locOf reason) mempty $ "Cannot unify record type with fields" indent 2 (pretty (Record fs1)) "with type" @@ -366,7 +370,7 @@ unionTyVars reason v t = do let pts = L.intersect v_pts t_pts in if null pts then - throwError . TypeError (locOf reason) mempty $ + typeError (locOf reason) mempty $ "Cannot unify type that must be one of" indent 2 (pretty v_pts) "with type that must be one of" @@ -375,14 +379,14 @@ unionTyVars reason v t = do ( TyVarUnsol _ (TyVarPrim _ v_pts), Left TyVarRecord {} ) -> - throwError . TypeError (locOf reason) mempty $ + typeError (locOf reason) mempty $ "Cannot unify type that must be one of" indent 2 (pretty v_pts) "with type that must be record." ( TyVarUnsol _ (TyVarPrim _ v_pts), Left TyVarSum {} ) -> - throwError . TypeError (locOf reason) mempty $ + typeError (locOf reason) mempty $ "Cannot unify type that must be one of" indent 2 (pretty v_pts) "with type that must be sum." @@ -397,13 +401,13 @@ unionTyVars reason v t = do ( TyVarUnsol _ TyVarSum {}, Left (TyVarPrim _ pts) ) -> - throwError . TypeError (locOf reason) mempty $ + typeError (locOf reason) mempty $ "A sum type cannot be one of" indent 2 (pretty pts) ( TyVarUnsol _ (TyVarSum _ cs1), Left (TyVarRecord _ fs) ) -> - throwError . TypeError (locOf reason) mempty $ + typeError (locOf reason) mempty $ "Cannot unify type with constructors" indent 2 (pretty (Sum cs1)) "with type" @@ -423,13 +427,13 @@ unionTyVars reason v t = do ( TyVarUnsol _ TyVarRecord {}, Left (TyVarPrim _ pts) ) -> - throwError . TypeError (locOf reason) mempty $ + typeError (locOf reason) mempty $ "A record type cannot be one of" indent 2 (pretty pts) ( TyVarUnsol _ (TyVarRecord _ fs1), Left (TyVarSum _ cs) ) -> - throwError . TypeError (locOf reason) mempty $ + typeError (locOf reason) mempty $ "Cannot unify record type" indent 2 (pretty (Record fs1)) "with type" @@ -504,7 +508,7 @@ solveEq reason orig_t1 orig_t2 = do where cannotUnify = do tyvars <- gets solverTyVars - throwError . TypeError (locOf reason) mempty $ + typeError (locOf reason) mempty $ "Cannot unify" indent 2 (pretty (substTyVars (substTyVar tyvars) orig_t1)) "with" @@ -558,7 +562,7 @@ solveTyVar (tv, (_, TyVarRecord loc fs1)) = do (_, tv_t) <- lookupTyVar tv case tv_t of Left _ -> - throwError . TypeError loc mempty $ + typeError loc mempty $ "Type" <+> prettyName tv <+> "is ambiguous." @@ -570,7 +574,7 @@ solveTyVar (tv, (_, TyVarSum loc cs1)) = do (_, tv_t) <- lookupTyVar tv case tv_t of Left _ -> - throwError . TypeError loc mempty $ + typeError loc mempty $ "Type is ambiguous." "Must be a sum type with constructors" indent 2 (pretty (Scalar (Sum cs1))) @@ -579,13 +583,13 @@ solveTyVar (tv, (_, TyVarEql loc)) = do (_, tv_t) <- lookupTyVar tv case tv_t of Left _ -> - throwError . TypeError loc mempty $ + typeError loc mempty $ "Type is ambiguous (must be equality type)" "Add a type annotation to disambiguate the type." Right ty | orderZero ty -> pure () | otherwise -> - throwError . TypeError loc mempty $ + typeError loc mempty $ "Type" indent 2 (align (pretty ty)) "does not support equality (may contain function)." From 8a99272d1c1e4d1ca76062ee1e0accf1f6d40f32 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 11 Jul 2024 18:24:39 +0200 Subject: [PATCH 226/258] Refine. --- src/Language/Futhark/TypeChecker/Constraints.hs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index c96fa9f442..b85799cc1f 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -582,10 +582,11 @@ solveTyVar (tv, (_, TyVarSum loc cs1)) = do solveTyVar (tv, (_, TyVarEql loc)) = do (_, tv_t) <- lookupTyVar tv case tv_t of - Left _ -> + Left TyVarEql {} -> typeError loc mempty $ "Type is ambiguous (must be equality type)" "Add a type annotation to disambiguate the type." + Left _ -> pure () Right ty | orderZero ty -> pure () | otherwise -> From f7892a191d36a459d616de568ef280a8a7a96e8e Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 11 Jul 2024 20:35:01 +0200 Subject: [PATCH 227/258] More fixes. --- .../Futhark/TypeChecker/Constraints.hs | 78 +++++++++------- src/Language/Futhark/TypeChecker/Rank.hs | 7 +- src/Language/Futhark/TypeChecker/Terms2.hs | 89 +++++++++++-------- 3 files changed, 102 insertions(+), 72 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index b85799cc1f..b7bd2f5d81 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -11,6 +11,7 @@ module Language.Futhark.TypeChecker.Constraints TyVars, TyParams, Solution, + UnconTyVar, solve, ) where @@ -89,7 +90,7 @@ type Constraints = [Ct] -- that it is the type of. data TyVarInfo = -- | Can be substituted with anything. - TyVarFree Loc + TyVarFree Loc Liftedness | -- | Can only be substituted with these primitive types. TyVarPrim Loc [PrimType] | -- | Must be a record with these fields. @@ -101,14 +102,14 @@ data TyVarInfo deriving (Show, Eq) instance Pretty TyVarInfo where - pretty (TyVarFree _) = "free" + pretty (TyVarFree _ l) = "free" <+> pretty l pretty (TyVarPrim _ pts) = "∈" <+> pretty pts pretty (TyVarRecord _ fs) = pretty $ Scalar $ Record fs pretty (TyVarSum _ cs) = pretty $ Scalar $ Sum cs pretty (TyVarEql _) = "equality" instance Located TyVarInfo where - locOf (TyVarFree loc) = loc + locOf (TyVarFree loc _) = loc locOf (TyVarPrim loc _) = loc locOf (TyVarRecord loc _) = loc locOf (TyVarSum loc _) = loc @@ -158,7 +159,7 @@ substTyVar m v = Just (Right (TyVarUnsol {})) -> Nothing Nothing -> Nothing -lookupTyVar :: TyVar -> SolveM (Int, Either TyVarInfo Type) +lookupTyVar :: TyVar -> SolveM (Level, Either TyVarInfo Type) lookupTyVar orig = do tyvars <- gets solverTyVars let f v = case M.lookup v tyvars of @@ -181,7 +182,11 @@ setInfo v info = modify $ \s -> s {solverTyVars = M.insert v (Right info) $ solv -- sides that contain a type variable. type Solution = M.Map TyVar (Either [PrimType] (TypeBase () NoUniqueness)) -solution :: SolverState -> ([VName], Solution) +-- | An unconstrained type variable comprises a name and (ironically) +-- a constraint on how it can be instantiated. +type UnconTyVar = (VName, Liftedness) + +solution :: SolverState -> ([UnconTyVar], Solution) solution s = ( mapMaybe unconstrained $ M.toList $ solverTyVars s, M.mapMaybe mkSubst $ solverTyVars s @@ -195,7 +200,7 @@ solution s = mkSubst (Right (TyVarUnsol _ (TyVarPrim _ pts))) = Just $ Left pts mkSubst _ = Nothing - unconstrained (v, Right (TyVarUnsol _ (TyVarFree _))) = Just v + unconstrained (v, Right (TyVarUnsol _ (TyVarFree _ l))) = Just (v, l) unconstrained _ = Nothing newtype SolveM a = SolveM {runSolveM :: StateT SolverState (Except TypeError) a} @@ -257,26 +262,11 @@ scopeViolation reason v1 ty v2 = <+> dquotes (prettyName v2) <+> "is rigidly bound in a deeper scope." -scopeCheck :: Reason -> TyVar -> Int -> Type -> SolveM () -scopeCheck reason v v_lvl ty = do - mapM_ check $ typeVars ty - where - check ty_v = do - ty_v_info <- gets $ M.lookup ty_v . solverTyVars - case ty_v_info of - Just (Right (TyVarParam ty_v_lvl _)) - | ty_v_lvl > v_lvl -> scopeViolation reason v ty ty_v - Just (Right (TyVarUnsol ty_v_lvl info)) - | ty_v_lvl /= v_lvl -> - setInfo ty_v $ TyVarUnsol v_lvl info - _ -> pure () - -- Precondition: 'v' is currently flexible. subTyVar :: Reason -> VName -> Int -> Type -> SolveM () subTyVar reason v v_lvl t = do occursCheck reason v t v_info <- gets $ M.lookup v . solverTyVars - scopeCheck reason v v_lvl t case (v_info, t) of (Just (Right (TyVarUnsol _ TyVarFree {})), _) -> pure () @@ -348,22 +338,27 @@ subTyVar reason v v_lvl t = do unionTyVars :: Reason -> VName -> VName -> SolveM () unionTyVars reason v t = do v_info <- gets $ either alreadyLinked id . fromMaybe unknown . M.lookup v . solverTyVars - (lvl, t') <- lookupTyVar t + (t_lvl, t') <- lookupTyVar t case (v_info, t') of + ( TyVarUnsol _ (TyVarFree _ v_l), + Left (TyVarFree t_loc t_l) + ) + | v_l /= t_l -> + setInfo t $ TyVarUnsol t_lvl $ TyVarFree t_loc (min v_l t_l) -- When either is completely unconstrained. (TyVarUnsol _ TyVarFree {}, _) -> pure () ( TyVarUnsol _ info, Left (TyVarFree {}) ) -> - setInfo t (TyVarUnsol lvl info) + setInfo t (TyVarUnsol t_lvl info) -- -- TyVarPrim cases ( TyVarUnsol _ info@TyVarPrim {}, Left TyVarEql {} ) -> - setInfo t (TyVarUnsol lvl info) + setInfo t (TyVarUnsol t_lvl info) ( TyVarUnsol _ (TyVarPrim _ v_pts), Left (TyVarPrim t_loc t_pts) ) -> @@ -375,7 +370,7 @@ unionTyVars reason v t = do indent 2 (pretty v_pts) "with type that must be one of" indent 2 (pretty t_pts) - else setInfo t (TyVarUnsol lvl (TyVarPrim t_loc pts)) + else setInfo t (TyVarUnsol t_lvl (TyVarPrim t_loc pts)) ( TyVarUnsol _ (TyVarPrim _ v_pts), Left TyVarRecord {} ) -> @@ -397,7 +392,7 @@ unionTyVars reason v t = do ) -> do unifySharedConstructors reason cs1 cs2 let cs3 = cs1 <> cs2 - setInfo t (TyVarUnsol lvl (TyVarSum loc cs3)) + setInfo t (TyVarUnsol t_lvl (TyVarSum loc cs3)) ( TyVarUnsol _ TyVarSum {}, Left (TyVarPrim _ pts) ) -> @@ -423,7 +418,7 @@ unionTyVars reason v t = do ) -> do unifySharedFields reason fs1 fs2 let fs3 = fs1 <> fs2 - setInfo t (TyVarUnsol lvl (TyVarRecord loc fs3)) + setInfo t (TyVarUnsol t_lvl (TyVarRecord loc fs3)) ( TyVarUnsol _ TyVarRecord {}, Left (TyVarPrim _ pts) ) -> @@ -557,7 +552,18 @@ solveCt ct = CtEq reason t1 t2 -> solveEq reason t1 t2 CtAM {} -> pure () -- Good vibes only. -solveTyVar :: (VName, (Int, TyVarInfo)) -> SolveM () +scopeCheck :: Reason -> TyVar -> Int -> Type -> SolveM () +scopeCheck reason v v_lvl ty = do + mapM_ check $ typeVars ty + where + check ty_v = do + ty_v_info <- gets $ M.lookup ty_v . solverTyVars + case ty_v_info of + Just (Right (TyVarParam ty_v_lvl _)) + | ty_v_lvl > v_lvl -> scopeViolation reason v ty ty_v + _ -> pure () + +solveTyVar :: (VName, (Level, TyVarInfo)) -> SolveM () solveTyVar (tv, (_, TyVarRecord loc fs1)) = do (_, tv_t) <- lookupTyVar tv case tv_t of @@ -594,10 +600,18 @@ solveTyVar (tv, (_, TyVarEql loc)) = do "Type" indent 2 (align (pretty ty)) "does not support equality (may contain function)." -solveTyVar (_, _) = - pure () - -solve :: Constraints -> TyParams -> TyVars -> Either TypeError ([VName], Solution) +solveTyVar (tv, (lvl, _)) = do + (_, tv_t) <- lookupTyVar tv + case tv_t of + Right ty -> + scopeCheck (Reason mempty) tv lvl ty + _ -> pure () + +solve :: + Constraints -> + TyParams -> + TyVars -> + Either TypeError ([UnconTyVar], Solution) solve constraints typarams tyvars = second solution . runExcept diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index a052d44f54..c00ba3b106 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -160,7 +160,7 @@ addCt (CtAM _ r m f) = do addObj tr addTyVarInfo :: TyVar -> (Int, TyVarInfo) -> RankM () -addTyVarInfo _ (_, TyVarFree _) = pure () +addTyVarInfo _ (_, TyVarFree {}) = pure () addTyVarInfo tv (_, TyVarPrim {}) = addConstraint $ rank tv ~==~ constant 0 addTyVarInfo tv (_, TyVarRecord {}) = @@ -392,8 +392,11 @@ addRankInfo t = do t' <- newTyVar t old_tyvars <- asks envTyVars let (level, tvinfo) = fromJust $ old_tyvars M.!? t + l = case tvinfo of + TyVarFree _ tvinfo_l -> tvinfo_l + _ -> Unlifted modify $ \s -> s {substTyVars = M.insert t' (level, tvinfo) $ substTyVars s} - modify $ \s -> s {substTyVars = M.insert t (level, TyVarFree $ locOf tvinfo) $ substTyVars s} + modify $ \s -> s {substTyVars = M.insert t (level, TyVarFree (locOf tvinfo) l) $ substTyVars s} class SubstRanks a where substRanks :: (MonadTypeChecker m) => a -> SubstT m a diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index cb3082908a..c362e7ea5c 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -231,11 +231,15 @@ newTyVarWith desc info = do modify $ \s -> s {termTyVars = M.insert v (lvl, info) $ termTyVars s} pure v -newTyVar :: (Located loc) => loc -> Name -> TermM TyVar -newTyVar loc desc = newTyVarWith desc $ TyVarFree $ locOf loc +newTyVar :: (Located loc) => loc -> Liftedness -> Name -> TermM TyVar +newTyVar loc l desc = newTyVarWith desc $ TyVarFree (locOf loc) l -newType :: (Located loc) => loc -> Name -> u -> TermM (TypeBase dim u) -newType loc desc u = tyVarType u <$> newTyVar loc desc +newType :: (Located loc) => loc -> Liftedness -> Name -> u -> TermM (TypeBase dim u) +newType loc l desc u = tyVarType u <$> newTyVar loc l desc + +-- | New type that must be allowed as an array element. +newElemType :: (Located loc) => loc -> Name -> u -> TermM (TypeBase dim u) +newElemType loc desc u = tyVarType u <$> newTyVar loc Unlifted desc newTypeWithField :: SrcLoc -> Name -> Name -> Type -> TermM Type newTypeWithField loc desc k t = @@ -404,8 +408,8 @@ instTypeScheme _qn loc tparams t = do (names, substs) <- fmap (unzip . catMaybes) $ forM tparams $ \tparam -> case tparam of - TypeParamType _ v _ -> do - v' <- newTyVar loc $ nameFromString $ takeWhile isAscii $ baseString v + TypeParamType l v _ -> do + v' <- newTyVar loc l $ nameFromString $ takeWhile isAscii $ baseString v pure $ Just (v, (typeParamName tparam, tyVarType NoUniqueness v')) TypeParamDim {} -> pure Nothing @@ -490,13 +494,13 @@ checkPat' (Id name NoInfo loc) (Ascribed t) = do t' <- asStructType t pure $ Id name (Info t') loc checkPat' (Id name NoInfo loc) NoneInferred = do - t <- newType loc "t" Observe + t <- newType loc Lifted "t" Observe pure $ Id name (Info t) loc checkPat' (Wildcard _ loc) (Ascribed t) = do t' <- asStructType t pure $ Wildcard (Info t') loc checkPat' (Wildcard NoInfo loc) NoneInferred = do - t <- newType loc "t" Observe + t <- newType loc Lifted "t" Observe pure $ Wildcard (Info t) loc checkPat' (TuplePat ps loc) (Ascribed t) | Just ts <- isTupleRecord t, @@ -505,7 +509,7 @@ checkPat' (TuplePat ps loc) (Ascribed t) <$> zipWithM checkPat' ps (map Ascribed ts) <*> pure loc | otherwise = do - ps_tvs <- replicateM (length ps) (newTyVar loc "t") + ps_tvs <- replicateM (length ps) (newTyVar loc Lifted "t") ctEq (Reason (locOf loc)) (Scalar (tupleRecord $ map (tyVarType NoUniqueness) ps_tvs)) t TuplePat <$> zipWithM checkPat' ps (map (Ascribed . tyVarType Observe) ps_tvs) <*> pure loc checkPat' (TuplePat ps loc) NoneInferred = @@ -525,7 +529,9 @@ checkPat' p@(RecordPat p_fs loc) (Ascribed t) L.sort (map fst p_fs) == L.sort (M.keys t_fs) = RecordPat . M.toList <$> check t_fs <*> pure loc | otherwise = do - p_fs' <- traverse (const $ newType loc "t" NoUniqueness) $ M.fromList p_fs + p_fs' <- + traverse (const $ newType loc Lifted "t" NoUniqueness) $ + M.fromList p_fs ctEq (Reason (locOf loc)) (Scalar (Record p_fs')) t checkPat' p $ Ascribed $ Observe <$ Scalar (Record p_fs') where @@ -578,7 +584,7 @@ checkPat' (PatConstr n NoInfo ps loc) (Ascribed (Scalar (Sum cs))) pure $ PatConstr n (Info (Scalar (Sum cs'))) ps' loc checkPat' (PatConstr n NoInfo ps loc) (Ascribed t) = do ps' <- forM ps $ \p -> do - p_t <- newType (srclocOf p) "t" Observe + p_t <- newType (srclocOf p) Lifted "t" Observe checkPat' p $ Ascribed p_t t' <- newTypeWithConstr loc "t" Observe n $ map (toType . patternType) ps' ctEq (Reason (locOf loc)) t' t @@ -730,8 +736,8 @@ checkApplyOne loc fname (fframe, ftype) (argframe, argtype) = do (a, b) <- split $ Scalar t pure (arrayOf s a, arrayOf s b) split ftype' = do - a <- newType loc "arg" NoUniqueness - b <- newType loc "res" Nonunique + a <- newType loc Lifted "arg" NoUniqueness + b <- newType loc Lifted "res" Nonunique ctEq (Reason (locOf loc)) ftype' $ Scalar $ Arrow NoUniqueness Unnamed Observe a $ RetType [] b pure (a, b `setUniqueness` NoUniqueness) @@ -759,7 +765,7 @@ mustHaveFields loc t [f] ve_t = do rt :: Type <- newTypeWithField loc "ft" f ve_t ctEq (Reason (locOf loc)) t rt mustHaveFields loc t (f : fs) ve_t = do - ft <- newType loc "ft" NoUniqueness + ft <- newType loc Lifted "ft" NoUniqueness rt <- newTypeWithField loc "rt" f ft mustHaveFields loc ft fs ve_t ctEq (Reason (locOf loc)) t rt @@ -844,7 +850,7 @@ checkExp (Not arg loc) = do arg' <- require "logical negation" (Bool : anyIntType) =<< checkExp arg pure $ Not arg' loc checkExp (Hole NoInfo loc) = - Hole <$> (Info <$> newType loc "hole" NoUniqueness) <*> pure loc + Hole <$> (Info <$> newType loc Lifted "hole" NoUniqueness) <*> pure loc checkExp (Parens e loc) = Parens <$> checkExp e <*> pure loc checkExp (TupLit es loc) = @@ -878,7 +884,7 @@ checkExp (ArrayLit es _ loc) = do -- type variables for pathologically large arrays with -- type-unsuffixed integers. Add some special case that handles that -- more efficiently. - et <- newType loc "et" NoUniqueness + et <- newElemType loc "et" NoUniqueness es' <- forM es $ \e -> do e' <- checkExp e e_t <- expType e' @@ -974,7 +980,7 @@ checkExp (OpSectionLeft op _ e _ _ loc) = do optype <- lookupVar loc op e' <- checkExp e e_t <- expType e' - t2 <- newType loc "t" NoUniqueness + t2 <- newType loc Lifted "t" NoUniqueness t2' <- asStructType t2 let f1 = frameOf e' (rt, ams) <- checkApply loc (Just op) (mempty, optype) ((f1, e_t) NE.:| [(mempty, t2)]) @@ -997,7 +1003,7 @@ checkExp (OpSectionRight op _ e _ NoInfo loc) = do optype <- lookupVar loc op e' <- checkExp e e_t <- expType e' - t1 <- newType loc "t" NoUniqueness + t1 <- newType loc Lifted "t" NoUniqueness t1' <- asStructType t1 let f2 = frameOf e' (rt, ams) <- checkApply loc (Just op) (mempty, optype) ((mempty, t1) NE.:| [(f2, e_t)]) @@ -1019,8 +1025,8 @@ checkExp (OpSectionRight op _ e _ NoInfo loc) = do loc -- checkExp (ProjectSection fields NoInfo loc) = do - a <- newType loc "a" NoUniqueness - b <- newType loc "b" NoUniqueness + a <- newType loc Lifted "a" NoUniqueness + b <- newType loc Lifted "b" NoUniqueness mustHaveFields loc a fields b ft <- asStructType $ Scalar $ Arrow mempty Unnamed Observe a $ RetType [] $ b `setUniqueness` Nonunique pure $ ProjectSection fields (Info ft) loc @@ -1087,7 +1093,7 @@ checkExp (AppExp (Range start maybe_step end loc) _) = do pure e' maybe_step' <- traverse check maybe_step end' <- traverse check end - range_t <- newType loc "range" NoUniqueness + range_t <- newElemType loc "range" NoUniqueness range_t' <- asType range_t start_t <- expType start' ctEq (Reason (locOf start')) range_t' (arrayOfRank 1 start_t) @@ -1095,7 +1101,7 @@ checkExp (AppExp (Range start maybe_step end loc) _) = do -- checkExp (Project k e NoInfo loc) = do e' <- checkExp e - kt <- newType loc "kt" NoUniqueness + kt <- newType loc Lifted "kt" NoUniqueness t <- newTypeWithField loc "t" k kt e_t <- expType e' ctEq (Reason (locOf e')) e_t t @@ -1113,9 +1119,9 @@ checkExp (RecordUpdate src fields ve NoInfo loc) = do -- checkExp (IndexSection slice NoInfo loc) = do slice' <- checkSlice slice - index_arg_t <- newType loc "index" NoUniqueness - index_elem_t <- newType loc "index_elem" NoUniqueness - index_res_t <- newType loc "index_res" NoUniqueness + index_arg_t <- newElemType loc "index" NoUniqueness + index_elem_t <- newElemType loc "index_elem" NoUniqueness + index_res_t <- newElemType loc "index_res" NoUniqueness let num_slices = length $ filter isSlice slice ctEq (Reason (locOf loc)) index_arg_t $ arrayOfRank num_slices index_elem_t ctEq (Reason (locOf loc)) index_res_t $ arrayOfRank (length slice) index_elem_t @@ -1126,8 +1132,8 @@ checkExp (AppExp (Index e slice loc) _) = do e' <- checkExp e e_t <- expType e' slice' <- checkSlice slice - index_tv <- newTyVar loc "index" - index_elem_t <- newType loc "index_elem" NoUniqueness + index_tv <- newTyVar loc Unlifted "index" + index_elem_t <- newElemType loc "index_elem" NoUniqueness let num_slices = length $ filter isSlice slice ctEq (Reason (locOf loc)) (tyVarType NoUniqueness index_tv) $ arrayOfRank num_slices index_elem_t ctEq (Reason (locOf e')) e_t $ arrayOfRank (length slice) index_elem_t @@ -1140,7 +1146,7 @@ checkExp (Update src slice ve loc) = do ve' <- checkExp ve ve_t <- expType ve' let num_slices = length $ filter isSlice slice - update_elem_t <- newType loc "update_elem" NoUniqueness + update_elem_t <- newElemType loc "update_elem" NoUniqueness ctEq (Reason (locOf src')) src_t $ arrayOfRank (length slice) update_elem_t ctEq (Reason (locOf ve')) ve_t $ arrayOfRank num_slices update_elem_t pure $ Update src' slice' ve' loc @@ -1154,7 +1160,7 @@ checkExp (AppExp (LetWith dest src slice ve body loc) _) = do ve' <- checkExp ve ve_t <- expType ve' let num_slices = length $ filter isSlice slice - update_elem_t <- newType loc "update_elem" NoUniqueness + update_elem_t <- newElemType loc "update_elem" NoUniqueness ctEq (Reason (locOf loc)) src_t $ arrayOfRank (length slice) update_elem_t ctEq (Reason (locOf ve')) ve_t $ arrayOfRank num_slices update_elem_t bind [dest'] $ do @@ -1170,12 +1176,14 @@ checkExp (AppExp (If e1 e2 e3 loc) _) = do e2_t <- expType e2' e3' <- checkExp e3 e3_t <- expType e3' + if_t <- newType loc SizeLifted "if_t" NoUniqueness ctEq (Reason (locOf e1')) e1_t (Scalar (Prim Bool)) - ctEq (Reason (locOf loc)) e2_t e3_t + ctEq (Reason (locOf loc)) e2_t if_t + ctEq (Reason (locOf loc)) e3_t if_t - e2_t' <- asStructType e2_t - pure $ AppExp (If e1' e2' e3' loc) (Info $ AppRes e2_t' []) + if_t' <- asStructType if_t + pure $ AppExp (If e1' e2' e3' loc) (Info $ AppRes if_t' []) -- checkExp (AppExp (Match e cs loc) _) = do e' <- checkExp e @@ -1205,7 +1213,7 @@ checkExp (AppExp (Loop _ pat arg form body loc) _) = do pure (While cond', body') ForIn elemp arr -> do arr' <- checkExp arr - elem_t <- newType elemp "elem" NoUniqueness + elem_t <- newElemType elemp "elem" NoUniqueness arr_t <- expType arr' elem_t' <- asType elem_t ctEq (Reason (locOf arr')) arr_t $ arrayOfRank 1 elem_t' @@ -1270,13 +1278,16 @@ doDefaults tyvars_at_toplevel substs = do pure $ M.map (substTyVars (`M.lookup` substs')) substs' generalise :: - TypeBase () NoUniqueness -> [VName] -> Solution -> ([TypeParam], [VName]) + TypeBase () NoUniqueness -> + [UnconTyVar] -> + Solution -> + ([TypeParam], [VName]) generalise fun_t unconstrained solution = -- Candidates for let-generalisation are those type variables that -- are used in fun_t. let visible = foldMap expandTyVars $ typeVars fun_t - onTyVar v - | v `S.member` visible = Left $ TypeParamType Unlifted v mempty + onTyVar (v, l) + | v `S.member` visible = Left $ TypeParamType l v mempty | otherwise = Right v in partitionEithers $ map onTyVar unconstrained where @@ -1286,7 +1297,7 @@ generalise fun_t unconstrained solution = _ -> S.singleton v generaliseAndDefaults :: - [VName] -> + [UnconTyVar] -> Solution -> TypeBase () NoUniqueness -> TermM ([TypeParam], M.Map VName (TypeBase () NoUniqueness)) @@ -1354,6 +1365,8 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do unlines [ "## constraints:", unlines $ map prettyString cts', + "## typarams:", + unlines (map (prettyString . bimap prettyNameString fst) (M.toList typarams)), "## tyvars':", unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars', "## solution:", @@ -1396,7 +1409,7 @@ checkSingleExp e = runTermM $ do -- turn out to be polymorphic, in which case it is unified with i64. checkSizeExp :: ExpBase NoInfo VName -> - TypeM (Either TypeError ([VName], M.Map TyVar (TypeBase () NoUniqueness)), Exp) + TypeM (Either TypeError ([UnconTyVar], M.Map TyVar (TypeBase () NoUniqueness)), Exp) checkSizeExp e = runTermM $ do e' <- checkSizeExp' e cts <- gets termConstraints From fbafb6db1cc4c53033ad1555bbcc354a3e50a2d5 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 11 Jul 2024 20:59:10 +0200 Subject: [PATCH 228/258] Define before use. --- src/Language/Futhark/TypeChecker/Terms2.hs | 34 +++++++++++----------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index c362e7ea5c..d231cf7263 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -673,23 +673,6 @@ bindParams tps orig_ps m = bindTypeParams tps $ do incLevel $ descend [] orig_ps -checkApply :: - SrcLoc -> - Maybe (QualName VName) -> - (Shape Size, Type) -> - NE.NonEmpty (Shape Size, Type) -> - TermM (Type, NE.NonEmpty AutoMap) -checkApply loc fname (fframe, ftype) args = do - ((_, _, rt), argts) <- mapAccumLM onArg (0, fframe, ftype) args - pure (rt, argts) - where - onArg (i, f_f, f_t) (argframe, argtype) = do - (rt, am) <- checkApplyOne loc (fname, i) (f_f, f_t) (argframe, argtype) - pure - ( (i + 1, autoFrame am, rt), - am - ) - checkApplyOne :: SrcLoc -> (Maybe (QualName VName), Int) -> (Shape Size, Type) -> (Shape Size, Type) -> TermM (Type, AutoMap) checkApplyOne loc fname (fframe, ftype) (argframe, argtype) = do (a, b) <- split ftype @@ -741,6 +724,23 @@ checkApplyOne loc fname (fframe, ftype) (argframe, argtype) = do ctEq (Reason (locOf loc)) ftype' $ Scalar $ Arrow NoUniqueness Unnamed Observe a $ RetType [] b pure (a, b `setUniqueness` NoUniqueness) +checkApply :: + SrcLoc -> + Maybe (QualName VName) -> + (Shape Size, Type) -> + NE.NonEmpty (Shape Size, Type) -> + TermM (Type, NE.NonEmpty AutoMap) +checkApply loc fname (fframe, ftype) args = do + ((_, _, rt), argts) <- mapAccumLM onArg (0, fframe, ftype) args + pure (rt, argts) + where + onArg (i, f_f, f_t) (argframe, argtype) = do + (rt, am) <- checkApplyOne loc (fname, i) (f_f, f_t) (argframe, argtype) + pure + ( (i + 1, autoFrame am, rt), + am + ) + checkSlice :: SliceBase NoInfo VName -> TermM [DimIndex] checkSlice = mapM checkDimIndex where From 4b2a9e2ec47124fb48f718167377dd6b52f4ecaa Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 11 Jul 2024 21:00:31 +0200 Subject: [PATCH 229/258] Break long line. --- src/Language/Futhark/TypeChecker/Terms2.hs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index d231cf7263..f35aa83de9 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -707,7 +707,11 @@ checkApplyOne loc fname (fframe, ftype) (argframe, argtype) = do ] pure ( arrayOf (toShape (SVar m)) b, - AutoMap {autoRep = toShape r_var, autoMap = toShape m_var, autoFrame = toShape m_var <> fframe} + AutoMap + { autoRep = toShape r_var, + autoMap = toShape m_var, + autoFrame = toShape m_var <> fframe + } ) where toSComp (Var (QualName [] x) _ _) = SVar x From d393544492b332d7c187ae863127eeb963c93572 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 11 Jul 2024 21:08:16 +0200 Subject: [PATCH 230/258] Add missing cases. --- src/Language/Futhark/TypeChecker/Rank.hs | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index c00ba3b106..6f597ca8f0 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -452,6 +452,22 @@ updAM rank_map e = in AppExp (Apply f' args' loc) res AppExp (BinOp op t (x, Info (xv, xam)) (y, Info (yv, yam)) loc) res -> AppExp (BinOp op t (updAM rank_map x, Info (xv, upd xam)) (updAM rank_map y, Info (yv, upd yam)) loc) res + OpSectionRight name t arg (Info (pa, t1a), Info (pb, t1b, argext, am)) t2 loc -> + OpSectionRight + name + t + (updAM rank_map arg) + (Info (pa, t1a), Info (pb, t1b, argext, upd am)) + t2 + loc + OpSectionLeft name t arg (Info (pa, t1a, argext, am), Info (pb, t1b)) (ret, retext) loc -> + OpSectionLeft + name + t + (updAM rank_map arg) + (Info (pa, t1a, argext, upd am), Info (pb, t1b)) + (ret, retext) + loc _ -> runIdentity $ astMap mapper e where dimToRank (Var (QualName [] x) _ _) = From d86eec2eb49e192541cfacd00fb1bc3237bc624f Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 11 Jul 2024 21:22:07 +0200 Subject: [PATCH 231/258] This is OK now. --- tests/record-update6.fut | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/record-update6.fut b/tests/record-update6.fut index 53349ac0ab..fef501e1a2 100644 --- a/tests/record-update6.fut +++ b/tests/record-update6.fut @@ -1,10 +1,9 @@ -- Inference of record in lambda. -- == --- error: Full type of type octnode = {body: i32} -def f (octree: []octnode) (i: i32) = +entry f (octree: []octnode) (i: i32) = map (\n -> if n.body != i then n else n with body = 0) octree From 43c85b30d600178d79b563c7a977b8546e12f24c Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 11 Jul 2024 21:54:35 +0200 Subject: [PATCH 232/258] Correct handling of type annotation. --- src/Language/Futhark/TypeChecker/Terms.hs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index b1dfbf494e..d2f03ecc9c 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -687,13 +687,13 @@ checkExp (Lambda params body rettype_te (Info (RetType _ rt)) loc) = do params'' <- mapM updateTypes params' - (rettype', rettype_st) <- - case rettype_checked of - Just (te, _, ext) -> - pure (Just te, RetType ext rt') - Nothing -> do - RetType ext ret <- inferReturnSizes params'' $ toRes Nonunique body_t - pure (Nothing, RetType ext ret) + (rettype', rettype_st) <- case rettype_checked of + Just (te, ret, ext) -> do + ret' <- normTypeFully ret + pure (Just te, RetType ext ret') + Nothing -> do + ret <- inferReturnSizes params'' $ toRes Nonunique body_t + pure (Nothing, ret) pure (params'', body', rettype', rettype_st) From 44c54056cfad05ab501f239e2b707f661047d4a8 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 12 Jul 2024 08:55:28 +0200 Subject: [PATCH 233/258] Simplify. --- .../Futhark/TypeChecker/Constraints.hs | 57 +++++++++++-------- 1 file changed, 32 insertions(+), 25 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index b7bd2f5d81..178c22e2b9 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -171,6 +171,14 @@ lookupTyVar orig = do Just (Right (TyVarUnsol lvl info)) -> pure (lvl, Left info) f orig +-- | Variable must be flexible. +lookupTyVarInfo :: TyVar -> SolveM (Level, TyVarInfo) +lookupTyVarInfo v = do + (lvl, r) <- lookupTyVar v + case r of + Left info -> pure (lvl, info) + Right _ -> error $ "Tyvar is nonflexible: " <> prettyNameString v + setLink :: TyVar -> VName -> SolveM () setLink v info = modify $ \s -> s {solverTyVars = M.insert v (Left info) $ solverTyVars s} @@ -338,11 +346,15 @@ subTyVar reason v v_lvl t = do unionTyVars :: Reason -> VName -> VName -> SolveM () unionTyVars reason v t = do v_info <- gets $ either alreadyLinked id . fromMaybe unknown . M.lookup v . solverTyVars - (t_lvl, t') <- lookupTyVar t + (t_lvl, t_info) <- lookupTyVarInfo t + + -- Insert the link from v to t, and then update the info of t based + -- on the existing info of v and t. + setLink v t - case (v_info, t') of + case (v_info, t_info) of ( TyVarUnsol _ (TyVarFree _ v_l), - Left (TyVarFree t_loc t_l) + TyVarFree t_loc t_l ) | v_l /= t_l -> setInfo t $ TyVarUnsol t_lvl $ TyVarFree t_loc (min v_l t_l) @@ -350,17 +362,17 @@ unionTyVars reason v t = do (TyVarUnsol _ TyVarFree {}, _) -> pure () ( TyVarUnsol _ info, - Left (TyVarFree {}) + TyVarFree {} ) -> setInfo t (TyVarUnsol t_lvl info) -- -- TyVarPrim cases ( TyVarUnsol _ info@TyVarPrim {}, - Left TyVarEql {} + TyVarEql {} ) -> setInfo t (TyVarUnsol t_lvl info) ( TyVarUnsol _ (TyVarPrim _ v_pts), - Left (TyVarPrim t_loc t_pts) + TyVarPrim t_loc t_pts ) -> let pts = L.intersect v_pts t_pts in if null pts @@ -372,14 +384,14 @@ unionTyVars reason v t = do indent 2 (pretty t_pts) else setInfo t (TyVarUnsol t_lvl (TyVarPrim t_loc pts)) ( TyVarUnsol _ (TyVarPrim _ v_pts), - Left TyVarRecord {} + TyVarRecord {} ) -> typeError (locOf reason) mempty $ "Cannot unify type that must be one of" indent 2 (pretty v_pts) "with type that must be record." ( TyVarUnsol _ (TyVarPrim _ v_pts), - Left TyVarSum {} + TyVarSum {} ) -> typeError (locOf reason) mempty $ "Cannot unify type that must be one of" @@ -388,19 +400,19 @@ unionTyVars reason v t = do -- -- TyVarSum cases ( TyVarUnsol _ (TyVarSum _ cs1), - Left (TyVarSum loc cs2) + TyVarSum loc cs2 ) -> do unifySharedConstructors reason cs1 cs2 let cs3 = cs1 <> cs2 setInfo t (TyVarUnsol t_lvl (TyVarSum loc cs3)) ( TyVarUnsol _ TyVarSum {}, - Left (TyVarPrim _ pts) + TyVarPrim _ pts ) -> typeError (locOf reason) mempty $ "A sum type cannot be one of" indent 2 (pretty pts) ( TyVarUnsol _ (TyVarSum _ cs1), - Left (TyVarRecord _ fs) + TyVarRecord _ fs ) -> typeError (locOf reason) mempty $ "Cannot unify type with constructors" @@ -408,25 +420,25 @@ unionTyVars reason v t = do "with type" indent 2 (pretty (Scalar (Record fs))) ( TyVarUnsol _ (TyVarSum _ cs1), - Left (TyVarEql _) + TyVarEql _ ) -> mapM_ (mapM_ (mustSupportEql reason)) cs1 -- -- TyVarRecord cases ( TyVarUnsol _ (TyVarRecord _ fs1), - Left (TyVarRecord loc fs2) + TyVarRecord loc fs2 ) -> do unifySharedFields reason fs1 fs2 let fs3 = fs1 <> fs2 setInfo t (TyVarUnsol t_lvl (TyVarRecord loc fs3)) ( TyVarUnsol _ TyVarRecord {}, - Left (TyVarPrim _ pts) + TyVarPrim _ pts ) -> typeError (locOf reason) mempty $ "A record type cannot be one of" indent 2 (pretty pts) ( TyVarUnsol _ (TyVarRecord _ fs1), - Left (TyVarSum _ cs) + TyVarSum _ cs ) -> typeError (locOf reason) mempty $ "Cannot unify record type" @@ -434,18 +446,18 @@ unionTyVars reason v t = do "with type" indent 2 (pretty (Scalar (Sum cs))) ( TyVarUnsol _ (TyVarRecord _ fs1), - Left (TyVarEql _) + TyVarEql _ ) -> mapM_ (mustSupportEql reason) fs1 -- -- TyVarEql cases - (TyVarUnsol _ (TyVarEql _), Left TyVarPrim {}) -> + (TyVarUnsol _ (TyVarEql _), TyVarPrim {}) -> pure () - (TyVarUnsol _ (TyVarEql _), Left TyVarEql {}) -> + (TyVarUnsol _ (TyVarEql _), TyVarEql {}) -> pure () - (TyVarUnsol _ (TyVarEql _), Left (TyVarRecord _ fs)) -> + (TyVarUnsol _ (TyVarEql _), TyVarRecord _ fs) -> mustSupportEql reason $ Scalar $ Record fs - (TyVarUnsol _ (TyVarEql _), Left (TyVarSum _ cs)) -> + (TyVarUnsol _ (TyVarEql _), TyVarSum _ cs) -> mustSupportEql reason $ Scalar $ Sum cs -- -- Internal error cases @@ -453,11 +465,6 @@ unionTyVars reason v t = do alreadySolved (TyVarParam {}, _) -> isParam - (_, Right t'') -> - error $ "unionTyVars: rhs " <> prettyNameString t <> " is solved as " <> prettyString t'' - - -- Finally insert the actual link. - setLink v t where unknown = error $ "unionTyVars: Nothing v: " <> prettyNameString v alreadyLinked = error $ "Type variable already linked: " <> prettyNameString v From 2333b7e1a5e13fadc3ee5e03573996c08ef009d1 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 14 Jul 2024 22:38:31 +0200 Subject: [PATCH 234/258] Also AUTOMAP in return type annotations. --- src/Language/Futhark/TypeChecker/Rank.hs | 62 ++++++++++++++-------- src/Language/Futhark/TypeChecker/Terms2.hs | 13 ++--- 2 files changed, 47 insertions(+), 28 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 6f597ca8f0..24254d7392 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -256,13 +256,19 @@ rankAnalysis1 :: M.Map TyVar Type -> [Pat ParamType] -> Exp -> - m (([Ct], M.Map TyVar Type, TyVars), [Pat ParamType], Exp) -rankAnalysis1 loc cs tyVars artificial params body = do - solutions <- rankAnalysis loc cs tyVars artificial params body + Maybe (TypeExp Exp VName) -> + m + ( ([Ct], M.Map TyVar Type, TyVars), + [Pat ParamType], + Exp, + Maybe (TypeExp Exp VName) + ) +rankAnalysis1 loc cs tyVars artificial params body retdecl = do + solutions <- rankAnalysis loc cs tyVars artificial params body retdecl case solutions of [sol] -> pure sol sols -> do - let (_, _, bodies') = unzip3 sols + let (_, _, bodies', _) = L.unzip4 sols typeError loc mempty $ stack $ [ "Rank ILP is ambiguous.", @@ -278,10 +284,17 @@ rankAnalysis :: M.Map TyVar Type -> [Pat ParamType] -> Exp -> - m [(([Ct], M.Map TyVar Type, TyVars), [Pat ParamType], Exp)] -rankAnalysis _ [] tyVars artificial params body = - pure [(([], artificial, tyVars), params, body)] -rankAnalysis loc cs tyVars artificial params body = do + Maybe (TypeExp Exp VName) -> + m + [ ( ([Ct], M.Map TyVar Type, TyVars), + [Pat ParamType], + Exp, + Maybe (TypeExp Exp VName) + ) + ] +rankAnalysis _ [] tyVars artificial params body retdecl = + pure [(([], artificial, tyVars), params, body, retdecl)] +rankAnalysis loc cs tyVars artificial params body retdecl = do debugTraceM 3 $ unlines [ "##rankAnalysis", @@ -294,18 +307,21 @@ rankAnalysis loc cs tyVars artificial params body = do cts_tyvars' <- mapM (substRankInfo cs artificial tyVars) rank_maps let bodys = map (`updAM` body) rank_maps params' = map ((`map` params) . updAMPat) rank_maps - pure $ zip3 cts_tyvars' params' bodys + retdecls = map ((<$> retdecl) . updAMTypeExp) rank_maps + pure $ L.zip4 cts_tyvars' params' bodys retdecls where cs' = foldMap distribAndSplitCnstrs $ foldMap distribAndSplitArrows cs +type RankMap = M.Map VName Int + substRankInfo :: (MonadTypeChecker m) => [Ct] -> M.Map VName Type -> TyVars -> - Map VName Int -> + RankMap -> m ([Ct], M.Map VName Type, TyVars) substRankInfo cs artificial tyVars rankmap = do ((cs', artificial', tyVars'), new_cs, new_tyVars) <- @@ -316,7 +332,7 @@ substRankInfo cs artificial tyVars rankmap = do isCtAM (CtAM {}) = True isCtAM _ = False -runSubstT :: (MonadTypeChecker m) => TyVars -> Map VName Int -> SubstT m a -> m (a, [Ct], TyVars) +runSubstT :: (MonadTypeChecker m) => TyVars -> RankMap -> SubstT m a -> m (a, [Ct], TyVars) runSubstT tyVars rankmap (SubstT m) = do let env = SubstEnv @@ -344,7 +360,7 @@ newtype SubstT m a = SubstT (StateT SubstState (ReaderT SubstEnv m) a) data SubstEnv = SubstEnv { envTyVars :: TyVars, - envRanks :: Map VName Int + envRanks :: RankMap } data SubstState = SubstState @@ -443,7 +459,7 @@ instance SubstRanks TyVarInfo where instance SubstRanks (Int, TyVarInfo) where substRanks (lvl, tv) = (lvl,) <$> substRanks tv -updAM :: Map VName Int -> Exp -> Exp +updAM :: RankMap -> Exp -> Exp updAM rank_map e = case e of AppExp (Apply f args loc) res -> @@ -476,15 +492,17 @@ updAM rank_map e = shapeToRank = Shape . foldMap dimToRank upd (AutoMap r m f) = AutoMap (shapeToRank r) (shapeToRank m) (shapeToRank f) - mapper = - identityMapper - { mapOnExp = pure . updAM rank_map - } + mapper = identityMapper {mapOnExp = pure . updAM rank_map} -updAMPat :: M.Map VName Int -> Pat ParamType -> Pat ParamType +updAMPat :: RankMap -> Pat ParamType -> Pat ParamType updAMPat rank_map p = runIdentity $ astMap m p where - m = - identityMapper - { mapOnExp = pure . updAM rank_map - } + m = identityMapper {mapOnExp = pure . updAM rank_map} + +updAMTypeExp :: + RankMap -> + TypeExp Exp VName -> + TypeExp Exp VName +updAMTypeExp rank_map te = runIdentity $ astMap m te + where + m = identityMapper {mapOnExp = pure . updAM rank_map} diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index f35aa83de9..cd3d460793 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -1356,10 +1356,10 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do unlines $ map (\(v, t) -> prettyNameString v <> " => " <> prettyString t) (M.toList artificial) ] - onRankSolution retdecl' typarams - =<< rankAnalysis1 loc cts tyvars artificial params' body' + onRankSolution typarams + =<< rankAnalysis1 loc cts tyvars artificial params' body' retdecl' where - onRankSolution retdecl' typarams ((cts', artificial, tyvars'), params', body'') = do + onRankSolution typarams ((cts', artificial, tyvars'), params', body'', retdecl') = do solution <- bitraverse pure @@ -1400,8 +1400,8 @@ checkSingleExp e = runTermM $ do tyvars <- gets termTyVars typarams <- gets termTyParams artificial <- gets termArtificial - ((cts', _artificial', tyvars'), _, e'') <- - rankAnalysis1 (srclocOf e') cts tyvars artificial [] e' + ((cts', _artificial', tyvars'), _, e'', _) <- + rankAnalysis1 (srclocOf e') cts tyvars artificial [] e' Nothing case solve cts' typarams tyvars' of Left err -> pure (Left err, e'') Right (unconstrained, solution) -> do @@ -1421,7 +1421,8 @@ checkSizeExp e = runTermM $ do typarams <- gets termTyParams artificial <- gets termArtificial - (cts_tyvars', _, es') <- unzip3 <$> rankAnalysis (srclocOf e) cts tyvars artificial [] e' + (cts_tyvars', _, es', _) <- + L.unzip4 <$> rankAnalysis (srclocOf e) cts tyvars artificial [] e' Nothing solutions <- forM cts_tyvars' $ \(cts', _artificial', tyvars') -> From 5f94900f1fc4dcc19a6fa614b495a77f125388f3 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 14 Jul 2024 23:53:20 +0200 Subject: [PATCH 235/258] Propagate liftedness properly. --- .../Futhark/TypeChecker/Constraints.hs | 80 +++++++++++++++---- src/Language/Futhark/TypeChecker/Terms2.hs | 8 +- 2 files changed, 68 insertions(+), 20 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 178c22e2b9..ca8f713d0d 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -128,13 +128,13 @@ type Level = Int type TyVars = M.Map TyVar (Level, TyVarInfo) -- | Explicit type parameters. -type TyParams = M.Map TyVar (Level, Loc) +type TyParams = M.Map TyVar (Level, Liftedness, Loc) data TyVarSol = -- | Has been substituted with this. TyVarSol Level Type | -- | Is an explicit (rigid) type parameter in the source program. - TyVarParam Level Loc + TyVarParam Level Liftedness Loc | -- | Not substituted yet; has this constraint. TyVarUnsol Level TyVarInfo deriving (Show) @@ -148,7 +148,7 @@ initialState :: TyParams -> TyVars -> SolverState initialState typarams tyvars = SolverState $ M.map g typarams <> M.map f tyvars where f (lvl, info) = Right $ TyVarUnsol lvl info - g (lvl, loc) = Right $ TyVarParam lvl loc + g (lvl, l, loc) = Right $ TyVarParam lvl l loc substTyVar :: (Monoid u) => M.Map TyVar (Either VName TyVarSol) -> VName -> Maybe (TypeBase SComp u) substTyVar m v = @@ -159,18 +159,24 @@ substTyVar m v = Just (Right (TyVarUnsol {})) -> Nothing Nothing -> Nothing -lookupTyVar :: TyVar -> SolveM (Level, Either TyVarInfo Type) -lookupTyVar orig = do +maybeLookupTyVar :: TyVar -> SolveM (Maybe TyVarSol) +maybeLookupTyVar orig = do tyvars <- gets solverTyVars let f v = case M.lookup v tyvars of - Nothing -> error $ "Unknown tyvar: " <> prettyNameString v + Nothing -> pure Nothing Just (Left v') -> f v' - Just (Right (TyVarSol lvl t)) -> pure (lvl, Right t) - Just (Right (TyVarParam lvl _)) -> - pure (lvl, Right $ Scalar $ TypeVar mempty (qualName orig) []) - Just (Right (TyVarUnsol lvl info)) -> pure (lvl, Left info) + Just (Right info) -> pure $ Just info f orig +lookupTyVar :: TyVar -> SolveM (Level, Either TyVarInfo Type) +lookupTyVar orig = + maybe bad unpack <$> maybeLookupTyVar orig + where + bad = error $ "Unknown tyvar: " <> prettyNameString orig + unpack (TyVarParam {}) = error $ "Is a type param: " <> prettyNameString orig + unpack (TyVarSol lvl t) = (lvl, Right t) + unpack (TyVarUnsol lvl info) = (lvl, Left info) + -- | Variable must be flexible. lookupTyVarInfo :: TyVar -> SolveM (Level, TyVarInfo) lookupTyVarInfo v = do @@ -275,6 +281,11 @@ subTyVar :: Reason -> VName -> Int -> Type -> SolveM () subTyVar reason v v_lvl t = do occursCheck reason v t v_info <- gets $ M.lookup v . solverTyVars + + -- Set a solution for v, then update info for t in case v has any + -- odd constraints. + setInfo v (TyVarSol v_lvl t) + case (v_info, t) of (Just (Right (TyVarUnsol _ TyVarFree {})), _) -> pure () @@ -340,8 +351,6 @@ subTyVar reason v v_lvl t = do (Nothing, _) -> error $ "subTyVar: Nothing v: " <> prettyNameString v - setInfo v (TyVarSol v_lvl t) - -- Precondition: 'v' and 't' are both currently flexible. unionTyVars :: Reason -> VName -> VName -> SolveM () unionTyVars reason v t = do @@ -566,10 +575,37 @@ scopeCheck reason v v_lvl ty = do check ty_v = do ty_v_info <- gets $ M.lookup ty_v . solverTyVars case ty_v_info of - Just (Right (TyVarParam ty_v_lvl _)) + Just (Right (TyVarParam ty_v_lvl _ _)) | ty_v_lvl > v_lvl -> scopeViolation reason v ty ty_v _ -> pure () +-- If a type variable has a liftedness constraint, we propagate that +-- constraint to its solution. The actual checking for correct usage +-- is done later. +liftednessCheck :: Reason -> TyVar -> Liftedness -> Type -> SolveM () +liftednessCheck reason v l (Scalar (TypeVar _ (QualName [] v2) _)) = do + v2_info <- maybeLookupTyVar v2 + case v2_info of + Nothing -> + -- Is an opaque type. + pure () + Just (TyVarSol _ v2_ty) -> + liftednessCheck reason v l v2_ty + Just TyVarParam {} -> pure () + Just (TyVarUnsol lvl (TyVarFree loc v2_l)) + | l /= v2_l -> + setInfo v2 $ TyVarUnsol lvl $ TyVarFree loc (min l v2_l) + Just TyVarUnsol {} -> pure () +liftednessCheck _ _ _ (Scalar Prim {}) = pure () +liftednessCheck _ _ Lifted _ = pure () +liftednessCheck _ _ _ Array {} = pure () +liftednessCheck _ _ _ (Scalar Arrow {}) = pure () +liftednessCheck reason v l (Scalar (Record fs)) = + mapM_ (liftednessCheck reason v l) fs +liftednessCheck reason v l (Scalar (Sum cs)) = + mapM_ (mapM_ $ liftednessCheck reason v l) cs +liftednessCheck _ _ _ (Scalar TypeVar {}) = pure () + solveTyVar :: (VName, (Level, TyVarInfo)) -> SolveM () solveTyVar (tv, (_, TyVarRecord loc fs1)) = do (_, tv_t) <- lookupTyVar tv @@ -607,11 +643,23 @@ solveTyVar (tv, (_, TyVarEql loc)) = do "Type" indent 2 (align (pretty ty)) "does not support equality (may contain function)." -solveTyVar (tv, (lvl, _)) = do +solveTyVar (tv, (lvl, TyVarFree loc l)) = do + (_, tv_t) <- lookupTyVar tv + case tv_t of + Right ty -> do + scopeCheck (Reason loc) tv lvl ty + liftednessCheck (Reason loc) tv l ty + _ -> pure () +solveTyVar (tv, (_, TyVarPrim loc pts)) = do (_, tv_t) <- lookupTyVar tv case tv_t of - Right ty -> - scopeCheck (Reason mempty) tv lvl ty + Right ty + | ty `elem` map (Scalar . Prim) pts -> pure () + | otherwise -> + typeError loc mempty $ + "Numeric constant inferred to be of type" + indent 2 (align (pretty ty)) + "which is not possible." _ -> pure () solve :: diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index cd3d460793..e58154801e 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -655,9 +655,8 @@ bindTypeParams tparams m = types = mapMaybe typeParamType tparams typeParamType (TypeParamType l v _) = Just (v, TypeAbbr l [] $ RetType [] $ Scalar (TypeVar mempty (qualName v) [])) - typeParamType TypeParamDim {} = - Nothing - typeParam lvl (TypeParamType _ v loc) = Just (v, (lvl, locOf loc)) + typeParamType TypeParamDim {} = Nothing + typeParam lvl (TypeParamType l v loc) = Just (v, (lvl, l, locOf loc)) typeParam _ _ = Nothing bindParams :: @@ -1370,7 +1369,8 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do [ "## constraints:", unlines $ map prettyString cts', "## typarams:", - unlines (map (prettyString . bimap prettyNameString fst) (M.toList typarams)), + let f (lvl, l, _) = (lvl, l) + in unlines (map (prettyString . bimap prettyNameString f) (M.toList typarams)), "## tyvars':", unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars', "## solution:", From 041e9e05971bb853edbbb7ea120b435ae913f1cc Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 14 Jul 2024 23:58:47 +0200 Subject: [PATCH 236/258] Simplify. --- .../Futhark/TypeChecker/Constraints.hs | 38 +++++++++---------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index ca8f713d0d..bc458d0616 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -582,29 +582,29 @@ scopeCheck reason v v_lvl ty = do -- If a type variable has a liftedness constraint, we propagate that -- constraint to its solution. The actual checking for correct usage -- is done later. -liftednessCheck :: Reason -> TyVar -> Liftedness -> Type -> SolveM () -liftednessCheck reason v l (Scalar (TypeVar _ (QualName [] v2) _)) = do - v2_info <- maybeLookupTyVar v2 - case v2_info of +liftednessCheck :: Liftedness -> Type -> SolveM () +liftednessCheck l (Scalar (TypeVar _ (QualName [] v) _)) = do + v_info <- maybeLookupTyVar v + case v_info of Nothing -> -- Is an opaque type. pure () - Just (TyVarSol _ v2_ty) -> - liftednessCheck reason v l v2_ty + Just (TyVarSol _ v_ty) -> + liftednessCheck l v_ty Just TyVarParam {} -> pure () - Just (TyVarUnsol lvl (TyVarFree loc v2_l)) - | l /= v2_l -> - setInfo v2 $ TyVarUnsol lvl $ TyVarFree loc (min l v2_l) + Just (TyVarUnsol lvl (TyVarFree loc v_l)) + | l /= v_l -> + setInfo v $ TyVarUnsol lvl $ TyVarFree loc (min l v_l) Just TyVarUnsol {} -> pure () -liftednessCheck _ _ _ (Scalar Prim {}) = pure () -liftednessCheck _ _ Lifted _ = pure () -liftednessCheck _ _ _ Array {} = pure () -liftednessCheck _ _ _ (Scalar Arrow {}) = pure () -liftednessCheck reason v l (Scalar (Record fs)) = - mapM_ (liftednessCheck reason v l) fs -liftednessCheck reason v l (Scalar (Sum cs)) = - mapM_ (mapM_ $ liftednessCheck reason v l) cs -liftednessCheck _ _ _ (Scalar TypeVar {}) = pure () +liftednessCheck _ (Scalar Prim {}) = pure () +liftednessCheck Lifted _ = pure () +liftednessCheck _ Array {} = pure () +liftednessCheck _ (Scalar Arrow {}) = pure () +liftednessCheck l (Scalar (Record fs)) = + mapM_ (liftednessCheck l) fs +liftednessCheck l (Scalar (Sum cs)) = + mapM_ (mapM_ $ liftednessCheck l) cs +liftednessCheck _ (Scalar TypeVar {}) = pure () solveTyVar :: (VName, (Level, TyVarInfo)) -> SolveM () solveTyVar (tv, (_, TyVarRecord loc fs1)) = do @@ -648,7 +648,7 @@ solveTyVar (tv, (lvl, TyVarFree loc l)) = do case tv_t of Right ty -> do scopeCheck (Reason loc) tv lvl ty - liftednessCheck (Reason loc) tv l ty + liftednessCheck l ty _ -> pure () solveTyVar (tv, (_, TyVarPrim loc pts)) = do (_, tv_t) <- lookupTyVar tv From 9834d6866e74620b0328775b1ddf0537cda9bf5d Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 15 Jul 2024 00:08:24 +0200 Subject: [PATCH 237/258] Simplify the level stuff. --- .../Futhark/TypeChecker/Constraints.hs | 140 +++++++++--------- 1 file changed, 68 insertions(+), 72 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index bc458d0616..1b66ef9b44 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -132,11 +132,11 @@ type TyParams = M.Map TyVar (Level, Liftedness, Loc) data TyVarSol = -- | Has been substituted with this. - TyVarSol Level Type + TyVarSol Type | -- | Is an explicit (rigid) type parameter in the source program. TyVarParam Level Liftedness Loc | -- | Not substituted yet; has this constraint. - TyVarUnsol Level TyVarInfo + TyVarUnsol TyVarInfo deriving (Show) newtype SolverState = SolverState @@ -147,14 +147,14 @@ newtype SolverState = SolverState initialState :: TyParams -> TyVars -> SolverState initialState typarams tyvars = SolverState $ M.map g typarams <> M.map f tyvars where - f (lvl, info) = Right $ TyVarUnsol lvl info + f (_lvl, info) = Right $ TyVarUnsol info g (lvl, l, loc) = Right $ TyVarParam lvl l loc substTyVar :: (Monoid u) => M.Map TyVar (Either VName TyVarSol) -> VName -> Maybe (TypeBase SComp u) substTyVar m v = case M.lookup v m of Just (Left v') -> substTyVar m v' - Just (Right (TyVarSol _ t')) -> Just $ second (const mempty) $ substTyVars (substTyVar m) t' + Just (Right (TyVarSol t')) -> Just $ second (const mempty) $ substTyVars (substTyVar m) t' Just (Right TyVarParam {}) -> Nothing Just (Right (TyVarUnsol {})) -> Nothing Nothing -> Nothing @@ -168,21 +168,21 @@ maybeLookupTyVar orig = do Just (Right info) -> pure $ Just info f orig -lookupTyVar :: TyVar -> SolveM (Level, Either TyVarInfo Type) +lookupTyVar :: TyVar -> SolveM (Either TyVarInfo Type) lookupTyVar orig = maybe bad unpack <$> maybeLookupTyVar orig where bad = error $ "Unknown tyvar: " <> prettyNameString orig unpack (TyVarParam {}) = error $ "Is a type param: " <> prettyNameString orig - unpack (TyVarSol lvl t) = (lvl, Right t) - unpack (TyVarUnsol lvl info) = (lvl, Left info) + unpack (TyVarSol t) = Right t + unpack (TyVarUnsol info) = Left info -- | Variable must be flexible. -lookupTyVarInfo :: TyVar -> SolveM (Level, TyVarInfo) +lookupTyVarInfo :: TyVar -> SolveM TyVarInfo lookupTyVarInfo v = do - (lvl, r) <- lookupTyVar v + r <- lookupTyVar v case r of - Left info -> pure (lvl, info) + Left info -> pure info Right _ -> error $ "Tyvar is nonflexible: " <> prettyNameString v setLink :: TyVar -> VName -> SolveM () @@ -206,15 +206,15 @@ solution s = M.mapMaybe mkSubst $ solverTyVars s ) where - mkSubst (Right (TyVarSol _lvl t)) = + mkSubst (Right (TyVarSol t)) = Just $ Right $ first (const ()) $ substTyVars (substTyVar (solverTyVars s)) t mkSubst (Left v') = Just . fromMaybe (Right $ Scalar $ TypeVar mempty (qualName v') []) $ mkSubst =<< M.lookup v' (solverTyVars s) - mkSubst (Right (TyVarUnsol _ (TyVarPrim _ pts))) = Just $ Left pts + mkSubst (Right (TyVarUnsol (TyVarPrim _ pts))) = Just $ Left pts mkSubst _ = Nothing - unconstrained (v, Right (TyVarUnsol _ (TyVarFree _ l))) = Just (v, l) + unconstrained (v, Right (TyVarUnsol (TyVarFree _ l))) = Just (v, l) unconstrained _ = Nothing newtype SolveM a = SolveM {runSolveM :: StateT SolverState (Except TypeError) a} @@ -277,19 +277,19 @@ scopeViolation reason v1 ty v2 = <+> "is rigidly bound in a deeper scope." -- Precondition: 'v' is currently flexible. -subTyVar :: Reason -> VName -> Int -> Type -> SolveM () -subTyVar reason v v_lvl t = do +subTyVar :: Reason -> VName -> Type -> SolveM () +subTyVar reason v t = do occursCheck reason v t v_info <- gets $ M.lookup v . solverTyVars -- Set a solution for v, then update info for t in case v has any -- odd constraints. - setInfo v (TyVarSol v_lvl t) + setInfo v (TyVarSol t) case (v_info, t) of - (Just (Right (TyVarUnsol _ TyVarFree {})), _) -> + (Just (Right (TyVarUnsol TyVarFree {})), _) -> pure () - ( Just (Right (TyVarUnsol _ (TyVarPrim _ v_pts))), + ( Just (Right (TyVarUnsol (TyVarPrim _ v_pts))), _ ) -> if t `elem` map (Scalar . Prim) v_pts @@ -300,7 +300,7 @@ subTyVar reason v v_lvl t = do indent 2 (pretty v_pts) "with" indent 2 (pretty t) - ( Just (Right (TyVarUnsol _ (TyVarSum _ cs1))), + ( Just (Right (TyVarUnsol (TyVarSum _ cs1))), Scalar (Sum cs2) ) -> if all (`elem` M.keys cs2) (M.keys cs1) @@ -311,7 +311,7 @@ subTyVar reason v v_lvl t = do indent 2 (pretty (Sum cs1)) "with type" indent 2 (pretty (Sum cs2)) - ( Just (Right (TyVarUnsol _ (TyVarSum _ cs1))), + ( Just (Right (TyVarUnsol (TyVarSum _ cs1))), _ ) -> typeError (locOf reason) mempty $ @@ -319,7 +319,7 @@ subTyVar reason v v_lvl t = do indent 2 (pretty (Sum cs1)) "with type" indent 2 (pretty t) - ( Just (Right (TyVarUnsol _ (TyVarRecord _ fs1))), + ( Just (Right (TyVarUnsol (TyVarRecord _ fs1))), Scalar (Record fs2) ) -> if all (`elem` M.keys fs2) (M.keys fs1) @@ -330,7 +330,7 @@ subTyVar reason v v_lvl t = do indent 2 (pretty (Record fs1)) "with record type" indent 2 (pretty (Record fs2)) - ( Just (Right (TyVarUnsol _ (TyVarRecord _ fs1))), + ( Just (Right (TyVarUnsol (TyVarRecord _ fs1))), _ ) -> typeError (locOf reason) mempty $ @@ -338,7 +338,7 @@ subTyVar reason v v_lvl t = do indent 2 (pretty (Record fs1)) "with type" indent 2 (pretty t) - (Just (Right (TyVarUnsol _ (TyVarEql _))), _) -> + (Just (Right (TyVarUnsol (TyVarEql _))), _) -> mustSupportEql reason t -- -- Internal error cases @@ -355,32 +355,32 @@ subTyVar reason v v_lvl t = do unionTyVars :: Reason -> VName -> VName -> SolveM () unionTyVars reason v t = do v_info <- gets $ either alreadyLinked id . fromMaybe unknown . M.lookup v . solverTyVars - (t_lvl, t_info) <- lookupTyVarInfo t + t_info <- lookupTyVarInfo t -- Insert the link from v to t, and then update the info of t based -- on the existing info of v and t. setLink v t case (v_info, t_info) of - ( TyVarUnsol _ (TyVarFree _ v_l), + ( TyVarUnsol (TyVarFree _ v_l), TyVarFree t_loc t_l ) | v_l /= t_l -> - setInfo t $ TyVarUnsol t_lvl $ TyVarFree t_loc (min v_l t_l) + setInfo t $ TyVarUnsol $ TyVarFree t_loc (min v_l t_l) -- When either is completely unconstrained. - (TyVarUnsol _ TyVarFree {}, _) -> + (TyVarUnsol TyVarFree {}, _) -> pure () - ( TyVarUnsol _ info, + ( TyVarUnsol info, TyVarFree {} ) -> - setInfo t (TyVarUnsol t_lvl info) + setInfo t (TyVarUnsol info) -- -- TyVarPrim cases - ( TyVarUnsol _ info@TyVarPrim {}, + ( TyVarUnsol info@TyVarPrim {}, TyVarEql {} ) -> - setInfo t (TyVarUnsol t_lvl info) - ( TyVarUnsol _ (TyVarPrim _ v_pts), + setInfo t (TyVarUnsol info) + ( TyVarUnsol (TyVarPrim _ v_pts), TyVarPrim t_loc t_pts ) -> let pts = L.intersect v_pts t_pts @@ -391,15 +391,15 @@ unionTyVars reason v t = do indent 2 (pretty v_pts) "with type that must be one of" indent 2 (pretty t_pts) - else setInfo t (TyVarUnsol t_lvl (TyVarPrim t_loc pts)) - ( TyVarUnsol _ (TyVarPrim _ v_pts), + else setInfo t (TyVarUnsol (TyVarPrim t_loc pts)) + ( TyVarUnsol (TyVarPrim _ v_pts), TyVarRecord {} ) -> typeError (locOf reason) mempty $ "Cannot unify type that must be one of" indent 2 (pretty v_pts) "with type that must be record." - ( TyVarUnsol _ (TyVarPrim _ v_pts), + ( TyVarUnsol (TyVarPrim _ v_pts), TyVarSum {} ) -> typeError (locOf reason) mempty $ @@ -408,19 +408,19 @@ unionTyVars reason v t = do "with type that must be sum." -- -- TyVarSum cases - ( TyVarUnsol _ (TyVarSum _ cs1), + ( TyVarUnsol (TyVarSum _ cs1), TyVarSum loc cs2 ) -> do unifySharedConstructors reason cs1 cs2 let cs3 = cs1 <> cs2 - setInfo t (TyVarUnsol t_lvl (TyVarSum loc cs3)) - ( TyVarUnsol _ TyVarSum {}, + setInfo t (TyVarUnsol (TyVarSum loc cs3)) + ( TyVarUnsol TyVarSum {}, TyVarPrim _ pts ) -> typeError (locOf reason) mempty $ "A sum type cannot be one of" indent 2 (pretty pts) - ( TyVarUnsol _ (TyVarSum _ cs1), + ( TyVarUnsol (TyVarSum _ cs1), TyVarRecord _ fs ) -> typeError (locOf reason) mempty $ @@ -428,25 +428,25 @@ unionTyVars reason v t = do indent 2 (pretty (Sum cs1)) "with type" indent 2 (pretty (Scalar (Record fs))) - ( TyVarUnsol _ (TyVarSum _ cs1), + ( TyVarUnsol (TyVarSum _ cs1), TyVarEql _ ) -> mapM_ (mapM_ (mustSupportEql reason)) cs1 -- -- TyVarRecord cases - ( TyVarUnsol _ (TyVarRecord _ fs1), + ( TyVarUnsol (TyVarRecord _ fs1), TyVarRecord loc fs2 ) -> do unifySharedFields reason fs1 fs2 let fs3 = fs1 <> fs2 - setInfo t (TyVarUnsol t_lvl (TyVarRecord loc fs3)) - ( TyVarUnsol _ TyVarRecord {}, + setInfo t (TyVarUnsol (TyVarRecord loc fs3)) + ( TyVarUnsol TyVarRecord {}, TyVarPrim _ pts ) -> typeError (locOf reason) mempty $ "A record type cannot be one of" indent 2 (pretty pts) - ( TyVarUnsol _ (TyVarRecord _ fs1), + ( TyVarUnsol (TyVarRecord _ fs1), TyVarSum _ cs ) -> typeError (locOf reason) mempty $ @@ -454,19 +454,19 @@ unionTyVars reason v t = do indent 2 (pretty (Record fs1)) "with type" indent 2 (pretty (Scalar (Sum cs))) - ( TyVarUnsol _ (TyVarRecord _ fs1), + ( TyVarUnsol (TyVarRecord _ fs1), TyVarEql _ ) -> mapM_ (mustSupportEql reason) fs1 -- -- TyVarEql cases - (TyVarUnsol _ (TyVarEql _), TyVarPrim {}) -> + (TyVarUnsol (TyVarEql _), TyVarPrim {}) -> pure () - (TyVarUnsol _ (TyVarEql _), TyVarEql {}) -> + (TyVarUnsol (TyVarEql _), TyVarEql {}) -> pure () - (TyVarUnsol _ (TyVarEql _), TyVarRecord _ fs) -> + (TyVarUnsol (TyVarEql _), TyVarRecord _ fs) -> mustSupportEql reason $ Scalar $ Record fs - (TyVarUnsol _ (TyVarEql _), TyVarSum _ cs) -> + (TyVarUnsol (TyVarEql _), TyVarSum _ cs) -> mustSupportEql reason $ Scalar $ Sum cs -- -- Internal error cases @@ -529,14 +529,14 @@ solveEq reason orig_t1 orig_t2 = do tyvars <- gets solverTyVars let flexible v = case M.lookup v tyvars of Just (Left v') -> flexible v' - Just (Right (TyVarUnsol lvl _)) -> Just lvl - Just (Right TyVarSol {}) -> Nothing - Just (Right TyVarParam {}) -> Nothing - Nothing -> Nothing + Just (Right (TyVarUnsol _)) -> True + Just (Right TyVarSol {}) -> False + Just (Right TyVarParam {}) -> False + Nothing -> False sub t@(Scalar (TypeVar u (QualName [] v) [])) = case M.lookup v tyvars of Just (Left v') -> sub $ Scalar (TypeVar u (QualName [] v') []) - Just (Right (TyVarSol _ t')) -> sub t' + Just (Right (TyVarSol t')) -> sub t' _ -> t sub t = t case (sub t1, sub t2) of @@ -546,18 +546,14 @@ solveEq reason orig_t1 orig_t2 = do | v1 == v2 -> pure () | otherwise -> case (flexible v1, flexible v2) of - (Nothing, Nothing) -> cannotUnify - (Just lvl, Nothing) -> subTyVar reason v1 lvl t2' - (Nothing, Just lvl) -> subTyVar reason v2 lvl t1' - (Just lvl1, Just lvl2) - | lvl1 <= lvl2 -> unionTyVars reason v1 v2 - | otherwise -> unionTyVars reason v2 v1 + (False, False) -> cannotUnify + (True, False) -> subTyVar reason v1 t2' + (False, True) -> subTyVar reason v2 t1' + (True, True) -> unionTyVars reason v1 v2 (Scalar (TypeVar _ (QualName [] v1) []), t2') - | Just lvl <- flexible v1 -> - subTyVar reason v1 lvl t2' + | flexible v1 -> subTyVar reason v1 t2' (t1', Scalar (TypeVar _ (QualName [] v2) [])) - | Just lvl <- flexible v2 -> - subTyVar reason v2 lvl t1' + | flexible v2 -> subTyVar reason v2 t1' (t1', t2') -> case unify t1' t2' of Nothing -> cannotUnify Just eqs -> mapM_ solveCt' eqs @@ -589,12 +585,12 @@ liftednessCheck l (Scalar (TypeVar _ (QualName [] v) _)) = do Nothing -> -- Is an opaque type. pure () - Just (TyVarSol _ v_ty) -> + Just (TyVarSol v_ty) -> liftednessCheck l v_ty Just TyVarParam {} -> pure () - Just (TyVarUnsol lvl (TyVarFree loc v_l)) + Just (TyVarUnsol (TyVarFree loc v_l)) | l /= v_l -> - setInfo v $ TyVarUnsol lvl $ TyVarFree loc (min l v_l) + setInfo v $ TyVarUnsol $ TyVarFree loc (min l v_l) Just TyVarUnsol {} -> pure () liftednessCheck _ (Scalar Prim {}) = pure () liftednessCheck Lifted _ = pure () @@ -608,7 +604,7 @@ liftednessCheck _ (Scalar TypeVar {}) = pure () solveTyVar :: (VName, (Level, TyVarInfo)) -> SolveM () solveTyVar (tv, (_, TyVarRecord loc fs1)) = do - (_, tv_t) <- lookupTyVar tv + tv_t <- lookupTyVar tv case tv_t of Left _ -> typeError loc mempty $ @@ -620,7 +616,7 @@ solveTyVar (tv, (_, TyVarRecord loc fs1)) = do Right _ -> pure () solveTyVar (tv, (_, TyVarSum loc cs1)) = do - (_, tv_t) <- lookupTyVar tv + tv_t <- lookupTyVar tv case tv_t of Left _ -> typeError loc mempty $ @@ -629,7 +625,7 @@ solveTyVar (tv, (_, TyVarSum loc cs1)) = do indent 2 (pretty (Scalar (Sum cs1))) Right _ -> pure () solveTyVar (tv, (_, TyVarEql loc)) = do - (_, tv_t) <- lookupTyVar tv + tv_t <- lookupTyVar tv case tv_t of Left TyVarEql {} -> typeError loc mempty $ @@ -644,14 +640,14 @@ solveTyVar (tv, (_, TyVarEql loc)) = do indent 2 (align (pretty ty)) "does not support equality (may contain function)." solveTyVar (tv, (lvl, TyVarFree loc l)) = do - (_, tv_t) <- lookupTyVar tv + tv_t <- lookupTyVar tv case tv_t of Right ty -> do scopeCheck (Reason loc) tv lvl ty liftednessCheck l ty _ -> pure () solveTyVar (tv, (_, TyVarPrim loc pts)) = do - (_, tv_t) <- lookupTyVar tv + tv_t <- lookupTyVar tv case tv_t of Right ty | ty `elem` map (Scalar . Prim) pts -> pure () From e5088b58c06274bd4d2b6673a19f8bed5f297c3d Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 15 Jul 2024 10:22:50 +0200 Subject: [PATCH 238/258] Improve handling of branches. --- src/Language/Futhark/TypeChecker/Unify.hs | 29 +++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Unify.hs b/src/Language/Futhark/TypeChecker/Unify.hs index 47e04ee3aa..8bd5b68af7 100644 --- a/src/Language/Futhark/TypeChecker/Unify.hs +++ b/src/Language/Futhark/TypeChecker/Unify.hs @@ -807,10 +807,31 @@ unifyMostCommon :: StructType -> m (StructType, [VName]) unifyMostCommon usage t1 t2 = do - -- We are ignoring the dimensions here, because any mismatches - -- should be turned into fresh size variables. - let allOK _ _ _ _ _ = pure () - unifyWith allOK usage mempty noBreadCrumbs t1 t2 + -- Like 'unifySizes', except we do not fail on mismatches - these + -- are instead turned into fresh existential sizes in + -- 'newDimOnMismatch'. The most annoying thing is that we have to + -- replicate scope checking, because we don't want to link if it + -- would fail. + constraints <- getConstraints + + let varLevel v = fst <$> M.lookup v constraints + expLevel e = + L.foldl' max 0 $ mapMaybe varLevel $ S.toList $ fvVars $ freeInExp e + + onDims bcs bound nonrigid e1 e2 + | Just es <- similarExps e1 e2 = + mapM_ (uncurry $ onDims bcs bound nonrigid) es + onDims bcs _ nonrigid (Var v1 _ _) e2 + | Just lvl1 <- nonrigid (qualLeaf v1), + expLevel e2 < lvl1 = + linkVarToDim usage bcs (qualLeaf v1) lvl1 e2 + onDims bcs _ nonrigid e1 (Var v2 _ _) + | Just lvl2 <- nonrigid (qualLeaf v2), + expLevel e1 < lvl2 = + linkVarToDim usage bcs (qualLeaf v2) lvl2 e1 + onDims _ _ _ _ _ = pure () + + unifyWith onDims usage mempty noBreadCrumbs t1 t2 t1' <- normTypeFully t1 t2' <- normTypeFully t2 newDimOnMismatch (locOf usage) t1' t2' From 104641407905ded29f48818407bf5c66e5457498 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 15 Jul 2024 10:24:49 +0200 Subject: [PATCH 239/258] Supposed to be ambiguous. --- tests/automap/ambiguous0.fut | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/automap/ambiguous0.fut b/tests/automap/ambiguous0.fut index 58a663bf36..8c1ec556c3 100644 --- a/tests/automap/ambiguous0.fut +++ b/tests/automap/ambiguous0.fut @@ -1 +1,4 @@ +-- == +-- error: ambiguous + def ambig (xss : [][]i32) = i64.sum (length xss) From d0bea363dfcaa6e5b726e4dbf64832d84cbccaed Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 15 Jul 2024 11:02:22 +0200 Subject: [PATCH 240/258] Now inferred differently. --- tests/tridag.fut | 43 +++++++++++++++++++++---------------------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/tests/tridag.fut b/tests/tridag.fut index a055dca86a..e8cc6718e8 100644 --- a/tests/tridag.fut +++ b/tests/tridag.fut @@ -34,32 +34,31 @@ -- } -def tridag(nn: i32, - b: *[]f64, d: *[]f64, - a: []f64, c: []f64 ): ([]f64,[]f64) = - if (nn == 1) +def tridag [nn] (b: *[]f64, d: *[nn]f64, + a: []f64, c: []f64 ): ([]f64,[]f64) = + if (nn == 1) --then ( b, map(\f64 (f64 x, f64 y) -> x / y, d, b) ) then (b, [d[0]/b[0]]) - else - let (b,d) = loop((b, d)) for i < (nn-1) do - let xm = a[i+1] / b[i] - let b[i+1] = b[i+1] - xm*c[i] - let d[i+1] = d[i+1] - xm*d[i] in - (b, d) + else + let (b,d) = loop((b, d)) for i < (nn-1) do + let xm = a[i+1] / b[i] + let b[i+1] = b[i+1] - xm*c[i] + let d[i+1] = d[i+1] - xm*d[i] in + (b, d) - let d[nn-1] = d[nn-1] / b[nn-1] in + let d[nn-1] = d[nn-1] / b[nn-1] in - let d = loop(d) for i < (nn-1) do - let k = nn - 2 - i - let d[k] = ( d[k] - c[k]*d[k+1] ) / b[k] in - d - in (b, d) + let d = loop(d) for i < (nn-1) do + let k = nn - 2 - i + let d[k] = ( d[k] - c[k]*d[k+1] ) / b[k] in + d + in (b, d) def main: ([]f64,[]f64) = - let nn = reduce (+) 0 ([1,2,3,4]) - let a = replicate nn 3.33 - let b = map (\x -> f64.i64(x) + 1.0) (iota(nn)) - let c = map (\x -> 1.11*f64.i64(x) + 0.5) (iota(nn)) - let d = map (\x -> 1.01*f64.i64(x) + 0.25) (iota(nn)) - in tridag(i32.i64 nn, b, d, a, c) + let nn = reduce (+) 0 ([1,2,3,4]) + let a = replicate nn 3.33 + let b = map (\x -> f64.i64(x) + 1.0) (iota(nn)) + let c = map (\x -> 1.11*f64.i64(x) + 0.5) (iota(nn)) + let d = map (\x -> 1.01*f64.i64(x) + 0.25) (iota(nn)) + in tridag(b, d, a, c) From 0a07d320d058ed99b21affa891f5779b6360f141 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 15 Jul 2024 22:16:12 +0200 Subject: [PATCH 241/258] Must be more explicit now. --- tests/shapes/polymorphic4.fut | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/shapes/polymorphic4.fut b/tests/shapes/polymorphic4.fut index b44af86c34..acab851f67 100644 --- a/tests/shapes/polymorphic4.fut +++ b/tests/shapes/polymorphic4.fut @@ -2,6 +2,6 @@ -- == -- error: do not match -def foo f x : [1]i32 = +def foo (f : (n: i64) -> [n]i32) x : [1]i32 = let r = if true then f x : []i32 else [1i32] in r From 80cfae5c78b307aa45e38a7b28e1a425153f2b8f Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 16 Jul 2024 17:46:11 +0200 Subject: [PATCH 242/258] This is OK now. --- tests/sumtypes/coerce1.fut | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/sumtypes/coerce1.fut b/tests/sumtypes/coerce1.fut index eeff92a2a3..b6bfe42f3d 100644 --- a/tests/sumtypes/coerce1.fut +++ b/tests/sumtypes/coerce1.fut @@ -1,5 +1,4 @@ -- == --- error: Ambiguous size.*anonymous size type opt 't = #some t | #none From 66be58ecb9b8011fb0fe99b2e519984da28fabc9 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 16 Jul 2024 17:46:25 +0200 Subject: [PATCH 243/258] Fix Constr. --- src/Language/Futhark/TypeChecker/Terms.hs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index d2f03ecc9c..06125f214d 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -796,6 +796,13 @@ checkExp (AppExp (Loop _ mergepat mergeexp form loopbody loc) _) = do checkExp (Constr name es (Info t) loc) = do t' <- replaceTyVars loc t es' <- mapM checkExp es + case t' of + Scalar (Sum cs) + | Just name_ts <- M.lookup name cs -> + zipWithM_ (unify $ mkUsage loc "inferred variant") name_ts $ + map typeOf es' + _ -> + error $ "checkExp Constr: " <> prettyString t' pure $ Constr name es' (Info t') loc checkExp (AppExp (Match e cs loc) _) = do e' <- checkExp e From 7068b88970e7d118f038731f8f5f16853f0d3954 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 16 Jul 2024 23:02:56 +0200 Subject: [PATCH 244/258] Fiddle with liftedness checking. --- src/Language/Futhark/TypeChecker/Terms.hs | 42 +++++++++---------- .../Futhark/TypeChecker/Terms/Loop.hs | 6 +-- .../Futhark/TypeChecker/Terms/Monad.hs | 29 +++++++++++++ src/Language/Futhark/TypeChecker/Terms2.hs | 9 ++-- src/Language/Futhark/TypeChecker/Unify.hs | 35 ---------------- 5 files changed, 56 insertions(+), 65 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 06125f214d..082e60f1ea 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -383,12 +383,14 @@ checkExp (ArrayLit all_es _ loc) = [] -> do et <- newTypeVar loc "t" t <- arrayOfM loc et (Shape [sizeFromInteger 0 mempty]) + mustBeUnlifted (locOf loc) et pure $ ArrayLit [] (Info t) loc e : es -> do e' <- checkExp e et <- expType e' es' <- mapM (unifies "type of first array element" et <=< checkExp) es t <- arrayOfM loc et (Shape [sizeFromInteger (genericLength all_es) mempty]) + mustBeUnlifted (locOf loc) et pure $ ArrayLit (e' : es') (Info t) loc checkExp (AppExp (Range start maybe_step end loc) _) = do start' <- checkExp start @@ -519,24 +521,6 @@ checkExp (Project k e _ loc) = do | Just kt <- M.lookup k fs -> pure $ Project k e' (Info kt) loc _ -> error $ "checkExp Project: " <> show t -checkExp (AppExp (If e1 e2 e3 loc) _) = do - e1' <- checkExp e1 - e2' <- checkExp e2 - e3' <- checkExp e3 - - let bool = Scalar $ Prim Bool - e1_t <- expType e1' - onFailure (CheckingRequired [bool] e1_t) $ - unify (mkUsage e1' "use as 'if' condition") bool e1_t - - (brancht, retext) <- unifyBranches loc e2' e3' - - zeroOrderType - (mkUsage loc "returning value of this type from 'if' expression") - "type returned from branch" - brancht - - pure $ AppExp (If e1' e2' e3' loc) (Info $ AppRes brancht retext) checkExp (Parens e loc) = Parens <$> checkExp e <*> pure loc checkExp (QualParens (modname, modnameloc) e loc) = do @@ -804,14 +788,28 @@ checkExp (Constr name es (Info t) loc) = do _ -> error $ "checkExp Constr: " <> prettyString t' pure $ Constr name es' (Info t') loc +checkExp (AppExp (If e1 e2 e3 loc) _) = do + e1' <- checkExp e1 + e2' <- checkExp e2 + e3' <- checkExp e3 + + let bool = Scalar $ Prim Bool + e1_t <- expType e1' + onFailure (CheckingRequired [bool] e1_t) $ + unify (mkUsage e1' "use as 'if' condition") bool e1_t + + (t, retext) <- unifyBranches loc e2' e3' + + mustBeOrderZero (locOf loc) t + + pure $ AppExp (If e1' e2' e3' loc) (Info $ AppRes t retext) checkExp (AppExp (Match e cs loc) _) = do e' <- checkExp e mt <- expType e' (cs', t, retext) <- checkCases mt cs - zeroOrderType - (mkUsage loc "being returned 'match'") - "type returned from pattern match" - t + + mustBeOrderZero (locOf loc) t + pure $ AppExp (Match e' cs' loc) (Info $ AppRes t retext) checkExp (Attr info e loc) = Attr <$> checkAttr info <*> checkExp e <*> pure loc diff --git a/src/Language/Futhark/TypeChecker/Terms/Loop.hs b/src/Language/Futhark/TypeChecker/Terms/Loop.hs index cf447e7408..334c67ed5b 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Loop.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Loop.hs @@ -136,11 +136,7 @@ checkLoop :: checkLoop checkExp (mergepat, mergeexp, form, loopbody) loc = do mergeexp' <- checkExp mergeexp known_before <- M.keysSet <$> getConstraints - zeroOrderType - (mkUsage mergeexp "use as loop variable") - "type used as loop variable" - . toStruct - =<< expTypeFully mergeexp' + mustBeOrderZero (locOf mergeexp) =<< expTypeFully mergeexp' -- The handling of dimension sizes is a bit intricate, but very -- similar to checking a function, followed by checking a call to diff --git a/src/Language/Futhark/TypeChecker/Terms/Monad.hs b/src/Language/Futhark/TypeChecker/Terms/Monad.hs index 4d9be5deff..36d9c8bade 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Monad.hs @@ -29,6 +29,8 @@ module Language.Futhark.TypeChecker.Terms.Monad replaceTyVars, updateTypes, Names, + mustBeOrderZero, + mustBeUnlifted, -- * Primitive checking unifies, @@ -618,6 +620,33 @@ updateTypes = astMap tv mapOnResRetType = normTypeFully } +mustBeOrderZero :: Loc -> StructType -> TermTypeM () +mustBeOrderZero loc t = do + constraints <- getConstraints + let liftedType v = + case M.lookup v constraints of + Just (_, ParamType Lifted _) -> True + _ -> False + when (not (orderZero t) || any liftedType (typeVars t)) $ + typeError loc mempty $ + textwrap "This expression may not be of function type, but is inferred to be of type" + indent 2 (align (pretty t)) + "which may be a function." + +mustBeUnlifted :: Loc -> StructType -> TermTypeM () +mustBeUnlifted loc t = do + constraints <- getConstraints + let liftedType v = + case M.lookup v constraints of + Just (_, ParamType Lifted _) -> True + Just (_, ParamType SizeLifted _) -> True + _ -> False + when (not (orderZero t) || any liftedType (typeVars t)) $ + typeError loc mempty $ + textwrap "This expression must be of unlifted type, but is inferred to be of type" + indent 2 (align (pretty t)) + "which may be a function or a value with hidden sizes." + --- Basic checking unifies :: T.Text -> StructType -> Exp -> TermTypeM Exp diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index e58154801e..8fabb67962 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -1191,10 +1191,13 @@ checkExp (AppExp (If e1 e2 e3 loc) _) = do checkExp (AppExp (Match e cs loc) _) = do e' <- checkExp e e_t <- expType e' - (cs', t) <- checkCases e_t cs - t' <- asStructType t - pure $ AppExp (Match e' cs' loc) (Info $ AppRes t' []) + + match_t <- newType loc SizeLifted "match_t" NoUniqueness + ctEq (Reason (locOf loc)) match_t t + + match_t' <- asStructType match_t + pure $ AppExp (Match e' cs' loc) (Info $ AppRes match_t' []) -- checkExp (AppExp (Loop _ pat arg form body loc) _) = do arg' <- checkExp arg diff --git a/src/Language/Futhark/TypeChecker/Unify.hs b/src/Language/Futhark/TypeChecker/Unify.hs index 8bd5b68af7..7d0f39bf91 100644 --- a/src/Language/Futhark/TypeChecker/Unify.hs +++ b/src/Language/Futhark/TypeChecker/Unify.hs @@ -14,7 +14,6 @@ module Language.Futhark.TypeChecker.Unify noBreadCrumbs, hasNoBreadCrumbs, dimNotes, - zeroOrderType, arrayElemType, normType, normTypeFully, @@ -665,40 +664,6 @@ linkVarToDim usage bcs vn lvl e = do _ -> modifyConstraints $ M.insert dim' (lvl, c) checkVar _ _ = pure () -zeroOrderTypeWith :: - (MonadUnify m) => - Usage -> - BreadCrumbs -> - StructType -> - m () -zeroOrderTypeWith usage bcs t = do - unless (orderZero t) $ - unifyError usage mempty bcs $ - "Type" indent 2 (pretty t) "found to be functional." - mapM_ mustBeZeroOrder . S.toList . typeVars =<< normType t - where - mustBeZeroOrder vn = do - constraints <- getConstraints - case M.lookup vn constraints of - Just (lvl, NoConstraint _ _) -> - modifyConstraints $ M.insert vn (lvl, NoConstraint Unlifted usage) - Just (_, ParamType Lifted ploc) -> - unifyError usage mempty bcs $ - "Type parameter" - <+> dquotes (prettyName vn) - <+> "at" - <+> pretty (locStr ploc) - <+> "may be a function." - _ -> pure () - --- | Assert that this type must be zero-order. -zeroOrderType :: - (MonadUnify m) => Usage -> T.Text -> StructType -> m () -zeroOrderType usage desc = - zeroOrderTypeWith usage $ breadCrumb bc noBreadCrumbs - where - bc = Matching $ "When checking" <+> textwrap desc - arrayElemTypeWith :: (MonadUnify m, Pretty (Shape dim), Pretty u) => Usage -> From 2807accc420ed887c92df1249fea4e48f549398f Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 4 Sep 2024 16:34:16 +0200 Subject: [PATCH 245/258] Work on error message. --- .../Futhark/TypeChecker/Constraints.hs | 34 ++++++++++++------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 1b66ef9b44..16b5650e94 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -481,49 +481,57 @@ unionTyVars reason v t = do isParam = error $ "Type name is a type parameter: " <> prettyNameString v -- Unify at the root, emitting new equalities that must hold. -unify :: Type -> Type -> Maybe [(Type, Type)] +unify :: Type -> Type -> Either (Doc a) [(Type, Type)] unify (Scalar (Prim pt1)) (Scalar (Prim pt2)) - | pt1 == pt2 = Just [] + | pt1 == pt2 = Right [] unify (Scalar (TypeVar _ (QualName _ v1) targs1)) (Scalar (TypeVar _ (QualName _ v2) targs2)) | v1 == v2 = - Just $ mapMaybe f $ zip targs1 targs2 + Right $ mapMaybe f $ zip targs1 targs2 where f (TypeArgType t1, TypeArgType t2) = Just (t1, t2) f _ = Nothing unify (Scalar (Arrow _ _ _ t1a (RetType _ t1r))) (Scalar (Arrow _ _ _ t2a (RetType _ t2r))) = - Just [(t1a, t2a), (t1r', t2r')] + Right [(t1a, t2a), (t1r', t2r')] where t1r' = t1r `setUniqueness` NoUniqueness t2r' = t2r `setUniqueness` NoUniqueness unify (Scalar (Record fs1)) (Scalar (Record fs2)) | M.keys fs1 == M.keys fs2 = - Just $ M.elems $ M.intersectionWith (,) fs1 fs2 + Right $ M.elems $ M.intersectionWith (,) fs1 fs2 + | otherwise = + let missing = + filter (`notElem` M.keys fs1) (M.keys fs2) + <> filter (`notElem` M.keys fs2) (M.keys fs1) + in Left $ + "Unshared fields:" <+> commasep (map pretty missing) <> "." unify (Scalar (Sum cs1)) (Scalar (Sum cs2)) | M.keys cs1 == M.keys cs2 = fmap concat . forM cs' $ \(ts1, ts2) -> do - guard $ length ts1 == length ts2 - Just $ zip ts1 ts2 + if length ts1 == length ts2 + then Right $ zip ts1 ts2 + else Left mempty where cs' = M.elems $ M.intersectionWith (,) cs1 cs2 unify t1 t2 | Just t1' <- peelArray 1 t1, Just t2' <- peelArray 1 t2 = - Just [(t1', t2')] -unify _ _ = Nothing + Right [(t1', t2')] +unify _ _ = Left mempty solveEq :: Reason -> Type -> Type -> SolveM () solveEq reason orig_t1 orig_t2 = do solveCt' (orig_t1, orig_t2) where - cannotUnify = do + cannotUnify details = do tyvars <- gets solverTyVars typeError (locOf reason) mempty $ "Cannot unify" indent 2 (pretty (substTyVars (substTyVar tyvars) orig_t1)) "with" indent 2 (pretty (substTyVars (substTyVar tyvars) orig_t2)) + <> details solveCt' (t1, t2) = do tyvars <- gets solverTyVars @@ -546,7 +554,7 @@ solveEq reason orig_t1 orig_t2 = do | v1 == v2 -> pure () | otherwise -> case (flexible v1, flexible v2) of - (False, False) -> cannotUnify + (False, False) -> cannotUnify mempty (True, False) -> subTyVar reason v1 t2' (False, True) -> subTyVar reason v2 t1' (True, True) -> unionTyVars reason v1 v2 @@ -555,8 +563,8 @@ solveEq reason orig_t1 orig_t2 = do (t1', Scalar (TypeVar _ (QualName [] v2) [])) | flexible v2 -> subTyVar reason v2 t1' (t1', t2') -> case unify t1' t2' of - Nothing -> cannotUnify - Just eqs -> mapM_ solveCt' eqs + Left details -> cannotUnify details + Right eqs -> mapM_ solveCt' eqs solveCt :: Ct -> SolveM () solveCt ct = From d8dbec2b5451be86b02d7c1cab436902f24ec3ad Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 25 Sep 2024 09:13:39 +0200 Subject: [PATCH 246/258] Fix some mistaken tests. --- tests/shapes/error6.fut | 2 +- tests/shapes/shape_duplicate.fut | 4 ++-- tests/shapes/size-inference2.fut | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/shapes/error6.fut b/tests/shapes/error6.fut index 3fda73dd6e..5c7332d94a 100644 --- a/tests/shapes/error6.fut +++ b/tests/shapes/error6.fut @@ -2,7 +2,7 @@ -- == -- error: "n" -def ap (f: (n: i64) -> [n]i32) (k: i64) : [k]i32 = +def ap (f: (n: i64) -> [n]i64) (k: i64) : [k]i64 = f k def main = ap (\n -> iota (n+1)) 10 diff --git a/tests/shapes/shape_duplicate.fut b/tests/shapes/shape_duplicate.fut index 3bbd5f391f..b29e1e7cbe 100644 --- a/tests/shapes/shape_duplicate.fut +++ b/tests/shapes/shape_duplicate.fut @@ -4,7 +4,7 @@ -- == -- error: do not match -def f [n][m] ((_, elems: [n]i32): (i32,[m]i32)): i32 = +def f [n][m] ((_, elems: [n]i64): (i64,[m]i64)): i64 = n + m + elems[0] -def main (x: i32, y: []i32): i32 = f (x, y) +def main (x: i64, y: []i64): i64 = f (x, y) diff --git a/tests/shapes/size-inference2.fut b/tests/shapes/size-inference2.fut index b6f59d4a9a..2804383f72 100644 --- a/tests/shapes/size-inference2.fut +++ b/tests/shapes/size-inference2.fut @@ -2,4 +2,4 @@ -- == -- error: Sizes.*do not match -def main [n] (xs: [n]i32) : [n]i32 = iota (length xs) +def main [n] (xs: [n]i32) : [n]i64 = iota (length xs) From 90d36781ff5c5e7da0fbca5bbc5c1f6c8caa9b5b Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 25 Sep 2024 09:15:18 +0200 Subject: [PATCH 247/258] Linebreak. --- src/Language/Futhark/TypeChecker/Constraints.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 16b5650e94..351e738f33 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -531,7 +531,7 @@ solveEq reason orig_t1 orig_t2 = do indent 2 (pretty (substTyVars (substTyVar tyvars) orig_t1)) "with" indent 2 (pretty (substTyVars (substTyVar tyvars) orig_t2)) - <> details + details solveCt' (t1, t2) = do tyvars <- gets solverTyVars From bc804c1cda489692196b6cba14c0adfba612500e Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 25 Sep 2024 09:19:36 +0200 Subject: [PATCH 248/258] Fix another test. --- tests/shapes/error4.fut | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/shapes/error4.fut b/tests/shapes/error4.fut index b842bdf44a..cf75bfe897 100644 --- a/tests/shapes/error4.fut +++ b/tests/shapes/error4.fut @@ -2,7 +2,7 @@ -- == -- error: Sizes.*"n".*do not match -def f (g: (n: i64) -> [n]i32) (l: i64): i32 = +def f (g: (n: i64) -> [n]i64) (l: i64): i64 = (g l)[0] def main = f (\n : []i64 -> iota (n+1)) From 2a0cc26e09032980080c2e041c551e9a7b3651db Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 25 Sep 2024 09:41:23 +0200 Subject: [PATCH 249/258] Fix more error messages. --- tests/issue1787.fut | 2 +- tests/issue514.fut | 2 +- tests/types/inference-error4.fut | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/issue1787.fut b/tests/issue1787.fut index 90cb01dd72..ed4aef3fae 100644 --- a/tests/issue1787.fut +++ b/tests/issue1787.fut @@ -1,5 +1,5 @@ -- == --- error: found to be functional +-- error: function type entry main: i32 -> i32 -> i32 = ((true, (.0)), (false, (.1))) diff --git a/tests/issue514.fut b/tests/issue514.fut index 2f70eca04f..057d69b71a 100644 --- a/tests/issue514.fut +++ b/tests/issue514.fut @@ -1,4 +1,4 @@ -- == --- error: issue514.fut:4:26-36 +-- error: issue514.fut:4:13-22 def main = (2.0 + 3.0) / (2 + 3i32) diff --git a/tests/types/inference-error4.fut b/tests/types/inference-error4.fut index 809b98302a..0ff781f33a 100644 --- a/tests/types/inference-error4.fut +++ b/tests/types/inference-error4.fut @@ -1,6 +1,6 @@ -- If something is used in a loop, it cannot later be inferred as a -- function. -- == --- error: functional +-- error: function type def f x = (loop x = x for i < 10 do x, x 2) From b77f1c2d742dace00425ea2746c6869966eaaec4 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 25 Sep 2024 10:57:28 +0200 Subject: [PATCH 250/258] Less weird. --- tests/ascription0.fut | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/ascription0.fut b/tests/ascription0.fut index 5aff8c054a..8c3a50e026 100644 --- a/tests/ascription0.fut +++ b/tests/ascription0.fut @@ -3,6 +3,6 @@ -- == -- error: match -def main(x: i32, y:i32): i32 = +def main(x: i32, y:i32): (bool,bool) = let (((a): i32), b: i32) : (bool,bool) = (x,y) in (a,b) From dee11882979616b2a8e7e8189ce3932914c0410f Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 25 Sep 2024 11:12:30 +0200 Subject: [PATCH 251/258] Introduce breadcrumbs in constraint solver. --- futhark.cabal | 1 + .../Futhark/TypeChecker/Constraints.hs | 66 ++++++++------- src/Language/Futhark/TypeChecker/Error.hs | 79 ++++++++++++++++++ .../Futhark/TypeChecker/Terms/Monad.hs | 1 + src/Language/Futhark/TypeChecker/Unify.hs | 83 ++++--------------- 5 files changed, 132 insertions(+), 98 deletions(-) create mode 100644 src/Language/Futhark/TypeChecker/Error.hs diff --git a/futhark.cabal b/futhark.cabal index 054e5e2917..bcd12b7a12 100644 --- a/futhark.cabal +++ b/futhark.cabal @@ -421,6 +421,7 @@ library Language.Futhark.TypeChecker Language.Futhark.TypeChecker.Consumption Language.Futhark.TypeChecker.Constraints + Language.Futhark.TypeChecker.Error Language.Futhark.TypeChecker.Names Language.Futhark.TypeChecker.Match Language.Futhark.TypeChecker.Modules diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 351e738f33..9c9e7869c9 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -27,6 +27,7 @@ import Data.Maybe import Data.Set qualified as S import Futhark.Util.Pretty import Language.Futhark +import Language.Futhark.TypeChecker.Error import Language.Futhark.TypeChecker.Monad (Notes, TypeError (..)) import Language.Futhark.TypeChecker.Types (substTyVars) @@ -237,13 +238,14 @@ occursCheck reason v tp = do unifySharedConstructors :: Reason -> + BreadCrumbs -> M.Map Name [Type] -> M.Map Name [Type] -> SolveM () -unifySharedConstructors reason cs1 cs2 = +unifySharedConstructors reason bcs cs1 cs2 = forM_ (M.toList $ M.intersectionWith (,) cs1 cs2) $ \(c, (ts1, ts2)) -> if length ts1 == length ts2 - then zipWithM_ (solveEq reason) ts1 ts2 + then zipWithM_ (solveEq reason bcs) ts1 ts2 else typeError (locOf reason) mempty $ "Cannot unify type with constructor" @@ -254,12 +256,13 @@ unifySharedConstructors reason cs1 cs2 = unifySharedFields :: Reason -> + BreadCrumbs -> M.Map Name Type -> M.Map Name Type -> SolveM () -unifySharedFields reason fs1 fs2 = +unifySharedFields reason bcs fs1 fs2 = forM_ (M.toList $ M.intersectionWith (,) fs1 fs2) $ \(_f, (ts1, ts2)) -> - solveEq reason ts1 ts2 + solveEq reason bcs ts1 ts2 mustSupportEql :: Reason -> Type -> SolveM () mustSupportEql _reason _t = pure () @@ -277,8 +280,8 @@ scopeViolation reason v1 ty v2 = <+> "is rigidly bound in a deeper scope." -- Precondition: 'v' is currently flexible. -subTyVar :: Reason -> VName -> Type -> SolveM () -subTyVar reason v t = do +subTyVar :: Reason -> BreadCrumbs -> VName -> Type -> SolveM () +subTyVar reason bcs v t = do occursCheck reason v t v_info <- gets $ M.lookup v . solverTyVars @@ -304,7 +307,7 @@ subTyVar reason v t = do Scalar (Sum cs2) ) -> if all (`elem` M.keys cs2) (M.keys cs1) - then unifySharedConstructors reason cs1 cs2 + then unifySharedConstructors reason bcs cs1 cs2 else typeError (locOf reason) mempty $ "Cannot unify type with constructors" @@ -323,7 +326,7 @@ subTyVar reason v t = do Scalar (Record fs2) ) -> if all (`elem` M.keys fs2) (M.keys fs1) - then unifySharedFields reason fs1 fs2 + then unifySharedFields reason bcs fs1 fs2 else typeError (locOf reason) mempty $ "Cannot unify record type with fields" @@ -352,8 +355,8 @@ subTyVar reason v t = do error $ "subTyVar: Nothing v: " <> prettyNameString v -- Precondition: 'v' and 't' are both currently flexible. -unionTyVars :: Reason -> VName -> VName -> SolveM () -unionTyVars reason v t = do +unionTyVars :: Reason -> BreadCrumbs -> VName -> VName -> SolveM () +unionTyVars reason bcs v t = do v_info <- gets $ either alreadyLinked id . fromMaybe unknown . M.lookup v . solverTyVars t_info <- lookupTyVarInfo t @@ -411,7 +414,7 @@ unionTyVars reason v t = do ( TyVarUnsol (TyVarSum _ cs1), TyVarSum loc cs2 ) -> do - unifySharedConstructors reason cs1 cs2 + unifySharedConstructors reason bcs cs1 cs2 let cs3 = cs1 <> cs2 setInfo t (TyVarUnsol (TyVarSum loc cs3)) ( TyVarUnsol TyVarSum {}, @@ -437,7 +440,7 @@ unionTyVars reason v t = do ( TyVarUnsol (TyVarRecord _ fs1), TyVarRecord loc fs2 ) -> do - unifySharedFields reason fs1 fs2 + unifySharedFields reason bcs fs1 fs2 let fs3 = fs1 <> fs2 setInfo t (TyVarUnsol (TyVarRecord loc fs3)) ( TyVarUnsol TyVarRecord {}, @@ -481,7 +484,7 @@ unionTyVars reason v t = do isParam = error $ "Type name is a type parameter: " <> prettyNameString v -- Unify at the root, emitting new equalities that must hold. -unify :: Type -> Type -> Either (Doc a) [(Type, Type)] +unify :: Type -> Type -> Either (Doc a) [(BreadCrumbs, (Type, Type))] unify (Scalar (Prim pt1)) (Scalar (Prim pt2)) | pt1 == pt2 = Right [] unify @@ -490,16 +493,19 @@ unify | v1 == v2 = Right $ mapMaybe f $ zip targs1 targs2 where - f (TypeArgType t1, TypeArgType t2) = Just (t1, t2) + f (TypeArgType t1, TypeArgType t2) = Just (mempty, (t1, t2)) f _ = Nothing unify (Scalar (Arrow _ _ _ t1a (RetType _ t1r))) (Scalar (Arrow _ _ _ t2a (RetType _ t2r))) = - Right [(t1a, t2a), (t1r', t2r')] + Right [(mempty, (t1a, t2a)), (mempty, (t1r', t2r'))] where t1r' = t1r `setUniqueness` NoUniqueness t2r' = t2r `setUniqueness` NoUniqueness unify (Scalar (Record fs1)) (Scalar (Record fs2)) | M.keys fs1 == M.keys fs2 = - Right $ M.elems $ M.intersectionWith (,) fs1 fs2 + Right $ + map (first matchingField) $ + M.toList $ + M.intersectionWith (,) fs1 fs2 | otherwise = let missing = filter (`notElem` M.keys fs1) (M.keys fs2) @@ -510,19 +516,19 @@ unify (Scalar (Sum cs1)) (Scalar (Sum cs2)) | M.keys cs1 == M.keys cs2 = fmap concat . forM cs' $ \(ts1, ts2) -> do if length ts1 == length ts2 - then Right $ zip ts1 ts2 + then Right $ zipWith (curry (mempty,)) ts1 ts2 else Left mempty where cs' = M.elems $ M.intersectionWith (,) cs1 cs2 unify t1 t2 | Just t1' <- peelArray 1 t1, Just t2' <- peelArray 1 t2 = - Right [(t1', t2')] + Right [(mempty, (t1', t2'))] unify _ _ = Left mempty -solveEq :: Reason -> Type -> Type -> SolveM () -solveEq reason orig_t1 orig_t2 = do - solveCt' (orig_t1, orig_t2) +solveEq :: Reason -> BreadCrumbs -> Type -> Type -> SolveM () +solveEq reason obcs orig_t1 orig_t2 = do + solveCt' (obcs, (orig_t1, orig_t2)) where cannotUnify details = do tyvars <- gets solverTyVars @@ -533,7 +539,7 @@ solveEq reason orig_t1 orig_t2 = do indent 2 (pretty (substTyVars (substTyVar tyvars) orig_t2)) details - solveCt' (t1, t2) = do + solveCt' (bcs, (t1, t2)) = do tyvars <- gets solverTyVars let flexible v = case M.lookup v tyvars of Just (Left v') -> flexible v' @@ -554,22 +560,22 @@ solveEq reason orig_t1 orig_t2 = do | v1 == v2 -> pure () | otherwise -> case (flexible v1, flexible v2) of - (False, False) -> cannotUnify mempty - (True, False) -> subTyVar reason v1 t2' - (False, True) -> subTyVar reason v2 t1' - (True, True) -> unionTyVars reason v1 v2 + (False, False) -> cannotUnify $ pretty bcs + (True, False) -> subTyVar reason bcs v1 t2' + (False, True) -> subTyVar reason bcs v2 t1' + (True, True) -> unionTyVars reason bcs v1 v2 (Scalar (TypeVar _ (QualName [] v1) []), t2') - | flexible v1 -> subTyVar reason v1 t2' + | flexible v1 -> subTyVar reason bcs v1 t2' (t1', Scalar (TypeVar _ (QualName [] v2) [])) - | flexible v2 -> subTyVar reason v2 t1' + | flexible v2 -> subTyVar reason bcs v2 t1' (t1', t2') -> case unify t1' t2' of - Left details -> cannotUnify details + Left details -> cannotUnify $ pretty bcs details Right eqs -> mapM_ solveCt' eqs solveCt :: Ct -> SolveM () solveCt ct = case ct of - CtEq reason t1 t2 -> solveEq reason t1 t2 + CtEq reason t1 t2 -> solveEq reason mempty t1 t2 CtAM {} -> pure () -- Good vibes only. scopeCheck :: Reason -> TyVar -> Int -> Type -> SolveM () diff --git a/src/Language/Futhark/TypeChecker/Error.hs b/src/Language/Futhark/TypeChecker/Error.hs new file mode 100644 index 0000000000..d4fbc70aad --- /dev/null +++ b/src/Language/Futhark/TypeChecker/Error.hs @@ -0,0 +1,79 @@ +-- | Fundamental facilities for constructing type error messages. +module Language.Futhark.TypeChecker.Error + ( -- * Breadcrumbs + BreadCrumbs, + hasNoBreadCrumbs, + matchingField, + matchingConstructor, + matchingTypes, + matching, + ) +where + +import Futhark.Util.Pretty +import Language.Futhark + +-- | A piece of information that describes what process the type +-- checker currently performing. This is used to give better error +-- messages for unification errors. +data BreadCrumb + = MatchingTypes StructType StructType + | MatchingFields [Name] + | MatchingConstructor Name + | Matching (Doc ()) + +instance Pretty BreadCrumb where + pretty (MatchingTypes t1 t2) = + "When matching type" + indent 2 (pretty t1) + "with" + indent 2 (pretty t2) + pretty (MatchingFields fields) = + "When matching types of record field" + <+> dquotes (mconcat $ punctuate "." $ map pretty fields) + <> dot + pretty (MatchingConstructor c) = + "When matching types of constructor" <+> dquotes (pretty c) <> dot + pretty (Matching s) = + unAnnotate s + +-- | Unification failures can occur deep down inside complicated types +-- (consider nested records). We leave breadcrumbs behind us so we can +-- report the path we took to find the mismatch. When combining +-- breadcrumbs with the 'Semigroup' instance, put the innermost +-- breadcrumbs to the left. +newtype BreadCrumbs = BreadCrumbs [BreadCrumb] + +instance Semigroup BreadCrumbs where + BreadCrumbs (MatchingFields xs : bcs1) <> BreadCrumbs (MatchingFields ys : bcs2) = + BreadCrumbs $ MatchingFields (ys <> xs) : bcs1 <> bcs2 + BreadCrumbs bcs1 <> BreadCrumbs bcs2 = + BreadCrumbs $ bcs1 <> bcs2 + +instance Monoid BreadCrumbs where + mempty = BreadCrumbs [] + +-- | Is the path empty? +hasNoBreadCrumbs :: BreadCrumbs -> Bool +hasNoBreadCrumbs (BreadCrumbs []) = True +hasNoBreadCrumbs _ = False + +-- | Matching a record field. +matchingField :: Name -> BreadCrumbs +matchingField f = BreadCrumbs [MatchingFields [f]] + +-- | Matching two types. +matchingTypes :: StructType -> StructType -> BreadCrumbs +matchingTypes t1 t2 = BreadCrumbs [MatchingTypes t1 t2] + +-- | Matching a constructor. +matchingConstructor :: Name -> BreadCrumbs +matchingConstructor c = BreadCrumbs [MatchingConstructor c] + +-- | Matching anything. +matching :: Doc () -> BreadCrumbs +matching d = BreadCrumbs [Matching d] + +instance Pretty BreadCrumbs where + pretty (BreadCrumbs []) = mempty + pretty (BreadCrumbs bcs) = line <> stack (map pretty bcs) diff --git a/src/Language/Futhark/TypeChecker/Terms/Monad.hs b/src/Language/Futhark/TypeChecker/Terms/Monad.hs index 36d9c8bade..ba30adbaeb 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Monad.hs @@ -66,6 +66,7 @@ import Futhark.Util.Pretty hiding (space) import Language.Futhark import Language.Futhark.Traversals import Language.Futhark.TypeChecker.Constraints (TyVar) +import Language.Futhark.TypeChecker.Error import Language.Futhark.TypeChecker.Monad hiding (BoundV, lookupMod, stateNameSource) import Language.Futhark.TypeChecker.Monad qualified as TypeM import Language.Futhark.TypeChecker.Types diff --git a/src/Language/Futhark/TypeChecker/Unify.hs b/src/Language/Futhark/TypeChecker/Unify.hs index 7d0f39bf91..d7898df2c1 100644 --- a/src/Language/Futhark/TypeChecker/Unify.hs +++ b/src/Language/Futhark/TypeChecker/Unify.hs @@ -10,9 +10,6 @@ module Language.Futhark.TypeChecker.Unify MonadUnify (..), Rigidity (..), RigidSource (..), - BreadCrumbs, - noBreadCrumbs, - hasNoBreadCrumbs, dimNotes, arrayElemType, normType, @@ -33,57 +30,10 @@ import Data.Set qualified as S import Data.Text qualified as T import Futhark.Util.Pretty import Language.Futhark +import Language.Futhark.TypeChecker.Error import Language.Futhark.TypeChecker.Monad hiding (BoundV) import Language.Futhark.TypeChecker.Types --- | A piece of information that describes what process the type --- checker currently performing. This is used to give better error --- messages for unification errors. -data BreadCrumb - = MatchingTypes StructType StructType - | MatchingFields [Name] - | MatchingConstructor Name - | Matching (Doc ()) - -instance Pretty BreadCrumb where - pretty (MatchingTypes t1 t2) = - "When matching type" - indent 2 (pretty t1) - "with" - indent 2 (pretty t2) - pretty (MatchingFields fields) = - "When matching types of record field" - <+> dquotes (mconcat $ punctuate "." $ map pretty fields) - <> dot - pretty (MatchingConstructor c) = - "When matching types of constructor" <+> dquotes (pretty c) <> dot - pretty (Matching s) = - unAnnotate s - --- | Unification failures can occur deep down inside complicated types --- (consider nested records). We leave breadcrumbs behind us so we --- can report the path we took to find the mismatch. -newtype BreadCrumbs = BreadCrumbs [BreadCrumb] - --- | An empty path. -noBreadCrumbs :: BreadCrumbs -noBreadCrumbs = BreadCrumbs [] - --- | Is the path empty? -hasNoBreadCrumbs :: BreadCrumbs -> Bool -hasNoBreadCrumbs (BreadCrumbs xs) = null xs - --- | Drop a breadcrumb on the path behind you. -breadCrumb :: BreadCrumb -> BreadCrumbs -> BreadCrumbs -breadCrumb (MatchingFields xs) (BreadCrumbs (MatchingFields ys : bcs)) = - BreadCrumbs $ MatchingFields (ys ++ xs) : bcs -breadCrumb bc (BreadCrumbs bcs) = - BreadCrumbs $ bc : bcs - -instance Pretty BreadCrumbs where - pretty (BreadCrumbs []) = mempty - pretty (BreadCrumbs bcs) = line <> stack (map pretty bcs) - -- | A usage that caused a type constraint. data Usage = Usage (Maybe T.Text) Loc deriving (Show) @@ -387,7 +337,7 @@ unifyWith onDims usage = subunify False ) | tn == arg_tn, length targs == length arg_targs -> do - let bcs' = breadCrumb (Matching "When matching type arguments.") bcs + let bcs' = matching "When matching type arguments." <> bcs zipWithM_ (unifyTypeArg bcs') targs arg_targs ( Scalar (TypeVar _ (QualName [] v1) []), Scalar (TypeVar _ (QualName [] v2) []) @@ -439,13 +389,13 @@ unifyWith onDims usage = subunify False subunify (not ord) bound - (breadCrumb (Matching "When matching parameter types.") bcs) + (matching "When matching parameter types." <> bcs) a1 a2 subunify ord bound' - (breadCrumb (Matching "When matching return types.") bcs) + (matching "When matching return types." <> bcs) (toStruct b1') (toStruct b2') @@ -511,7 +461,7 @@ unifySizes usage bcs _ _ e1 e2 = do -- | Unifies two types. unify :: (MonadUnify m) => Usage -> StructType -> StructType -> m () -unify usage = unifyWith (unifySizes usage) usage mempty noBreadCrumbs +unify usage = unifyWith (unifySizes usage) usage mempty mempty occursCheck :: (MonadUnify m) => @@ -597,14 +547,13 @@ linkVarToType usage bound bcs vn lvl tp_unnorm = do <> " used as size(s) would go out of scope." let unliftedBcs unlifted_usage = - breadCrumb - ( Matching $ - "When verifying that" - <+> dquotes (prettyName vn) - <+> textwrap "is not instantiated with a function type, due to" - <+> pretty unlifted_usage + matching + ( "When verifying that" + <+> dquotes (prettyName vn) + <+> textwrap "is not instantiated with a function type, due to" + <+> pretty unlifted_usage ) - bcs + <> bcs case snd <$> M.lookup vn constraints of Just (NoConstraint Unlifted unlift_usage) -> do @@ -699,9 +648,7 @@ arrayElemType :: TypeBase dim u -> m () arrayElemType usage desc = - arrayElemTypeWith usage $ breadCrumb bc noBreadCrumbs - where - bc = Matching $ "When checking" <+> textwrap desc + arrayElemTypeWith usage $ matching $ "When checking" <+> textwrap desc unifySharedFields :: (MonadUnify m) => @@ -714,7 +661,7 @@ unifySharedFields :: m () unifySharedFields onDims usage bound bcs fs1 fs2 = forM_ (M.toList $ M.intersectionWith (,) fs1 fs2) $ \(f, (t1, t2)) -> - unifyWith onDims usage bound (breadCrumb (MatchingFields [f]) bcs) t1 t2 + unifyWith onDims usage bound (matchingField f <> bcs) t1 t2 unifySharedConstructors :: (MonadUnify m) => @@ -731,7 +678,7 @@ unifySharedConstructors onDims usage bound bcs cs1 cs2 = where unifyConstructor c f1 f2 | length f1 == length f2 = do - let bcs' = breadCrumb (MatchingConstructor c) bcs + let bcs' = matchingConstructor c <> bcs zipWithM_ (unifyWith onDims usage bound bcs') f1 f2 | otherwise = unifyError usage mempty bcs $ @@ -796,7 +743,7 @@ unifyMostCommon usage t1 t2 = do linkVarToDim usage bcs (qualLeaf v2) lvl2 e1 onDims _ _ _ _ _ = pure () - unifyWith onDims usage mempty noBreadCrumbs t1 t2 + unifyWith onDims usage mempty mempty t1 t2 t1' <- normTypeFully t1 t2' <- normTypeFully t2 newDimOnMismatch (locOf usage) t1' t2' From e098b6d7077bc7fa63fed0df4423f530e613ca05 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 25 Sep 2024 11:32:13 +0200 Subject: [PATCH 252/258] A bit more work. --- src/Language/Futhark/TypeChecker/Constraints.hs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 9c9e7869c9..de304a934c 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -261,8 +261,8 @@ unifySharedFields :: M.Map Name Type -> SolveM () unifySharedFields reason bcs fs1 fs2 = - forM_ (M.toList $ M.intersectionWith (,) fs1 fs2) $ \(_f, (ts1, ts2)) -> - solveEq reason bcs ts1 ts2 + forM_ (M.toList $ M.intersectionWith (,) fs1 fs2) $ \(f, (ts1, ts2)) -> + solveEq reason (matchingField f <> bcs) ts1 ts2 mustSupportEql :: Reason -> Type -> SolveM () mustSupportEql _reason _t = pure () @@ -355,6 +355,9 @@ subTyVar reason bcs v t = do error $ "subTyVar: Nothing v: " <> prettyNameString v -- Precondition: 'v' and 't' are both currently flexible. +-- +-- The purpose of this function is to combine the partial knowledge we +-- may have about these two type variables. unionTyVars :: Reason -> BreadCrumbs -> VName -> VName -> SolveM () unionTyVars reason bcs v t = do v_info <- gets $ either alreadyLinked id . fromMaybe unknown . M.lookup v . solverTyVars From 7210e8f93cb89906207159c4558c2b4cb5a73537 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sat, 5 Oct 2024 09:50:22 +0200 Subject: [PATCH 253/258] Better error for tuple mismatches. --- src/Language/Futhark/TypeChecker/Constraints.hs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index de304a934c..51fe1140b0 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -509,6 +509,15 @@ unify (Scalar (Record fs1)) (Scalar (Record fs2)) map (first matchingField) $ M.toList $ M.intersectionWith (,) fs1 fs2 + | Just n1 <- length <$> areTupleFields fs1, + Just n2 <- length <$> areTupleFields fs2, + n1 /= n2 = + Left $ + "Tuples have" + <+> pretty n1 + <+> "and" + <+> pretty n2 + <+> "elements respectively." | otherwise = let missing = filter (`notElem` M.keys fs1) (M.keys fs2) From 67cdaf99ccea330f88b9f0ebf1133738842ee1f0 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sat, 5 Oct 2024 10:55:41 +0200 Subject: [PATCH 254/258] Better reasons. --- .../Futhark/TypeChecker/Constraints.hs | 64 +++++++++++++------ src/Language/Futhark/TypeChecker/Terms2.hs | 12 +++- 2 files changed, 53 insertions(+), 23 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 51fe1140b0..2727340a1f 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -28,19 +28,9 @@ import Data.Set qualified as S import Futhark.Util.Pretty import Language.Futhark import Language.Futhark.TypeChecker.Error -import Language.Futhark.TypeChecker.Monad (Notes, TypeError (..)) +import Language.Futhark.TypeChecker.Monad (Notes, TypeError (..), aNote) import Language.Futhark.TypeChecker.Types (substTyVars) --- | The reason for a type constraint. Used to generate type error --- messages. -newtype Reason = Reason - { reasonLoc :: Loc - } - deriving (Eq, Ord, Show) - -instance Located Reason where - locOf = reasonLoc - type SVar = VName -- | A shape component. `SDim` is a single dimension of unspecified @@ -68,6 +58,22 @@ type Type = TypeBase SComp NoUniqueness toType :: TypeBase Size u -> TypeBase SComp u toType = first (const SDim) +-- | The reason for a type constraint. Used to generate type error +-- messages. +data Reason + = -- | No particular reason. + Reason Loc + | -- | Arising from pattern match. + ReasonPatMatch Loc (PatBase NoInfo VName ParamType) Type + | -- | Arising from explicit ascription. + ReasonAscription Loc Type Type + deriving (Show) + +instance Located Reason where + locOf (Reason l) = l + locOf (ReasonPatMatch l _ _) = l + locOf (ReasonAscription l _ _) = l + data Ct = CtEq Reason Type Type | CtAM Reason SVar SVar (Shape SComp) @@ -542,14 +548,32 @@ solveEq :: Reason -> BreadCrumbs -> Type -> Type -> SolveM () solveEq reason obcs orig_t1 orig_t2 = do solveCt' (obcs, (orig_t1, orig_t2)) where - cannotUnify details = do + cannotUnify notes bcs t1 t2 = do tyvars <- gets solverTyVars - typeError (locOf reason) mempty $ - "Cannot unify" - indent 2 (pretty (substTyVars (substTyVar tyvars) orig_t1)) - "with" - indent 2 (pretty (substTyVars (substTyVar tyvars) orig_t2)) - details + case reason of + ReasonPatMatch loc pat value_t -> + typeError loc notes . stack $ + [ "Pattern", + indent 2 $ align $ pretty pat, + "cannot match value of type", + indent 2 $ align $ pretty value_t + ] + <> [pretty bcs | not $ hasNoBreadCrumbs bcs] + ReasonAscription loc expected actual -> + typeError loc notes . stack $ + [ "Expression does not have expected type from type ascription.", + "Expected:" <+> align (pretty expected), + "Actual: " <+> align (pretty actual) + ] + <> [pretty bcs | not $ hasNoBreadCrumbs bcs] + Reason loc -> + typeError loc notes . stack $ + [ "Cannot unify", + indent 2 (pretty (substTyVars (substTyVar tyvars) t1)), + "with", + indent 2 (pretty (substTyVars (substTyVar tyvars) t2)) + ] + <> [pretty bcs | not $ hasNoBreadCrumbs bcs] solveCt' (bcs, (t1, t2)) = do tyvars <- gets solverTyVars @@ -572,7 +596,7 @@ solveEq reason obcs orig_t1 orig_t2 = do | v1 == v2 -> pure () | otherwise -> case (flexible v1, flexible v2) of - (False, False) -> cannotUnify $ pretty bcs + (False, False) -> cannotUnify mempty bcs t1 t2 (True, False) -> subTyVar reason bcs v1 t2' (False, True) -> subTyVar reason bcs v2 t1' (True, True) -> unionTyVars reason bcs v1 v2 @@ -581,7 +605,7 @@ solveEq reason obcs orig_t1 orig_t2 = do (t1', Scalar (TypeVar _ (QualName [] v2) [])) | flexible v2 -> subTyVar reason bcs v2 t1' (t1', t2') -> case unify t1' t2' of - Left details -> cannotUnify $ pretty bcs details + Left details -> cannotUnify (aNote details) bcs t1' t2' Right eqs -> mapM_ solveCt' eqs solveCt :: Ct -> SolveM () diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 8fabb67962..79013c727e 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -510,7 +510,10 @@ checkPat' (TuplePat ps loc) (Ascribed t) <*> pure loc | otherwise = do ps_tvs <- replicateM (length ps) (newTyVar loc Lifted "t") - ctEq (Reason (locOf loc)) (Scalar (tupleRecord $ map (tyVarType NoUniqueness) ps_tvs)) t + ctEq + (ReasonPatMatch (locOf loc) (TuplePat ps loc) (toStruct t)) + (Scalar (tupleRecord $ map (tyVarType NoUniqueness) ps_tvs)) + t TuplePat <$> zipWithM checkPat' ps (map (Ascribed . tyVarType Observe) ps_tvs) <*> pure loc checkPat' (TuplePat ps loc) NoneInferred = TuplePat <$> mapM (`checkPat'` NoneInferred) ps <*> pure loc @@ -551,7 +554,10 @@ checkPat' (PatAscription p t loc) maybe_outer_t = do case maybe_outer_t of Ascribed outer_t -> do - ctEq (Reason (locOf loc)) st' outer_t + ctEq + (ReasonAscription (locOf loc) (toStruct st') (toStruct outer_t)) + st' + outer_t PatAscription <$> checkPat' p (Ascribed st') <*> pure t' @@ -1238,7 +1244,7 @@ checkExp (Ascript e te loc) = do (te', _, RetType _ st, _) <- checkTypeExp checkSizeExp' te e_t <- expType e' st' <- asType st - ctEq (Reason (locOf e')) e_t st' + ctEq (ReasonAscription (locOf e') (toStruct st') (toStruct e_t)) e_t st' pure $ Ascript e' te' loc checkExp (Coerce e te NoInfo loc) = do e' <- checkExp e From 59ddb21561918612f8a5fa21cc900a9eb9e49d0d Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 6 Oct 2024 10:24:46 +0200 Subject: [PATCH 255/258] Constructor match. --- src/Language/Futhark/TypeChecker/Constraints.hs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 2727340a1f..4835c3fd2b 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -251,7 +251,7 @@ unifySharedConstructors :: unifySharedConstructors reason bcs cs1 cs2 = forM_ (M.toList $ M.intersectionWith (,) cs1 cs2) $ \(c, (ts1, ts2)) -> if length ts1 == length ts2 - then zipWithM_ (solveEq reason bcs) ts1 ts2 + then zipWithM_ (solveEq reason $ matchingConstructor c <> bcs) ts1 ts2 else typeError (locOf reason) mempty $ "Cannot unify type with constructor" @@ -532,12 +532,12 @@ unify (Scalar (Record fs1)) (Scalar (Record fs2)) "Unshared fields:" <+> commasep (map pretty missing) <> "." unify (Scalar (Sum cs1)) (Scalar (Sum cs2)) | M.keys cs1 == M.keys cs2 = - fmap concat . forM cs' $ \(ts1, ts2) -> do + fmap concat . forM cs' $ \(c, (ts1, ts2)) -> do if length ts1 == length ts2 - then Right $ zipWith (curry (mempty,)) ts1 ts2 + then Right $ zipWith (curry (matchingConstructor c,)) ts1 ts2 else Left mempty where - cs' = M.elems $ M.intersectionWith (,) cs1 cs2 + cs' = M.toList $ M.intersectionWith (,) cs1 cs2 unify t1 t2 | Just t1' <- peelArray 1 t1, Just t2' <- peelArray 1 t2 = From 946ae65868937a25f988c053ecb28cafa14c3cc4 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 6 Oct 2024 10:51:40 +0200 Subject: [PATCH 256/258] More reasons. --- .../Futhark/TypeChecker/Constraints.hs | 87 ++++++++++++------- src/Language/Futhark/TypeChecker/Terms2.hs | 8 +- 2 files changed, 60 insertions(+), 35 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 4835c3fd2b..1075c16823 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -59,7 +59,7 @@ toType :: TypeBase Size u -> TypeBase SComp u toType = first (const SDim) -- | The reason for a type constraint. Used to generate type error --- messages. +-- messages. The expected type is always the first one. data Reason = -- | No particular reason. Reason Loc @@ -67,12 +67,14 @@ data Reason ReasonPatMatch Loc (PatBase NoInfo VName ParamType) Type | -- | Arising from explicit ascription. ReasonAscription Loc Type Type - deriving (Show) + | ReasonRetType Loc Type Type + deriving (Eq, Show) instance Located Reason where locOf (Reason l) = l locOf (ReasonPatMatch l _ _) = l locOf (ReasonAscription l _ _) = l + locOf (ReasonRetType l _ _) = l data Ct = CtEq Reason Type Type @@ -227,6 +229,12 @@ solution s = newtype SolveM a = SolveM {runSolveM :: StateT SolverState (Except TypeError) a} deriving (Functor, Applicative, Monad, MonadState SolverState, MonadError TypeError) +-- Try to substitute as much information as we have. +enrichType :: Type -> SolveM Type +enrichType t = do + s <- get + pure $ substTyVars (substTyVar (solverTyVars s)) t + typeError :: Loc -> Notes -> Doc () -> SolveM () typeError loc notes msg = throwError $ TypeError loc notes msg @@ -285,6 +293,50 @@ scopeViolation reason v1 ty v2 = <+> dquotes (prettyName v2) <+> "is rigidly bound in a deeper scope." +cannotUnify :: + Reason -> + Notes -> + BreadCrumbs -> + Type -> + Type -> + SolveM () +cannotUnify reason notes bcs t1 t2 = do + t1' <- enrichType t1 + t2' <- enrichType t2 + case reason of + ReasonPatMatch loc pat value_t -> + typeError loc notes . stack $ + [ "Pattern", + indent 2 $ align $ pretty pat, + "cannot match value of type", + indent 2 $ align $ pretty value_t + ] + <> [pretty bcs | not $ hasNoBreadCrumbs bcs] + ReasonAscription loc expected actual -> + typeError loc notes . stack $ + [ "Expression does not have expected type from type ascription.", + "Expected:" <+> align (pretty expected), + "Actual: " <+> align (pretty actual) + ] + <> [pretty bcs | not $ hasNoBreadCrumbs bcs] + ReasonRetType loc expected actual -> do + expected' <- enrichType expected + actual' <- enrichType actual + typeError loc notes . stack $ + [ "Function body does not have expected type.", + "Expected:" <+> align (pretty expected'), + "Actual: " <+> align (pretty actual') + ] + <> [pretty bcs | not $ hasNoBreadCrumbs bcs] + Reason loc -> + typeError loc notes . stack $ + [ "Cannot unify", + indent 2 (pretty t1'), + "with", + indent 2 (pretty t2') + ] + <> [pretty bcs | not $ hasNoBreadCrumbs bcs] + -- Precondition: 'v' is currently flexible. subTyVar :: Reason -> BreadCrumbs -> VName -> Type -> SolveM () subTyVar reason bcs v t = do @@ -548,33 +600,6 @@ solveEq :: Reason -> BreadCrumbs -> Type -> Type -> SolveM () solveEq reason obcs orig_t1 orig_t2 = do solveCt' (obcs, (orig_t1, orig_t2)) where - cannotUnify notes bcs t1 t2 = do - tyvars <- gets solverTyVars - case reason of - ReasonPatMatch loc pat value_t -> - typeError loc notes . stack $ - [ "Pattern", - indent 2 $ align $ pretty pat, - "cannot match value of type", - indent 2 $ align $ pretty value_t - ] - <> [pretty bcs | not $ hasNoBreadCrumbs bcs] - ReasonAscription loc expected actual -> - typeError loc notes . stack $ - [ "Expression does not have expected type from type ascription.", - "Expected:" <+> align (pretty expected), - "Actual: " <+> align (pretty actual) - ] - <> [pretty bcs | not $ hasNoBreadCrumbs bcs] - Reason loc -> - typeError loc notes . stack $ - [ "Cannot unify", - indent 2 (pretty (substTyVars (substTyVar tyvars) t1)), - "with", - indent 2 (pretty (substTyVars (substTyVar tyvars) t2)) - ] - <> [pretty bcs | not $ hasNoBreadCrumbs bcs] - solveCt' (bcs, (t1, t2)) = do tyvars <- gets solverTyVars let flexible v = case M.lookup v tyvars of @@ -596,7 +621,7 @@ solveEq reason obcs orig_t1 orig_t2 = do | v1 == v2 -> pure () | otherwise -> case (flexible v1, flexible v2) of - (False, False) -> cannotUnify mempty bcs t1 t2 + (False, False) -> cannotUnify reason mempty bcs t1 t2 (True, False) -> subTyVar reason bcs v1 t2' (False, True) -> subTyVar reason bcs v2 t1' (True, True) -> unionTyVars reason bcs v1 v2 @@ -605,7 +630,7 @@ solveEq reason obcs orig_t1 orig_t2 = do (t1', Scalar (TypeVar _ (QualName [] v2) [])) | flexible v2 -> subTyVar reason bcs v2 t1' (t1', t2') -> case unify t1' t2' of - Left details -> cannotUnify (aNote details) bcs t1' t2' + Left details -> cannotUnify reason (aNote details) bcs t1' t2' Right eqs -> mapM_ solveCt' eqs solveCt :: Ct -> SolveM () diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 79013c727e..40e26b7107 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -840,9 +840,9 @@ checkRetDecl body Nothing = (,Nothing) <$> expType body checkRetDecl body (Just te) = do (te', _, RetType _ st, _) <- checkTypeExp checkSizeExp' te body_t <- expType body - st' <- asType st - ctEq (Reason (locOf body)) body_t st' - pure (second (const NoUniqueness) st', Just te') + st' <- toStruct <$> asType st + ctEq (ReasonRetType (locOf body) st' body_t) st' body_t + pure (st', Just te') checkExp :: ExpBase NoInfo VName -> TermM (ExpBase Info VName) -- @@ -1372,7 +1372,7 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do bitraverse pure (fmap (second (onArtificial artificial)) . onTySolution params' body'') - $ solve cts' typarams tyvars' + $ solve (reverse cts') typarams tyvars' debugTraceM 3 $ unlines [ "## constraints:", From fba57820538e155f5a06a6f7b8b75dfb076c661d Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 13 Oct 2024 18:55:48 +0200 Subject: [PATCH 257/258] Function application reason. --- .../Futhark/TypeChecker/Constraints.hs | 35 +++++++++-- src/Language/Futhark/TypeChecker/Terms2.hs | 59 ++++++++++++------- 2 files changed, 66 insertions(+), 28 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 1075c16823..6ed3b64a84 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -68,6 +68,7 @@ data Reason | -- | Arising from explicit ascription. ReasonAscription Loc Type Type | ReasonRetType Loc Type Type + | ReasonApply Loc (Maybe (QualName VName)) Exp Type Type deriving (Eq, Show) instance Located Reason where @@ -75,6 +76,7 @@ instance Located Reason where locOf (ReasonPatMatch l _ _) = l locOf (ReasonAscription l _ _) = l locOf (ReasonRetType l _ _) = l + locOf (ReasonApply l _ _ _ _) = l data Ct = CtEq Reason Type Type @@ -304,6 +306,14 @@ cannotUnify reason notes bcs t1 t2 = do t1' <- enrichType t1 t2' <- enrichType t2 case reason of + Reason loc -> + typeError loc notes . stack $ + [ "Cannot unify", + indent 2 (pretty t1'), + "with", + indent 2 (pretty t2') + ] + <> [pretty bcs | not $ hasNoBreadCrumbs bcs] ReasonPatMatch loc pat value_t -> typeError loc notes . stack $ [ "Pattern", @@ -328,14 +338,27 @@ cannotUnify reason notes bcs t1 t2 = do "Actual: " <+> align (pretty actual') ] <> [pretty bcs | not $ hasNoBreadCrumbs bcs] - Reason loc -> + ReasonApply loc f e expected actual -> do + expected' <- enrichType expected + actual' <- enrichType actual typeError loc notes . stack $ - [ "Cannot unify", - indent 2 (pretty t1'), - "with", - indent 2 (pretty t2') + [ header, + "Expected:" <+> align (pretty expected'), + "Actual: " <+> align (pretty actual') ] - <> [pretty bcs | not $ hasNoBreadCrumbs bcs] + where + header = + case f of + Nothing -> + "Cannot apply function to" + <+> dquotes (shorten $ group $ pretty e) + <> " (invalid type)." + Just fname -> + "Cannot apply" + <+> dquotes (pretty fname) + <+> "to" + <+> dquotes (align $ shorten $ group $ pretty e) + <> " (invalid type)." -- Precondition: 'v' is currently flexible. subTyVar :: Reason -> BreadCrumbs -> VName -> Type -> SolveM () diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 40e26b7107..89b5320a67 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -678,8 +678,13 @@ bindParams tps orig_ps m = bindTypeParams tps $ do incLevel $ descend [] orig_ps -checkApplyOne :: SrcLoc -> (Maybe (QualName VName), Int) -> (Shape Size, Type) -> (Shape Size, Type) -> TermM (Type, AutoMap) -checkApplyOne loc fname (fframe, ftype) (argframe, argtype) = do +checkApplyOne :: + SrcLoc -> + (Maybe (QualName VName), Int) -> + (Shape Size, Type) -> + (Maybe Exp, Shape Size, Type) -> + TermM (Type, AutoMap) +checkApplyOne loc fname (fframe, ftype) (arg, argframe, argtype) = do (a, b) <- split ftype r <- newSVar loc "R" m <- newSVar loc "M" @@ -689,7 +694,11 @@ checkApplyOne loc fname (fframe, ftype) (argframe, argtype) = do lhs = arrayOf (toShape (SVar r)) argtype rhs = arrayOf (toShape (SVar m)) a ctAM (Reason (locOf loc)) r m $ fmap toSComp (toShape m_var <> fframe) - ctEq (Reason (locOf loc)) lhs rhs + let reason = case arg of + Just arg' -> + ReasonApply (locOf loc) (fst fname) arg' lhs rhs + Nothing -> Reason (locOf loc) + ctEq reason lhs rhs debugTraceM 3 $ unlines [ "## checkApplyOne", @@ -737,14 +746,14 @@ checkApply :: SrcLoc -> Maybe (QualName VName) -> (Shape Size, Type) -> - NE.NonEmpty (Shape Size, Type) -> + NE.NonEmpty (Maybe Exp, Shape Size, Type) -> TermM (Type, NE.NonEmpty AutoMap) checkApply loc fname (fframe, ftype) args = do ((_, _, rt), argts) <- mapAccumLM onArg (0, fframe, ftype) args pure (rt, argts) where - onArg (i, f_f, f_t) (argframe, argtype) = do - (rt, am) <- checkApplyOne loc (fname, i) (f_f, f_t) (argframe, argtype) + onArg (i, f_f, f_t) arg = do + (rt, am) <- checkApplyOne loc (fname, i) (f_f, f_t) arg pure ( (i + 1, autoFrame am, rt), am @@ -943,21 +952,17 @@ checkExp (Constr name es NoInfo loc) = do -- checkExp (AppExp (Apply fe args loc) NoInfo) = do fe' <- checkExp fe - (args', argts') <- - NE.unzip - <$> forM - args - ( \(_, arg) -> do - arg' <- checkExp arg - arg_t <- expType arg' - pure (arg', (frameOf arg', arg_t)) - ) + (args', apply_args) <- + fmap NE.unzip . forM args $ \(_, arg) -> do + arg' <- checkExp arg + arg_t <- expType arg' + pure (arg', (Just arg', frameOf arg', arg_t)) fe_t <- expType fe' - (rt, ams) <- checkApply loc fname (frameOf fe', fe_t) argts' + (rt, ams) <- checkApply loc fname (frameOf fe', fe_t) apply_args rt' <- asStructType rt - pure $ - AppExp (Apply fe' (NE.zipWith (\am arg -> (Info (Nothing, am), arg)) ams args') loc) $ - Info (AppRes rt' []) + let args'' = + NE.zipWith (\am arg -> (Info (Nothing, am), arg)) ams args' + pure $ AppExp (Apply fe' args'' loc) $ Info (AppRes rt' []) where fname = case fe of @@ -975,7 +980,7 @@ checkExp (AppExp (BinOp (op, oploc) NoInfo (e1, _) (e2, _) loc) NoInfo) = do loc (Just op) (mempty, ftype) - ((frameOf e1', e1_t) NE.:| [(frameOf e2', e2_t)]) + ((Just e1', frameOf e1', e1_t) NE.:| [(Just e2', frameOf e2', e2_t)]) rt' <- asStructType rt let (am1 NE.:| [am2]) = ams @@ -992,7 +997,12 @@ checkExp (OpSectionLeft op _ e _ _ loc) = do t2 <- newType loc Lifted "t" NoUniqueness t2' <- asStructType t2 let f1 = frameOf e' - (rt, ams) <- checkApply loc (Just op) (mempty, optype) ((f1, e_t) NE.:| [(mempty, t2)]) + (rt, ams) <- + checkApply + loc + (Just op) + (mempty, optype) + ((Just e', f1, e_t) NE.:| [(Nothing, mempty, t2)]) rt' <- asStructType rt let (am1 NE.:| _) = ams @@ -1015,7 +1025,12 @@ checkExp (OpSectionRight op _ e _ NoInfo loc) = do t1 <- newType loc Lifted "t" NoUniqueness t1' <- asStructType t1 let f2 = frameOf e' - (rt, ams) <- checkApply loc (Just op) (mempty, optype) ((mempty, t1) NE.:| [(f2, e_t)]) + (rt, ams) <- + checkApply + loc + (Just op) + (mempty, optype) + ((Nothing, mempty, t1) NE.:| [(Just e', f2, e_t)]) rt' <- asStructType rt let (_ NE.:| [am2]) = ams t2 <- asStructType e_t From 29293af45b6d9e33eddf5ef5bf53ca5c1d39ab3d Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 13 Oct 2024 19:17:39 +0200 Subject: [PATCH 258/258] Branch reasons. --- src/Language/Futhark/TypeChecker/Constraints.hs | 10 ++++++++++ src/Language/Futhark/TypeChecker/Terms2.hs | 6 +++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 6ed3b64a84..c8ceb92763 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -69,6 +69,7 @@ data Reason ReasonAscription Loc Type Type | ReasonRetType Loc Type Type | ReasonApply Loc (Maybe (QualName VName)) Exp Type Type + | ReasonBranches Loc Type Type deriving (Eq, Show) instance Located Reason where @@ -77,6 +78,7 @@ instance Located Reason where locOf (ReasonAscription l _ _) = l locOf (ReasonRetType l _ _) = l locOf (ReasonApply l _ _ _ _) = l + locOf (ReasonBranches l _ _) = l data Ct = CtEq Reason Type Type @@ -359,6 +361,14 @@ cannotUnify reason notes bcs t1 t2 = do <+> "to" <+> dquotes (align $ shorten $ group $ pretty e) <> " (invalid type)." + ReasonBranches loc former latter -> do + former' <- enrichType former + latter' <- enrichType latter + typeError loc notes . stack $ + [ "Branches differ in type.", + "Former:" <+> pretty former', + "Latter:" <+> pretty latter' + ] -- Precondition: 'v' is currently flexible. subTyVar :: Reason -> BreadCrumbs -> VName -> Type -> SolveM () diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 89b5320a67..41594c4282 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -810,7 +810,7 @@ checkCases mt rest_cs = (c, Just cs) -> do (c', c_t) <- checkCase mt c (cs', cs_t) <- checkCases mt cs - ctEq (Reason (locOf c)) c_t cs_t + ctEq (ReasonBranches (locOf c) c_t cs_t) c_t cs_t pure (NE.cons c' cs', c_t) -- | An unmatched pattern. Used in in the generation of @@ -1203,8 +1203,8 @@ checkExp (AppExp (If e1 e2 e3 loc) _) = do if_t <- newType loc SizeLifted "if_t" NoUniqueness ctEq (Reason (locOf e1')) e1_t (Scalar (Prim Bool)) - ctEq (Reason (locOf loc)) e2_t if_t - ctEq (Reason (locOf loc)) e3_t if_t + ctEq (ReasonBranches (locOf loc) e2_t e3_t) e2_t if_t + ctEq (ReasonBranches (locOf loc) e2_t e3_t) e3_t if_t if_t' <- asStructType if_t pure $ AppExp (If e1' e2' e3' loc) (Info $ AppRes if_t' [])