Skip to content

Commit

Permalink
Type-hint rename_schema.py (#847)
Browse files Browse the repository at this point in the history
* Type-hint existing function signatures

* Remove types from function comments

* Fix type-hint problem requiring name attribute

See
#834 (comment)

* lint

* rename ast to schema_ast

* arg/return value definitions to lowercase

* Remove leading articles for parameter descriptions

* Dynamically confirm type-hint and rename_types contain the same information

* rename type_hint_rename_types to RenameTypes

* lint

* compute rename_types dynamically

* lint

* Decide on duplication as cleanest solution

* Replace Dict with Mapping for renaming type-hint

* Tighten return type bound

* Bump linter version to prevent false positive error

* Revert "Bump linter version to prevent false positive error"

This reverts commit 7579db6.

* Make RenameTypes a module-level attribute instead of a class attribute

* Add Generic type

* lowercase first word of return description

Co-authored-by: Predrag Gruevski <[email protected]>

* Switch Set to AbstractSet type hint

* remove newline

* Add newline after multiple de-indent

Co-authored-by: Predrag Gruevski <[email protected]>

Co-authored-by: Predrag Gruevski <[email protected]>
  • Loading branch information
LWprogramming and obi1kenobi authored Jun 22, 2020
1 parent 1692b30 commit 3cbc9d2
Showing 1 changed file with 130 additions and 60 deletions.
190 changes: 130 additions & 60 deletions graphql_compiler/schema_transformation/rename_schema.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
# Copyright 2019-present Kensho Technologies, LLC.
from collections import namedtuple

from graphql import build_ast_schema
from typing import AbstractSet, Any, Dict, List, Mapping, Optional, Tuple, TypeVar, Union, cast

from graphql import (
DocumentNode,
EnumTypeDefinitionNode,
FieldDefinitionNode,
InterfaceTypeDefinitionNode,
NamedTypeNode,
Node,
ObjectTypeDefinitionNode,
UnionTypeDefinitionNode,
build_ast_schema,
)
from graphql.language.visitor import Visitor, visit
import six

Expand All @@ -26,7 +37,27 @@
)


def rename_schema(ast, renamings):
# Union of classes of nodes to be renamed by an instance of RenameSchemaTypesVisitor. Note that
# RenameSchemaTypesVisitor also has a class attribute rename_types which parallels the classes here.
# This duplication is necessary due to language and linter constraints-- see the comment in the
# RenameSchemaTypesVisitor class for more information.
# Unfortunately, RenameTypes itself has to be a module attribute instead of a class attribute
# because a bug in flake8 produces a linting error if RenameTypes is a class attribute and we type
# hint the return value of the RenameSchemaTypesVisitor's _rename_name_and_add_to_record() method as
# RenameTypes. More on this here: https://github.com/PyCQA/pyflakes/issues/441
RenameTypes = Union[
EnumTypeDefinitionNode,
InterfaceTypeDefinitionNode,
NamedTypeNode,
ObjectTypeDefinitionNode,
UnionTypeDefinitionNode,
]
RenameTypesT = TypeVar("RenameTypesT", bound=RenameTypes)


def rename_schema(
schema_ast: DocumentNode, renamings: Mapping[str, str]
) -> RenamedSchemaDescriptor:
"""Create a RenamedSchemaDescriptor; types and query type fields are renamed using renamings.
Any type, interface, enum, or fields of the root type/query type whose name
Expand All @@ -35,12 +66,12 @@ def rename_schema(ast, renamings):
belonging to the root/query type will never be renamed.
Args:
ast: Document, representing a valid schema that does not contain extensions, input
object definitions, mutations, or subscriptions, whose fields of the query type share
the same name as the types they query. Not modified by this function
renamings: Dict[str, str], mapping original type/field names to renamed type/field names.
Type or query type field names that do not appear in the dict will be unchanged.
Any dict-like object that implements get(key, [default]) may also be used
schema_ast: represents a valid schema that does not contain extensions, input object
definitions, mutations, or subscriptions, whose fields of the query type share
the same name as the types they query. Not modified by this function
renamings: maps original type/field names to renamed type/field names. Type or query type
field names that do not appear in the dict will be unchanged. Any dict-like
object that implements get(key, [default]) may also be used
Returns:
RenamedSchemaDescriptor, a namedtuple that contains the AST of the renamed schema, and the
Expand All @@ -59,77 +90,82 @@ def rename_schema(ast, renamings):
- SchemaNameConflictError if there are conflicts between the renamed types or fields
"""
# Check input schema satisfies various structural requirements
check_ast_schema_is_valid(ast)
check_ast_schema_is_valid(schema_ast)

schema = build_ast_schema(ast)
schema = build_ast_schema(schema_ast)
query_type = get_query_type_name(schema)
scalars = get_scalar_names(schema)

# Rename types, interfaces, enums
ast, reverse_name_map = _rename_types(ast, renamings, query_type, scalars)
schema_ast, reverse_name_map = _rename_types(schema_ast, renamings, query_type, scalars)
reverse_name_map_changed_names_only = {
renamed_name: original_name
for renamed_name, original_name in six.iteritems(reverse_name_map)
if renamed_name != original_name
}

# Rename query type fields
ast = _rename_query_type_fields(ast, renamings, query_type)
schema_ast = _rename_query_type_fields(schema_ast, renamings, query_type)
return RenamedSchemaDescriptor(
schema_ast=ast,
schema=build_ast_schema(ast),
schema_ast=schema_ast,
schema=build_ast_schema(schema_ast),
reverse_name_map=reverse_name_map_changed_names_only,
)


def _rename_types(ast, renamings, query_type, scalars):
def _rename_types(
schema_ast: DocumentNode,
renamings: Mapping[str, str],
query_type: str,
scalars: AbstractSet[str],
) -> Tuple[DocumentNode, Dict[str, str]]:
"""Rename types, enums, interfaces using renamings.
The query type will not be renamed. Scalar types, field names, enum values will not be renamed.
The input AST will not be modified.
The input schema AST will not be modified.
Args:
ast: Document, the schema that we're returning a modified version of
renamings: Dict[str, str], mapping original type/interface/enum name to renamed name. If
a name does not appear in the dict, it will be unchanged
query_type: str, name of the query type, e.g. 'RootSchemaQuery'
scalars: Set[str], the set of all scalars used in the schema, including user defined
scalars and and used builtin scalars, excluding unused builtins
schema_ast: schema that we're returning a modified version of
renamings: maps original type/interface/enum name to renamed name. Any name not in the dict
will be unchanged
query_type: name of the query type, e.g. 'RootSchemaQuery'
scalars: set of all scalars used in the schema, including user defined scalars and used
builtin scalars, excluding unused builtins
Returns:
Tuple[Document, Dict[str, str]], containing the modified version of the AST, and
the renamed type name to original type name map. Map contains all types, including
those that were not renamed.
Tuple containing the modified version of the schema AST, and the renamed type name to
original type name map. Map contains all types, including those that were not renamed.
Raises:
- InvalidTypeNameError if the schema contains an invalid type name, or if the user attempts
to rename a type to an invalid name
- SchemaNameConflictError if the rename causes name conflicts
"""
visitor = RenameSchemaTypesVisitor(renamings, query_type, scalars)
renamed_ast = visit(ast, visitor)
renamed_schema_ast = visit(schema_ast, visitor)
return renamed_schema_ast, visitor.reverse_name_map

return renamed_ast, visitor.reverse_name_map


def _rename_query_type_fields(ast, renamings, query_type):
def _rename_query_type_fields(
schema_ast: DocumentNode, renamings: Mapping[str, str], query_type: str
) -> DocumentNode:
"""Rename all fields of the query type.
The input AST will not be modified.
The input schema AST will not be modified.
Args:
ast: DocumentNode, the schema that we're returning a modified version of
renamings: Dict[str, str], mapping original field name to renamed name. If a name
does not appear in the dict, it will be unchanged
query_type: string, name of the query type, e.g. 'RootSchemaQuery'
schema_ast: schema that we're returning a modified version of
renamings: maps original query type field name to renamed name. Any name not in the dict
will be unchanged
query_type: name of the query type, e.g. 'RootSchemaQuery'
Returns:
DocumentNode, representing the modified version of the input schema AST
modified version of the input schema AST
"""
visitor = RenameQueryTypeFieldsVisitor(renamings, query_type)
renamed_ast = visit(ast, visitor)
return renamed_ast
renamed_schema_ast = visit(schema_ast, visitor)
return renamed_schema_ast


class RenameSchemaTypesVisitor(Visitor):
Expand Down Expand Up @@ -176,6 +212,16 @@ class RenameSchemaTypesVisitor(Visitor):
"ScalarTypeExtensionNode",
}
)
# rename_types must be a set of strings corresponding to the names of the classes in
# RenameTypes. The duplication exists because introspection for Unions via typing.get_args()
# doesn't exist until Python 3.8. In Python 3.8, this would be a valid way to define
# rename_types:
# rename_types = frozenset(cls.__name__ for cls in get_args(RenameTypes)) # type: ignore
# Note: even with Python 3.8, the mypy version at the time of writing (version 0.770) doesn't
# allow for introspection for Unions. mypy's maintainers recently merged a PR
# (https://github.com/python/mypy/pull/8779) that permits this line of code, but did so after
# the mypy 0.770 release. If we do end up removing the duplication at a later point but not
# update the mypy version, we'd need to ignore it (as shown in the in-line comment).
rename_types = frozenset(
{
"EnumTypeDefinitionNode",
Expand All @@ -186,36 +232,35 @@ class RenameSchemaTypesVisitor(Visitor):
}
)

def __init__(self, renamings, query_type, scalar_types):
def __init__(
self, renamings: Mapping[str, str], query_type: str, scalar_types: AbstractSet[str]
) -> None:
"""Create a visitor for renaming types in a schema AST.
Args:
renamings: Dict[str, str], mapping from original type name to renamed type name.
Any name not in the dict will be unchanged
query_type: str, name of the query type (e.g. RootSchemaQuery), which will not
be renamed
scalar_types: Set[str], set of names of all scalars used in the schema, including
all user defined scalars and any builtin scalars that were used
renamings: maps original type name to renamed name. Any name not in the dict will be
unchanged
query_type: name of the query type (e.g. RootSchemaQuery), which will not be renamed
scalar_types: set of all scalars used in the schema, including all user defined scalars
and any builtin scalars that were used
"""
self.renamings = renamings
self.reverse_name_map = {} # Dict[str, str], from renamed type name to original type name
self.reverse_name_map: Dict[str, str] = {} # from renamed type name to original type name
# reverse_name_map contains all types, including those that were unchanged
self.query_type = query_type
self.scalar_types = frozenset(scalar_types)
self.builtin_types = frozenset({"String", "Int", "Float", "Boolean", "ID"})

def _rename_name_and_add_to_record(self, node):
def _rename_name_and_add_to_record(self, node: RenameTypesT) -> RenameTypesT:
"""Change the name of the input node if necessary, add the name pair to reverse_name_map.
Don't rename if the type is the query type, a scalar type, or a builtin type.
The input node will not be modified. reverse_name_map may be modified.
Args:
node: EnumTypeDefinitionNode, InterfaceTypeDefinitionNode, NamedTypeNode,
ObjectTypeDefinitionNode, or UnionTypeDefinitionNode. An object representing an
AST component, containing a .name attribute corresponding to an AST node of type
NameNode.
node: object representing an AST component, containing a .name attribute
corresponding to an AST node of type NameNode.
Returns:
Node object, identical to the input node, except with possibly a new name. If the
Expand Down Expand Up @@ -257,15 +302,17 @@ def _rename_name_and_add_to_record(self, node):
node_with_new_name = get_copy_of_node_with_new_name(node, new_name_string)
return node_with_new_name

def enter(self, node, key, parent, path, ancestors):
def enter(
self, node: Node, key: Any, parent: Any, path: List[Any], ancestors: List[Any],
) -> Optional[Node]:
"""Upon entering a node, operate depending on node type."""
node_type = type(node).__name__
if node_type in self.noop_types:
# Do nothing, continue traversal
return None
elif node_type in self.rename_types:
# Rename node, put name pair into record
renamed_node = self._rename_name_and_add_to_record(node)
renamed_node = self._rename_name_and_add_to_record(cast(RenameTypes, node))
if renamed_node is node: # Name unchanged, continue traversal
return None
else: # Name changed, return new node, `visit` will make shallow copies along path
Expand All @@ -276,13 +323,13 @@ def enter(self, node, key, parent, path, ancestors):


class RenameQueryTypeFieldsVisitor(Visitor):
def __init__(self, renamings, query_type):
def __init__(self, renamings: Mapping[str, str], query_type: str) -> None:
"""Create a visitor for renaming fields of the query type in a schema AST.
Args:
renamings: Dict[str, str], from original field name to renamed field name. Any
name not in the dict will be unchanged
query_type: str, name of the query type (e.g. RootSchemaQuery)
renamings: maps original field name to renamed field name. Any name not in the dict will
be unchanged
query_type: name of the query type (e.g. RootSchemaQuery)
"""
# Note that as field names and type names have been confirmed to match up, any renamed
# field already has a corresponding renamed type. If no errors, due to either invalid
Expand All @@ -292,17 +339,38 @@ def __init__(self, renamings, query_type):
self.renamings = renamings
self.query_type = query_type

def enter_object_type_definition(self, node, *args):
def enter_object_type_definition(
self,
node: ObjectTypeDefinitionNode,
key: Any,
parent: Any,
path: List[Any],
ancestors: List[Any],
) -> None:
"""If the node's name matches the query type, record that we entered the query type."""
if node.name.value == self.query_type:
self.in_query_type = True

def leave_object_type_definition(self, node, key, parent, path, ancestors):
def leave_object_type_definition(
self,
node: ObjectTypeDefinitionNode,
key: Any,
parent: Any,
path: List[Any],
ancestors: List[Any],
) -> None:
"""If the node's name matches the query type, record that we left the query type."""
if node.name.value == self.query_type:
self.in_query_type = False

def enter_field_definition(self, node, *args):
def enter_field_definition(
self,
node: FieldDefinitionNode,
key: Any,
parent: Any,
path: List[Any],
ancestors: List[Any],
) -> Optional[Node]:
"""If inside the query type, rename field and add the name pair to reverse_field_map."""
if self.in_query_type:
field_name = node.name.value
Expand All @@ -312,3 +380,5 @@ def enter_field_definition(self, node, *args):
else: # Make copy of node with the changed name, return the copy
field_node_with_new_name = get_copy_of_node_with_new_name(node, new_field_name)
return field_node_with_new_name

return None

0 comments on commit 3cbc9d2

Please sign in to comment.