diff --git a/apps/generic.py b/apps/generic.py index d8e3181d27..a612d3b3ef 100644 --- a/apps/generic.py +++ b/apps/generic.py @@ -173,6 +173,10 @@ def get_serializer_class(self, *args, **kwargs): return type(self.serializer_class.__name__, (self.serializer_class,), {"Meta": self.serializer_meta}) +class ApiMixinModelViewSet(ApiMixin, _ModelViewSet): + pagination_class = DataPageNumberPagination + + def custom_exception_handler(exc, context): """ 自定义错误处理方式 diff --git a/apps/node_man/handlers/meta.py b/apps/node_man/handlers/meta.py index 24900d5196..42fe4875f5 100644 --- a/apps/node_man/handlers/meta.py +++ b/apps/node_man/handlers/meta.py @@ -469,6 +469,32 @@ def fetch_os_type_children(os_types: Tuple = constants.OsType): os_type_children.append({"id": os_type, "name": constants.OS_CHN.get(os_type, os_type)}) return os_type_children + @staticmethod + def fetch_agent_pkg_manager_children(): + mock_version = [ + {"name": "2.1.8", "id": "2.1.8"}, + {"name": "2.1.7", "id": "2.1.7"}, + ] + mock_tags = [ + {"name": "稳定版本", "id": "stable"}, + {"name": "最新版本", "id": "latest"}, + ] + mock_creator = [ + {"name": "user1", "id": "user1"}, + {"name": "user2", "id": "user2"}, + ] + mock_is_ready = [ + {"name": "启用", "id": True}, + {"name": "停用", "id": False}, + ] + + return [ + {"name": _("版本号"), "id": "version", "children": mock_version}, + {"name": _("标签信息"), "id": "tags", "children": mock_tags}, + {"name": _("上传用户"), "id": "creator", "children": mock_creator}, + {"name": _("状态"), "id": "is_ready", "children": mock_is_ready}, + ] + def filter_condition(self, category): """ 获取过滤条件 @@ -495,6 +521,8 @@ def filter_condition(self, category): elif category == "os_type": ret = self.fetch_os_type_children() return ret + elif category == "agent_pkg_manage": + return self.fetch_agent_pkg_manager_children() @staticmethod def install_default_values_formatter(install_default_values: Dict[str, Dict[str, Any]]): diff --git a/apps/node_man/models.py b/apps/node_man/models.py index e5b69cb70b..9312d85c40 100644 --- a/apps/node_man/models.py +++ b/apps/node_man/models.py @@ -2487,3 +2487,55 @@ class Meta: index_together = [ ["bk_biz_id", "enable"], ] + + +class AgentPackages(models.Model): + pkg_name = models.CharField(_("压缩包名"), max_length=128) + version = models.CharField(_("版本号"), max_length=128) + module = models.CharField(_("所属服务"), max_length=32) + project = models.CharField(_("工程名"), max_length=32, db_index=True) + pkg_size = models.IntegerField(_("包大小")) + pkg_path = models.CharField(_("包路径"), max_length=128) + md5 = models.CharField(_("md5值"), max_length=32) + pkg_mtime = models.CharField(_("包更新时间"), max_length=48) + pkg_ctime = models.CharField(_("包创建时间"), max_length=48) + location = models.CharField(_("安装包链接"), max_length=512) + os = models.CharField( + _("系统类型"), + max_length=32, + choices=constants.PLUGIN_OS_CHOICES, + default=constants.PluginOsType.linux, + db_index=True, + ) + cpu_arch = models.CharField( + _("CPU类型"), max_length=32, choices=constants.CPU_CHOICES, default=constants.CpuType.x86_64, db_index=True + ) + creator = models.CharField(_("操作人"), max_length=45, default="admin") + + is_release_version = models.BooleanField(_("是否已经发布版本"), default=True, db_index=True) + # 由于创建记录时,文件可能仍然在传输过程中,因此需要标志位判断是否已经可用 + is_ready = models.BooleanField(_("插件是否可用"), default=True) + + version_log = models.TextField(_("版本日志"), null=True, blank=True) + version_log_en = models.TextField(_("英文版本日志"), null=True, blank=True) + + class Meta: + verbose_name = _("Agent包(AgentPackages)") + verbose_name_plural = _("Agent包(AgentPackages)") + + +class AgentPackageDesc(models.Model): + """ + Agent包信息表 + """ + + # 安装包名需要全局唯一,防止冲突 + name = models.CharField(_("安装包名"), max_length=32, unique=True, db_index=True) + description = models.TextField(_("安装包描述")) + module = models.CharField(_("所属服务"), max_length=32) + description_en = models.TextField(_("英文插件描述"), null=True, blank=True) + category = models.CharField(_("所属范围"), max_length=32, choices=constants.CATEGORY_CHOICES) + + class Meta: + verbose_name = _("Agent信息(AgentPackageDesc)") + verbose_name_plural = _("Agent信息(AgentPackageDesc)") diff --git a/apps/node_man/serializers/meta.py b/apps/node_man/serializers/meta.py index 8bbcee1074..802f46b072 100644 --- a/apps/node_man/serializers/meta.py +++ b/apps/node_man/serializers/meta.py @@ -25,3 +25,7 @@ class JobSettingSerializer(serializers.Serializer): install_download_limit_speed = serializers.IntegerField(label=_("安装下载限速"), max_value=JOB_MAX_VALUE, min_value=0) parallel_install_number = serializers.IntegerField(label=_("并行安装数"), max_value=JOB_MAX_VALUE, min_value=0) node_man_log_level = serializers.ChoiceField(label=_("节点管理日志级别"), choices=list(NODE_MAN_LOG_LEVEL)) + + +class FilterConditionSerializer(serializers.Serializer): + category = serializers.ChoiceField(label=_("分类"), choices=["agent_pkg_manage", "agent_pkg_quick_search"]) diff --git a/apps/node_man/serializers/packager_manage.py b/apps/node_man/serializers/packager_manage.py new file mode 100644 index 0000000000..40247351b2 --- /dev/null +++ b/apps/node_man/serializers/packager_manage.py @@ -0,0 +1,132 @@ +# -*- coding: utf-8 -*- +""" +TencentBlueKing is pleased to support the open source community by making 蓝鲸智云-节点管理(BlueKing-BK-NODEMAN) available. +Copyright (C) 2017-2022 THL A29 Limited, a Tencent company. All rights reserved. +Licensed under the MIT License (the "License"); you may not use this file except in compliance with the License. +You may obtain a copy of the License at https://opensource.org/licenses/MIT +Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +specific language governing permissions and limitations under the License. +""" +from django.utils.translation import ugettext_lazy as _ +from rest_framework import serializers + +from apps.exceptions import ValidationError +from apps.node_man.models import AgentPackages + + +class AgentPackageSerializer(serializers.ModelSerializer): + class meta: + model = AgentPackages + fields = "__all__" + + +class TagsSerializer(serializers.Serializer): + id = serializers.CharField() + name = serializers.CharField() + + +class ConditionsSerializer(serializers.Serializer): + key = serializers.ChoiceField(choices=["version", "os_cpu_arch", "tags", "is_ready"]) + values = serializers.ListField() + + +class SearchSerializer(serializers.Serializer): + os_cpu_arch = serializers.CharField(required=False) + tags = serializers.ListField(required=False) + + +class PackageDescSearchSerializer(serializers.Serializer): + os_cpu_arch = serializers.CharField(required=False) + + +class PackageSerializer(serializers.Serializer): + id = serializers.IntegerField() + pkg_name = serializers.CharField() + version = serializers.CharField() + os = serializers.CharField() + cpu_arch = serializers.CharField() + tags = TagsSerializer(many=True) + creator = serializers.CharField() + pkg_ctime = serializers.DateTimeField() + host_count = serializers.IntegerField() + is_ready = serializers.BooleanField() + + +class PackageDescSerializer(serializers.Serializer): + id = serializers.IntegerField() + version = serializers.CharField() + tags = TagsSerializer(many=True) + packages = PackageSerializer(many=True) + is_ready = serializers.BooleanField() + + +class SearchResponseSerializer(serializers.Serializer): + total = serializers.IntegerField() + list = PackageSerializer(many=True) + + +class PackageDescResponseSerialiaer(serializers.Serializer): + total = serializers.IntegerField() + list = PackageDescSerializer(many=True) + + +class OperateSerializer(serializers.Serializer): + is_ready = serializers.BooleanField() + + +# TODO 与plugin相同可抽取公共Serializer +class UploadSerializer(serializers.Serializer): + class PkgFileField(serializers.FileField): + def to_internal_value(self, data): + data = super().to_internal_value(data) + file_name = data.name + if not (file_name.endswith(".tgz") or file_name.endswith(".tar.gz")): + raise ValidationError(_("仅支持'tgz', 'tar.gz'拓展名的文件")) + return data + + module = serializers.ChoiceField(choices=["gse_agent", "gse_proxy"], required=False, default="gse_agent") + package_file = PkgFileField() + + +class UploadResponseSerializer(serializers.Serializer): + id = serializers.IntegerField() + name = serializers.CharField() + pkg_size = serializers.IntegerField() + + +class ParseSerializer(serializers.Serializer): + file_name = serializers.CharField() + + +class ParseResponseSerializer(serializers.Serializer): + class ParsePackageSerializer(serializers.Serializer): + module = serializers.ChoiceField(choices=["agent", "proxy"]) + pkg_name = serializers.CharField() + pkg_abs_path = serializers.CharField() + version = serializers.CharField() + os = serializers.CharField() + cpu_arch = serializers.CharField() + config_templates = serializers.ListField() + + description = serializers.CharField() + packages = ParsePackageSerializer(many=True) + + +class AgentRegisterSerializer(serializers.Serializer): + class RegisterPackageSerializer(serializers.Serializer): + pkg_abs_path = serializers.CharField() + tags = serializers.ListField() + + is_release = serializers.BooleanField() + packages = RegisterPackageSerializer(many=True) + + +class AgentRegisterTaskSerializer(serializers.Serializer): + job_id = serializers.IntegerField() + + +class AgentRegisterTaskResponseSerializer(serializers.Serializer): + is_finish = serializers.BooleanField() + status = serializers.ChoiceField(choices=["SUCCESS", "FAILED", "RUNNING"]) + message = serializers.CharField() diff --git a/apps/node_man/urls.py b/apps/node_man/urls.py index 42144e2220..5605ef5a66 100644 --- a/apps/node_man/urls.py +++ b/apps/node_man/urls.py @@ -40,6 +40,10 @@ ) from apps.node_man.views.healthz import HealthzViewSet from apps.node_man.views.host_v2 import HostV2ViewSet +from apps.node_man.views.package_manage import ( + AgentPackageDescViewSet, + PackageManageViewSet, +) from apps.node_man.views.plugin import GsePluginViewSet from apps.node_man.views.plugin_v2 import PluginV2ViewSet from apps.node_man.views.sync_task import SyncTaskViewSet @@ -67,6 +71,8 @@ router.register(r"v2/plugin", PluginV2ViewSet, basename="plugin_v2") router.register(r"healthz", HealthzViewSet, basename="healthz") router.register(r"sync_task", SyncTaskViewSet, basename="sync_task") +router.register(r"agent/package", PackageManageViewSet, basename="package_manage") +router.register(r"agent/package_desc", AgentPackageDescViewSet, basename="package_desc") biz_dispatcher = DjangoBasicResourceApiDispatcher(iam, settings.BK_IAM_SYSTEM_ID) biz_dispatcher.register("biz", BusinessResourceProvider()) diff --git a/apps/node_man/views/meta.py b/apps/node_man/views/meta.py index af0db561e7..1c444ab4e0 100644 --- a/apps/node_man/views/meta.py +++ b/apps/node_man/views/meta.py @@ -17,7 +17,10 @@ from apps.node_man.exceptions import NotSuperUserError from apps.node_man.handlers.iam import IamHandler from apps.node_man.handlers.meta import MetaHandler -from apps.node_man.serializers.meta import JobSettingSerializer +from apps.node_man.serializers.meta import ( + FilterConditionSerializer, + JobSettingSerializer, +) from apps.utils.local import get_request_username META_VIEW_TAGS = ["meta"] @@ -26,6 +29,7 @@ class MetaViews(APIViewSet): @swagger_auto_schema( operation_summary="获取过滤条件", + query_serializer=FilterConditionSerializer, tags=META_VIEW_TAGS, ) @action(detail=False) diff --git a/apps/node_man/views/package_manage.py b/apps/node_man/views/package_manage.py new file mode 100644 index 0000000000..1278233a1d --- /dev/null +++ b/apps/node_man/views/package_manage.py @@ -0,0 +1,270 @@ +# -*- coding: utf-8 -*- +""" +TencentBlueKing is pleased to support the open source community by making 蓝鲸智云-节点管理(BlueKing-BK-NODEMAN) available. +Copyright (C) 2017-2022 THL A29 Limited, a Tencent company. All rights reserved. +Licensed under the MIT License (the "License"); you may not use this file except in compliance with the License. +You may obtain a copy of the License at https://opensource.org/licenses/MIT +Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +specific language governing permissions and limitations under the License. +""" +from django_filters.rest_framework import DjangoFilterBackend +from rest_framework import filters +from rest_framework.decorators import action +from rest_framework.response import Response +from rest_framework.status import HTTP_200_OK + +from apps.generic import ApiMixinModelViewSet as ModelViewSet +from apps.node_man import models +from apps.node_man.serializers import packager_manage as pkg_manage +from common.utils.drf_utils import swagger_auto_schema + +PACKAGE_MANAGE_VIEW_TAGS = ["PKG_Manager"] +PACKAGE_DES_VIEW_TAGS = ["PKG_Desc"] + + +class PackageManageViewSet(ModelViewSet): + # queryset = models.Packages.objects.filter(module__in=["gse_proxy", "gse_agent"]) + queryset = models.AgentPackages.objects.all() + # model = models.Packages + # http_method_names = ["get", "post"] + # ordering_fields = ("module",) + serializer_class = pkg_manage.PackageSerializer + filter_backends = (DjangoFilterBackend, filters.SearchFilter, filters.OrderingFilter) + + filter_fields = ("module", "creator", "is_ready", "version") + + @swagger_auto_schema( + query_serializer=pkg_manage.SearchSerializer, + responses={200: pkg_manage.SearchResponseSerializer}, + operation_summary="安装包列表", + tags=PACKAGE_MANAGE_VIEW_TAGS, + ) + def list(self, request, *args, **kwargs): + mock_data = { + "total": 2, + "list": [ + { + "id": 1, + "pkg_name": "pkg_name", + "version": "1.1.1", + "os": "Linux", + "cpu_arch": "x86_64", + "tags": [{"id": "stable", "name": "稳定版本"}], + "creator": "string", + "pkg_ctime": "2019-08-24 14:15:22", + "host_count": 100, + "is_ready": True, + }, + { + "id": 2, + "pkg_name": "pkg_name", + "version": "1.1.2", + "os": "Linux", + "os_cpu_arch": "x86_64", + "tags": [{"id": "stable", "name": "稳定版本"}], + "creator": "string", + "pkg_ctime": "2019-08-24 14:15:22", + "host_count": 100, + "is_ready": True, + }, + ], + } + return Response(mock_data) + # return super().list(request, *args, **kwargs) + + @swagger_auto_schema( + operation_summary="操作类动作:启用/停用", + body_in=pkg_manage.OperateSerializer, + responses={200: pkg_manage.SearchResponseSerializer}, + tags=PACKAGE_MANAGE_VIEW_TAGS, + ) + def update(self, request, validated_data, *args, **kwargs): + mock_data = { + "id": 1, + "pkg_name": "pkg_name", + "version": "1.1.1", + "os": "Linux", + "cpu_arch": "x86_64", + "tags": [{"id": "stable", "name": "稳定版本"}], + "creator": "string", + "pkg_ctime": "2019-08-24 14:15:22", + "host_count": 100, + "is_ready": True, + } + + return Response(mock_data) + + @swagger_auto_schema( + operation_summary="删除安装包", + tags=PACKAGE_MANAGE_VIEW_TAGS, + ) + def destroy(self, request, *args, **kwargs): + + return Response() + + @swagger_auto_schema( + operation_summary="获取快速筛选信息", + tags=PACKAGE_MANAGE_VIEW_TAGS, + ) + @action(detail=False, methods=["GET"]) + def quick_search_condition(self, request, *args, **kwargs): + mock_version = [ + {"name": "2.1.8", "id": "2.1.8", "host_count": 10}, + {"name": "2.1.7", "id": "2.1.7", "host_count": 10}, + {"name": "ALL", "id": "all", "host_count": 20}, + ] + + mock_os_cpu_arch = [ + {"name": "Linux_x86_64", "id": "linux_x86_64", "host_count": 10}, + {"name": "Linux_x86", "id": "linux_x86", "host_count": 10}, + {"name": "ALL", "id": "all", "host_count": 20}, + ] + + mock_data = [ + {"name": "操作系统/架构", "id": "os_cpu_arch", "children": mock_os_cpu_arch}, + {"name": "版本号", "id": "version", "children": mock_version}, + ] + + return Response(mock_data) + + @swagger_auto_schema( + operation_summary="Agent包上传", + tags=PACKAGE_MANAGE_VIEW_TAGS, + responses={HTTP_200_OK: pkg_manage.UploadResponseSerializer}, + ) + @action(detail=False, methods=["POST"], serializer_class=pkg_manage.UploadSerializer) + def upload(self, request): + # data = self.validated_data + mock_data = { + "id": 1, + "name": "gse_agent.tgz", + "pkg_size": 100, + } + return Response(mock_data) + + @swagger_auto_schema( + operation_summary="解析Agent包", + tags=PACKAGE_MANAGE_VIEW_TAGS, + responses={HTTP_200_OK: pkg_manage.ParseResponseSerializer}, + ) + @action(detail=False, methods=["POST"], serializer_class=pkg_manage.ParseSerializer) + def parse(self, request): + mock_data = { + "description": "test", + "packages": [ + { + "pkg_abs_path": "xxx/xxxxx", + "pkg_name": "gseagent_2.1.7_linux_x86_64.tgz", + "module": "agent", + "version": "2.1.7", + "config_templates": [], + "os": "x86_64", + }, + { + "pkg_abs_path": "xxx/xxxxx", + "pkg_name": "gseagent_2.1.7_linux_x86.tgz", + "module": "agent", + "version": "2.1.7", + "config_templates": [], + "os": "x86", + }, + ], + } + return Response(mock_data) + + @swagger_auto_schema( + operation_summary="创建Agent包注册任务", + tags=PACKAGE_MANAGE_VIEW_TAGS, + responses={HTTP_200_OK: pkg_manage.AgentRegisterTaskSerializer}, + ) + @action(detail=False, methods=["POST"], serializer_class=pkg_manage.AgentRegisterSerializer) + def create_register_task(self, request): + mock_data = {"job_id": 1} + return Response(mock_data) + + @swagger_auto_schema( + operation_summary="查询Agent包注册任务", + tags=PACKAGE_MANAGE_VIEW_TAGS, + query_in=pkg_manage.AgentRegisterTaskSerializer, + responses={HTTP_200_OK: pkg_manage.AgentRegisterTaskResponseSerializer}, + ) + @action(detail=False, methods=["GET"]) + def query_register_task(self, request, validated_data): + + mock_data = { + "is_finish": True, + "status": "SUCCESS", + "message": "", + } + return Response(mock_data) + + @swagger_auto_schema( + operation_summary="获取Agent包标签", + tags=PACKAGE_MANAGE_VIEW_TAGS, + responses={HTTP_200_OK: pkg_manage.TagsSerializer(many=True)}, + ) + @action(detail=False, methods=["GET"]) + def tags(self, request): + # 由tag handler 实现 + mock_data = [ + { + "id": "builtin", + "name": "内置标签", + "children": [ + {"id": "stable", "name": "稳定版本", "children": []}, + {"id": "latest", "name": "最新版本", "children": []}, + ], + }, + {"id": "custom", "name": "自定义标签", "children": [{"id": "custom", "name": "自定义版本", "children": []}]}, + ] + return Response(mock_data) + + @swagger_auto_schema( + operation_summary="获取Agent包版本", + tags=PACKAGE_MANAGE_VIEW_TAGS, + responses={HTTP_200_OK: pkg_manage.TagsSerializer(many=True)}, + ) + @action(detail=False, methods=["GET"]) + def version(self, request): + pass + + +class AgentPackageDescViewSet(ModelViewSet): + queryset = models.AgentPackageDesc.objects.all() + # model = models.Packages + # http_method_names = ["get", "post"] + # ordering_fields = ("module",) + # serializer_class = pkg_manage.PackageSerializer + # filter_backends = (DjangoFilterBackend, filters.SearchFilter, filters.OrderingFilter) + + # filter_fields = ("module", "creator", "is_ready", "version") + + @swagger_auto_schema( + query_in=pkg_manage.PackageDescSearchSerializer, + responses={200: pkg_manage.PackageDescResponseSerialiaer}, + operation_summary="Agent版本列表", + tags=PACKAGE_DES_VIEW_TAGS, + ) + def list(self, request, *args, **kwargs): + + mock_data = { + "total": 10, + "list": [ + { + "id": 1, + "version": "2.1.2", + "tags": [{"id": "stable", "name": "稳定版本"}], + "is_ready": True, + "description": "我是描述", + "packages": [ + { + "" "pkg_name": "gseagent-2.1.2.tgz", + "tags": [{"id": "stable", "name": "稳定版本"}, {"id": "latest", "name": "最新版本"}], + } + ], + } + ], + } + return Response(mock_data) + # return super().list(request, *args, **kwargs) diff --git a/common/utils/drf_utils.py b/common/utils/drf_utils.py new file mode 100644 index 0000000000..c2d4a540c0 --- /dev/null +++ b/common/utils/drf_utils.py @@ -0,0 +1,271 @@ +# -*- coding: utf-8 -*- +""" + * TencentBlueKing is pleased to support the open source community by making 蓝鲸智云-蓝鲸 PaaS 平台(BlueKing-PaaS) available. + * Copyright (C) 2017-2021 THL A29 Limited, a Tencent company. All rights reserved. + * Licensed under the MIT License (the "License"); you may not use this file except in compliance with the License. + * You may obtain a copy of the License at http://opensource.org/licenses/MIT + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. +""" +import copy +import functools +from collections import namedtuple +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union + +from django.conf import settings +from django.http.response import HttpResponseBase +from django.utils.module_loading import import_string +from rest_framework import status +from rest_framework.exceptions import ValidationError +from rest_framework.fields import empty +from rest_framework.serializers import BaseSerializer +from rest_framework.settings import api_settings +from rest_framework.utils.serializer_helpers import ReturnDict, ReturnList + +if TYPE_CHECKING: + from rest_framework.request import Request + + +def stringify_validation_error(error: ValidationError) -> List[str]: + """Transform DRF's ValidationError into a list of error strings + + >>> stringify_validation_error(ValidationError({'foo': ErrorDetail('err')})) + ['foo: err'] + """ + results: List[str] = [] + + def traverse(err_detail: Any, keys: List[str]): + """Traverse error data to collect all error messages""" + + # Dig deeper when structure is list or dict + if isinstance(err_detail, (ReturnList, list, tuple)): + for err in err_detail: + traverse(err, keys) + elif isinstance(err_detail, (ReturnDict, dict)): + for key, err in err_detail.items(): + # Make a copy of keys so the inner loop won't affect outer scope + _keys = copy.copy(keys) + if key != api_settings.NON_FIELD_ERRORS_KEY: + _keys.append(str(key)) + traverse(err, _keys) + else: + if not keys: + results.append(str(err_detail)) + else: + results.append("{}: {}".format(".".join(keys), str(err_detail))) + + traverse(error.detail, []) + return sorted(results) + + +############# +# drf crown # +############# +class WearOptions: + is_unittest = False + skip_swagger_schema = False + + +try: + from drf_yasg.utils import swagger_auto_schema as drf_swagger_auto_schema + +except ImportError: + WearOptions.skip_swagger_schema = True + + +ResponseParams = namedtuple("ResponseParams", "data,params") + + +_DEFAULT_SETTINGS_PREFIX = "DRF_CROWN_" + + +def enable_unittest(): + """Call me when you running testing""" + WearOptions.is_unittest = True + + +@dataclass +class Config: + """Config for Injector, control the process of injecting""" + + return_validated_data: bool = True + remain_request: bool = False + # sometime return raw data instead of serializer + skip_out_cls: bool = False + default_return_status: status = status.HTTP_200_OK + + +@dataclass +class ViewCrown: + """A injector for injecting serializer as dependency""" + + body_in: Optional[Union[Type[BaseSerializer], BaseSerializer]] + query_in: Optional[Union[Type[BaseSerializer], BaseSerializer]] + out: Union[Type[BaseSerializer], BaseSerializer] + config_params: Optional[dict] = field(default_factory=dict) + valid_params: dict = field(default_factory=dict) + + def __post_init__(self): + if self.query_in and self.body_in: + raise ValueError("there should be only one param between in_body & in_query") + + self.valid_params = self.valid_params or {"raise_exception": True} + + # Priority decreases + # 1. config as parameter from decorator + # 2. config from django.settings + # 3. config from Config class(above) + _config = getattr(settings, _DEFAULT_SETTINGS_PREFIX + "DEFAULT_CONFIG", {}).copy() + _config.update(self.config_params or {}) + self.config = Config(**_config) + + # remain an entrance for custom response class + try: + self.resp_cls = import_string(getattr(settings, _DEFAULT_SETTINGS_PREFIX + "RESP_CLS")) + except AttributeError: + self.resp_cls = import_string("rest_framework.response.Response") + + def get_in_serializer_instance(self, request: Optional["Request"] = None) -> "BaseSerializer": + if not self.body_in and not self.query_in: + raise ValueError("should given at least one serializer input") + + _data = empty + if self.body_in: + _in = self.body_in + + if request is not None: + _data = getattr(request, "data") + else: + _in = self.query_in + + if request is not None: + _data = getattr(request, "query_params") + + if isinstance(_in, BaseSerializer): + # 由于传入的是全局对象,会残留上一次请求的结果 + # 这里需要手动清理一下 + if hasattr(_in, "_validated_data"): + delattr(_in, "_validated_data") + + _in.initial_data = _data + slz_obj = _in + elif issubclass(_in, BaseSerializer): + slz_obj = _in(data=_data) + else: + raise ValueError("unknown serializer input") + + return slz_obj + + def get_serializer_instance_by_request(self, request: "Request") -> "BaseSerializer": + """Get in serializer instance""" + slz_obj = self.get_in_serializer_instance(request) + slz_obj.is_valid(**self.valid_params) + return slz_obj + + def get_validated_data(self, request: "Request") -> dict: + """Get validated data via in_serializer""" + return self.get_serializer_instance_by_request(request).validated_data + + def get_in_params(self, request: "Request") -> dict: + """Get extra params before view logic""" + if WearOptions.is_unittest: + return {} + + if self.config.return_validated_data: + return {"validated_data": self.get_validated_data(request)} + else: + return {"serializer_instance": self.get_serializer_instance_by_request(request)} + + def get_response(self, data, out_params: dict) -> Any: + """Get Response data""" + if WearOptions.is_unittest: + return data + + if self.config.skip_out_cls: + return data + + if isinstance(data, (self.resp_cls, HttpResponseBase)): + return data + + if isinstance(self.out, BaseSerializer): + # 由于传入的是全局对象,会残留上一次请求的结果 + # 这里需要手动清理一下 + if hasattr(self.out, "_data"): + delattr(self.out, "_data") + + self.out.instance = data + _data = self.out.data + elif issubclass(self.out, BaseSerializer): + _data = self.out(data, **out_params).data + else: + raise ValueError("unknown serializer output") + + return self.resp_cls(data=_data, status=self.config.default_return_status) + + +def generate_swagger_params(crown: ViewCrown, swagger_params: dict) -> dict: + """ + assemble params for swagger_auto_schema by crown + """ + default_params = {} + if crown.body_in: + default_params = {"request_body": crown.get_in_serializer_instance()} + elif crown.query_in: + default_params = {"query_serializer": crown.get_in_serializer_instance()} + + if crown.out: + default_params.update({"responses": {crown.config.default_return_status: crown.out}}) + + default_params.update(swagger_params or {}) + return default_params + + +def swagger_auto_schema( + body_in: Optional[Union[Type[BaseSerializer], BaseSerializer]] = None, + query_in: Optional[Union[Type[BaseSerializer], BaseSerializer]] = None, + out: Optional[Union[Type[BaseSerializer], BaseSerializer]] = None, + config: Optional[dict] = None, + **swagger_kwargs +): + """ + Sugar for simpling drf serializer specification + :param body_in: input serializer (request body) + :param query_in: input serializer (query) + :param out: output serializer + :param config: initial info of Config + :param swagger_kwargs: pass to swagger_auto_schema of drf-yasg + """ + + def decorator_serializer_inject(func): + crown = ViewCrown(body_in, query_in, out, config) + + if not WearOptions.skip_swagger_schema: + func = drf_swagger_auto_schema(**generate_swagger_params(crown, swagger_kwargs))(func) + + @functools.wraps(func) + def decorated(*args, **kwargs): + new_args = list(args) + in_content: Dict[str, Any] = {} + if body_in or query_in: + in_content.update(**crown.get_in_params(new_args[1])) + + if not crown.config.remain_request: + del new_args[1] + + original_data = func(*new_args, **kwargs, **in_content) + if not out: + return original_data + + # support runtime serializer params, like "context" + params = {} + if isinstance(original_data, ResponseParams): + params = original_data.params + original_data = original_data.data + + return crown.get_response(original_data, params) + + return decorated + + return decorator_serializer_inject diff --git a/config/default.py b/config/default.py index 6159f1382f..c1f8791c60 100644 --- a/config/default.py +++ b/config/default.py @@ -816,6 +816,8 @@ def get_standard_redis_mode(cls, config_redis_mode: str, default: Optional[str] if env.BKAPP_MONITOR_REPORTER_ENABLE: monitor_report_config() +DRF_CROWN_DEFAULT_CONFIG = {"remain_request": True} + # remove disabled apps if locals().get("DISABLED_APPS"): INSTALLED_APPS = locals().get("INSTALLED_APPS", [])