Skip to content

Commit

Permalink
Use cypher params in _query
Browse files Browse the repository at this point in the history
This commit changes `_query()` function in Neo4jStore and adds test that
shows how the exploit could be abused in previous versions
  • Loading branch information
LilDojd committed Nov 16, 2024
1 parent 99034dc commit 10456b9
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 19 deletions.
34 changes: 15 additions & 19 deletions alchemiscale/storage/statestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,26 +468,21 @@ def _get_node(
) -> Union[Node, Tuple[Node, Subgraph]]:
"""
If `return_subgraph = True`, also return subgraph for gufe object.
"""
qualname = scoped_key.qualname

properties = {"_scoped_key": str(scoped_key)}
prop_string = ", ".join(
"{}: '{}'".format(key, value) for key, value in properties.items()
)

prop_string = f" {{{prop_string}}}"
# Safety: qualname comes from GufeKey which is validated
qualname = scoped_key.qualname
parameters = {"scoped_key": str(scoped_key)}

q = f"""
MATCH (n:{qualname}{prop_string})
MATCH (n:{qualname} {{ _scoped_key: $scoped_key }})
"""

if return_subgraph:
q += """
OPTIONAL MATCH p = (n)-[r:DEPENDS_ON*]->(m)
WHERE NOT (m)-[:DEPENDS_ON]->()
RETURN n,p
RETURN n, p
"""
else:
q += """
Expand All @@ -497,10 +492,12 @@ def _get_node(
nodes = set()
subgraph = Subgraph()

for record in self.execute_query(q).records:
result = self.execute_query(q, parameters_=parameters)

for record in result.records:
node = record_data_to_node(record["n"])
nodes.add(node)
if return_subgraph and record["p"] is not None:
if return_subgraph and record.get("p") is not None:
subgraph = subgraph | subgraph_from_path_record(record["p"])
else:
subgraph = node
Expand All @@ -521,8 +518,8 @@ def _query(
self,
*,
qualname: str,
additional: Dict = None,
key: GufeKey = None,
additional: Optional[Dict] = None,
key: Optional[GufeKey] = None,
scope: Scope = Scope(),
return_gufe=False,
):
Expand All @@ -532,9 +529,8 @@ def _query(
"_project": scope.project,
}

for k, v in list(properties.items()):
if v is None:
properties.pop(k)
# Remove None values from properties
properties = {k: v for k, v in properties.items() if v is not None}

if key is not None:
properties["_gufe_key"] = str(key)
Expand All @@ -547,7 +543,7 @@ def _query(
prop_string = ""
else:
prop_string = ", ".join(
"{}: '{}'".format(key, value) for key, value in properties.items()
"{}: ${}".format(key, key) for key in properties.keys()
)

prop_string = f" {{{prop_string}}}"
Expand All @@ -568,7 +564,7 @@ def _query(
"""

with self.transaction() as tx:
res = tx.run(q).to_eager_result()
res = tx.run(q, **properties).to_eager_result()

nodes = list()
subgraph = Subgraph()
Expand Down
46 changes: 46 additions & 0 deletions alchemiscale/tests/integration/storage/test_statestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,52 @@ def test_query_transformations(self, n4js, network_tyk2, multiple_scopes):
== 1
)

def test_query_transformations_exploit(self, n4js, multiple_scopes, network_tyk2):
# This test is to show that common cypher exploits are mitigated by using parameters

an = network_tyk2

n4js.assemble_network(an, multiple_scopes[0])
n4js.assemble_network(an, multiple_scopes[1])

malicious_name = """'})
WITH {_org: '', _campaign: '', _project: '', _gufe_key: ''} AS n
RETURN n
UNION
MATCH (m) DETACH DELETE m
WITH {_org: '', _campaign: '', _project: '', _gufe_key: ''} AS n
RETURN n
UNION
CREATE (mark:InjectionMark {_scoped_key: 'InjectionMark-12345-test-testcamp-testproj'})
WITH {_org: '', _campaign: '', _project: '', _gufe_key: ''} AS n // """
try:
n4js.query_transformations(name=malicious_name)
except AttributeError as e:
# With old _query, AttributeError would be thrown AFTER the transaction has finished, and the database is already corrupted
assert "'dict' object has no attribute 'labels'" in str(e)
assert len(n4js.query_transformations(scope=multiple_scopes[0])) == 0

mark = n4js._query(qualname="InjectionMark")
assert len(mark) == 0

assert len(n4js.query_transformations()) == len(network_tyk2.edges) * 2
assert len(n4js.query_transformations(scope=multiple_scopes[0])) == len(
network_tyk2.edges
)

assert (
len(n4js.query_transformations(name="lig_ejm_31_to_lig_ejm_50_complex"))
== 2
)
assert (
len(
n4js.query_transformations(
scope=multiple_scopes[0], name="lig_ejm_31_to_lig_ejm_50_complex"
)
)
== 1
)

def test_query_chemicalsystems(self, n4js, network_tyk2, multiple_scopes):
an = network_tyk2

Expand Down

0 comments on commit 10456b9

Please sign in to comment.