Skip to content

Commit

Permalink
Support ordered vs. unordered statements
Browse files Browse the repository at this point in the history
Summary:
This is an internal change only, nothing is exposed to clients
yet. I'm refactoring the internals to support explicitly ordered
vs. unordered lists of statements. Overall it's cleaner and removes a
level of nesting that was quite confusing.

Reviewed By: donsbot

Differential Revision: D62647670

fbshipit-source-id: 521e32ae610e54031f64d8c306dd007670560fe2
  • Loading branch information
Simon Marlow authored and facebook-github-bot committed Sep 27, 2024
1 parent 4c204d9 commit f9ab0b0
Show file tree
Hide file tree
Showing 52 changed files with 396 additions and 328 deletions.
4 changes: 2 additions & 2 deletions glean/db/Glean/Database/Schema/ComputeIds.hs
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,8 @@ resolveQuery
:: RefToIdEnv
-> Query_ PredicateRef TypeRef
-> Query_ PredicateId TypeId
resolveQuery env (SourceQuery head stmts) =
SourceQuery (refsToIds env <$> head) (refsToIds env <$> stmts)
resolveQuery env (SourceQuery head stmts ord) =
SourceQuery (refsToIds env <$> head) (refsToIds env <$> stmts) ord

-- Serialize a definition to produce its hash. Note that the hash
-- includes the PredicateRef/TypeRef: two predicates or types are the
Expand Down
19 changes: 11 additions & 8 deletions glean/db/Glean/Query/Expand.hs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ expandDerivedPredicateCall
-> TypecheckedQuery -- ^ query from the derived predicate
-> F TcQuery
expandDerivedPredicateCall PredicateDetails{..} key val QueryWithInfo{..} = do
(TcQuery _ keyDef maybeValDef stmts) <-
(TcQuery _ keyDef maybeValDef stmts ord) <-
instantiateWithFreshVariables qiQuery qiNumVars

let
Expand All @@ -55,19 +55,21 @@ expandDerivedPredicateCall PredicateDetails{..} key val QueryWithInfo{..} = do
x <- fresh predicateKeyType
y <- fresh predicateValueType
return $
TcQuery predicateKeyType (Ref (MatchVar x)) (Just (Ref (MatchVar y))) $
stmts ++ [
TcQuery predicateKeyType (Ref (MatchVar x)) (Just (Ref (MatchVar y)))
(stmts ++ [
TcStatement predicateKeyType (Ref (MatchBind x)) keyDef,
TcStatement predicateKeyType (Ref (MatchVar x)) key',
TcStatement predicateValueType (Ref (MatchBind y)) valDef,
TcStatement predicateValueType (Ref (MatchVar y)) val' ]
TcStatement predicateValueType (Ref (MatchVar y)) val' ])
ord
Nothing -> do
x <- fresh predicateKeyType
return $
TcQuery predicateKeyType (Ref (MatchVar x)) Nothing $
stmts ++ [
TcQuery predicateKeyType (Ref (MatchVar x)) Nothing
(stmts ++ [
TcStatement predicateKeyType (Ref (MatchBind x)) keyDef,
TcStatement predicateKeyType (Ref (MatchVar x)) key' ]
TcStatement predicateKeyType (Ref (MatchVar x)) key' ])
ord

-- | Make a fresh instance of a query where none of the variables
-- clash with existing variables. We know the maximum variable in the
Expand All @@ -80,11 +82,12 @@ instantiateWithFreshVariables query numVars = do
put state { flNextVar = base + numVars }
return $ instantiateQuery base query
where
instantiateQuery base (TcQuery ty head maybeVal stmts) =
instantiateQuery base (TcQuery ty head maybeVal stmts ord) =
TcQuery ty
(instantiatePat base head)
(fmap (instantiatePat base) maybeVal)
(map (instantiateStmt base) stmts)
ord

instantiateStmt base (TcStatement ty lhs rhs) =
TcStatement ty (instantiatePat base lhs) (instantiatePat base rhs)
Expand Down
89 changes: 49 additions & 40 deletions glean/db/Glean/Query/Flatten.hs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ flatten rec dbSchema ver deriveStored QueryWithInfo{..} =
fmap fst $ flip runStateT state $ do
(flattened, returnType) <- do
flat <- flattenQuery qiQuery `catchError` flattenFailure
captureKey ver dbSchema flat (case qiQuery of TcQuery ty _ _ _ -> ty)
captureKey ver dbSchema flat (case qiQuery of TcQuery ty _ _ _ _ -> ty)
nextVar <- gets flNextVar
return $ QueryWithInfo flattened nextVar returnType
where
Expand All @@ -75,29 +75,34 @@ flattenQuery query = do
return (FlatQuery head' maybeVal (flattenStmtGroups stmts'))

flattenQuery' :: TcQuery -> F ([Statements], Expr, Maybe Expr)
flattenQuery' (TcQuery ty head Nothing stmts) = do
flattenQuery' (TcQuery ty head Nothing stmts ord) = do
stmts' <- mapM flattenStatement stmts
pats <- flattenPattern head
case pats of
[(stmts,head')] -> return (stmts' ++ [stmts], head', Nothing)
_many -> do
[(stmts,head')]
| Ordered <- ord -> return (stmts' ++ [stmts], head', Nothing)
| otherwise -> return ([mconcat stmts' <> stmts], head', Nothing)
_many -> do -- TODO: ord
-- If there are or-patterns on the LHS, then we have
-- P1 | P2 | ... where stmts
-- so we will generate
-- X where stmts; X = P1 | P2 | ...
v <- fresh ty
let
alts =
[ flattenStmtGroups [stmts] ++
[ singletonGroup
(FlatStatement ty (Ref (MatchBind v)) (TermGenerator head)) ]
[ flattenStmtGroups [stmts `thenStmt`
FlatStatement ty (Ref (MatchBind v)) (TermGenerator head)]
| (stmts, head) <- pats ]
return
( stmts' ++ [mempty `thenStmt` FlatDisjunction alts]
( case ord of
Ordered ->
stmts' ++ [mempty `thenStmt` FlatDisjunction alts]
Unordered ->
[mconcat stmts' <> mempty `thenStmt` FlatDisjunction alts]
, Ref (MatchVar v)
, Nothing
)
flattenQuery' (TcQuery ty head (Just val) stmts) = do
flattenQuery' (TcQuery ty head (Just val) stmts _ord {- TODO -}) = do
stmts' <- mapM flattenStatement stmts
pats <- flattenPattern head
vals <- flattenPattern val
Expand All @@ -111,9 +116,8 @@ flattenQuery' (TcQuery ty head (Just val) stmts) = do
v <- fresh ty
let
alts =
[ flattenStmtGroups [stmts] ++
[ singletonGroup (FlatStatement ty lhsPair
(TermGenerator (Tuple [head,val]))) ]
[ flattenStmtGroups [stmts `thenStmt`
FlatStatement ty lhsPair (TermGenerator (Tuple [head,val]))]
| (stmts, head, val) <- many ]
lhsPair = Tuple [Ref (MatchBind k), Ref (MatchBind v)]
return
Expand All @@ -122,7 +126,6 @@ flattenQuery' (TcQuery ty head (Just val) stmts) = do
, Just (Ref (MatchVar v))
)


flattenStatement :: TcStatement -> F Statements
flattenStatement (TcStatement ty lhs rhs) = do
rgens <- flattenSeqGenerators rhs
Expand Down Expand Up @@ -174,7 +177,7 @@ flattenSeqGenerators (Ref (MatchExt (Typed ty match))) = case match of
[ do
(stmts, gen) <- flattenFactGen pid range kpat vpat
return
( kstmts <> vstmts <> floatGroups (flattenStmtGroups stmts),
( kstmts <> vstmts <> floatGroup (flattenStmtGroups stmts),
gen )
| (kstmts, kpat) <- kpats
, (vstmts, vpat) <- vpats ]
Expand All @@ -186,13 +189,13 @@ flattenSeqGenerators (Ref (MatchExt (Typed ty match))) = case match of
return [(stmts, SetElementGenerator ty pat') | (stmts,pat') <- r ]
TcQueryGen query -> do
(stmts, term, _) <- flattenQuery' query
return [(floatGroups (flattenStmtGroups stmts), TermGenerator term)]
return [(floatGroup (flattenStmtGroups stmts), TermGenerator term)]
TcAll query -> do
(stmts, term, _) <- flattenQuery' query
var <- fresh ty
return
[ (Statements [FlatAllStatement var term
(concatMap unStatements stmts)]
(flattenStmtGroups stmts)]
,TermGenerator (Ref (MatchVar var)))]
TcNegation stmts -> do
stmts' <- flattenStmtGroups <$> mapM flattenStatement stmts
Expand Down Expand Up @@ -424,7 +427,7 @@ valid expressions and lifting out won't affect performance.

-- | A set of statements. The statements in the set will be reordered
-- by the Reorder pass later.
newtype Statements = Statements { unStatements :: [FlatStatement] }
newtype Statements = Statements { _unStatements :: [FlatStatement] }

instance Semigroup Statements where
Statements s1 <> Statements s2 = Statements (s2 <> s1)
Expand Down Expand Up @@ -454,9 +457,10 @@ irrelevant _ = False
-- Statements. This is used when we need to retain the ordering
-- between some statements, but allow the whole sequence to be
-- reordered with respect to other statements around it.
floatGroups :: [FlatStatementGroup] -> Statements
floatGroups [] = Statements []
floatGroups g = Statements [grouping g]
floatGroup :: FlatStatementGroup -> Statements
floatGroup (FlatStatementGroup [] _) = Statements []
floatGroup (FlatStatementGroup s Unordered) = Statements s
floatGroup g@(FlatStatementGroup _ Ordered) = Statements [grouping g]
-- Note: we nest groups by using FlatDisjunction with a single
-- alternative. This is so that this set of groups may be reordered with
-- respect to other statements/groups at the same level. For this to
Expand All @@ -466,16 +470,18 @@ floatGroups g = Statements [grouping g]
flattenStmts :: Statements -> [FlatStatement]
flattenStmts (Statements s) = reverse s

disjunction :: [[FlatStatementGroup]] -> FlatStatement
disjunction [[x :| []]] = x
disjunction :: [FlatStatementGroup] -> FlatStatement
disjunction [FlatStatementGroup [x] _] = x
disjunction groups = FlatDisjunction groups

flattenStmtGroups :: [Statements] -> [FlatStatementGroup]
flattenStmtGroups :: [Statements] -> FlatStatementGroup
flattenStmtGroups [one] = FlatStatementGroup (flattenStmts one) Unordered
flattenStmtGroups stmtss =
concat [ mkGroup x xs | x:xs <- map flattenStmts stmtss ]
FlatStatementGroup (mapMaybe (mkGroup . flattenStmts) stmtss) Ordered
where
mkGroup (FlatDisjunction [g]) [] = g -- flatten unnecessary nesting
mkGroup x xs = [x :| xs]
mkGroup [] = Nothing
mkGroup [one] = Just one
mkGroup more = Just (FlatDisjunction [FlatStatementGroup more Unordered])

singleTerm :: a -> F [(Statements, a)]
singleTerm t = return [(mempty, t)]
Expand Down Expand Up @@ -541,7 +547,8 @@ captureKey
-> FlatQuery
-> Type
-> F (FlatQuery, Type)
captureKey ver dbSchema (FlatQuery pat Nothing stmts) ty
captureKey ver dbSchema
(FlatQuery pat Nothing (FlatStatementGroup stmts ord)) ty
| Angle.PredicateTy pidRef@(PidRef pid _) <- ty = do
let
-- look for $result = pred pat
Expand All @@ -564,7 +571,7 @@ captureKey ver dbSchema (FlatQuery pat Nothing stmts) ty
keyExpr = Ref (MatchVar keyVar)
valExpr = maybe (Tuple []) (Ref . MatchVar) maybeValVar
in
(singletonGroup stmt' , Just (keyExpr, valExpr))
(stmt' :| [], Just (keyExpr, valExpr))

captureStmt fidVar keyVar@(Var keyTy _ _) maybeValVar
(FlatStatement ty lhs (DerivedFactGenerator pid kexpr vexpr))
Expand Down Expand Up @@ -602,7 +609,7 @@ captureKey ver dbSchema (FlatQuery pat Nothing stmts) ty
]
, Just (keyExpr, valExpr))
captureStmt _ _ _ other =
(singletonGroup other, Nothing)
(other :| [], Nothing)

PredicateDetails{..} <- case lookupPid pid dbSchema of
Nothing -> throwError "internal: captureKey"
Expand All @@ -615,22 +622,23 @@ captureKey ver dbSchema (FlatQuery pat Nothing stmts) ty
(stmts', captured) =
case pat of
Ref (MatchVar (Var _ v _)) ->
let k :: [NonEmpty (NonEmpty FlatStatement, Maybe (Pat, Pat))]
k = fmap (fmap (captureStmt v keyVar maybeValVar)) stmts
(stmtss, captured) = unzip $ map NonEmpty.unzip k
let k :: [(NonEmpty FlatStatement, Maybe (Pat, Pat))]
k = map (captureStmt v keyVar maybeValVar) stmts
(stmtss, captured) = unzip k

conc :: NonEmpty (NonEmpty FlatStatement)
-> NonEmpty FlatStatement
conc ((x :| xs) :| ys) = x :| (xs ++ concatMap NonEmpty.toList ys)
conc :: [NonEmpty FlatStatement] -> [FlatStatement]
conc ((x :| xs) : ys) = x : (xs ++ concatMap NonEmpty.toList ys)
conc [] = []
in
(map conc stmtss, captured)
(conc stmtss, captured)
_other -> (stmts, [])

returnTy = tupleSchema [ty, predicateKeyType, predicateValueType]

case catMaybes (concatMap NonEmpty.toList captured) of
case catMaybes captured of
[(key, val)] ->
return (FlatQuery (RTS.Tuple [pat, key, val]) Nothing stmts', returnTy)
return (FlatQuery (RTS.Tuple [pat, key, val]) Nothing
(FlatStatementGroup stmts' ord), returnTy)
_ -> do
pat' <- case pat of
RTS.Ref MatchWild{} -> RTS.Ref . MatchVar <$> fresh ty
Expand All @@ -643,7 +651,8 @@ captureKey ver dbSchema (FlatQuery pat Nothing stmts) ty
(Ref (MatchBind keyVar))
(Ref (maybe (MatchWild predicateValueType) MatchBind maybeValVar))
SeekOnAllFacts)
return (query (stmts' ++ [singletonGroup lookup]), returnTy)
group = grouping (FlatStatementGroup stmts' ord)
return (query (FlatStatementGroup [group, lookup] Ordered), returnTy)

| otherwise = do
-- We have
Expand Down Expand Up @@ -684,7 +693,7 @@ captureKey ver dbSchema (FlatQuery pat Nothing stmts) ty

return
( FlatQuery result Nothing
(stmts ++ [singletonGroup resultStmt1, singletonGroup resultStmt2])
(FlatStatementGroup (stmts ++ [resultStmt1, resultStmt2]) ord)
, retTy )

captureKey _ _ (FlatQuery _ Just{} _) _ =
Expand Down
Loading

0 comments on commit f9ab0b0

Please sign in to comment.