-
Notifications
You must be signed in to change notification settings - Fork 3
/
Unify.hs
150 lines (128 loc) · 5.99 KB
/
Unify.hs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
--------------------------------------------------------------------------
-- Unification and matching
--------------------------------------------------------------------------
module Unify ( subsume
, unify
, matchfun
, matchfunN
) where
import Data.IORef
import PPrint
import Types
import Subst
import Operations
--------------------------------------------------------------------------
-- Function match
--------------------------------------------------------------------------
matchfun :: Type -> Infer (Type,Type)
matchfun tp
= do rho <- instantiate tp
case rho of
TApp (TApp (TCon "->") arg) res
-> do sarg <- subst arg
return (sarg,res)
TVar (TypeVar _ (Uni ref rref))
-> do mtp <- readIORef ref
case mtp of
Just tp -> matchfun tp
Nothing -> do rank <- readIORef rref
arg <- freshTVar rank
res <- freshTVar rank
writeIORef ref (Just (mkFun arg res))
return (arg,res)
_ -> failure ("applying a non-function: " ++ show rho)
-- | Match an N-ary function
matchfunN :: Type -> Int -> Infer ([Type],Type)
matchfunN tp n
= do (arg,res) <- matchfun tp
collect [arg] res
where
collect args res | length args < n
= case res of
TApp (TApp (TCon "->") targ) tres -> collect (targ:args) tres
TVar (TypeVar _ (Uni ref rref)) -> do mtp <- readIORef ref
case mtp of
Just tp -> collect args tp
Nothing -> return (reverse args,res)
Forall _ _ -> return (reverse args,res)
_ -> failure ("n-ary function match failed: " ++ show (pretty (foldr1 mkFun (reverse (res:args)))))
collect args res
= return (reverse args, res)
--------------------------------------------------------------------------
-- Subsumption
-- "subsume tp1 tp2" returns a substitution S, such that
-- we can instantiate tp2 to some type tp3 and S(tp1) = S(tp3)
--------------------------------------------------------------------------
subsume :: Type -> Type -> Infer ()
subsume tp1 tp2
= do (sks,rho1) <- skolemize tp1
rho2 <- instantiate tp2
unify rho1 rho2
-- check for escaping skolems
sk1 <- freeSkolems tp1
sk2 <- freeSkolems tp2
check (sks `disjoint` (sk1 `union` sk2))
("type is not polymorphic enough in subsume:\n type1: " ++ show tp1 ++ "\n type2: " ++ show tp2)
--------------------------------------------------------------------------
-- unification
--------------------------------------------------------------------------
unify :: Type -> Type -> Infer ()
unify (TCon n1) (TCon n2) | n1 == n2
= return ()
unify (TVar v1) (TVar v2) | v1 == v2
= return ()
unify (TApp t1 t2) (TApp u1 u2)
= do unify t1 u1
unify t2 u2
unify (TVar v1) t2 | isUni (tvFlavour v1)
= unifyVar v1 t2
unify t1 (TVar v2) | isUni (tvFlavour v2)
= unifyVar v2 t1
-- this case assumes that types are in normal form
unify t1@(Forall ids1 r1) t2@(Forall ids2 r2) | length ids1 == length ids2
= do sks <- freshSkolems (length ids1)
rho1 <- subNew ids1 (map TVar sks) |-> r1
rho2 <- subNew ids2 (map TVar sks) |-> r2
unify rho1 rho2
-- check for escaping skolems
sk1 <- freeSkolems t1
sk2 <- freeSkolems t2
check (sks `disjoint` (sk1 `union` sk2))
("type is not polymorphic enough in unify:\n type1: " ++ show (pretty t1) ++ "\n type2: " ++ show (pretty t2))
unify t1 t2
= failure ("cannot unify types:\n type1: " ++ show (pretty t1) ++ "\n type2: " ++ show (pretty t2))
-- | Unify a variable
unifyVar tv@(TypeVar id1 (Uni ref1 rref1)) tp2
= do mtp1 <- readIORef ref1
case mtp1 of
Just tp1 -> unify tp1 tp2
Nothing -> case tp2 of
TVar (TypeVar id2 (Uni ref2 rref2))
-> do mtp2 <- readIORef ref2
case mtp2 of
Just tp3 -> unify (TVar tv) tp3 -- note: we can't shorten here since tv could be an element of tp3
Nothing -> do writeIORef ref1 (Just tp2)
-- adjust the lambda-rank of the unifiable variable
rank1 <- readIORef rref1
rank2 <- readIORef rref2
onlyIf (rank2 > rank1) (writeIORef rref2 rank1)
_ -> do tvs <- freeTvs tp2
check (not (tv `elem` tvs)) ("infinite type: " ++ show tv ++ " and " ++ show tp2) -- occurs check
writeIORef ref1 (Just tp2)
-- adjust the lambda-rank of the unifiable variables in tp2
rank1 <- readIORef rref1
adjustRank rank1 tp2
-- | adjust the lambda-rank of the unifiable variables in a type
adjustRank :: Rank -> Type -> IO ()
adjustRank rank tp
= case tp of
TVar (TypeVar id2 (Uni ref2 rref2))
-> do mtp <- readIORef ref2
case mtp of
Just tp -> adjustRank rank tp
Nothing -> do rank2 <- readIORef rref2
onlyIf (rank2 > rank) (writeIORef rref2 rank)
Forall ids rho -> adjustRank rank rho
TApp t1 t2 -> do adjustRank rank t1
adjustRank rank t2
_ -> return ()