Skip to content

Commit

Permalink
实现BaseModelSerializer,支持序列化字段过滤
Browse files Browse the repository at this point in the history
  • Loading branch information
hhyo committed Oct 22, 2022
1 parent 43315f5 commit a0ae8fc
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 51 deletions.
35 changes: 26 additions & 9 deletions sql_api/api_views/sql_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
import django_filters
from django.contrib.auth.decorators import permission_required
from django.utils.decorators import method_decorator
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import extend_schema, extend_schema_view, OpenApiResponse
from drf_spectacular.utils import extend_schema, extend_schema_view
from rest_framework import views, permissions, serializers, viewsets, filters
from rest_framework.decorators import action
from rest_framework.permissions import IsAuthenticated
Expand Down Expand Up @@ -64,7 +63,9 @@ def post(self, request):
summary="获取SQL工单列表",
description="获取SQL工单列表,支持筛选、分页、检索等",
request=SqlWorkflowSerializer,
responses={200: SqlWorkflowSerializer},
responses={
200: SqlWorkflowSerializer(exclude=["sql_content", "display_content"])
},
),
retrieve=extend_schema(
summary="获取SQL工单详情",
Expand All @@ -73,9 +74,14 @@ def post(self, request):
),
rollback_sql=extend_schema(
summary="获取SQL工单回滚语句",
responses={200: serializers.Serializer(many=True)},
responses={
200: serializers.ListSerializer(
child=serializers.ListField(default=["sql", "rollback_sql"])
)
},
description="通过工单ID获取回滚语句",
),
create=extend_schema(exclude=True),
partial_update=extend_schema(exclude=True),
alter_run_date=extend_schema(
summary="修改SQL工单可执行时间范围",
Expand All @@ -86,6 +92,7 @@ def post(self, request):
)
class SqlWorkflowView(viewsets.ModelViewSet):
permission_classes = [IsAuthenticated, SqlWorkFlowViewPermission]
serializer_class = SqlWorkflowSerializer
pagination_class = BootStrapTablePagination
filter_backends = [
filters.SearchFilter,
Expand All @@ -107,7 +114,7 @@ def get_queryset(self):
pass
# 非管理员,拥有审核权限、资源组粒度执行权限的,可以查看组内所有工单
elif user.has_perm("sql.sql_review") or user.has_perm(
"sql.sql_execute_for_resource_group"
"sql.sql_execute_for_resource_group"
):
filter_dict["group_id__in"] = [
group.group_id for group in user_groups(user)
Expand All @@ -118,12 +125,22 @@ def get_queryset(self):
queryset = SqlWorkflow.objects.filter(**filter_dict).order_by("-id")
return self.get_serializer_class().setup_eager_loading(queryset)

def get_serializer_class(self):
def get_serializer(self, *args, **kwargs):
serializer_class = self.get_serializer_class()
kwargs.setdefault("context", self.get_serializer_context())
if self.action == "retrieve":
return SqlWorkflowDetailSerializer
return SqlWorkflowSerializer
return serializer_class(*args, **kwargs)
return serializer_class(
*args, **kwargs, exclude=["sql_content", "display_content"]
)

@action(methods=["get"], detail=True, pagination_class=None, filter_backends=[], search_fields=None)
@action(
methods=["get"],
detail=True,
pagination_class=None,
filter_backends=[],
search_fields=None,
)
def rollback_sql(self, request, *args, **kwargs):
obj = self.get_object()
data = self.get_serializer().rollback_sql(obj)
Expand Down
40 changes: 40 additions & 0 deletions sql_api/serializers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# -*- coding: UTF-8 -*-
"""
@author: hhyo
@license: Apache Licence
@file: __init__.py
@time: 2022/10/22
"""
__author__ = "hhyo"

from rest_framework import serializers


class BaseModelSerializer(serializers.ModelSerializer):
"""BaseModelSerializer,主要是引入过滤和排除字段的方法"""

def __init__(self, *args, **kwargs):
"""
``fields`` 需要保留的字段列表
``exclude`` 需要排除的字段列表
"""
fields = kwargs.pop("fields", None)
exclude = kwargs.pop("exclude", None)
super(BaseModelSerializer, self).__init__(*args, **kwargs)

for field_name in set(self.fields.keys()):
if not any([fields, exclude]):
break
if fields and field_name in fields:
continue
if exclude and field_name not in exclude:
continue
self.fields.pop(field_name, None)

@staticmethod
def setup_eager_loading(queryset):
"""
Perform necessary eager loading of data.
https://ses4j.github.io/2015/11/23/optimizing-slow-django-rest-framework-performance/
"""
pass
60 changes: 18 additions & 42 deletions sql_api/serializers/sql_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from sql.engines import ReviewSet, get_engine
from sql.engines.models import ReviewResult
from sql.models import Instance, SqlWorkflow
from sql_api.serializers import BaseModelSerializer

logger = logging.getLogger("default")

Expand Down Expand Up @@ -69,57 +70,18 @@ class ExecuteCheckResultSerializer(serializers.Serializer):
affected_rows = serializers.IntegerField(read_only=True)


class SqlWorkflowSerializer(serializers.ModelSerializer):
class SqlWorkflowSerializer(BaseModelSerializer):
"""SQL工单"""

instance_name = serializers.CharField(source="instance.instance_name")

def __init__(self, *args, **kwargs):
"""
``fields`` 需要保留的字段列表
``exclude`` 需要排除的字段列表
"""
fields = kwargs.pop("fields", None)
exclude = kwargs.pop("exclude", None)
super(SqlWorkflowSerializer, self).__init__(*args, **kwargs)

for field_name in set(self.fields.keys()):
if not any([fields, exclude]):
break
if fields and field_name in fields:
continue
if exclude and field_name not in exclude:
continue
self.fields.pop(field_name, None)
sql_content = serializers.CharField(source="sqlworkflowcontent.sql_content")
display_content = serializers.SerializerMethodField()

@staticmethod
def setup_eager_loading(queryset):
"""
Perform necessary eager loading of data.
https://ses4j.github.io/2015/11/23/optimizing-slow-django-rest-framework-performance/
"""
queryset = queryset.select_related("instance")
return queryset

@staticmethod
def rollback_sql(obj):
try:
query_engine = get_engine(instance=obj.instance)
return query_engine.get_rollback(workflow=obj)
except Exception as msg:
logger.error(traceback.format_exc())
raise serializers.ValidationError({"errors": msg})

class Meta:
model = SqlWorkflow
fields = "__all__"


class SqlWorkflowDetailSerializer(SqlWorkflowSerializer):
instance_name = serializers.CharField(source="instance.instance_name")
sql_content = serializers.CharField(source="sqlworkflowcontent.sql_content")
display_content = serializers.SerializerMethodField()

@extend_schema_field(field=serializers.ListField(child=ReviewResultSerializer()))
def get_display_content(self, obj):
"""获取工单详情用于列表展示的内容,区分不同的状态进行转换"""
Expand Down Expand Up @@ -161,6 +123,20 @@ def get_display_content(self, obj):
rows = obj.sqlworkflowcontent.review_content
return json.loads(rows)

@staticmethod
def rollback_sql(obj):
"""获取工单回滚语句"""
try:
query_engine = get_engine(instance=obj.instance)
return query_engine.get_rollback(workflow=obj)
except Exception as msg:
logger.error(traceback.format_exc())
raise serializers.ValidationError({"errors": msg})

class Meta:
model = SqlWorkflow
fields = "__all__"


class SqlWorkflowDetailSerializer(SqlWorkflowSerializer):
"""仅用做文档生成,无实际意义"""

0 comments on commit a0ae8fc

Please sign in to comment.