Skip to content

Commit

Permalink
Infer the left-hand-side of dot
Browse files Browse the repository at this point in the history
Summary:
This example should work:

```
cxx1.RecordDeclaration R where R.name.name = "vector"
```

but it currently fails because the type of R is not known at the time
we typecheck `R.name`.  This diff fills in the missing inference
functionality. It required adding the ability to "demote" an
expression from a predicate type to its key type, which is the dual of
the "promote" operator we already had.

With this change I'm happy enough with dot syntax that I think we can
advertise it more widely.

Reviewed By: josefs

Differential Revision: D62499221

fbshipit-source-id: c1eec9af52457c9b246d3794df3a501955333ec4
  • Loading branch information
Simon Marlow authored and facebook-github-bot committed Sep 11, 2024
1 parent 4c3fd7b commit 0ba07de
Show file tree
Hide file tree
Showing 8 changed files with 123 additions and 32 deletions.
2 changes: 2 additions & 0 deletions glean/db/Glean/Query/Expand.hs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ instantiateWithFreshVariables query numVars = do
TcAltSelect (Typed ty (instantiatePat base pat)) field
instantiateTcTerm base (TcPromote ty pat) =
TcPromote ty (instantiatePat base pat)
instantiateTcTerm base (TcDemote ty pat) =
TcDemote ty (instantiatePat base pat)
instantiateTcTerm base (TcStructPat fs) =
TcStructPat [(n, instantiatePat base p) | (n,p) <- fs]

Expand Down
22 changes: 13 additions & 9 deletions glean/db/Glean/Query/Flatten.hs
Original file line number Diff line number Diff line change
Expand Up @@ -321,15 +321,19 @@ flattenPattern pat = case pat of
-- pat.field ==> X where { field = X } = pat
Ref (MatchExt (Typed ty (TcFieldSelect (Typed recTy pat) name))) -> do
r <- flattenPattern pat
let sel v =
[ if name == n then v else Ref (MatchWild ty)
| Angle.RecordTy fields <- [derefType recTy]
, Angle.FieldDef n ty <- fields
]
forM r $ \(stmts, p) -> do
v <- Ref . MatchVar <$> fresh ty
let stmt = FlatStatement recTy (Tuple (sel v)) (TermGenerator p)
return (stmts `thenStmt` stmt, v)
case derefType recTy of
Angle.RecordTy fields -> do
let sel v =
[ if name == n then v else Ref (MatchWild ty)
| Angle.FieldDef n ty <- fields
]
forM r $ \(stmts, p) -> do
v <- Ref . MatchVar <$> fresh ty
let stmt = FlatStatement recTy (Tuple (sel v)) (TermGenerator p)
return (stmts `thenStmt` stmt, v)
_other ->
throwError $ "internal: TcFieldSelect: " <>
Text.pack (show (displayDefault recTy))

-- pat.field? ==> X where { field = X } = pat
Ref (MatchExt (Typed ty (TcAltSelect (Typed sumTy pat) name))) -> do
Expand Down
2 changes: 2 additions & 0 deletions glean/db/Glean/Query/Prune.hs
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ prune hasFacts (QueryWithInfo q _ t) = do
p' <- prunePat p
return $ Ref $ MatchExt $ Typed ty $ TcAltSelect (Typed ty' p') f
TcPromote _ p -> prunePat p
TcDemote _ p -> prunePat p
TcStructPat{} -> error "prune: TcStructPat"

type R a = State S a
Expand Down Expand Up @@ -244,6 +245,7 @@ renumberVars ty q =
p' <- renamePat p
return $ TcAltSelect (Typed ty p') f
TcPromote ty p -> TcPromote ty <$> renamePat p
TcDemote ty p -> TcDemote ty <$> renamePat p
TcStructPat fs -> fmap TcStructPat $ forM fs $ \(n,p) ->
(n,) <$> renamePat p

Expand Down
61 changes: 46 additions & 15 deletions glean/db/Glean/Query/Typecheck.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
{- TODO
- implement mutable type variables to speed up type inference
- cleanup:
- split up into separate files
- merge inferExpr and typecheckPattern?
-}

Expand Down Expand Up @@ -389,7 +388,11 @@ inferExpr ctx pat = case pat of
return ((name, expr'), (name, ty))
let (fields', types) = unzip pairs
x <- freshTyVarInt
let ty = HasTy (Map.fromList types) (length fields > 1) x
let
must_be_rec
| length fields > 1 = Just True
| otherwise = Nothing
ty = HasTy (Map.fromList types) must_be_rec x
promote (sourcePatSpan pat)
(Ref (MatchExt (Typed ty (TcStructPat fields')))) ty

Expand All @@ -399,7 +402,7 @@ inferExpr ctx pat = case pat of

Enum _ name -> do
x <- freshTyVarInt
let ty = HasTy (Map.singleton name unit) False x
let ty = HasTy (Map.singleton name unit) Nothing x
promote (sourcePatSpan pat)
(Ref (MatchExt (Typed ty (TcStructPat [(name, RTS.Tuple [])])))) ty

Expand Down Expand Up @@ -459,11 +462,15 @@ fieldSelect src ty pat fieldName sum = do
"?' not '." <> pretty fieldName <> "'"
MaybeTy elemTy ->
fieldSelect src (lowerMaybe elemTy) pat fieldName sum
TyVar{} ->
prettyErrorIn src $ nest 4 $ vcat [
"cannot determine the type of the left-hand-side of '.'",
"please add a type signature."
]
TyVar{} -> do
x <- freshTyVarInt
fieldTy <- freshTyVar
let recTy = HasTy (Map.singleton fieldName fieldTy) (Just (not sum)) x
-- allow the lhs to be a predicate:
fn <- demoteTo (sourcePatSpan src) ty' recTy
let sel | sum = TcAltSelect (Typed recTy (fn pat)) fieldName
| otherwise = TcFieldSelect (Typed recTy (fn pat)) fieldName
return (Ref (MatchExt (Typed fieldTy sel)), fieldTy)
_other ->
err $ "expression is not a " <> if sum then "union type" else "record"

Expand Down Expand Up @@ -991,6 +998,7 @@ tcQueryDeps q = Set.fromList $ map getRef (overQuery q)
TcFieldSelect (Typed _ p) _ -> overPat p
TcAltSelect (Typed _ p) _ -> overPat p
TcPromote _ p -> overPat p
TcDemote _ p -> overPat p
TcStructPat fs -> foldMap overPat (map snd fs)

data UseOfNegation
Expand Down Expand Up @@ -1047,6 +1055,7 @@ tcTermUsesNegation = \case
TcFieldSelect (Typed _ p) _ -> tcPatUsesNegation p
TcAltSelect (Typed _ p) _ -> tcPatUsesNegation p
TcPromote _ p -> tcPatUsesNegation p
TcDemote _ p -> tcPatUsesNegation p
TcStructPat fs -> firstJust tcPatUsesNegation (map snd fs)

{-
Expand Down Expand Up @@ -1159,6 +1168,28 @@ promoteTo s t u = do
addErrSpan s $ unify t u
return id

-- | dual to promoteTo. The destination type cannot be a TyVar.
demoteTo :: IsSrcSpan s => s -> Type -> Type -> T (TcPat -> TcPat)
demoteTo _ _ TyVar{} = error "demote: TyVar"
demoteTo s t@(TyVar x) u = do
subst <- gets tcSubst
case IntMap.lookup x subst of
Just t -> demoteTo s t u
Nothing -> do
addPromote s u (TyVar x)
return (Ref . MatchExt . Typed u . TcDemote t)
demoteTo _ (PredicateTy (PidRef p _)) (PredicateTy (PidRef q _))
| p == q = return id
demoteTo s t@(PredicateTy (PidRef _ ref)) u = do
PredicateDetails{..} <- getPredicateDetails ref
addErrSpan s $ unify predicateKeyType u
return (Ref . MatchExt . Typed u . TcDemote t)
demoteTo s t u = do
-- the source type (t) is not a predicate or a tyvar, so it must be
-- the key type and we can unify directly.
addErrSpan s $ unify t u
return id

addPromote :: IsSrcSpan s => s -> Type -> Type -> T ()
addPromote span from to =
modify $ \s ->
Expand All @@ -1167,13 +1198,6 @@ addPromote span from to =
resolvePromote :: T ()
resolvePromote = do
promotes <- gets tcPromote
whenDebug $ liftIO $ hPutStrLn stderr $ show $ vcat
[ "promotes: "
, vcat
[ displayDefault from <> " -> " <> displayDefault to
| (from,to,_) <- promotes
]
]
let
resolve
:: [(Type,Type,Some IsSrcSpan)]
Expand All @@ -1199,6 +1223,13 @@ resolvePromote = do

loop [] = return ()
loop promotes = do
whenDebug $ liftIO $ hPutStrLn stderr $ show $ vcat
[ "promotes: "
, vcat
[ displayDefault from <> " -> " <> displayDefault to
| (from,to,_) <- promotes
]
]
resolved <- resolve promotes False
if length resolved < length promotes
then loop resolved
Expand Down
3 changes: 3 additions & 0 deletions glean/db/Glean/Query/Typecheck/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ data TcTerm
-- - A == B, or
-- - B = P, where P : A for some predicate P
-- Turns into either nothing or TcFactGen after typechecking
| TcDemote Type TcPat
| TcStructPat [(FieldName, TcPat)]
-- An unresolved pattern matching a record or sum type.
deriving Show
Expand Down Expand Up @@ -110,6 +111,8 @@ instance Display TcTerm where
hsep (display opts op : map (displayAtom opts) args)
display opts (TcPromote _ pat) =
"^" <> displayAtom opts pat
display opts (TcDemote _ pat) =
"" <> displayAtom opts pat
display opts (TcStructPat fs) =
cat [ nest 2 $ cat [ "{", fields fs], "}"]
where
Expand Down
32 changes: 28 additions & 4 deletions glean/db/Glean/Query/Typecheck/Unify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,19 @@ unify (TyVar x) (TyVar y) | x == y = return ()
unify (TyVar x) t = extend x t
unify t (TyVar x) = extend x t

unify (HasTy a ra x) (HasTy b rb y) = do
mapM_ (uncurry unify) $ Map.intersectionWith (,) a b
unify a@(HasTy fa ra x) b@(HasTy fb rb y) = do
rec <- case (ra,rb) of
(Just x, Just y) | x /= y -> unifyError a b
(Nothing, _) -> return rb
_otherwise -> return ra
mapM_ (uncurry unify) $ Map.intersectionWith (,) fa fb
z <- freshTyVarInt
let all = HasTy (Map.union a b) (ra || rb) z
let all = HasTy (Map.union fa fb) rec z
extend x all
extend y all

unify a@(HasTy _ (Just False) _) b@RecordTy{} =
unifyError a b
unify a@(HasTy m _ x) b@(RecordTy fs) = do
forM_ fs $ \(FieldDef f ty) ->
case Map.lookup f m of
Expand All @@ -87,7 +93,9 @@ unify a@(HasTy m _ x) b@(RecordTy fs) = do
unifyError a b
extend x (RecordTy fs)

unify a@(HasTy m False x) b@(SumTy fs) = do
unify a@(HasTy _ (Just True) _) b@SumTy{} =
unifyError a b
unify a@(HasTy m _ x) b@(SumTy fs) = do
forM_ fs $ \(FieldDef f ty) ->
case Map.lookup f m of
Nothing -> return ()
Expand Down Expand Up @@ -222,6 +230,21 @@ zonkTcPat p = case p of
(TcFactGen pidRef e' vpat SeekOnAllFacts))))
_ ->
return e'
Ref (MatchExt (Typed ty (TcDemote inner e))) -> do
ty' <- zonkType ty
inner' <- zonkType inner
e' <- zonkTcPat e
case (ty', inner') of
(TyVar{}, _) -> error "zonkMatch: tyvar"
(_, TyVar{}) -> error "zonkMatch: tyvar"
(PredicateTy (PidRef _ ref), PredicateTy (PidRef _ ref'))
| ref == ref' -> return e'
(_other, PredicateTy (PidRef _ ref)) -> do
PredicateDetails{..} <- getPredicateDetails ref
return (Ref (MatchExt (Typed ty'
(TcDeref inner' predicateValueType e'))))
_ ->
return e'
Ref (MatchExt (Typed ty (TcStructPat fs))) -> do
ty' <- zonkType ty
case ty' of
Expand Down Expand Up @@ -295,6 +318,7 @@ zonkTcTerm t = case t of
<*> pure f
TcElements p -> TcElements <$> zonkTcPat p
TcPromote{} -> error "zonkTcTerm: TcPromote" -- handled in zonkTcPat
TcDemote{} -> error "zonkTcTerm: TcPromote" -- handled in zonkTcPat
TcStructPat{} -> error "zonkTcTerm: TcStructPat" -- handled in zonkTcPat

zonkTcStatement :: TcStatement -> T TcStatement
Expand Down
12 changes: 8 additions & 4 deletions glean/hs/Glean/Angle/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -397,14 +397,17 @@ data Type_ pref tref

-- These are used during typechecking only
| TyVar {-# UNPACK #-}!Int
| HasTy (Map Name (Type_ pref tref)) !Bool {-# UNPACK #-}!Int
| HasTy (Map Name (Type_ pref tref)) !(Maybe Bool) {-# UNPACK #-}!Int
-- HasTy { field:type .. } R X
-- Constrains X to be a record or sum type containing at least
-- the given fields/types. X can only be instantiated
-- with a type containing a superset of those fields: either
-- a bigger HasTy or a RecordTy/SumTy.
-- B is True if the type must be a RecordTy, otherwise it
-- can be either a RecordTy or a SumTy.
-- R is
-- Just True -> type must be a record
-- Just False -> type must be a sum type
-- Nothing -> type can beeither a RecordTy or a SumTy
-- can be
deriving (Eq, Show, Functor, Foldable, Generic)

instance (Binary pref, Binary tref) => Binary (Type_ pref tref)
Expand Down Expand Up @@ -713,7 +716,8 @@ instance (Display pref, Display tref) => Display (Type_ pref tref) where
[ nest 2 $ vsep $ "{" : punctuate sepr (map doField (Map.toList m))
, "..." <> pretty x <> "}" ]
where
sepr = if rec then "," else "|"
sepr | Just False <- rec = "|"
| otherwise = ","
doField (n, ty) = pretty n <> " : " <> display opts ty

displayAtom opts t = case t of
Expand Down
21 changes: 21 additions & 0 deletions glean/test/tests/Angle/AngleTest.hs
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,27 @@ angleDotTest = dbTestCase $ \env repo -> do
Left (BadQuery x) -> "type error" `Text.isInfixOf` x
_ -> False

-- infer the lhs of the dot
r <- runQuery env repo $ angle @Glean.Test.Predicate
"X where X.sum_.c?.nat = 42; X : glean.test.Predicate"
assertEqual "infer lhs" 1 (length r)

-- ensure that we catch wrong usage of .field?
r <- try $ runQuery env repo $ angleData @Text
"X where X.sum_.c.nat = 42; X : glean.test.Predicate"
print r
assertBool "infer lhs error 1" $
case r of
Left (BadQuery x) -> "type error" `Text.isInfixOf` x
_ -> False

r <- try $ runQuery env repo $ angleData @Text
"X where X.sum_.c?.nat? = 42; X : glean.test.Predicate"
print r
assertBool "infer lhs error 2" $
case r of
Left (BadQuery x) -> "type error" `Text.isInfixOf` x
_ -> False

-- if statements
angleIfThenElse :: (forall a . Query a -> Query a) -> Test
Expand Down

0 comments on commit 0ba07de

Please sign in to comment.