Skip to content

Commit

Permalink
Merge pull request #22 from linw1995/feature/var_parent
Browse files Browse the repository at this point in the history
Add a new context variable "parent".
  • Loading branch information
linw1995 authored Sep 23, 2020
2 parents ac47f66 + 3336e86 commit 6bcbc60
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 53 deletions.
119 changes: 66 additions & 53 deletions jsonpath/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
import weakref

from abc import abstractmethod
from contextlib import suppress
from contextlib import contextmanager, suppress
from contextvars import ContextVar
from typing import (
Any,
Callable,
Dict,
Generator,
Iterable,
List,
Optional,
Expand All @@ -31,6 +32,7 @@


var_root: ContextVar[Any] = ContextVar("root")
var_parent: ContextVar[Union[List[Any], Dict[str, Any]]] = ContextVar("parent")
T_SELF_VALUE = Union[Tuple[int, Any], Tuple[str, Any]]
var_self: ContextVar[T_SELF_VALUE] = ContextVar("self")
var_finding: ContextVar[bool] = ContextVar("finding", default=False)
Expand Down Expand Up @@ -76,23 +78,52 @@ class JSONPathFindError(JSONPathError):
"""


def _dfs_find(
expr: Optional["Expr"], elements: List[Any], rv: List[Any]
) -> None:
@contextmanager
def temporary_set(
context_var: ContextVar[Any], value: Any
) -> Generator[None, None, None]:
"""
Set the context variable temporarily via the 'with' statement.
>>> var_boo = ContextVar("boo")
>>> with temporary_set(var_boo, True):
... assert var_boo.get() is True
>>> var_boo.get()
Traceback (most recent call last):
...
LookupError: ...
"""
token = context_var.set(value)
try:
yield
finally:
context_var.reset(token)


def _dfs_find(expr: "Expr", elements: List[Any], rv: List[Any]) -> None:
"""
use DFS to find all target elements.
the next expr finds in the result found by the current expr.
"""
if expr is None:
rv.extend(elements)
else:
for element in elements:
try:
_dfs_find(
expr.get_next(), expr.find(element), rv,
)
except JSONPathFindError:
pass
next_expr = expr.get_next()
for element in elements:
try:
found_elements = expr.find(element)
except JSONPathFindError:
continue

if not found_elements:
continue

if next_expr is None:
# collect all found elements if there is no next expr.
rv.extend(found_elements)
continue

with temporary_set(var_parent, element):
_dfs_find(
next_expr, found_elements, rv,
)


class ExprMeta(type):
Expand Down Expand Up @@ -137,12 +168,11 @@ def find(self: "Expr", element: Any) -> List[Any]:
# but only the first time finding can set the root element.
token_root = var_root.set(element)

token_finding = var_finding.set(True)
try:
_dfs_find(begin, [element], rv)
with temporary_set(var_finding, True):
_dfs_find(begin, [element], rv)
return rv
finally:
var_finding.reset(token_finding)
if token_root:
var_root.reset(token_root)

Expand Down Expand Up @@ -518,18 +548,15 @@ def find(self, element: Any) -> List[Any]:

for item in items:
# save the current item into var_self for Self()
token_self = var_self.set(item)
# set var_finding False to
# start new finding process for the nested expr: self.idx
token_finding = var_finding.set(False)
_, value = item
try:
with temporary_set(var_self, item), temporary_set(
var_finding, False
):
_, value = item
rv = self.expr.find(value)
if rv and rv[0]:
filtered_items.append(value)
finally:
var_finding.reset(token_finding)
var_self.reset(token_self)
return filtered_items


Expand Down Expand Up @@ -573,8 +600,7 @@ def find(self, element: List[Any]) -> Any:
if isinstance(element, list):
# set var_finding False to start new finding process for
# the nested expr: self.start, self.end and self.step
token_finding = var_finding.set(False)
try:
with temporary_set(var_finding, False):
start = (
self.start.find(element)
if isinstance(self.start, Expr)
Expand All @@ -590,8 +616,6 @@ def find(self, element: List[Any]) -> Any:
if isinstance(self.step, Expr)
else self.step
)
finally:
var_finding.reset(token_finding)

if not start:
start = 0
Expand Down Expand Up @@ -644,11 +668,8 @@ def _get_partial_expression(self) -> str:
def find(self, element: Any) -> List[Any]:
# set var_finding False to
# start new finding process for the nested expr: self.expr
token = var_finding.set(False)
try:
with temporary_set(var_finding, False):
return [self._expr.find(element)]
finally:
var_finding.reset(token)


def _recursive_find(expr: Expr, element: Any, rv: List[Any]) -> None:
Expand All @@ -660,12 +681,14 @@ def _recursive_find(expr: Expr, element: Any, rv: List[Any]) -> None:
rv.extend(find_rv)
except JSONPathFindError:
pass
if isinstance(element, list):
for item in element:
_recursive_find(expr, item, rv)
elif isinstance(element, dict):
for item in element.values():
_recursive_find(expr, item, rv)

with temporary_set(var_parent, element):
if isinstance(element, list):
for item in element:
_recursive_find(expr, item, rv)
elif isinstance(element, dict):
for item in element.values():
_recursive_find(expr, item, rv)


class Search(Expr):
Expand Down Expand Up @@ -761,10 +784,9 @@ def _get_target_expression(self) -> str:

def get_target_value(self) -> Any:
if isinstance(self.target, Expr):
try:
# set var_finding False to
# start new finding process for the nested expr: self.target
token = var_finding.set(False)
# set var_finding False to
# start new finding process for the nested expr: self.target
with temporary_set(var_finding, False):
# multiple exprs begins on self-value in filtering find,
# except the self.target expr starts with root-value.
_, value = var_self.get()
Expand All @@ -773,9 +795,6 @@ def get_target_value(self) -> Any:
raise JSONPathFindError

return rv[0]
finally:
var_finding.reset(token)

else:
return self.target

Expand Down Expand Up @@ -956,11 +975,8 @@ def find(self, element: Any) -> List[bool]:
if isinstance(target_arg, Expr):
# set var_finding False to
# start new finding process for the nested expr: target_arg
token = var_finding.set(False)
try:
with temporary_set(var_finding, False):
rv = self._target.find(element)
finally:
var_finding.reset(token)

if not rv:
return []
Expand Down Expand Up @@ -994,11 +1010,8 @@ def _get_partial_expression(self) -> str:
def find(self, element: Any) -> List[bool]:
# set var_finding False to
# start new finding process for the nested expr: target
token = var_finding.set(False)
try:
with temporary_set(var_finding, False):
rv = self._expr.find(element)
finally:
var_finding.reset(token)

return [not v for v in rv]

Expand Down
54 changes: 54 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Self,
Slice,
Value,
var_parent,
)


Expand Down Expand Up @@ -426,3 +427,56 @@ def test_get_expression(expr, expect):
],
ids=reprlib.repr,
)(test_get_expression)


def test_get_parent_object():
root = {"a": 1}

class TestName1(Name):
def find(self, element):
with pytest.raises(LookupError):
var_parent.get()

assert element == root
return super().find(element)

assert TestName1("a").find(root) == [1]

root = {"a": {"b": 1}}

class TestName2(Name):
def find(self, element):
assert var_parent.get() == root
assert element == {"b": 1}
return super().find(element)

assert Name("a").chain(TestName2("b")).find(root) == [1]


def test_get_parent_array():
root = [{"a": 1}, {"a": 2}]

class TestName(Name):
def find(self, element):
assert var_parent.get() == root
assert element in root
return super().find(element)

assert Array().chain(TestName("a")).find(root) == [1, 2]


def test_get_parent_while_searching():
root = {"a": {"b": {"c": 1}}}

parents = []
history = []

class TestName(Name):
def find(self, element):
parents.append(var_parent.get())
history.append(element)
return super().find(element)

assert Root().Search(TestName("c")).find(root) == [1]
assert parents == [root, root, root["a"], root["a"]["b"]]
assert history == [root, root["a"], root["a"]["b"], 1]

0 comments on commit 6bcbc60

Please sign in to comment.