diff --git a/package.yaml b/package.yaml index 3545f1d..4ae853b 100644 --- a/package.yaml +++ b/package.yaml @@ -33,7 +33,14 @@ library: - monad-loops - sparse-linear-algebra ghc-options: - - -Wall + - -Wunused-imports + - -Wname-shadowing + - -Wincomplete-patterns + - -Wtype-defaults + - -Wunused-local-binds + - -Wunused-matches + - -Wunused-top-binds + - -Werror executables: zar-exe: diff --git a/src/Cotree.hs b/src/Cotree.hs index 196f9b9..9c198ba 100644 --- a/src/Cotree.hs +++ b/src/Cotree.hs @@ -3,15 +3,9 @@ module Cotree where -import Data.List (sortBy) -import Data.Maybe (fromMaybe) - import Datatypes -import ListTree -import Nat import Sexp import Tree -import Util -- | Greatest fixed point / final TreeF-coalgebra type Cotree a = Fix (TreeF a) diff --git a/src/Dep.hs b/src/Dep.hs index a74fa29..e423a6f 100644 --- a/src/Dep.hs +++ b/src/Dep.hs @@ -8,7 +8,6 @@ import Data.List (intersect, nub, union) import Data.Maybe (fromJust, fromMaybe) import Lang import Symtab (Id(..)) -import Util (debug) -- Compute dependencies of variables in a command (possibly a sequence -- of commands). @@ -35,7 +34,7 @@ init_deps (Assign (x, _) e) = [(Id x, id_of_name <$> fvs e)] init_deps (Sample (x, _) e) = [(Id x, id_of_name <$> fvs e)] init_deps (Seq c1 c2) = union_deps (init_deps c1) (init_deps c2) init_deps (Ite _ c1 c2) = union_deps (init_deps c1) (init_deps c2) -init_deps (While e c) = init_deps c +init_deps (While _ c) = init_deps c init_deps _ = [] -- Compute transitive closure (iterate until fixed point). @@ -45,12 +44,12 @@ iter_deps deps = where deps' = f deps (fst <$> deps) f :: [(Id, [Id])] -> [Id] -> [(Id, [Id])] - f deps (x:xs) = - let ys = fromJust $ lookup x deps - ys_deps = nub $ concat $ fromMaybe [] . flip lookup deps <$> ys + f deps0 (x:xs) = + let ys = fromJust $ lookup x deps0 + ys_deps = nub $ concat $ fromMaybe [] . flip lookup deps0 <$> ys in - f (upd_deps x (union ys_deps) deps) xs - f deps [] = deps + f (upd_deps x (union ys_deps) deps0) xs + f deps0 [] = deps0 -- Collect variables that are directly assigned random values. diff --git a/src/Distributions.hs b/src/Distributions.hs index 1cbf781..04859a6 100644 --- a/src/Distributions.hs +++ b/src/Distributions.hs @@ -1,7 +1,6 @@ module Distributions where import Data.Ratio -import Data.Typeable -- import Lang import Tree diff --git a/src/IOInterp.hs b/src/IOInterp.hs index 8bca905..1d64e7f 100644 --- a/src/IOInterp.hs +++ b/src/IOInterp.hs @@ -12,7 +12,6 @@ import Data.Maybe (fromMaybe) import Data.Typeable import System.Random -import Classes import Lang hiding (Com, Env, Exp, St, Val, interp) import qualified Lang (Com, Env, Exp, St, Val) @@ -42,8 +41,11 @@ eval env (EVar x) st = eval env (EUnop u e) st = case (u, eval env e st) of (UNot, VBool b) -> VBool $ not b + (UNot, _) -> error "IOInterp:eval: ill-typed UNot" (UFst, VPair x _) -> x + (UFst, _) -> error "IOInterp:eval: ill-typed UFst" (USnd, VPair _ y) -> y + (USnd, _) -> error "IOInterp:eval: ill-typed USnd" eval env (EBinop b e1 e2) st = case (b, eval env e1 st, eval env e2 st) of @@ -62,6 +64,7 @@ eval env (EBinop b e1 e2) st = (BLt, VRational r1, VRational r2) -> VBool $ r1 < r2 (BLt, VInteger i1, VInteger i2) -> VBool $ i1 < i2 (BLt, VFloat f1, VFloat f2) -> VBool $ f1 < f2 + (_, _, _ ) -> error "IOInterp:eval: ill-typed EBinop" eval env (EPair e1 e2) st = VPair (eval env e1 st) (eval env e2 st) diff --git a/src/IORepr.hs b/src/IORepr.hs index d5c58b1..56e1ae1 100644 --- a/src/IORepr.hs +++ b/src/IORepr.hs @@ -6,17 +6,14 @@ module IORepr (primitives) where import Control.Monad.Identity -import Data.Proxy import System.Random import Classes import IOInterp import Lang hiding (Env, Exp, SomeVal, SomeTypeVal, Val) -import qualified Lang as L (Env, Exp, SomeVal, SomeTypeVal(..), Val) +import qualified Lang as L (Exp, SomeTypeVal(..), Val) -type Env = L.Env Identity IO type Exp = L.Exp Identity IO -type SomeVal = L.SomeVal Identity IO type SomeTypeVal = L.SomeTypeVal Identity IO type Val = L.Val Identity IO @@ -45,7 +42,7 @@ bernoulli_prim = VPrim f -- IO actions are never equal. instance Eq a => Eq (IO a) where - f == g = False + _ == _ = False -- Trivial show instance. instance Show a => Show (IO a) where show _ = "IO" diff --git a/src/Inference.hs b/src/Inference.hs index 4aed9aa..072210e 100644 --- a/src/Inference.hs +++ b/src/Inference.hs @@ -5,15 +5,12 @@ module Inference where import Control.Monad -import Data.Bifunctor (bimap, second) +import Data.Bifunctor (second) import Data.List (sum) import Data.Maybe (fromMaybe) -import Data.Typeable import Classes -import Lang -import Tree -import Util (counts, debug) +import Util (counts) -- | Any g with a Sample instance admits sampling based inference. diff --git a/src/Lang.hs b/src/Lang.hs index 5cb2ffd..fa22eeb 100644 --- a/src/Lang.hs +++ b/src/Lang.hs @@ -22,19 +22,11 @@ module Lang where -import Control.Monad.Identity -import Control.Monad.Reader -import Control.Monad.State hiding (get) -import qualified Control.Monad.State as S (get) - -import Data.Bifunctor (first) import Data.Proxy import Data.Typeable import Classes -import Distributions import Symtab (Id(..)) -import Util (debug, mapJoin) -- Dummy instances for arrow type indices. instance Eq (a -> b) where @@ -273,7 +265,7 @@ instance Eq (Exp m g a) where ENil == ENil = True EUniform e1 == EUniform e2 = e1 == e2 -- TODO finish - _ == _ = False + _ == _ = error "internal error in Lang:Eq (Exp ...)" -- instance Show a => Show (Exp m g a) where -- show (EVal v) = "(EVal " ++ show v ++ ")" @@ -299,6 +291,7 @@ instance Show a => Show (Exp m g a) where show (EUnop u e) = "EUnop " ++ show u ++ " " ++ show e show (EBinop b e1 e2) = "EBinop " ++ show b ++ " " ++ show e1 ++ " " ++ show e2 + show (EPair e1 e2) = "EPair " ++ show e1 ++ " " ++ show e2 show ENil = "ENil" show (ECons e1 e2) = "ECons " ++ show e1 ++ " " ++ show e2 show (EDestruct l z f) = -- "(EDestruct " ++ show l ++ " " ++ show z ++ ")" @@ -308,7 +301,7 @@ instance Show a => Show (Exp m g a) where show (ECom _ c) = "ECom " ++ show c show (ECond b e1 e2) = "ECond " ++ show b ++ " " ++ show e1 ++ " " ++ show e2 - show (EPrim f) = "EPrim" + show (EPrim _) = "EPrim" show (EUniform l) = "EUniform " ++ show l @@ -378,7 +371,8 @@ fvs = go [] go bound (EDestruct l z f) = go bound l ++ go bound z ++ go bound f go bound (EApp e1 e2) = go bound e1 ++ go bound e2 go bound (ELam x body) = go (SomeName x : bound) body - go bound (ECom args com) = + --note(jgs): fvs(ECom _ com) ignores com? + go bound (ECom args _) = concatMap (\(SomeNameExp _ e) -> go bound e) args go bound (ECond b e1 e2) = go bound b ++ go bound e1 ++ go bound e2 go _ _ = [] @@ -462,14 +456,17 @@ vlist_nth n (VCons hd tl) | n < 0 = error "vlist_nth: negative index" | n == 0 = hd | otherwise = vlist_nth (n-1) tl +vlist_nth _ _ = error "internal error in Lang:vlist_nth; please report" vlist_length :: Val m g [a] -> Int vlist_length VNil = 0 vlist_length (VCons _ tl) = 1 + vlist_length tl +vlist_length _ = error "internal error in Lang:vlist_length; please report" vlist_list :: Val m g [a] -> [Val m g a] vlist_list VNil = [] vlist_list (VCons x xs) = x : vlist_list xs +vlist_list _ = error "internal error in Lang:vlist_list; please report" ------------------------------------------------------------------------ @@ -482,5 +479,5 @@ class (Typeable m, AllF g) => Repr m g | g -> m where -- Initial environment containing primitives. initEnv :: Repr m g => Env m g -initEnv = (\(x, SomeTypeVal t v) -> +initEnv = (\(x, SomeTypeVal _ v) -> SomeNameExp (x, Proxy) (EVal v)) <$> primitives diff --git a/src/LinEq.hs b/src/LinEq.hs index 3185891..050d73c 100644 --- a/src/LinEq.hs +++ b/src/LinEq.hs @@ -5,7 +5,6 @@ import Data.Maybe (catMaybes, maybeToList) import Sexp import Tree -import Util (debug) -- Boolean-valued trees with mandatory labels at all split nodes. data LTree = @@ -80,7 +79,7 @@ equations_of_ltree _ = [] lookup_term :: Maybe Var -> [Term] -> Maybe Coeff lookup_term (Just x) ((c, Just y) : terms) = if x == y then Just c else lookup_term (Just x) terms -lookup_term Nothing ((c, Nothing) : terms) = Just c +lookup_term Nothing ((c, Nothing) : _) = Just c lookup_term x (_ : terms) = lookup_term x terms lookup_term _ [] = Nothing @@ -94,7 +93,7 @@ remove_term = go [] go acc (Just x) (tm@(c, Just y) : terms) = if x == y then Just (c, acc ++ terms) else go (tm:acc) (Just x) terms - go acc Nothing (tm@(c, Nothing) : terms) = Just (c, acc ++ terms) + go acc Nothing ((c, Nothing) : terms) = Just (c, acc ++ terms) go acc x (tm : terms) = go (tm:acc) x terms go _ _ [] = Nothing @@ -132,6 +131,8 @@ solve_equations :: [Equation] -> Equation solve_equations = go . sort where go :: [Equation] -> Equation + + go [] = error "internal error in LinEq:solve_equations" --note(jgs): fix go [eq] = simplify_equation eq go (Equation (x, terms) : eqs) = go $ simplify_equation . subst_equation x terms <$> eqs diff --git a/src/ListTree.hs b/src/ListTree.hs index 849e597..3f6dc13 100644 --- a/src/ListTree.hs +++ b/src/ListTree.hs @@ -1,10 +1,12 @@ module ListTree where -import Data.Bifunctor (second) -import Data.List (nub, sort) -import Sexp +--note(jgs): redundant imports +--import Data.Bifunctor (second) +--import Data.List (nub, sort) +--import Sexp import Tree -import Util +--import Util + -- List form of trees. type TreeL a = [[Tree a]] diff --git a/src/Parser.hs b/src/Parser.hs index 67b732b..e7410d2 100644 --- a/src/Parser.hs +++ b/src/Parser.hs @@ -1,10 +1,7 @@ module Parser where -import Control.Monad (void) - import Control.Monad.Combinators.Expr -- from parser-combinators import qualified Data.List.NonEmpty as NonEmpty -import Data.Maybe (fromMaybe) import Data.Ratio import Data.Set (singleton) import Data.Void @@ -15,7 +12,6 @@ import qualified Text.Megaparsec.Char.Lexer as L import Untyped import Symtab (Id(..)) import Token -import Util (debug) parens :: Parser a -> Parser a @@ -41,7 +37,7 @@ bool = choice rational :: Parser Rational rational = do num <- integer - symbol "/" + _ <- symbol "/" denom <- integer return $ num % denom @@ -369,12 +365,12 @@ dist = L.indentBlock scn $ do keyword "dist" dist_nm <- ident args <- parens $ commaSep func_arg - dist_ty <- symbol "->" >> ty + dty <- symbol "->" >> ty symbol ":" return $ L.IndentSome Nothing (\coms -> return $ Dist { dist_name = dist_nm - , dist_type = dist_ty + , dist_type = dty , dist_args = args , dist_body = mkSeq coms }) com @@ -390,9 +386,9 @@ prog = L.nonIndented scn (L.indentBlock scn p) where p = do funcs_dists <- many $ choice [Left <$> func, Right <$> dist] - com <- main + c <- main eof - return $ L.IndentNone (funcs_dists, com) + return $ L.IndentNone (funcs_dists, c) -- Main parsing function called from the outside. diff --git a/src/Sample.hs b/src/Sample.hs index 1b0b71a..8c155dd 100644 --- a/src/Sample.hs +++ b/src/Sample.hs @@ -85,7 +85,7 @@ mcmc_sample cz@(ct, ws) n = cata alg ct s <- mcmc_sample (zipUp cz) (n - 1) return (x:s) Nothing -> error "mcmc_sample: out of bits" - alg (SplitF s1 s2) = do + alg (SplitF _ _) = do case ws of [] -> do bit <- gets headMaybe diff --git a/src/SparseLinAlg.hs b/src/SparseLinAlg.hs index 8184c8f..a42c3f6 100644 --- a/src/SparseLinAlg.hs +++ b/src/SparseLinAlg.hs @@ -1,12 +1,11 @@ module SparseLinAlg (solve_tree) where -import Data.Bifunctor (bimap, second) +import Data.Bifunctor (bimap) import Data.List (sort) import Data.Maybe (fromJust) import Data.Sparse.SpMatrix import Data.Sparse.SpVector import Numeric.LinearAlgebra.Sparse -import System.IO.Unsafe (unsafePerformIO) import LinEq (Coeff, Var, Equation(..), remove_term, combine_terms, ltree_of_tree, equations_of_ltree) import Sexp @@ -34,7 +33,7 @@ mateq_of_equation (Equation (x, tms)) = constraint_matrix :: [MatEq] -> SpMatrix Rational constraint_matrix eqs = let l = concat $ f <$> zip [0..] eqs in - debug ("l: " ++ show ((\(x, y, z) -> (x, y, fromRational z)) <$> l)) $ + debug ("l: " ++ show ((\(x, y, z) -> (x, y, fromRational z :: Double)) <$> l)) $ fromListSM (n, n) l where n = length eqs @@ -77,8 +76,8 @@ solve_system_gmres mat rhs = -- let _ = unsafePerformIO $ prd mat in -- mat <\> (fromListDenseSV n ) mat <\> rhs - where - n = nrows mat -- should also be the length of the rhs vector + -- where + -- n = nrows mat -- should also be the length of the rhs vector solve_tree :: Tree Bool -> IO (SpVector Double) -- solve_tree = solve_system_gmres . tree_constraint_matrix diff --git a/src/Symtab.hs b/src/Symtab.hs index d771cfc..3267246 100644 --- a/src/Symtab.hs +++ b/src/Symtab.hs @@ -1,3 +1,4 @@ + -- | This module defines a type for identifiers along with an abstract -- datatype for maps indexed by them. @@ -16,14 +17,6 @@ import qualified Data.Map as Map newtype Id = Id { unId :: String } deriving (Eq, Ord) -intToId :: Int -> Id -intToId 0 = Id "a" -intToId 1 = Id "b" -intToId _ = Id "c" - --- instance Arbitrary Id where --- arbitrary = intToId <$> choose (0, 2) - assocGet :: Id -> [(Id, a)] -> Maybe a assocGet _ [] = Nothing assocGet x ((y, v) : ys) = if x == y then Just v else assocGet x ys @@ -36,14 +29,14 @@ assocSet nm x ((nm', x'):ys) = -- Update the value associated with an Id. assocUpdate :: Id -> (a -> a) -> [(Id, a)] -> [(Id, a)] -assocUpdate nm f [] = error $ "assocUpdate: " ++ show nm ++ " not found" +assocUpdate nm _ [] = error $ "assocUpdate: " ++ show nm ++ " not found" assocUpdate nm f ((nm', x):ys) = if nm == nm' then (nm, f x) : ys else (nm', x) : assocUpdate nm f ys assocIndex :: Id -> [(Id, a)] -> Maybe Int assocIndex nm ((x, _):xs) = if nm == x then Just 0 else (+ 1) <$> assocIndex nm xs -assocIndex nm [] = Nothing +assocIndex _ [] = Nothing -- A Symtab maps Ids to values of some type type Symtab a = Map.Map Id a diff --git a/src/Token.hs b/src/Token.hs index 2816bad..1c95286 100644 --- a/src/Token.hs +++ b/src/Token.hs @@ -1,7 +1,6 @@ module Token where import Control.Monad (void) -import Control.Monad.Combinators.Expr -- from parser-combinators import Data.Void import Text.Megaparsec import Text.Megaparsec.Char diff --git a/src/Tree.hs b/src/Tree.hs index 012dfc1..86ada1a 100644 --- a/src/Tree.hs +++ b/src/Tree.hs @@ -1,3 +1,4 @@ + {-# LANGUAGE DeriveFoldable, DeriveFunctor, DeriveTraversable, TupleSections #-} {-# LANGUAGE StandaloneDeriving, DeriveAnyClass #-} @@ -5,10 +6,9 @@ module Tree where import Control.Monad import Data.Bifunctor -import Data.List (nub, sort, (\\)) +import Data.List (nub) import Data.Maybe (fromMaybe) -import Classes import Datatypes import Sexp import Util @@ -182,7 +182,7 @@ reduce_whole (Split (Just lbl) t1 t2) = if all (== lbl) hs then case ls of [] -> error "reduce_hole: no leaves or outgoing holes" - (x:xs) | allEq ls -> (Leaf x, (lbl, Leaf x) : ps) + (x:_) | allEq ls -> (Leaf x, (lbl, Leaf x) : ps) --note(jgs): We ignore 'xs' here? _ -> (Split (Just lbl) t1' t2', ps) else (Split (Just lbl) t1' t2', ps) @@ -293,8 +293,11 @@ compatible_for_swap :: Eq a => Tree a -> Path -> Path -> Bool compatible_for_swap t p1 p2 = let t1 = get_subtree p1 t t2 = get_subtree p2 t - t1_labels = path_labels p1 t - t2_labels = path_labels p2 t in + --t1_labels = path_labels p1 t + --t2_labels = path_labels p2 t in + --note(jgs): is this right? + t1_labels = path_labels p1 t1 + t2_labels = path_labels p2 t2 in setEq t1_labels t2_labels -- compatible_for_swap :: Eq a => Tree a -> Path -> Path -> Bool @@ -309,12 +312,12 @@ group_dupes :: Eq a => Tree a -> Tree a group_dupes t = foldl f t [0 .. depth t] where f :: Eq a => Tree a -> Int -> Tree a - f t n = - let subtrees = at_depth t n in - if length subtrees <= 2 then t else + f t0 n = + let subtrees = at_depth t0 n in + if length subtrees <= 2 then t0 else case dupes [] subtrees of - Just (p1, p2) -> swap_subtrees (sibling_path p1) p2 t - Nothing -> t + Just (p1, p2) -> swap_subtrees (sibling_path p1) p2 t0 + Nothing -> t0 dupes :: Eq a => [(Tree a, Path)] -> [(Tree a, Path)] -> Maybe (Path, Path) dupes _ [] = Nothing dupes seen ((t', p):rest) = case lookup t' seen of @@ -340,8 +343,9 @@ swap_subtrees p1 p2 t = canon :: (Eq a, Show a) => Tree a -> Tree a canon t = let - (t1, ps) = reduce_whole t - t2 = apply_patches ps t1 + -- note(jgs): + -- (t1, ps) = reduce_whole t + -- t2 = apply_patches ps t1 t3 = group_dupes t -- TODO: fix (t4, ps') = reduce t3 t5 = apply_patches ps' t4 @@ -363,8 +367,8 @@ expand t = go t t where go :: Tree a -> Tree a -> Tree a go _ (Leaf x) = Leaf x - go t (Split n t1 t2) = Split n (go t t1) (go t t2) - go t (Hole _) = t + go t0 (Split n t1 t2) = Split n (go t0 t1) (go t0 t2) + go t0 (Hole _) = t0 -- Check if a tree is in canonical form (not necessarily the case that -- canon would have no effect). diff --git a/src/TreeInterp.hs b/src/TreeInterp.hs index bc92f8c..b4b8504 100644 --- a/src/TreeInterp.hs +++ b/src/TreeInterp.hs @@ -19,7 +19,6 @@ import Dep import Distributions import Lang hiding (Com, Exp, St, Val, interp) import qualified Lang (Com, Exp, St, Val) -import Sexp import Tree import Util @@ -51,11 +50,12 @@ instance MonadReader InterpEnv InterpM where runInterpM :: InterpEnv -> InterpState -> InterpM a -> (a, InterpState) runInterpM env s (InterpM f) = runIdentity $ runStateT (runReaderT f env) s -evalInterpM :: InterpEnv -> InterpState -> InterpM a -> a -evalInterpM env s f = fst $ runInterpM env s f +--note(jgs): unused +-- evalInterpM :: InterpEnv -> InterpState -> InterpM a -> a +-- evalInterpM env s f = fst $ runInterpM env s f -execInterpM :: InterpEnv -> InterpState -> InterpM a -> InterpState -execInterpM env s f = snd $ runInterpM env s f +-- execInterpM :: InterpEnv -> InterpState -> InterpM a -> InterpState +-- execInterpM env s f = snd $ runInterpM env s f -- Set up type synonyms. @@ -92,7 +92,7 @@ eval (EVar x) st = case envGet x env of Just e -> eval e st Nothing -> - let (x', proxy) = x + let (_, proxy) = x ty = typeOf proxy in error $ "eval: unbound variable: " ++ show x ++ " of type " ++ show ty ++ ".\nst: " ++ show st @@ -178,7 +178,7 @@ eval (EApp f e) st = do v <- eval e st case f' of VLam x body -> eval (subst x (EVal v) body) st - VPrim f -> f v >>= flip eval st + VPrim f0 -> f0 v >>= flip eval st eval (ECom args com) st = do st' <- mapM (\(SomeNameExp x e) -> SomeNameVal x <$> eval e st) args diff --git a/src/TreeRepr.hs b/src/TreeRepr.hs index 5932f8b..d75bf97 100644 --- a/src/TreeRepr.hs +++ b/src/TreeRepr.hs @@ -4,20 +4,16 @@ module TreeRepr () where -import Data.Proxy - import Classes import Distributions import TreeInterp -import Lang hiding (Env, Exp, SomeVal, SomeTypeVal, Val) -import qualified Lang as L (Env, Exp, SomeVal, SomeTypeVal(..), Val) +import Lang hiding (Exp, SomeTypeVal, Val) +import qualified Lang as L (Exp, SomeTypeVal(..), Val) import LinEq (infer') import Sample import Tree -type Env = L.Env InterpM Tree type Exp = L.Exp InterpM Tree -type SomeVal = L.SomeVal InterpM Tree type SomeTypeVal = L.SomeTypeVal InterpM Tree type Val = L.Val InterpM Tree diff --git a/src/Tycheck.hs b/src/Tycheck.hs index 38429f5..b762ba1 100644 --- a/src/Tycheck.hs +++ b/src/Tycheck.hs @@ -1,3 +1,4 @@ + {-# LANGUAGE GADTs, RankNTypes, StandaloneDeriving #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE KindSignatures #-} @@ -16,14 +17,12 @@ import Text.Megaparsec (SourcePos) import Text.Megaparsec.Error -import Classes import Lang as L hiding (SomeExp, SomeNameExp, SomeVal) -import qualified Lang as L (SomeNameExp(..), SomeVal(..)) +import qualified Lang as L (SomeNameExp(..)) import Parser (parse) import Symtab (Id(..), Symtab) -import qualified Symtab as S (add, empty, get, fromList) +import qualified Symtab as S (add, get, fromList) import qualified Untyped as U -import Util (debug) data SomeExp m g where SomeExp :: forall m g a. (Repr m g, Eq a, Show a, Typeable a) => @@ -96,6 +95,11 @@ val_of_lit (U.LRational r) = SomeTypeVal TRational $ L.VRational r val_of_lit (U.LFloat f) = SomeTypeVal TFloat $ L.VFloat f val_of_lit (U.LBool b) = SomeTypeVal TBool $ L.VBool b val_of_lit (U.LInteger i) = SomeTypeVal TInteger $ L.VInteger i +val_of_lit (U.LPair _ _) = error "internal error: LPair unimplemented in Tycheck:val_of_lit" +--note(jgs): fix + {-let SomeTypeVal t1 v1 = val_of_lit l1 + SomeTypeVal t2 v2 = val_of_lit l2 + in SomeTypeVal (TPair t1 t2) (VPair v1 v2)-} -- It seems like this shouldn't be necessary, but when we use the @@ -237,7 +241,7 @@ tycheckExp p (U.EBinop pos binop e1 e2) = do U.BPair -> return $ SomeExp (TPair t1 t2) $ L.EPair e1' e2' -tycheckExp p (U.ELam pos (Id x) t e) = +tycheckExp p (U.ELam _ (Id x) t e) = case tycheckType p t of SomeType t' -> do SomeExp s e' <- local (S.add (Id x) (SomeType t')) $ tycheckExp p e @@ -255,10 +259,10 @@ tycheckExp p (U.ECall pos e1 args) = where go :: Repr m g => Proxy g -> [U.Exp SourcePos] -> TycheckM m g (SomeExp m g) - go p [] = tycheckExp p e1 - go p (arg:args) = do - SomeExp t e <- go p args -- tycheckExp p (U.ECall pos e1 args) - SomeExp t' arg' <- tycheckExp p arg + go p0 [] = tycheckExp p0 e1 + go p0 (arg:args0) = do + SomeExp t e <- go p0 args0 -- tycheckExp p (U.ECall pos e1 args) + SomeExp t' arg' <- tycheckExp p0 arg case t of TArrow a b -> case typeEq t' a of @@ -390,16 +394,13 @@ tycheckCom (U.CIte pos e c1 c2) = do show t2 _ -> typeError pos $ "expected Bool, got " ++ show t --- tycheckCom (U.CFlip pos c1 c2) = --- SomeCom <$> pure Flip <*> tycheckCom c1 <*> tycheckCom c2 - tycheckCom (U.CObserve pos e) = do SomeExp t e' <- tycheckExp Proxy e case t of TBool -> return $ SomeCom TSt $ Observe e' _ -> typeError pos $ "expected Bool, got " ++ show t -tycheckCom (U.CReturn pos e) = do +tycheckCom (U.CReturn _ e) = do SomeExp t e' <- tycheckExp Proxy e return $ SomeCom (TExp t) $ Return e' @@ -460,8 +461,8 @@ tycheckFunction (U.Function { U.function_name = Id f_nm go ty ((x, x_ty):xs) body = do body' <- go ty xs body case (x_ty, body') of - (SomeType x_ty', SomeNameExp _ f_ty body'') -> - return $ SomeNameExp (f_nm, Proxy) (TArrow x_ty' f_ty) $ + (SomeType x_ty', SomeNameExp _ f_ty0 body'') -> + return $ SomeNameExp (f_nm, Proxy) (TArrow x_ty' f_ty0) $ ELam (nameOfType x x_ty') body'' @@ -507,14 +508,14 @@ tycheckProg = go com' <- tycheckCom com return ([], com') go (x:xs) com = do - SomeNameExp (x, _) t e <- either tycheckFunction tycheckDist x - (es, com') <- local (S.add (Id x) $ SomeType t) $ go xs com - return (SomeNameExp (x, Proxy) t e : es, com') + SomeNameExp (x0, _) t e <- either tycheckFunction tycheckDist x + (es, com') <- local (S.add (Id x0) $ SomeType t) $ go xs com + return (SomeNameExp (x0, Proxy) t e : es, com') -- Build initial context from the primitives list. initCtx :: Repr m g => [(String, SomeTypeVal m g)] -> Context m g initCtx prims = - S.fromList $ (\(x, SomeTypeVal t v) -> (Id x, SomeType t)) <$> prims + S.fromList $ (\(x, SomeTypeVal t _) -> (Id x, SomeType t)) <$> prims tycheck :: Repr m g => diff --git a/src/Untyped.hs b/src/Untyped.hs index ff8a0f8..4a3fb99 100644 --- a/src/Untyped.hs +++ b/src/Untyped.hs @@ -94,7 +94,6 @@ data Com a = | CIte a (Exp a) (Com a) (Maybe (Com a)) | CReturn a (Exp a) -- Derived commands: - | CFlip a (Com a) (Com a) | CObserve a (Exp a) | CWhile a (Exp a) (Com a) deriving Show @@ -104,8 +103,11 @@ data_of_exp (ELit x _) = x data_of_exp (EVar x _) = x data_of_exp (EUnop x _ _) = x data_of_exp (EBinop x _ _ _) = x +data_of_exp (ELam x _ _ _) = x data_of_exp (ECall x _ _) = x -data_of_exp (ENil x _ ) = x +data_of_exp (ENil x _) = x +data_of_exp (EDestruct x _ _ _) = x +data_of_exp (ECond x _ _ _) = x data_of_com :: Com a -> a data_of_com (CSkip x) = x @@ -115,7 +117,6 @@ data_of_com (CSample x _ _) = x data_of_com (CSeq x _ _) = x data_of_com (CIte x _ _ _) = x data_of_com (CReturn x _) = x -data_of_com (CFlip x _ _) = x data_of_com (CObserve x _) = x data_of_com (CWhile x _ _) = x @@ -168,9 +169,9 @@ instance ToSexp (Exp a) where toSexp (ECall _ e es) = "(ECall " ++ toSexp e ++ " " ++ toSexp es ++ ")" toSexp (ENil _ t) = "(ENil " ++ toSexp t ++ ")" toSexp (EDestruct _ e1 e2 e3) = - "(EDestruct " ++ toSexp e1 ++ " " ++ toSexp e2 ++ " " ++ toSexp e2 ++ ")" + "(EDestruct " ++ toSexp e1 ++ " " ++ toSexp e2 ++ " " ++ toSexp e3 ++ ")" toSexp (ECond _ e1 e2 e3) = - "(ECond " ++ toSexp e1 ++ " " ++ toSexp e2 ++ " " ++ toSexp e2 ++ ")" + "(ECond " ++ toSexp e1 ++ " " ++ toSexp e2 ++ " " ++ toSexp e3 ++ ")" instance ToSexp (Com a) where toSexp (CSkip _) = "CSkip" @@ -180,7 +181,6 @@ instance ToSexp (Com a) where toSexp (CSeq _ c1 c2) = "(CSeq " ++ toSexp c1 ++ " " ++ toSexp c2 ++ ")" toSexp (CIte _ e c1 c2) = "(CIte " ++ toSexp e ++ " " ++ toSexp c1 ++ " " ++ toSexp c2 ++ ")" - toSexp (CFlip _ c1 c2) = "(CSeq " ++ toSexp c1 ++ " " ++ toSexp c2 ++ ")" toSexp (CObserve _ e) = "(CObserve " ++ toSexp e ++ ")" toSexp (CReturn _ e) = "(CReturn " ++ toSexp e ++ ")" toSexp (CWhile _ e c) = "(CWhile " ++ toSexp e ++ " " ++ toSexp c ++ ")" diff --git a/src/Util.hs b/src/Util.hs index 4e2a36d..a859d11 100644 --- a/src/Util.hs +++ b/src/Util.hs @@ -37,11 +37,11 @@ classify :: Eq a => [a] -> [[a]] classify = classifyBy (==) classifyBy :: (a -> a -> Bool) -> [a] -> [[a]] -classifyBy (==) [] = [] -classifyBy (==) (x:xs) = (x:filter (== x) xs) - : classifyBy (==) (filter (/= x) xs) +classifyBy _ [] = [] +classifyBy eq (x:xs) = (x:filter (eq x) xs) + : classifyBy eq (filter (neq x) xs) where - x /= y = not (x == y) + neq x1 x2 = not (eq x1 x2) counts :: Eq a => [a] -> [(a,Int)] counts = map headLength . classify