Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Pydantic model definition for valid OpenAPI schema #1308

Merged
merged 10 commits into from
Feb 12, 2025
1 change: 0 additions & 1 deletion datajunction-server/datajunction_server/alembic/env.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""
Environment for Alembic migrations.
"""
# pylint: disable=no-member, unused-import, no-name-in-module, import-error

import os
from logging.config import fileConfig
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ async def get_access_token(
request_session = requests.session()
token_request = google.auth.transport.requests.Request(session=request_session)
user_data = id_token.verify_oauth2_token(
id_token=credentials._id_token, # pylint: disable=protected-access
id_token=credentials._id_token,
request=token_request,
audience=setting.google_oauth_client_id,
)
Expand Down
1 change: 0 additions & 1 deletion datajunction-server/datajunction_server/api/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ async def default_attribute_types(session: AsyncSession = Depends(get_session)):

# Update existing default attribute types
statement = select(AttributeType).filter(
# pylint: disable=no-member
AttributeType.name.in_( # type: ignore
set(default_attribute_type_names.keys()),
),
Expand Down
9 changes: 4 additions & 5 deletions datajunction-server/datajunction_server/api/cubes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# pylint: disable=too-many-arguments
"""
Cube related APIs.
"""
Expand Down Expand Up @@ -164,7 +163,7 @@ async def get_cube_dimension_sql(
include_counts: bool = False,
session: AsyncSession = Depends(get_session),
current_user: User = Depends(get_and_update_current_user),
validate_access: access.ValidateAccessFn = Depends( # pylint: disable=redefined-outer-name
validate_access: access.ValidateAccessFn = Depends(
validate_access,
),
) -> TranslatedSQL:
Expand All @@ -189,7 +188,7 @@ async def get_cube_dimension_sql(
"/cubes/{name}/dimensions/data",
name="Dimensions Values for Cube",
)
async def get_cube_dimension_values( # pylint: disable=too-many-locals
async def get_cube_dimension_values(
name: str,
*,
dimensions: List[str] = Query([], description="Dimensions to get values for"),
Expand All @@ -207,7 +206,7 @@ async def get_cube_dimension_values( # pylint: disable=too-many-locals
request: Request,
query_service_client: QueryServiceClient = Depends(get_query_service_client),
current_user: User = Depends(get_and_update_current_user),
validate_access: access.ValidateAccessFn = Depends( # pylint: disable=redefined-outer-name
validate_access: access.ValidateAccessFn = Depends(
validate_access,
),
) -> DimensionValues:
Expand Down Expand Up @@ -260,7 +259,7 @@ async def get_cube_dimension_values( # pylint: disable=too-many-locals
return DimensionValues( # pragma: no cover
dimensions=[
from_amenable_name(col.name)
for col in translated_sql.columns # type: ignore # pylint: disable=not-an-iterable
for col in translated_sql.columns # type: ignore
if col.name != "count"
],
values=dimension_values,
Expand Down
19 changes: 9 additions & 10 deletions datajunction-server/datajunction_server/api/data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# pylint: disable=too-many-arguments
"""
Data related APIs.
"""
Expand Down Expand Up @@ -57,7 +56,7 @@ async def add_availability_state(
*,
session: AsyncSession = Depends(get_session),
current_user: User = Depends(get_and_update_current_user),
validate_access: access.ValidateAccessFn = Depends( # pylint: disable=W0621
validate_access: access.ValidateAccessFn = Depends(
validate_access,
),
notify: Callable = Depends(get_notifier),
Expand Down Expand Up @@ -164,7 +163,7 @@ async def add_availability_state(


@router.get("/data/{node_name}/", name="Get Data for a Node")
async def get_data( # pylint: disable=too-many-locals
async def get_data(
node_name: str,
*,
dimensions: List[str] = Query([], description="Dimensional attributes to group by"),
Expand All @@ -184,7 +183,7 @@ async def get_data( # pylint: disable=too-many-locals
engine_name: Optional[str] = None,
engine_version: Optional[str] = None,
current_user: User = Depends(get_and_update_current_user),
validate_access: access.ValidateAccessFn = Depends( # pylint: disable=W0621
validate_access: access.ValidateAccessFn = Depends(
validate_access,
),
background_tasks: BackgroundTasks,
Expand Down Expand Up @@ -243,7 +242,7 @@ async def get_data( # pylint: disable=too-many-locals


@router.get("/stream/{node_name}", response_model=QueryWithResults)
async def get_data_stream_for_node( # pylint: disable=R0914, R0913
async def get_data_stream_for_node(
node_name: str,
*,
dimensions: List[str] = Query([], description="Dimensional attributes to group by"),
Expand All @@ -259,7 +258,7 @@ async def get_data_stream_for_node( # pylint: disable=R0914, R0913
engine_name: Optional[str] = None,
engine_version: Optional[str] = None,
current_user: User = Depends(get_and_update_current_user),
validate_access: access.ValidateAccessFn = Depends( # pylint: disable=W0621
validate_access: access.ValidateAccessFn = Depends(
validate_access,
),
background_tasks: BackgroundTasks,
Expand Down Expand Up @@ -365,7 +364,7 @@ def get_data_for_query(


@router.get("/data/", response_model=QueryWithResults, name="Get Data For Metrics")
async def get_data_for_metrics( # pylint: disable=R0914, R0913
async def get_data_for_metrics(
metrics: List[str] = Query([]),
dimensions: List[str] = Query([]),
filters: List[str] = Query([]),
Expand All @@ -379,7 +378,7 @@ async def get_data_for_metrics( # pylint: disable=R0914, R0913
engine_name: Optional[str] = None,
engine_version: Optional[str] = None,
current_user: User = Depends(get_and_update_current_user),
validate_access: access.ValidateAccessFn = Depends( # pylint: disable=W0621
validate_access: access.ValidateAccessFn = Depends(
validate_access,
),
) -> QueryWithResults:
Expand Down Expand Up @@ -424,7 +423,7 @@ async def get_data_for_metrics( # pylint: disable=R0914, R0913


@router.get("/stream/", response_model=QueryWithResults)
async def get_data_stream_for_metrics( # pylint: disable=R0914, R0913
async def get_data_stream_for_metrics(
metrics: List[str] = Query([]),
dimensions: List[str] = Query([]),
filters: List[str] = Query([]),
Expand All @@ -437,7 +436,7 @@ async def get_data_stream_for_metrics( # pylint: disable=R0914, R0913
engine_name: Optional[str] = None,
engine_version: Optional[str] = None,
current_user: User = Depends(get_and_update_current_user),
validate_access: access.ValidateAccessFn = Depends( # pylint: disable=W0621
validate_access: access.ValidateAccessFn = Depends(
validate_access,
),
) -> QueryWithResults:
Expand Down
6 changes: 3 additions & 3 deletions datajunction-server/datajunction_server/api/dimensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ async def list_dimensions(
*,
session: AsyncSession = Depends(get_session),
current_user: User = Depends(get_and_update_current_user),
validate_access: access.ValidateAccessFn = Depends( # pylint: disable=W0621
validate_access: access.ValidateAccessFn = Depends(
validate_access,
),
) -> List[NodeIndegreeOutput]:
Expand Down Expand Up @@ -73,7 +73,7 @@ async def find_nodes_with_dimension(
node_type: List[NodeType] = Query([]),
session: AsyncSession = Depends(get_session),
current_user: User = Depends(get_and_update_current_user),
validate_access: access.ValidateAccessFn = Depends( # pylint: disable=W0621
validate_access: access.ValidateAccessFn = Depends(
validate_access,
),
) -> List[NodeRevisionOutput]:
Expand Down Expand Up @@ -106,7 +106,7 @@ async def find_nodes_with_common_dimensions(
*,
session: AsyncSession = Depends(get_session),
current_user: User = Depends(get_and_update_current_user),
validate_access: access.ValidateAccessFn = Depends( # pylint: disable=W0621
validate_access: access.ValidateAccessFn = Depends(
validate_access,
),
) -> List[NodeRevisionOutput]:
Expand Down
7 changes: 3 additions & 4 deletions datajunction-server/datajunction_server/api/djsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@


@router.get("/djsql/data", response_model=QueryWithResults)
async def get_data_for_djsql( # pylint: disable=R0914, R0913
async def get_data_for_djsql(
query: str,
async_: bool = False,
*,
Expand All @@ -37,7 +37,7 @@ async def get_data_for_djsql( # pylint: disable=R0914, R0913
engine_name: Optional[str] = None,
engine_version: Optional[str] = None,
current_user: User = Depends(get_and_update_current_user),
validate_access: access.ValidateAccessFn = Depends( # pylint: disable=W0621
validate_access: access.ValidateAccessFn = Depends(
validate_access,
),
) -> QueryWithResults:
Expand Down Expand Up @@ -77,7 +77,6 @@ async def get_data_for_djsql( # pylint: disable=R0914, R0913
return result


# pylint: disable=R0914, R0913
@router.get("/djsql/stream/", response_model=QueryWithResults)
async def get_data_stream_for_djsql(
query: str,
Expand All @@ -88,7 +87,7 @@ async def get_data_stream_for_djsql(
engine_name: Optional[str] = None,
engine_version: Optional[str] = None,
current_user: User = Depends(get_and_update_current_user),
validate_access: access.ValidateAccessFn = Depends( # pylint: disable=W0621
validate_access: access.ValidateAccessFn = Depends(
validate_access,
),
) -> QueryWithResults: # pragma: no cover
Expand Down
75 changes: 58 additions & 17 deletions datajunction-server/datajunction_server/api/graphql/main.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
"""DJ graphql"""

import logging
from functools import wraps

import strawberry
from fastapi import Depends
from strawberry.fastapi import GraphQLRouter
from strawberry.types import Info

from datajunction_server.api.graphql.queries.catalogs import list_catalogs
from datajunction_server.api.graphql.queries.dag import common_dimensions
Expand All @@ -20,6 +24,43 @@
from datajunction_server.api.graphql.scalars.tag import Tag
from datajunction_server.utils import get_session, get_settings

logger = logging.getLogger(__name__)


def log_resolver(func):
"""
Adds generic logging to the GQL resolver.
"""

@wraps(func)
async def wrapper(*args, **kwargs):
resolver_name = func.__name__

info: Info = kwargs.get("info") if "info" in kwargs else None
user = info.context.get("user", "anonymous") if info else "unknown"
args_dict = {key: val for key, val in kwargs.items() if key != "info"}
log_tags = {
"query_name": resolver_name,
"user": user,
**args_dict,
}
log_args = " ".join(
[f"{tag}={value}" for tag, value in log_tags.items() if value],
)
try:
result = await func(*args, **kwargs)
logger.info("[GQL] %s", log_args)
return result
except Exception as exc: # pragma: no cover
logger.error( # pragma: no cover
"[GQL] status=error %s",
log_args,
exc_info=True,
)
raise exc # pragma: no cover

return wrapper


async def get_context(
session=Depends(get_session),
Expand All @@ -32,47 +73,47 @@ async def get_context(


@strawberry.type
class Query: # pylint: disable=R0903
class Query:
"""
Parent of all DJ graphql queries
"""

# Catalog and engine queries
list_catalogs: list[Catalog] = strawberry.field( # noqa: F811
resolver=list_catalogs,
list_catalogs: list[Catalog] = strawberry.field(
resolver=log_resolver(list_catalogs),
)
list_engines: list[Engine] = strawberry.field( # noqa: F811
resolver=list_engines,
list_engines: list[Engine] = strawberry.field(
resolver=log_resolver(list_engines),
)

# Node search queries
find_nodes: list[Node] = strawberry.field( # noqa: F811
resolver=find_nodes,
find_nodes: list[Node] = strawberry.field(
resolver=log_resolver(find_nodes),
description="Find nodes based on the search parameters.",
)
find_nodes_paginated: Connection[Node] = strawberry.field( # noqa: F811
resolver=find_nodes_paginated,
find_nodes_paginated: Connection[Node] = strawberry.field(
resolver=log_resolver(find_nodes_paginated),
description="Find nodes based on the search parameters with pagination",
)

# DAG queries
common_dimensions: list[DimensionAttribute] = strawberry.field( # noqa: F811
resolver=common_dimensions,
common_dimensions: list[DimensionAttribute] = strawberry.field(
resolver=log_resolver(common_dimensions),
description="Get common dimensions for one or more nodes",
)

# Generate SQL queries
measures_sql: list[GeneratedSQL] = strawberry.field( # noqa: F811
resolver=measures_sql,
measures_sql: list[GeneratedSQL] = strawberry.field(
resolver=log_resolver(measures_sql),
)

# Tags queries
list_tags: list[Tag] = strawberry.field( # noqa: F811
resolver=list_tags,
list_tags: list[Tag] = strawberry.field(
resolver=log_resolver(list_tags),
description="Find DJ node tags based on the search parameters.",
)
list_tag_types: list[str] = strawberry.field( # noqa: F811
resolver=list_tag_types,
list_tag_types: list[str] = strawberry.field(
resolver=log_resolver(list_tag_types),
description="List all DJ node tag types",
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
from sqlalchemy import select
from strawberry.types import Info

from datajunction_server.api.graphql.scalars.catalog_engine import (
Catalog, # pylint: disable=W0611
)
from datajunction_server.api.graphql.scalars.catalog_engine import Catalog
from datajunction_server.database.catalog import Catalog as DBCatalog


Expand All @@ -22,6 +20,6 @@ async def list_catalogs(
"""
session = info.context["session"] # type: ignore
return [
Catalog.from_pydantic(catalog) # type: ignore #pylint: disable=E1101
Catalog.from_pydantic(catalog) # type: ignore
for catalog in (await session.execute(select(DBCatalog))).scalars().all()
]
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


@strawberry.input
class CubeDefinition: # pylint: disable=too-few-public-methods
class CubeDefinition:
"""
The cube definition for the query
"""
Expand Down Expand Up @@ -42,7 +42,7 @@ class CubeDefinition: # pylint: disable=too-few-public-methods


@strawberry.input
class EngineSettings: # pylint: disable=too-few-public-methods
class EngineSettings:
"""
The engine settings for the query
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ async def find_nodes_by(
limit,
before,
after,
*options,
options=options,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ async def get_nodes_by_tag(
Retrieves all nodes with the given tag. A list of fields must be requested on the node,
or this will not return any data.
"""
from datajunction_server.api.graphql.resolvers.nodes import ( # pylint: disable=import-outside-toplevel
load_node_options,
)
from datajunction_server.api.graphql.resolvers.nodes import load_node_options

options = load_node_options(
fields["nodes"]
Expand Down
Loading