Skip to content

Commit

Permalink
Add enum type for visitor return values (#96)
Browse files Browse the repository at this point in the history
  • Loading branch information
Cito committed Jun 26, 2020
1 parent ffdf1b3 commit 57c6e2a
Show file tree
Hide file tree
Showing 38 changed files with 231 additions and 134 deletions.
6 changes: 3 additions & 3 deletions docs/modules/language.rst
Original file line number Diff line number Diff line change
Expand Up @@ -103,17 +103,17 @@ The module also exports the following special symbols which can be used as
return values in the :class:`Visitor` methods to signal particular actions:

.. data:: BREAK
:annotation: = True
:annotation: (same as ``True``)

This return value signals that no further nodes shall be visited.

.. data:: SKIP
:annotation: = False
:annotation: (same as ``False``)

This return value signals that the current node shall be skipped.

.. data:: REMOVE
:annotation: = Ellipsis
:annotation: (same as``Ellipsis``)

This return value signals that the current node shall be deleted.

Expand Down
2 changes: 2 additions & 0 deletions src/graphql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@
visit,
ParallelVisitor,
Visitor,
VisitorAction,
BREAK,
SKIP,
REMOVE,
Expand Down Expand Up @@ -532,6 +533,7 @@
"ParallelVisitor",
"TypeInfoVisitor",
"Visitor",
"VisitorAction",
"BREAK",
"SKIP",
"REMOVE",
Expand Down
2 changes: 2 additions & 0 deletions src/graphql/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
visit,
Visitor,
ParallelVisitor,
VisitorAction,
BREAK,
SKIP,
REMOVE,
Expand Down Expand Up @@ -115,6 +116,7 @@
"visit",
"Visitor",
"ParallelVisitor",
"VisitorAction",
"BREAK",
"SKIP",
"REMOVE",
Expand Down
40 changes: 31 additions & 9 deletions src/graphql/language/visitor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from copy import copy
from enum import Enum
from typing import (
Any,
Callable,
Expand All @@ -19,6 +20,7 @@
__all__ = [
"Visitor",
"ParallelVisitor",
"VisitorAction",
"visit",
"BREAK",
"SKIP",
Expand All @@ -28,10 +30,26 @@
]


# Special return values for the visitor methods:
class VisitorActionEnum(Enum):
"""Special return values for the visitor methods.
You can also use the values of this enum directly.
"""

BREAK = True
SKIP = False
REMOVE = Ellipsis


VisitorAction = Optional[VisitorActionEnum]

# Note that in GraphQL.js these are defined differently:
# BREAK = {}, SKIP = false, REMOVE = null, IDLE = undefined
BREAK, SKIP, REMOVE, IDLE = True, False, Ellipsis, None

BREAK = VisitorActionEnum.BREAK
SKIP = VisitorActionEnum.SKIP
REMOVE = VisitorActionEnum.REMOVE
IDLE = None

# Default map from visitor kinds to their traversable node attributes:
QUERY_DOCUMENT_KEYS: Dict[str, Tuple[str, ...]] = {
Expand Down Expand Up @@ -253,7 +271,7 @@ def visit(root: Node, visitor: Visitor, visitor_keys=None) -> Any:
for edit_key, edit_value in edits:
if in_array:
edit_key -= edit_offset
if in_array and edit_value is REMOVE:
if in_array and (edit_value is REMOVE or edit_value is Ellipsis):
node.pop(edit_key)
edit_offset += 1
else:
Expand Down Expand Up @@ -292,10 +310,10 @@ def visit(root: Node, visitor: Visitor, visitor_keys=None) -> Any:
if visit_fn:
result = visit_fn(visitor, node, key, parent, path, ancestors)

if result is BREAK:
if result is BREAK or result is True:
break

if result is SKIP:
if result is SKIP or result is False:
if not is_leaving:
path_pop()
continue
Expand Down Expand Up @@ -356,9 +374,9 @@ def enter(self, node, *args):
fn = visitor.get_visit_fn(node.kind)
if fn:
result = fn(visitor, node, *args)
if result is SKIP:
if result is SKIP or result is False:
skipping[i] = node
elif result == BREAK:
elif result is BREAK or result is True:
skipping[i] = BREAK
elif result is not None:
return result
Expand All @@ -370,9 +388,13 @@ def leave(self, node, *args):
fn = visitor.get_visit_fn(node.kind, is_leaving=True)
if fn:
result = fn(visitor, node, *args)
if result == BREAK:
if result is BREAK or result is True:
skipping[i] = BREAK
elif result is not None and result is not SKIP:
elif (
result is not None
and result is not SKIP
and result is not False
):
return result
elif skipping[i] is node:
skipping[i] = None
6 changes: 4 additions & 2 deletions src/graphql/validation/rules/executable_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
SchemaDefinitionNode,
SchemaExtensionNode,
TypeDefinitionNode,
VisitorAction,
SKIP,
)
from . import ASTValidationRule

Expand All @@ -21,7 +23,7 @@ class ExecutableDefinitionsRule(ASTValidationRule):
operation or fragment definitions.
"""

def enter_document(self, node: DocumentNode, *_args):
def enter_document(self, node: DocumentNode, *_args) -> VisitorAction:
for definition in node.definitions:
if not isinstance(definition, ExecutableDefinitionNode):
def_name = (
Expand All @@ -41,4 +43,4 @@ def enter_document(self, node: DocumentNode, *_args):
f"The {def_name} definition is not executable.", definition,
)
)
return self.SKIP
return SKIP
2 changes: 1 addition & 1 deletion src/graphql/validation/rules/fields_on_correct_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class FieldsOnCorrectTypeRule(ValidationRule):
type, or are an allowed meta field such as ``__typename``.
"""

def enter_field(self, node: FieldNode, *_args):
def enter_field(self, node: FieldNode, *_args) -> None:
type_ = self.context.get_parent_type()
if not type_:
return
Expand Down
10 changes: 7 additions & 3 deletions src/graphql/validation/rules/fragments_on_composite_types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from ...error import GraphQLError
from ...language import FragmentDefinitionNode, InlineFragmentNode, print_ast
from ...language import (
FragmentDefinitionNode,
InlineFragmentNode,
print_ast,
)
from ...type import is_composite_type
from ...utilities import type_from_ast
from . import ValidationRule
Expand All @@ -15,7 +19,7 @@ class FragmentsOnCompositeTypesRule(ValidationRule):
must also be a composite type.
"""

def enter_inline_fragment(self, node: InlineFragmentNode, *_args):
def enter_inline_fragment(self, node: InlineFragmentNode, *_args) -> None:
type_condition = node.type_condition
if type_condition:
type_ = type_from_ast(self.context.schema, type_condition)
Expand All @@ -29,7 +33,7 @@ def enter_inline_fragment(self, node: InlineFragmentNode, *_args):
)
)

def enter_fragment_definition(self, node: FragmentDefinitionNode, *_args):
def enter_fragment_definition(self, node: FragmentDefinitionNode, *_args) -> None:
type_condition = node.type_condition
type_ = type_from_ast(self.context.schema, type_condition)
if type_ and not is_composite_type(type_):
Expand Down
12 changes: 9 additions & 3 deletions src/graphql/validation/rules/known_argument_names.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from typing import cast, Dict, List, Union

from ...error import GraphQLError
from ...language import ArgumentNode, DirectiveDefinitionNode, DirectiveNode, SKIP
from ...language import (
ArgumentNode,
DirectiveDefinitionNode,
DirectiveNode,
SKIP,
VisitorAction,
)
from ...pyutils import did_you_mean, suggestion_list
from ...type import specified_directives
from . import ASTValidationRule, SDLValidationContext, ValidationContext
Expand Down Expand Up @@ -37,7 +43,7 @@ def __init__(self, context: Union[ValidationContext, SDLValidationContext]):

self.directive_args = directive_args

def enter_directive(self, directive_node: DirectiveNode, *_args):
def enter_directive(self, directive_node: DirectiveNode, *_args) -> VisitorAction:
directive_name = directive_node.name.value
known_args = self.directive_args.get(directive_name)
if directive_node.arguments and known_args is not None:
Expand Down Expand Up @@ -67,7 +73,7 @@ class KnownArgumentNamesRule(KnownArgumentNamesOnDirectivesRule):
def __init__(self, context: ValidationContext):
super().__init__(context)

def enter_argument(self, arg_node: ArgumentNode, *args):
def enter_argument(self, arg_node: ArgumentNode, *args) -> None:
context = self.context
arg_def = context.get_argument()
field_def = context.get_field_def()
Expand Down
4 changes: 3 additions & 1 deletion src/graphql/validation/rules/known_directives.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def __init__(self, context: Union[ValidationContext, SDLValidationContext]):
]
self.locations_map = locations_map

def enter_directive(self, node: DirectiveNode, _key, _parent, _path, ancestors):
def enter_directive(
self, node: DirectiveNode, _key, _parent, _path, ancestors
) -> None:
name = node.name.value
locations = self.locations_map.get(name)
if locations:
Expand Down
2 changes: 1 addition & 1 deletion src/graphql/validation/rules/known_fragment_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class KnownFragmentNamesRule(ValidationRule):
fragments defined in the same document.
"""

def enter_fragment_spread(self, node: FragmentSpreadNode, *_args):
def enter_fragment_spread(self, node: FragmentSpreadNode, *_args) -> None:
fragment_name = node.name.value
fragment = self.context.get_fragment(fragment_name)
if not fragment:
Expand Down
2 changes: 1 addition & 1 deletion src/graphql/validation/rules/known_type_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self, context: Union[ValidationContext, SDLValidationContext]):

def enter_named_type(
self, node: NamedTypeNode, _key, parent: Node, _path, ancestors: List[Node]
):
) -> None:
type_name = node.name.value
if (
type_name not in self.existing_types_map
Expand Down
4 changes: 2 additions & 2 deletions src/graphql/validation/rules/lone_anonymous_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ def __init__(self, context: ASTValidationContext):
super().__init__(context)
self.operation_count = 0

def enter_document(self, node: DocumentNode, *_args):
def enter_document(self, node: DocumentNode, *_args) -> None:
self.operation_count = sum(
1
for definition in node.definitions
if isinstance(definition, OperationDefinitionNode)
)

def enter_operation_definition(self, node: OperationDefinitionNode, *_args):
def enter_operation_definition(self, node: OperationDefinitionNode, *_args) -> None:
if not node.name and self.operation_count > 1:
self.report_error(
GraphQLError(
Expand Down
2 changes: 1 addition & 1 deletion src/graphql/validation/rules/lone_schema_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, context: SDLValidationContext):
)
self.schema_definitions_count = 0

def enter_schema_definition(self, node: SchemaDefinitionNode, *_args):
def enter_schema_definition(self, node: SchemaDefinitionNode, *_args) -> None:
if self.already_defined:
self.report_error(
GraphQLError(
Expand Down
12 changes: 7 additions & 5 deletions src/graphql/validation/rules/no_fragment_cycles.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Dict, List, Set

from ...error import GraphQLError
from ...language import FragmentDefinitionNode, FragmentSpreadNode
from ...language import FragmentDefinitionNode, FragmentSpreadNode, VisitorAction, SKIP
from . import ASTValidationContext, ASTValidationRule

__all__ = ["NoFragmentCyclesRule"]
Expand All @@ -20,12 +20,14 @@ def __init__(self, context: ASTValidationContext):
# Position in the spread path
self.spread_path_index_by_name: Dict[str, int] = {}

def enter_operation_definition(self, *_args):
return self.SKIP
def enter_operation_definition(self, *_args) -> VisitorAction:
return SKIP

def enter_fragment_definition(self, node: FragmentDefinitionNode, *_args):
def enter_fragment_definition(
self, node: FragmentDefinitionNode, *_args
) -> VisitorAction:
self.detect_cycle_recursive(node)
return self.SKIP
return SKIP

def detect_cycle_recursive(self, fragment: FragmentDefinitionNode):
# This does a straight-forward DFS to find cycles.
Expand Down
8 changes: 5 additions & 3 deletions src/graphql/validation/rules/no_undefined_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ def __init__(self, context: ValidationContext):
super().__init__(context)
self.defined_variable_names: Set[str] = set()

def enter_operation_definition(self, *_args):
def enter_operation_definition(self, *_args) -> None:
self.defined_variable_names.clear()

def leave_operation_definition(self, operation: OperationDefinitionNode, *_args):
def leave_operation_definition(
self, operation: OperationDefinitionNode, *_args
) -> None:
usages = self.context.get_recursive_variable_usages(operation)
defined_variables = self.defined_variable_names
for usage in usages:
Expand All @@ -38,5 +40,5 @@ def leave_operation_definition(self, operation: OperationDefinitionNode, *_args)
)
)

def enter_variable_definition(self, node: VariableDefinitionNode, *_args):
def enter_variable_definition(self, node: VariableDefinitionNode, *_args) -> None:
self.defined_variable_names.add(node.variable.name.value)
21 changes: 15 additions & 6 deletions src/graphql/validation/rules/no_unused_fragments.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from typing import List

from ...error import GraphQLError
from ...language import FragmentDefinitionNode, OperationDefinitionNode
from ...language import (
FragmentDefinitionNode,
OperationDefinitionNode,
VisitorAction,
SKIP,
)
from . import ASTValidationContext, ASTValidationRule

__all__ = ["NoUnusedFragmentsRule"]
Expand All @@ -19,15 +24,19 @@ def __init__(self, context: ASTValidationContext):
self.operation_defs: List[OperationDefinitionNode] = []
self.fragment_defs: List[FragmentDefinitionNode] = []

def enter_operation_definition(self, node: OperationDefinitionNode, *_args):
def enter_operation_definition(
self, node: OperationDefinitionNode, *_args
) -> VisitorAction:
self.operation_defs.append(node)
return False
return SKIP

def enter_fragment_definition(self, node: FragmentDefinitionNode, *_args):
def enter_fragment_definition(
self, node: FragmentDefinitionNode, *_args
) -> VisitorAction:
self.fragment_defs.append(node)
return False
return SKIP

def leave_document(self, *_args):
def leave_document(self, *_args) -> None:
fragment_names_used = set()
get_fragments = self.context.get_recursively_referenced_fragments
for operation in self.operation_defs:
Expand Down
Loading

0 comments on commit 57c6e2a

Please sign in to comment.