Skip to content

Commit

Permalink
Fix scoping of the all construct
Browse files Browse the repository at this point in the history
Summary:
The scoping of the `all` construct for sets is more akin to negation and `if`, in the sense that variables that only appears inside `all` cannot be moved outside. This was wrong in the optimizing and reordering steps.
This diff fixes the scoping problems.

Reviewed By: kbojarczuk

Differential Revision: D68720217

fbshipit-source-id: a31aacd3f00798cbc34cefdacc3f70c6a1b0e415
  • Loading branch information
Simon Marlow authored and facebook-github-bot committed Feb 11, 2025
1 parent 28911ca commit 91e40a4
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 46 deletions.
6 changes: 4 additions & 2 deletions glean/db/Glean/Query/Flatten/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,10 @@ data FlatStatement
= FlatStatement Type Pat Generator
-- ^ A simple statement: P = gen
| FlatAllStatement Var Pat FlatStatementGroup
-- ^ Similar to a vanilla statement, but the result is a set
-- containing the results of computing the statements.
-- ^ An all() statement.
-- @FlatAllStatement X P S@ means @X = all (P where S)@
-- Like disjunction and negation, Variables occurring in
-- @P where S@ are not considered bound.
| FlatNegation FlatStatementGroup
-- ^ The negation of a series of statements
| FlatDisjunction [FlatStatementGroup]
Expand Down
25 changes: 14 additions & 11 deletions glean/db/Glean/Query/Opt.hs
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,9 @@ instance Apply FlatStatement where
Ref (MatchVar x) -> return x
Ref (MatchBind x) -> return x
_ -> error "apply: FlatAllStatement"
g' <- optStmts g
FlatAllStatement v' <$> apply e <*> apply g'
enclose (stmtGroupScope g . termScope e) $ do
g' <- optStmts g
FlatAllStatement v' <$> apply e <*> pure g'
apply (FlatNegation stmts) = do
-- assumptions arising inside the negation are not true outside of it.
stmts' <- optStmtsEnclosed stmts
Expand All @@ -252,7 +253,7 @@ instance Apply FlatStatement where
-- are not true outside of it. However, those arising from the condition
-- are true in the 'then' case.
(cond', then') <-
enclose cond $ do
enclose (stmtGroupScope cond) $ do
cond' <- optStmts cond
then' <- optStmtsEnclosed then_
return (cond', then')
Expand All @@ -263,7 +264,7 @@ instance Apply FlatStatement where
else FlatConditional cond' then' else'

optStmtsEnclosed :: FlatStatementGroup -> U FlatStatementGroup
optStmtsEnclosed stmts = enclose stmts $ optStmts stmts
optStmtsEnclosed stmts = enclose (stmtGroupScope stmts) $ optStmts stmts

-- If a sequence of statements is found to be false, then we place
-- a falseStmt sentinel at the beginning. We don't actually remove
Expand Down Expand Up @@ -332,11 +333,11 @@ applyVar var@(Var _ v _) = do
-- or a negated query.
--
-- Does not add substitutions or new variables to the parent scope.
enclose :: FlatStatementGroup -> U a -> U a
enclose (FlatStatementGroup ord) u = do
enclose :: (VarSet -> VarSet) -> U a -> U a
enclose innerScope u = do
state0 <- get
-- set the outer scope to be the current scope
let scope = foldr ordStmtScope (optCurrentScope state0) ord
let scope = innerScope (optCurrentScope state0)
modify $ \s ->
s { optCurrentScope = scope, optOuterScope = optCurrentScope state0 }
a <- u
Expand Down Expand Up @@ -542,14 +543,16 @@ queryScope (FlatQuery key maybeVal (FlatStatementGroup ord)) =
where
s = termScope key IntSet.empty

stmtGroupScope :: FlatStatementGroup -> VarSet -> VarSet
stmtGroupScope (FlatStatementGroup g) r = foldr ordStmtScope r g

ordStmtScope :: Ordered FlatStatement -> VarSet -> VarSet
ordStmtScope = stmtScope . unOrdered

stmtScope :: FlatStatement -> VarSet -> VarSet
stmtScope (FlatStatement _ lhs rhs) r = termScope lhs (genScope rhs r)
stmtScope (FlatAllStatement v e (FlatStatementGroup ord)) r =
addToCurrentScope v $! termScope e $!
foldr ordStmtScope r ord
stmtScope (FlatAllStatement v _ (FlatStatementGroup _)) r =
addToCurrentScope v $! r
stmtScope (FlatNegation _) r = r
stmtScope (FlatDisjunction [FlatStatementGroup ord]) r =
foldr ordStmtScope r ord
Expand Down Expand Up @@ -723,7 +726,7 @@ filterStmt :: FlatStatement -> U FlatStatement
filterStmt stmt = case stmt of
FlatStatement{} -> return stmt
FlatAllStatement v e stmts ->
FlatAllStatement v e <$> filterGroup stmts
FlatAllStatement v e <$> filterGroupEnclosed stmts
FlatNegation stmts -> FlatNegation <$> filterGroupEnclosed stmts
FlatDisjunction [stmts] -> grouping <$> filterGroup stmts
FlatDisjunction stmtss ->
Expand Down
69 changes: 36 additions & 33 deletions glean/db/Glean/Query/Reorder.hs
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ reorder dbSchema QueryWithInfo{..} =

reorderQuery :: FlatQuery -> R CgQuery
reorderQuery (FlatQuery pat _ stmts) =
withScopeFor [stmts] $ do
withScopeFor (scopeVars stmts) $ do
stmts' <- reorderGroup stmts
(extra, pat') <- resolved `catchError` \e ->
maybeBindUnboundPredicate e resolved
Expand Down Expand Up @@ -212,11 +212,10 @@ reorderGroup g = do
-- | Define a new scope.
-- Adds all variables local to the statements to the scope at the start and
-- remove them from the scope in the end.
withScopeFor :: [FlatStatementGroup] -> R a -> R a
withScopeFor stmts act = do
withScopeFor :: VarSet -> R a -> R a
withScopeFor stmtsVars act = do
Scope outerScope bound <- gets roScope
let stmtsVars = foldMap scopeVars stmts
locals = IntSet.filter (`IntSet.notMember` outerScope) stmtsVars
let locals = IntSet.filter (`IntSet.notMember` outerScope) stmtsVars

modify $ \s -> s { roScope = Scope (outerScope <> locals) bound }
res <- act
Expand All @@ -226,28 +225,29 @@ withScopeFor stmts act = do
without (Scope scope bound) x =
Scope (IntSet.difference scope x)
(IntMap.filterWithKey (\v _ -> v `IntSet.notMember` x) bound)
-- | All variables that appear in the scope these statements are in.
-- Does not include variables local to sub-scopes such as those that only
-- appear:
-- - inside a negated subquery
-- - in some but not all branches of a disjunction
-- - in only one of 'else' or (condition + 'then') clauses of an if stmt
scopeVars :: FlatStatementGroup -> VarSet
scopeVars (FlatStatementGroup ord) =
foldMap (stmtScope . unOrdered) ord
where
stmtScope = \case
FlatNegation{} -> mempty
s@FlatStatement{} -> vars s
s@FlatAllStatement{} -> vars s
-- only count variables that appear in all branches of the disjunction
FlatDisjunction [] -> mempty
FlatDisjunction (s:ss) ->
foldr (IntSet.intersection . scopeVars) (scopeVars s) ss
FlatConditional cond then_ else_ ->
IntSet.intersection
(scopeVars cond <> scopeVars then_)
(scopeVars else_)

-- | All variables that appear in the scope these statements are in.
-- Does not include variables local to sub-scopes such as those that only
-- appear:
-- - inside a negated subquery
-- - in some but not all branches of a disjunction
-- - in only one of 'else' or (condition + 'then') clauses of an if stmt
scopeVars :: FlatStatementGroup -> VarSet
scopeVars (FlatStatementGroup ord) =
foldMap (stmtScope . unOrdered) ord
where
stmtScope = \case
FlatNegation{} -> mempty
s@FlatStatement{} -> vars s
FlatAllStatement{} -> mempty
-- only count variables that appear in all branches of the disjunction
FlatDisjunction [] -> mempty
FlatDisjunction (s:ss) ->
foldr (IntSet.intersection . scopeVars) (scopeVars s) ss
FlatConditional cond then_ else_ ->
IntSet.intersection
(scopeVars cond <> scopeVars then_)
(scopeVars else_)

{-
Note [Optimising statement groups]
Expand Down Expand Up @@ -931,18 +931,20 @@ toCgStatement stmt = case stmt of
lhs' <- fixVars IsPat lhs
return [CgStatement lhs' gen']
FlatAllStatement v e g -> do
stmts <- reorderGroup g
e' <- fixVars IsExpr e
cg <- withScopeFor (scopeVars g <> vars e) $ do
stmts <- reorderGroup g
e' <- fixVars IsExpr e
return [CgAllStatement v e' stmts]
bindVar v
return [CgAllStatement v e' stmts]
return cg
FlatNegation stmts -> do
stmts' <-
withinNegation $
withScopeFor [stmts] $
withScopeFor (scopeVars stmts) $
reorderGroup stmts
return [CgNegation stmts']
FlatDisjunction [stmts] ->
withScopeFor [stmts] $ reorderGroup stmts
withScopeFor (scopeVars stmts) $ reorderGroup stmts
FlatDisjunction groups -> do
cg <- map runIdentity <$> intersectBindings (map Identity groups)
return [CgDisjunction cg]
Expand All @@ -960,7 +962,8 @@ toCgStatement stmt = case stmt of
initialScope <- gets roScope
results <- forM groups $ \tgroup -> do
modify $ \state -> state { roScope = initialScope }
tstmts <- withScopeFor (toList tgroup) $ traverse reorderGroup tgroup
tstmts <- withScopeFor (foldMap scopeVars (toList tgroup)) $
traverse reorderGroup tgroup
newScope <- gets roScope
return (tstmts, allBound newScope)

Expand Down
4 changes: 4 additions & 0 deletions glean/test/tests/Angle/SetTest.hs
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ setSemanticsTest = TestList
, TestLabel "multiple set results" $ dbTestCase $ \env repo -> do
r <- runQuery_ env repo $ angleData @(Set Nat) [s| all X where X= (1|2) |]
assertEqual "results" 2 (length r)
, TestLabel "predicate in all" $ dbTestCase $ \env repo -> do
[set] <- runQuery_ env repo $ angleData @(Set Glean.Test.Predicate)
[s| all (glean.test.Predicate _) |]
assertEqual "angle - set matching" 4 (size set)
]

setLimitTest :: Test
Expand Down

0 comments on commit 91e40a4

Please sign in to comment.