From 57c6e2a140c69f40d568d4325d3327beb948bf7c Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 27 Jun 2020 00:32:18 +0200 Subject: [PATCH] Add enum type for visitor return values (#96) --- docs/modules/language.rst | 6 +- src/graphql/__init__.py | 2 + src/graphql/language/__init__.py | 2 + src/graphql/language/visitor.py | 40 +++++++--- .../rules/executable_definitions.py | 6 +- .../rules/fields_on_correct_type.py | 2 +- .../rules/fragments_on_composite_types.py | 10 ++- .../validation/rules/known_argument_names.py | 12 ++- .../validation/rules/known_directives.py | 4 +- .../validation/rules/known_fragment_names.py | 2 +- .../validation/rules/known_type_names.py | 2 +- .../rules/lone_anonymous_operation.py | 4 +- .../rules/lone_schema_definition.py | 2 +- .../validation/rules/no_fragment_cycles.py | 12 +-- .../rules/no_undefined_variables.py | 8 +- .../validation/rules/no_unused_fragments.py | 21 ++++-- .../validation/rules/no_unused_variables.py | 10 ++- .../rules/overlapping_fields_can_be_merged.py | 2 +- .../rules/possible_fragment_spreads.py | 4 +- .../rules/possible_type_extensions.py | 2 +- .../rules/provided_required_arguments.py | 10 ++- src/graphql/validation/rules/scalar_leafs.py | 2 +- .../rules/single_field_subscriptions.py | 2 +- .../validation/rules/unique_argument_names.py | 10 +-- .../rules/unique_directive_names.py | 10 ++- .../rules/unique_directives_per_location.py | 2 +- .../rules/unique_enum_value_names.py | 8 +- .../rules/unique_field_definition_names.py | 8 +- .../validation/rules/unique_fragment_names.py | 12 +-- .../rules/unique_input_field_names.py | 6 +- .../rules/unique_operation_names.py | 12 +-- .../rules/unique_operation_types.py | 6 +- .../validation/rules/unique_type_names.py | 8 +- .../validation/rules/unique_variable_names.py | 4 +- .../rules/values_of_correct_type.py | 28 ++++--- .../rules/variables_are_input_types.py | 2 +- .../rules/variables_in_allowed_position.py | 8 +- tests/language/test_visitor.py | 74 +++++++++++-------- 38 files changed, 231 insertions(+), 134 deletions(-) diff --git a/docs/modules/language.rst b/docs/modules/language.rst index c0c376e0..ef649d61 100644 --- a/docs/modules/language.rst +++ b/docs/modules/language.rst @@ -103,17 +103,17 @@ The module also exports the following special symbols which can be used as return values in the :class:`Visitor` methods to signal particular actions: .. data:: BREAK - :annotation: = True + :annotation: (same as ``True``) This return value signals that no further nodes shall be visited. .. data:: SKIP - :annotation: = False + :annotation: (same as ``False``) This return value signals that the current node shall be skipped. .. data:: REMOVE - :annotation: = Ellipsis + :annotation: (same as``Ellipsis``) This return value signals that the current node shall be deleted. diff --git a/src/graphql/__init__.py b/src/graphql/__init__.py index dd8443b5..30b1e1db 100644 --- a/src/graphql/__init__.py +++ b/src/graphql/__init__.py @@ -181,6 +181,7 @@ visit, ParallelVisitor, Visitor, + VisitorAction, BREAK, SKIP, REMOVE, @@ -532,6 +533,7 @@ "ParallelVisitor", "TypeInfoVisitor", "Visitor", + "VisitorAction", "BREAK", "SKIP", "REMOVE", diff --git a/src/graphql/language/__init__.py b/src/graphql/language/__init__.py index a4a38b69..9e5904a4 100644 --- a/src/graphql/language/__init__.py +++ b/src/graphql/language/__init__.py @@ -22,6 +22,7 @@ visit, Visitor, ParallelVisitor, + VisitorAction, BREAK, SKIP, REMOVE, @@ -115,6 +116,7 @@ "visit", "Visitor", "ParallelVisitor", + "VisitorAction", "BREAK", "SKIP", "REMOVE", diff --git a/src/graphql/language/visitor.py b/src/graphql/language/visitor.py index 4d9d4991..d60aea9f 100644 --- a/src/graphql/language/visitor.py +++ b/src/graphql/language/visitor.py @@ -1,4 +1,5 @@ from copy import copy +from enum import Enum from typing import ( Any, Callable, @@ -19,6 +20,7 @@ __all__ = [ "Visitor", "ParallelVisitor", + "VisitorAction", "visit", "BREAK", "SKIP", @@ -28,10 +30,26 @@ ] -# Special return values for the visitor methods: +class VisitorActionEnum(Enum): + """Special return values for the visitor methods. + + You can also use the values of this enum directly. + """ + + BREAK = True + SKIP = False + REMOVE = Ellipsis + + +VisitorAction = Optional[VisitorActionEnum] + # Note that in GraphQL.js these are defined differently: # BREAK = {}, SKIP = false, REMOVE = null, IDLE = undefined -BREAK, SKIP, REMOVE, IDLE = True, False, Ellipsis, None + +BREAK = VisitorActionEnum.BREAK +SKIP = VisitorActionEnum.SKIP +REMOVE = VisitorActionEnum.REMOVE +IDLE = None # Default map from visitor kinds to their traversable node attributes: QUERY_DOCUMENT_KEYS: Dict[str, Tuple[str, ...]] = { @@ -253,7 +271,7 @@ def visit(root: Node, visitor: Visitor, visitor_keys=None) -> Any: for edit_key, edit_value in edits: if in_array: edit_key -= edit_offset - if in_array and edit_value is REMOVE: + if in_array and (edit_value is REMOVE or edit_value is Ellipsis): node.pop(edit_key) edit_offset += 1 else: @@ -292,10 +310,10 @@ def visit(root: Node, visitor: Visitor, visitor_keys=None) -> Any: if visit_fn: result = visit_fn(visitor, node, key, parent, path, ancestors) - if result is BREAK: + if result is BREAK or result is True: break - if result is SKIP: + if result is SKIP or result is False: if not is_leaving: path_pop() continue @@ -356,9 +374,9 @@ def enter(self, node, *args): fn = visitor.get_visit_fn(node.kind) if fn: result = fn(visitor, node, *args) - if result is SKIP: + if result is SKIP or result is False: skipping[i] = node - elif result == BREAK: + elif result is BREAK or result is True: skipping[i] = BREAK elif result is not None: return result @@ -370,9 +388,13 @@ def leave(self, node, *args): fn = visitor.get_visit_fn(node.kind, is_leaving=True) if fn: result = fn(visitor, node, *args) - if result == BREAK: + if result is BREAK or result is True: skipping[i] = BREAK - elif result is not None and result is not SKIP: + elif ( + result is not None + and result is not SKIP + and result is not False + ): return result elif skipping[i] is node: skipping[i] = None diff --git a/src/graphql/validation/rules/executable_definitions.py b/src/graphql/validation/rules/executable_definitions.py index 4cdc84d2..a8386266 100644 --- a/src/graphql/validation/rules/executable_definitions.py +++ b/src/graphql/validation/rules/executable_definitions.py @@ -8,6 +8,8 @@ SchemaDefinitionNode, SchemaExtensionNode, TypeDefinitionNode, + VisitorAction, + SKIP, ) from . import ASTValidationRule @@ -21,7 +23,7 @@ class ExecutableDefinitionsRule(ASTValidationRule): operation or fragment definitions. """ - def enter_document(self, node: DocumentNode, *_args): + def enter_document(self, node: DocumentNode, *_args) -> VisitorAction: for definition in node.definitions: if not isinstance(definition, ExecutableDefinitionNode): def_name = ( @@ -41,4 +43,4 @@ def enter_document(self, node: DocumentNode, *_args): f"The {def_name} definition is not executable.", definition, ) ) - return self.SKIP + return SKIP diff --git a/src/graphql/validation/rules/fields_on_correct_type.py b/src/graphql/validation/rules/fields_on_correct_type.py index c784eae6..e15ac08b 100644 --- a/src/graphql/validation/rules/fields_on_correct_type.py +++ b/src/graphql/validation/rules/fields_on_correct_type.py @@ -27,7 +27,7 @@ class FieldsOnCorrectTypeRule(ValidationRule): type, or are an allowed meta field such as ``__typename``. """ - def enter_field(self, node: FieldNode, *_args): + def enter_field(self, node: FieldNode, *_args) -> None: type_ = self.context.get_parent_type() if not type_: return diff --git a/src/graphql/validation/rules/fragments_on_composite_types.py b/src/graphql/validation/rules/fragments_on_composite_types.py index d3a77594..e0a12249 100644 --- a/src/graphql/validation/rules/fragments_on_composite_types.py +++ b/src/graphql/validation/rules/fragments_on_composite_types.py @@ -1,5 +1,9 @@ from ...error import GraphQLError -from ...language import FragmentDefinitionNode, InlineFragmentNode, print_ast +from ...language import ( + FragmentDefinitionNode, + InlineFragmentNode, + print_ast, +) from ...type import is_composite_type from ...utilities import type_from_ast from . import ValidationRule @@ -15,7 +19,7 @@ class FragmentsOnCompositeTypesRule(ValidationRule): must also be a composite type. """ - def enter_inline_fragment(self, node: InlineFragmentNode, *_args): + def enter_inline_fragment(self, node: InlineFragmentNode, *_args) -> None: type_condition = node.type_condition if type_condition: type_ = type_from_ast(self.context.schema, type_condition) @@ -29,7 +33,7 @@ def enter_inline_fragment(self, node: InlineFragmentNode, *_args): ) ) - def enter_fragment_definition(self, node: FragmentDefinitionNode, *_args): + def enter_fragment_definition(self, node: FragmentDefinitionNode, *_args) -> None: type_condition = node.type_condition type_ = type_from_ast(self.context.schema, type_condition) if type_ and not is_composite_type(type_): diff --git a/src/graphql/validation/rules/known_argument_names.py b/src/graphql/validation/rules/known_argument_names.py index e094b10a..706aa38b 100644 --- a/src/graphql/validation/rules/known_argument_names.py +++ b/src/graphql/validation/rules/known_argument_names.py @@ -1,7 +1,13 @@ from typing import cast, Dict, List, Union from ...error import GraphQLError -from ...language import ArgumentNode, DirectiveDefinitionNode, DirectiveNode, SKIP +from ...language import ( + ArgumentNode, + DirectiveDefinitionNode, + DirectiveNode, + SKIP, + VisitorAction, +) from ...pyutils import did_you_mean, suggestion_list from ...type import specified_directives from . import ASTValidationRule, SDLValidationContext, ValidationContext @@ -37,7 +43,7 @@ def __init__(self, context: Union[ValidationContext, SDLValidationContext]): self.directive_args = directive_args - def enter_directive(self, directive_node: DirectiveNode, *_args): + def enter_directive(self, directive_node: DirectiveNode, *_args) -> VisitorAction: directive_name = directive_node.name.value known_args = self.directive_args.get(directive_name) if directive_node.arguments and known_args is not None: @@ -67,7 +73,7 @@ class KnownArgumentNamesRule(KnownArgumentNamesOnDirectivesRule): def __init__(self, context: ValidationContext): super().__init__(context) - def enter_argument(self, arg_node: ArgumentNode, *args): + def enter_argument(self, arg_node: ArgumentNode, *args) -> None: context = self.context arg_def = context.get_argument() field_def = context.get_field_def() diff --git a/src/graphql/validation/rules/known_directives.py b/src/graphql/validation/rules/known_directives.py index 3998847d..e6824f2f 100644 --- a/src/graphql/validation/rules/known_directives.py +++ b/src/graphql/validation/rules/known_directives.py @@ -41,7 +41,9 @@ def __init__(self, context: Union[ValidationContext, SDLValidationContext]): ] self.locations_map = locations_map - def enter_directive(self, node: DirectiveNode, _key, _parent, _path, ancestors): + def enter_directive( + self, node: DirectiveNode, _key, _parent, _path, ancestors + ) -> None: name = node.name.value locations = self.locations_map.get(name) if locations: diff --git a/src/graphql/validation/rules/known_fragment_names.py b/src/graphql/validation/rules/known_fragment_names.py index 49929106..66c030fd 100644 --- a/src/graphql/validation/rules/known_fragment_names.py +++ b/src/graphql/validation/rules/known_fragment_names.py @@ -12,7 +12,7 @@ class KnownFragmentNamesRule(ValidationRule): fragments defined in the same document. """ - def enter_fragment_spread(self, node: FragmentSpreadNode, *_args): + def enter_fragment_spread(self, node: FragmentSpreadNode, *_args) -> None: fragment_name = node.name.value fragment = self.context.get_fragment(fragment_name) if not fragment: diff --git a/src/graphql/validation/rules/known_type_names.py b/src/graphql/validation/rules/known_type_names.py index ae93712e..bcf3165e 100644 --- a/src/graphql/validation/rules/known_type_names.py +++ b/src/graphql/validation/rules/known_type_names.py @@ -39,7 +39,7 @@ def __init__(self, context: Union[ValidationContext, SDLValidationContext]): def enter_named_type( self, node: NamedTypeNode, _key, parent: Node, _path, ancestors: List[Node] - ): + ) -> None: type_name = node.name.value if ( type_name not in self.existing_types_map diff --git a/src/graphql/validation/rules/lone_anonymous_operation.py b/src/graphql/validation/rules/lone_anonymous_operation.py index 426ccac2..15eb830c 100644 --- a/src/graphql/validation/rules/lone_anonymous_operation.py +++ b/src/graphql/validation/rules/lone_anonymous_operation.py @@ -16,14 +16,14 @@ def __init__(self, context: ASTValidationContext): super().__init__(context) self.operation_count = 0 - def enter_document(self, node: DocumentNode, *_args): + def enter_document(self, node: DocumentNode, *_args) -> None: self.operation_count = sum( 1 for definition in node.definitions if isinstance(definition, OperationDefinitionNode) ) - def enter_operation_definition(self, node: OperationDefinitionNode, *_args): + def enter_operation_definition(self, node: OperationDefinitionNode, *_args) -> None: if not node.name and self.operation_count > 1: self.report_error( GraphQLError( diff --git a/src/graphql/validation/rules/lone_schema_definition.py b/src/graphql/validation/rules/lone_schema_definition.py index fd242914..e070983b 100644 --- a/src/graphql/validation/rules/lone_schema_definition.py +++ b/src/graphql/validation/rules/lone_schema_definition.py @@ -22,7 +22,7 @@ def __init__(self, context: SDLValidationContext): ) self.schema_definitions_count = 0 - def enter_schema_definition(self, node: SchemaDefinitionNode, *_args): + def enter_schema_definition(self, node: SchemaDefinitionNode, *_args) -> None: if self.already_defined: self.report_error( GraphQLError( diff --git a/src/graphql/validation/rules/no_fragment_cycles.py b/src/graphql/validation/rules/no_fragment_cycles.py index 5b4d63bf..fd4c5418 100644 --- a/src/graphql/validation/rules/no_fragment_cycles.py +++ b/src/graphql/validation/rules/no_fragment_cycles.py @@ -1,7 +1,7 @@ from typing import Dict, List, Set from ...error import GraphQLError -from ...language import FragmentDefinitionNode, FragmentSpreadNode +from ...language import FragmentDefinitionNode, FragmentSpreadNode, VisitorAction, SKIP from . import ASTValidationContext, ASTValidationRule __all__ = ["NoFragmentCyclesRule"] @@ -20,12 +20,14 @@ def __init__(self, context: ASTValidationContext): # Position in the spread path self.spread_path_index_by_name: Dict[str, int] = {} - def enter_operation_definition(self, *_args): - return self.SKIP + def enter_operation_definition(self, *_args) -> VisitorAction: + return SKIP - def enter_fragment_definition(self, node: FragmentDefinitionNode, *_args): + def enter_fragment_definition( + self, node: FragmentDefinitionNode, *_args + ) -> VisitorAction: self.detect_cycle_recursive(node) - return self.SKIP + return SKIP def detect_cycle_recursive(self, fragment: FragmentDefinitionNode): # This does a straight-forward DFS to find cycles. diff --git a/src/graphql/validation/rules/no_undefined_variables.py b/src/graphql/validation/rules/no_undefined_variables.py index 9a5bc852..33959fc6 100644 --- a/src/graphql/validation/rules/no_undefined_variables.py +++ b/src/graphql/validation/rules/no_undefined_variables.py @@ -18,10 +18,12 @@ def __init__(self, context: ValidationContext): super().__init__(context) self.defined_variable_names: Set[str] = set() - def enter_operation_definition(self, *_args): + def enter_operation_definition(self, *_args) -> None: self.defined_variable_names.clear() - def leave_operation_definition(self, operation: OperationDefinitionNode, *_args): + def leave_operation_definition( + self, operation: OperationDefinitionNode, *_args + ) -> None: usages = self.context.get_recursive_variable_usages(operation) defined_variables = self.defined_variable_names for usage in usages: @@ -38,5 +40,5 @@ def leave_operation_definition(self, operation: OperationDefinitionNode, *_args) ) ) - def enter_variable_definition(self, node: VariableDefinitionNode, *_args): + def enter_variable_definition(self, node: VariableDefinitionNode, *_args) -> None: self.defined_variable_names.add(node.variable.name.value) diff --git a/src/graphql/validation/rules/no_unused_fragments.py b/src/graphql/validation/rules/no_unused_fragments.py index 710e1399..42c5180a 100644 --- a/src/graphql/validation/rules/no_unused_fragments.py +++ b/src/graphql/validation/rules/no_unused_fragments.py @@ -1,7 +1,12 @@ from typing import List from ...error import GraphQLError -from ...language import FragmentDefinitionNode, OperationDefinitionNode +from ...language import ( + FragmentDefinitionNode, + OperationDefinitionNode, + VisitorAction, + SKIP, +) from . import ASTValidationContext, ASTValidationRule __all__ = ["NoUnusedFragmentsRule"] @@ -19,15 +24,19 @@ def __init__(self, context: ASTValidationContext): self.operation_defs: List[OperationDefinitionNode] = [] self.fragment_defs: List[FragmentDefinitionNode] = [] - def enter_operation_definition(self, node: OperationDefinitionNode, *_args): + def enter_operation_definition( + self, node: OperationDefinitionNode, *_args + ) -> VisitorAction: self.operation_defs.append(node) - return False + return SKIP - def enter_fragment_definition(self, node: FragmentDefinitionNode, *_args): + def enter_fragment_definition( + self, node: FragmentDefinitionNode, *_args + ) -> VisitorAction: self.fragment_defs.append(node) - return False + return SKIP - def leave_document(self, *_args): + def leave_document(self, *_args) -> None: fragment_names_used = set() get_fragments = self.context.get_recursively_referenced_fragments for operation in self.operation_defs: diff --git a/src/graphql/validation/rules/no_unused_variables.py b/src/graphql/validation/rules/no_unused_variables.py index eb655372..4b69d5df 100644 --- a/src/graphql/validation/rules/no_unused_variables.py +++ b/src/graphql/validation/rules/no_unused_variables.py @@ -18,10 +18,12 @@ def __init__(self, context: ValidationContext): super().__init__(context) self.variable_defs: List[VariableDefinitionNode] = [] - def enter_operation_definition(self, *_args): + def enter_operation_definition(self, *_args) -> None: self.variable_defs.clear() - def leave_operation_definition(self, operation: OperationDefinitionNode, *_args): + def leave_operation_definition( + self, operation: OperationDefinitionNode, *_args + ) -> None: variable_name_used: Set[str] = set() usages = self.context.get_recursive_variable_usages(operation) @@ -41,5 +43,7 @@ def leave_operation_definition(self, operation: OperationDefinitionNode, *_args) ) ) - def enter_variable_definition(self, definition: VariableDefinitionNode, *_args): + def enter_variable_definition( + self, definition: VariableDefinitionNode, *_args + ) -> None: self.variable_defs.append(definition) diff --git a/src/graphql/validation/rules/overlapping_fields_can_be_merged.py b/src/graphql/validation/rules/overlapping_fields_can_be_merged.py index b5a5da54..ee1d63e3 100644 --- a/src/graphql/validation/rules/overlapping_fields_can_be_merged.py +++ b/src/graphql/validation/rules/overlapping_fields_can_be_merged.py @@ -63,7 +63,7 @@ def __init__(self, context: ValidationContext): # times, so this improves the performance of this validator. self.cached_fields_and_fragment_names: Dict = {} - def enter_selection_set(self, selection_set: SelectionSetNode, *_args): + def enter_selection_set(self, selection_set: SelectionSetNode, *_args) -> None: conflicts = find_conflicts_within_selection_set( self.context, self.cached_fields_and_fragment_names, diff --git a/src/graphql/validation/rules/possible_fragment_spreads.py b/src/graphql/validation/rules/possible_fragment_spreads.py index ffd09631..b043904b 100644 --- a/src/graphql/validation/rules/possible_fragment_spreads.py +++ b/src/graphql/validation/rules/possible_fragment_spreads.py @@ -15,7 +15,7 @@ class PossibleFragmentSpreadsRule(ValidationRule): types which pass the type condition. """ - def enter_inline_fragment(self, node: InlineFragmentNode, *_args): + def enter_inline_fragment(self, node: InlineFragmentNode, *_args) -> None: context = self.context frag_type = context.get_type() parent_type = context.get_parent_type() @@ -32,7 +32,7 @@ def enter_inline_fragment(self, node: InlineFragmentNode, *_args): ) ) - def enter_fragment_spread(self, node: FragmentSpreadNode, *_args): + def enter_fragment_spread(self, node: FragmentSpreadNode, *_args) -> None: context = self.context frag_name = node.name.value frag_type = self.get_fragment_type(frag_name) diff --git a/src/graphql/validation/rules/possible_type_extensions.py b/src/graphql/validation/rules/possible_type_extensions.py index cf412410..adbf79d8 100644 --- a/src/graphql/validation/rules/possible_type_extensions.py +++ b/src/graphql/validation/rules/possible_type_extensions.py @@ -33,7 +33,7 @@ def __init__(self, context: SDLValidationContext): if isinstance(def_, TypeDefinitionNode) } - def check_extension(self, node: TypeExtensionNode, *_args): + def check_extension(self, node: TypeExtensionNode, *_args) -> None: schema = self.schema type_name = node.name.value def_node = self.defined_types.get(type_name) diff --git a/src/graphql/validation/rules/provided_required_arguments.py b/src/graphql/validation/rules/provided_required_arguments.py index 7e7e38b0..8bc721d3 100644 --- a/src/graphql/validation/rules/provided_required_arguments.py +++ b/src/graphql/validation/rules/provided_required_arguments.py @@ -8,6 +8,8 @@ InputValueDefinitionNode, NonNullTypeNode, TypeNode, + VisitorAction, + SKIP, print_ast, ) from ...pyutils import FrozenList @@ -53,7 +55,7 @@ def __init__(self, context: Union[ValidationContext, SDLValidationContext]): self.required_args_map = required_args_map - def leave_directive(self, directive_node: DirectiveNode, *_args): + def leave_directive(self, directive_node: DirectiveNode, *_args) -> None: # Validate on leave to allow for deeper errors to appear first. directive_name = directive_node.name.value required_args = self.required_args_map.get(directive_name) @@ -91,11 +93,11 @@ class ProvidedRequiredArgumentsRule(ProvidedRequiredArgumentsOnDirectivesRule): def __init__(self, context: ValidationContext): super().__init__(context) - def leave_field(self, field_node: FieldNode, *_args): + def leave_field(self, field_node: FieldNode, *_args) -> VisitorAction: # Validate on leave to allow for deeper errors to appear first. field_def = self.context.get_field_def() if not field_def: - return self.SKIP + return SKIP arg_nodes = field_node.arguments or FrozenList() arg_node_map = {arg.name.value: arg for arg in arg_nodes} @@ -111,6 +113,8 @@ def leave_field(self, field_node: FieldNode, *_args): ) ) + return None + def is_required_argument_node(arg: InputValueDefinitionNode) -> bool: return isinstance(arg.type, NonNullTypeNode) and arg.default_value is None diff --git a/src/graphql/validation/rules/scalar_leafs.py b/src/graphql/validation/rules/scalar_leafs.py index 66a3cf14..7364a4f8 100644 --- a/src/graphql/validation/rules/scalar_leafs.py +++ b/src/graphql/validation/rules/scalar_leafs.py @@ -13,7 +13,7 @@ class ScalarLeafsRule(ValidationRule): are of scalar or enum types. """ - def enter_field(self, node: FieldNode, *_args): + def enter_field(self, node: FieldNode, *_args) -> None: type_ = self.context.get_type() if type_: selection_set = node.selection_set diff --git a/src/graphql/validation/rules/single_field_subscriptions.py b/src/graphql/validation/rules/single_field_subscriptions.py index 6b948d4d..0a468b43 100644 --- a/src/graphql/validation/rules/single_field_subscriptions.py +++ b/src/graphql/validation/rules/single_field_subscriptions.py @@ -11,7 +11,7 @@ class SingleFieldSubscriptionsRule(ASTValidationRule): A GraphQL subscription is valid only if it contains a single root. """ - def enter_operation_definition(self, node: OperationDefinitionNode, *_args): + def enter_operation_definition(self, node: OperationDefinitionNode, *_args) -> None: if node.operation == OperationType.SUBSCRIPTION: if len(node.selection_set.selections) != 1: self.report_error( diff --git a/src/graphql/validation/rules/unique_argument_names.py b/src/graphql/validation/rules/unique_argument_names.py index 77f406e1..a1e12912 100644 --- a/src/graphql/validation/rules/unique_argument_names.py +++ b/src/graphql/validation/rules/unique_argument_names.py @@ -1,7 +1,7 @@ from typing import Dict from ...error import GraphQLError -from ...language import NameNode, ArgumentNode +from ...language import ArgumentNode, NameNode, VisitorAction, SKIP from . import ASTValidationContext, ASTValidationRule __all__ = ["UniqueArgumentNamesRule"] @@ -18,13 +18,13 @@ def __init__(self, context: ASTValidationContext): super().__init__(context) self.known_arg_names: Dict[str, NameNode] = {} - def enter_field(self, *_args): + def enter_field(self, *_args) -> None: self.known_arg_names.clear() - def enter_directive(self, *_args): + def enter_directive(self, *_args) -> None: self.known_arg_names.clear() - def enter_argument(self, node: ArgumentNode, *_args): + def enter_argument(self, node: ArgumentNode, *_args) -> VisitorAction: known_arg_names = self.known_arg_names arg_name = node.name.value if arg_name in known_arg_names: @@ -36,4 +36,4 @@ def enter_argument(self, node: ArgumentNode, *_args): ) else: known_arg_names[arg_name] = node.name - return self.SKIP + return SKIP diff --git a/src/graphql/validation/rules/unique_directive_names.py b/src/graphql/validation/rules/unique_directive_names.py index 656c9a0f..66dcb27e 100644 --- a/src/graphql/validation/rules/unique_directive_names.py +++ b/src/graphql/validation/rules/unique_directive_names.py @@ -1,7 +1,7 @@ from typing import Dict from ...error import GraphQLError -from ...language import NameNode, DirectiveDefinitionNode +from ...language import DirectiveDefinitionNode, NameNode, VisitorAction, SKIP from . import SDLValidationContext, SDLValidationRule __all__ = ["UniqueDirectiveNamesRule"] @@ -18,7 +18,9 @@ def __init__(self, context: SDLValidationContext): self.known_directive_names: Dict[str, NameNode] = {} self.schema = context.schema - def enter_directive_definition(self, node: DirectiveDefinitionNode, *_args): + def enter_directive_definition( + self, node: DirectiveDefinitionNode, *_args + ) -> VisitorAction: directive_name = node.name.value if self.schema and self.schema.get_directive(directive_name): @@ -39,4 +41,6 @@ def enter_directive_definition(self, node: DirectiveDefinitionNode, *_args): ) else: self.known_directive_names[directive_name] = node.name - return self.SKIP + return SKIP + + return None diff --git a/src/graphql/validation/rules/unique_directives_per_location.py b/src/graphql/validation/rules/unique_directives_per_location.py index 2b84a69b..83ce483b 100644 --- a/src/graphql/validation/rules/unique_directives_per_location.py +++ b/src/graphql/validation/rules/unique_directives_per_location.py @@ -52,7 +52,7 @@ def __init__(self, context: Union[ValidationContext, SDLValidationContext]): # Many different AST nodes may contain directives. Rather than listing them all, # just listen for entering any node, and check to see if it defines any directives. - def enter(self, node: Node, *_args): + def enter(self, node: Node, *_args) -> None: directives: List[DirectiveNode] = getattr(node, "directives", None) if not directives: return diff --git a/src/graphql/validation/rules/unique_enum_value_names.py b/src/graphql/validation/rules/unique_enum_value_names.py index 2c12bfc9..90071301 100644 --- a/src/graphql/validation/rules/unique_enum_value_names.py +++ b/src/graphql/validation/rules/unique_enum_value_names.py @@ -2,7 +2,7 @@ from typing import cast, Dict from ...error import GraphQLError -from ...language import NameNode, EnumTypeDefinitionNode +from ...language import NameNode, EnumTypeDefinitionNode, VisitorAction, SKIP from ...type import is_enum_type, GraphQLEnumType from . import SDLValidationContext, SDLValidationRule @@ -21,7 +21,9 @@ def __init__(self, context: SDLValidationContext): self.existing_type_map = schema.type_map if schema else {} self.known_value_names: Dict[str, Dict[str, NameNode]] = defaultdict(dict) - def check_value_uniqueness(self, node: EnumTypeDefinitionNode, *_args): + def check_value_uniqueness( + self, node: EnumTypeDefinitionNode, *_args + ) -> VisitorAction: existing_type_map = self.existing_type_map type_name = node.name.value value_names = self.known_value_names[type_name] @@ -53,7 +55,7 @@ def check_value_uniqueness(self, node: EnumTypeDefinitionNode, *_args): else: value_names[value_name] = value_def.name - return self.SKIP + return SKIP enter_enum_type_definition = check_value_uniqueness enter_enum_type_extension = check_value_uniqueness diff --git a/src/graphql/validation/rules/unique_field_definition_names.py b/src/graphql/validation/rules/unique_field_definition_names.py index 1c212e8a..374a9889 100644 --- a/src/graphql/validation/rules/unique_field_definition_names.py +++ b/src/graphql/validation/rules/unique_field_definition_names.py @@ -2,7 +2,7 @@ from typing import Any, Dict from ...error import GraphQLError -from ...language import NameNode, ObjectTypeDefinitionNode +from ...language import NameNode, ObjectTypeDefinitionNode, VisitorAction, SKIP from ...type import is_object_type, is_interface_type, is_input_object_type from . import SDLValidationContext, SDLValidationRule @@ -21,7 +21,9 @@ def __init__(self, context: SDLValidationContext): self.existing_type_map = schema.type_map if schema else {} self.known_field_names: Dict[str, Dict[str, NameNode]] = defaultdict(dict) - def check_field_uniqueness(self, node: ObjectTypeDefinitionNode, *_args): + def check_field_uniqueness( + self, node: ObjectTypeDefinitionNode, *_args + ) -> VisitorAction: existing_type_map = self.existing_type_map type_name = node.name.value field_names = self.known_field_names[type_name] @@ -49,7 +51,7 @@ def check_field_uniqueness(self, node: ObjectTypeDefinitionNode, *_args): else: field_names[field_name] = field_def.name - return self.SKIP + return SKIP enter_input_object_type_definition = check_field_uniqueness enter_input_object_type_extension = check_field_uniqueness diff --git a/src/graphql/validation/rules/unique_fragment_names.py b/src/graphql/validation/rules/unique_fragment_names.py index 744b231b..0aaf40df 100644 --- a/src/graphql/validation/rules/unique_fragment_names.py +++ b/src/graphql/validation/rules/unique_fragment_names.py @@ -1,7 +1,7 @@ from typing import Dict from ...error import GraphQLError -from ...language import NameNode, FragmentDefinitionNode +from ...language import NameNode, FragmentDefinitionNode, VisitorAction, SKIP from . import ASTValidationContext, ASTValidationRule __all__ = ["UniqueFragmentNamesRule"] @@ -17,10 +17,12 @@ def __init__(self, context: ASTValidationContext): super().__init__(context) self.known_fragment_names: Dict[str, NameNode] = {} - def enter_operation_definition(self, *_args): - return self.SKIP + def enter_operation_definition(self, *_args) -> VisitorAction: + return SKIP - def enter_fragment_definition(self, node: FragmentDefinitionNode, *_args): + def enter_fragment_definition( + self, node: FragmentDefinitionNode, *_args + ) -> VisitorAction: known_fragment_names = self.known_fragment_names fragment_name = node.name.value if fragment_name in known_fragment_names: @@ -32,4 +34,4 @@ def enter_fragment_definition(self, node: FragmentDefinitionNode, *_args): ) else: known_fragment_names[fragment_name] = node.name - return self.SKIP + return SKIP diff --git a/src/graphql/validation/rules/unique_input_field_names.py b/src/graphql/validation/rules/unique_input_field_names.py index 1241e663..2f297680 100644 --- a/src/graphql/validation/rules/unique_input_field_names.py +++ b/src/graphql/validation/rules/unique_input_field_names.py @@ -19,14 +19,14 @@ def __init__(self, context: ASTValidationContext): self.known_names_stack: List[Dict[str, NameNode]] = [] self.known_names: Dict[str, NameNode] = {} - def enter_object_value(self, *_args): + def enter_object_value(self, *_args) -> None: self.known_names_stack.append(self.known_names) self.known_names = {} - def leave_object_value(self, *_args): + def leave_object_value(self, *_args) -> None: self.known_names = self.known_names_stack.pop() - def enter_object_field(self, node: ObjectFieldNode, *_args): + def enter_object_field(self, node: ObjectFieldNode, *_args) -> None: known_names = self.known_names field_name = node.name.value if field_name in known_names: diff --git a/src/graphql/validation/rules/unique_operation_names.py b/src/graphql/validation/rules/unique_operation_names.py index defcce05..e9e13444 100644 --- a/src/graphql/validation/rules/unique_operation_names.py +++ b/src/graphql/validation/rules/unique_operation_names.py @@ -1,7 +1,7 @@ from typing import Dict from ...error import GraphQLError -from ...language import NameNode, OperationDefinitionNode +from ...language import NameNode, OperationDefinitionNode, VisitorAction, SKIP from . import ASTValidationContext, ASTValidationRule __all__ = ["UniqueOperationNamesRule"] @@ -17,7 +17,9 @@ def __init__(self, context: ASTValidationContext): super().__init__(context) self.known_operation_names: Dict[str, NameNode] = {} - def enter_operation_definition(self, node: OperationDefinitionNode, *_args): + def enter_operation_definition( + self, node: OperationDefinitionNode, *_args + ) -> VisitorAction: operation_name = node.name if operation_name: known_operation_names = self.known_operation_names @@ -31,7 +33,7 @@ def enter_operation_definition(self, node: OperationDefinitionNode, *_args): ) else: known_operation_names[operation_name.value] = operation_name - return self.SKIP + return SKIP - def enter_fragment_definition(self, *_args): - return self.SKIP + def enter_fragment_definition(self, *_args) -> VisitorAction: + return SKIP diff --git a/src/graphql/validation/rules/unique_operation_types.py b/src/graphql/validation/rules/unique_operation_types.py index f7015b5c..d3179548 100644 --- a/src/graphql/validation/rules/unique_operation_types.py +++ b/src/graphql/validation/rules/unique_operation_types.py @@ -6,6 +6,8 @@ OperationType, SchemaDefinitionNode, SchemaExtensionNode, + VisitorAction, + SKIP, ) from ...type import GraphQLObjectType from . import SDLValidationContext, SDLValidationRule @@ -40,7 +42,7 @@ def __init__(self, context: SDLValidationContext): def check_operation_types( self, node: Union[SchemaDefinitionNode, SchemaExtensionNode], *_args - ): + ) -> VisitorAction: for operation_type in node.operation_types or []: operation = operation_type.operation already_defined_operation_type = self.defined_operation_types.get(operation) @@ -62,6 +64,6 @@ def check_operation_types( ) else: self.defined_operation_types[operation] = operation_type - return self.SKIP + return SKIP enter_schema_definition = enter_schema_extension = check_operation_types diff --git a/src/graphql/validation/rules/unique_type_names.py b/src/graphql/validation/rules/unique_type_names.py index bdba66ba..275deff3 100644 --- a/src/graphql/validation/rules/unique_type_names.py +++ b/src/graphql/validation/rules/unique_type_names.py @@ -1,7 +1,7 @@ from typing import Dict from ...error import GraphQLError -from ...language import NameNode, TypeDefinitionNode +from ...language import NameNode, TypeDefinitionNode, VisitorAction, SKIP from . import SDLValidationContext, SDLValidationRule __all__ = ["UniqueTypeNamesRule"] @@ -18,7 +18,7 @@ def __init__(self, context: SDLValidationContext): self.known_type_names: Dict[str, NameNode] = {} self.schema = context.schema - def check_type_name(self, node: TypeDefinitionNode, *_args): + def check_type_name(self, node: TypeDefinitionNode, *_args) -> VisitorAction: type_name = node.name.value if self.schema and self.schema.get_type(type_name): @@ -39,7 +39,9 @@ def check_type_name(self, node: TypeDefinitionNode, *_args): ) else: self.known_type_names[type_name] = node.name - return self.SKIP + return SKIP + + return None enter_scalar_type_definition = enter_object_type_definition = check_type_name enter_interface_type_definition = enter_union_type_definition = check_type_name diff --git a/src/graphql/validation/rules/unique_variable_names.py b/src/graphql/validation/rules/unique_variable_names.py index 3576d90f..a472f660 100644 --- a/src/graphql/validation/rules/unique_variable_names.py +++ b/src/graphql/validation/rules/unique_variable_names.py @@ -17,10 +17,10 @@ def __init__(self, context: ASTValidationContext): super().__init__(context) self.known_variable_names: Dict[str, NameNode] = {} - def enter_operation_definition(self, *_args): + def enter_operation_definition(self, *_args) -> None: self.known_variable_names.clear() - def enter_variable_definition(self, node: VariableDefinitionNode, *_args): + def enter_variable_definition(self, node: VariableDefinitionNode, *_args) -> None: known_variable_names = self.known_variable_names variable_name = node.variable.name.value if variable_name in known_variable_names: diff --git a/src/graphql/validation/rules/values_of_correct_type.py b/src/graphql/validation/rules/values_of_correct_type.py index c73479c6..1cbd1470 100644 --- a/src/graphql/validation/rules/values_of_correct_type.py +++ b/src/graphql/validation/rules/values_of_correct_type.py @@ -12,6 +12,8 @@ ObjectValueNode, StringValueNode, ValueNode, + VisitorAction, + SKIP, print_ast, ) from ...pyutils import did_you_mean, suggestion_list, Undefined @@ -37,19 +39,20 @@ class ValuesOfCorrectTypeRule(ValidationRule): their position. """ - def enter_list_value(self, node: ListValueNode, *_args): + def enter_list_value(self, node: ListValueNode, *_args) -> VisitorAction: # Note: TypeInfo will traverse into a list's item type, so look to the parent # input type to check if it is a list. type_ = get_nullable_type(self.context.get_parent_input_type()) if not is_list_type(type_): self.is_valid_value_node(node) - return self.SKIP # Don't traverse further. + return SKIP # Don't traverse further. + return None - def enter_object_value(self, node: ObjectValueNode, *_args): + def enter_object_value(self, node: ObjectValueNode, *_args) -> VisitorAction: type_ = get_named_type(self.context.get_input_type()) if not is_input_object_type(type_): self.is_valid_value_node(node) - return self.SKIP # Don't traverse further. + return SKIP # Don't traverse further. # Ensure every required field exists. field_node_map = {field.name.value: field for field in node.fields} for field_name, field_def in type_.fields.items(): @@ -63,8 +66,9 @@ def enter_object_value(self, node: ObjectValueNode, *_args): node, ) ) + return None - def enter_object_field(self, node: ObjectFieldNode, *_args): + def enter_object_field(self, node: ObjectFieldNode, *_args) -> None: parent_type = get_named_type(self.context.get_parent_input_type()) field_type = self.context.get_input_type() if not field_type and is_input_object_type(parent_type): @@ -78,7 +82,7 @@ def enter_object_field(self, node: ObjectFieldNode, *_args): ) ) - def enter_null_value(self, node: NullValueNode, *_args): + def enter_null_value(self, node: NullValueNode, *_args) -> None: type_ = self.context.get_input_type() if is_non_null_type(type_): self.report_error( @@ -87,19 +91,19 @@ def enter_null_value(self, node: NullValueNode, *_args): ) ) - def enter_enum_value(self, node: EnumValueNode, *_args): + def enter_enum_value(self, node: EnumValueNode, *_args) -> None: self.is_valid_value_node(node) - def enter_int_value(self, node: IntValueNode, *_args): + def enter_int_value(self, node: IntValueNode, *_args) -> None: self.is_valid_value_node(node) - def enter_float_value(self, node: FloatValueNode, *_args): + def enter_float_value(self, node: FloatValueNode, *_args) -> None: self.is_valid_value_node(node) - def enter_string_value(self, node: StringValueNode, *_args): + def enter_string_value(self, node: StringValueNode, *_args) -> None: self.is_valid_value_node(node) - def enter_boolean_value(self, node: BooleanValueNode, *_args): + def enter_boolean_value(self, node: BooleanValueNode, *_args) -> None: self.is_valid_value_node(node) def is_valid_value_node(self, node: ValueNode) -> None: @@ -150,3 +154,5 @@ def is_valid_value_node(self, node: ValueNode) -> None: original_error=error, ) ) + + return diff --git a/src/graphql/validation/rules/variables_are_input_types.py b/src/graphql/validation/rules/variables_are_input_types.py index 04b0e5c7..f5931646 100644 --- a/src/graphql/validation/rules/variables_are_input_types.py +++ b/src/graphql/validation/rules/variables_are_input_types.py @@ -14,7 +14,7 @@ class VariablesAreInputTypesRule(ValidationRule): (scalar, enum, or input object). """ - def enter_variable_definition(self, node: VariableDefinitionNode, *_args): + def enter_variable_definition(self, node: VariableDefinitionNode, *_args) -> None: type_ = type_from_ast(self.context.schema, node.type) # If the variable type is not an input type, return an error. diff --git a/src/graphql/validation/rules/variables_in_allowed_position.py b/src/graphql/validation/rules/variables_in_allowed_position.py index ae0f9d4a..9f79b675 100644 --- a/src/graphql/validation/rules/variables_in_allowed_position.py +++ b/src/graphql/validation/rules/variables_in_allowed_position.py @@ -22,10 +22,12 @@ def __init__(self, context: ValidationContext): super().__init__(context) self.var_def_map: Dict[str, Any] = {} - def enter_operation_definition(self, *_args): + def enter_operation_definition(self, *_args) -> None: self.var_def_map.clear() - def leave_operation_definition(self, operation: OperationDefinitionNode, *_args): + def leave_operation_definition( + self, operation: OperationDefinitionNode, *_args + ) -> None: var_def_map = self.var_def_map usages = self.context.get_recursive_variable_usages(operation) @@ -53,7 +55,7 @@ def leave_operation_definition(self, operation: OperationDefinitionNode, *_args) ) ) - def enter_variable_definition(self, node: VariableDefinitionNode, *_args): + def enter_variable_definition(self, node: VariableDefinitionNode, *_args) -> None: self.var_def_map[node.variable.name.value] = node diff --git a/tests/language/test_visitor.py b/tests/language/test_visitor.py index 2d6696e4..2bb3fca3 100644 --- a/tests/language/test_visitor.py +++ b/tests/language/test_visitor.py @@ -2,7 +2,7 @@ from functools import partial from typing import cast, Dict, List, Optional, Tuple -from pytest import raises # type: ignore +from pytest import mark, raises # type: ignore from graphql.language import ( Node, @@ -199,7 +199,8 @@ def leave_operation_definition(self, *args): assert edited_ast == ast assert visited == ["enter", "leave"] - def allows_for_editing_on_enter(): + @mark.parametrize("remove_action", (REMOVE, Ellipsis), ids=("REMOVE", "Ellipsis")) + def allows_for_editing_on_enter(remove_action): ast = parse("{ a, b, c { a, b, c } }", no_location=True) # noinspection PyMethodMayBeStatic @@ -208,13 +209,14 @@ def enter(self, *args): check_visitor_fn_args(ast, *args) node = args[0] if isinstance(node, FieldNode) and node.name.value == "b": - return REMOVE + return remove_action edited_ast = visit(ast, TestVisitor()) assert ast == parse("{ a, b, c { a, b, c } }", no_location=True) assert edited_ast == parse("{ a, c { a, c } }", no_location=True) - def allows_for_editing_on_leave(): + @mark.parametrize("remove_action", (REMOVE, Ellipsis), ids=("REMOVE", "Ellipsis")) + def allows_for_editing_on_leave(remove_action): ast = parse("{ a, b, c { a, b, c } }", no_location=True) # noinspection PyMethodMayBeStatic @@ -223,21 +225,20 @@ def leave(self, *args): check_visitor_fn_args_edited(ast, *args) node = args[0] if isinstance(node, FieldNode) and node.name.value == "b": - return REMOVE + return remove_action edited_ast = visit(ast, TestVisitor()) assert ast == parse("{ a, b, c { a, b, c } }", no_location=True) assert edited_ast == parse("{ a, c { a, c } }", no_location=True) - def ignores_false_returned_on_leave(): + @mark.parametrize("skip_action", (SKIP, False), ids=("SKIP", "False")) + def ignores_false_returned_on_leave(skip_action): ast = parse("{ a, b, c { a, b, c } }", no_location=True) - assert SKIP is False - # noinspection PyMethodMayBeStatic class TestVisitor(Visitor): def leave(self, *args): - return SKIP + return skip_action returned_ast = visit(ast, TestVisitor()) assert returned_ast == parse("{ a, b, c { a, b, c } }", no_location=True) @@ -266,7 +267,8 @@ def enter(self, *args): visit(ast, visitor) assert visitor.did_visit_added_field - def allows_skipping_a_sub_tree(): + @mark.parametrize("skip_action", (SKIP, False), ids=("SKIP", "False")) + def allows_skipping_a_sub_tree(skip_action): ast = parse("{ a, b { x }, c }", no_location=True) visited = [] @@ -278,7 +280,7 @@ def enter(self, *args): kind, value = node.kind, get_value(node) visited.append(["enter", kind, value]) if kind == "field" and node.name.value == "b": - return SKIP + return skip_action def leave(self, *args): check_visitor_fn_args(ast, *args) @@ -305,7 +307,8 @@ def leave(self, *args): ["leave", "document", None], ] - def allows_early_exit_while_visiting(): + @mark.parametrize("break_action", (BREAK, True), ids=("BREAK", "True")) + def allows_early_exit_while_visiting(break_action): ast = parse("{ a, b { x }, c }", no_location=True) visited = [] @@ -317,7 +320,7 @@ def enter(self, *args): kind, value = node.kind, get_value(node) visited.append(["enter", kind, value]) if kind == "name" and node.value == "x": - return BREAK + return break_action def leave(self, *args): check_visitor_fn_args(ast, *args) @@ -342,7 +345,8 @@ def leave(self, *args): ["enter", "name", "x"], ] - def allows_early_exit_while_leaving(): + @mark.parametrize("break_action", (BREAK, True), ids=("BREAK", "True")) + def allows_early_exit_while_leaving(break_action): ast = parse("{ a, b { x }, c }", no_location=True) visited = [] @@ -360,7 +364,7 @@ def leave(self, *args): kind, value = node.kind, get_value(node) visited.append(["leave", kind, value]) if kind == "name" and node.value == "x": - return BREAK + return break_action visit(ast, TestVisitor()) assert visited == [ @@ -953,7 +957,8 @@ def enter_type_system_extension(self, *_args): def describe_visit_in_parallel(): - def allows_skipping_a_sub_tree(): + @mark.parametrize("skip_action", (SKIP, False), ids=("SKIP", "False")) + def allows_skipping_a_sub_tree(skip_action): # Note: nearly identical to the above test but using ParallelVisitor ast = parse("{ a, b { x }, c }") visited = [] @@ -966,7 +971,7 @@ def enter(self, *args): kind, value = node.kind, get_value(node) visited.append(["enter", kind, value]) if kind == "field" and node.name.value == "b": - return SKIP + return skip_action def leave(self, *args): check_visitor_fn_args(ast, *args) @@ -993,7 +998,8 @@ def leave(self, *args): ["leave", "document", None], ] - def allows_skipping_different_sub_trees(): + @mark.parametrize("skip_action", (SKIP, False), ids=("SKIP", "False")) + def allows_skipping_different_sub_trees(skip_action): ast = parse("{ a { x }, b { y} }") visited = [] @@ -1008,7 +1014,7 @@ def enter(self, *args): name = self.name visited.append([f"no-{name}", "enter", kind, value]) if kind == "field" and node.name.value == name: - return SKIP + return skip_action def leave(self, *args): check_visitor_fn_args(ast, *args) @@ -1055,7 +1061,8 @@ def leave(self, *args): ["no-b", "leave", "document", None], ] - def allows_early_exit_while_visiting(): + @mark.parametrize("break_action", (BREAK, True), ids=("BREAK", "True")) + def allows_early_exit_while_visiting(break_action): # Note: nearly identical to the above test but using ParallelVisitor. ast = parse("{ a, b { x }, c }") visited = [] @@ -1068,7 +1075,7 @@ def enter(self, *args): kind, value = node.kind, get_value(node) visited.append(["enter", kind, value]) if kind == "name" and node.value == "x": - return BREAK + return break_action def leave(self, *args): check_visitor_fn_args(ast, *args) @@ -1093,7 +1100,8 @@ def leave(self, *args): ["enter", "name", "x"], ] - def allows_early_exit_from_different_points(): + @mark.parametrize("break_action", (BREAK, True), ids=("BREAK", "True")) + def allows_early_exit_from_different_points(break_action): ast = parse("{ a { y }, b { x } }") visited = [] @@ -1108,7 +1116,7 @@ def enter(self, *args): name = self.name visited.append([f"break-{name}", "enter", kind, value]) if kind == "name" and node.value == name: - return BREAK + return break_action def leave(self, *args): assert self.name == "b" @@ -1142,7 +1150,8 @@ def leave(self, *args): ["break-b", "enter", "name", "b"], ] - def allows_early_exit_while_leaving(): + @mark.parametrize("break_action", (BREAK, True), ids=("BREAK", "True")) + def allows_early_exit_while_leaving(break_action): # Note: nearly identical to the above test but using ParallelVisitor. ast = parse("{ a, b { x }, c }") visited = [] @@ -1161,7 +1170,7 @@ def leave(self, *args): kind, value = node.kind, get_value(node) visited.append(["leave", kind, value]) if kind == "name" and node.value == "x": - return BREAK + return break_action visit(ast, ParallelVisitor([TestVisitor()])) assert visited == [ @@ -1181,7 +1190,8 @@ def leave(self, *args): ["leave", "name", "x"], ] - def allows_early_exit_from_leaving_different_points(): + @mark.parametrize("break_action", (BREAK, True), ids=("BREAK", "True")) + def allows_early_exit_from_leaving_different_points(break_action): ast = parse("{ a { y }, b { x } }") visited = [] @@ -1203,7 +1213,7 @@ def leave(self, *args): name = self.name visited.append([f"break-{name}", "leave", kind, value]) if kind == "field" and node.name.value == name: - return BREAK + return break_action visit(ast, ParallelVisitor([TestVisitor("a"), TestVisitor("b")])) assert visited == [ @@ -1245,7 +1255,8 @@ def leave(self, *args): ["break-b", "leave", "field", None], ] - def allows_for_editing_on_enter(): + @mark.parametrize("remove_action", (REMOVE, Ellipsis), ids=("REMOVE", "Ellipsis")) + def allows_for_editing_on_enter(remove_action): ast = parse("{ a, b, c { a, b, c } }", no_location=True) visited = [] @@ -1255,7 +1266,7 @@ def enter(self, *args): check_visitor_fn_args(ast, *args) node = args[0] if node.kind == "field" and node.name.value == "b": - return REMOVE + return remove_action # noinspection PyMethodMayBeStatic class TestVisitor2(Visitor): @@ -1301,7 +1312,8 @@ def leave(self, *args): ["leave", "document", None], ] - def allows_for_editing_on_leave(): + @mark.parametrize("remove_action", (REMOVE, Ellipsis), ids=("REMOVE", "Ellipsis")) + def allows_for_editing_on_leave(remove_action): ast = parse("{ a, b, c { a, b, c } }", no_location=True) visited = [] @@ -1311,7 +1323,7 @@ def leave(self, *args): check_visitor_fn_args_edited(ast, *args) node = args[0] if node.kind == "field" and node.name.value == "b": - return REMOVE + return remove_action # noinspection PyMethodMayBeStatic class TestVisitor2(Visitor):