diff --git a/gcloud/apigw/views/copy_template_across_project.py b/gcloud/apigw/views/copy_template_across_project.py index 40366ae9f..ad8f411e9 100644 --- a/gcloud/apigw/views/copy_template_across_project.py +++ b/gcloud/apigw/views/copy_template_across_project.py @@ -10,8 +10,8 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ - -import ujson as json +import logging +import json from apigw_manager.apigw.decorators import apigw_require from blueapps.account.decorators import login_exempt from django.views.decorators.csrf import csrf_exempt @@ -25,6 +25,7 @@ ) from gcloud.apigw.views.utils import logger +from gcloud.contrib.template_market.models import TemplateSharedRecord from gcloud.tasktmpl3.models import TaskTemplate from gcloud.template_base.utils import format_import_result_to_response_data from gcloud.utils.decorators import request_validate @@ -51,6 +52,15 @@ def copy_template_across_project(request, project_id): new_project_id = params_data["new_project_id"] template_id = params_data["template_id"] + record = TemplateSharedRecord.objects.filter(project_id=request.project.id, template_id=template_id).first() + if record is None: + logging.warning("The specified template could not be found") + return { + "result": False, + "message": "The specified template could not be found", + "code": err_code.REQUEST_FORBIDDEN_INVALID.code, + } + try: export_data = TaskTemplate.objects.export_templates([template_id], is_full=False, project_id=request.project.id) import_result = TaskTemplate.objects.import_templates( diff --git a/gcloud/contrib/template_market/admin.py b/gcloud/contrib/template_market/admin.py index 6b5429b3d..8c415c540 100644 --- a/gcloud/contrib/template_market/admin.py +++ b/gcloud/contrib/template_market/admin.py @@ -18,6 +18,6 @@ @admin.register(models.TemplateSharedRecord) class TemplateSharedRecordAdmin(admin.ModelAdmin): - list_display = ["market_record_id", "project_id", "templates", "creator", "create_at", "update_at", "extra_info"] - list_filter = ["project_id", "creator", "create_at", "update_at"] - search_fields = ["market_record_id", "project_id", "creator"] + list_display = ["project_id", "template_id", "creator", "create_at", "update_at", "extra_info"] + list_filter = ["project_id", "template_id", "creator", "create_at", "update_at"] + search_fields = ["project_id", "creator"] diff --git a/gcloud/contrib/template_market/migrations/0001_initial.py b/gcloud/contrib/template_market/migrations/0001_initial.py index b4b5c07fd..1a2ab4fc8 100644 --- a/gcloud/contrib/template_market/migrations/0001_initial.py +++ b/gcloud/contrib/template_market/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 3.2.15 on 2024-12-12 09:15 +# Generated by Django 3.2.15 on 2024-12-12 12:23 from django.db import migrations, models @@ -14,9 +14,8 @@ class Migration(migrations.Migration): name="TemplateSharedRecord", fields=[ ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), - ("market_record_id", models.CharField(db_index=True, max_length=32, verbose_name="模板市场记录 ID")), - ("project_id", models.IntegerField(help_text="项目 ID", verbose_name="项目 ID")), - ("templates", models.JSONField(help_text="模板 ID 列表", verbose_name="模板 ID 列表")), + ("project_id", models.IntegerField(default=-1, help_text="项目 ID", verbose_name="项目 ID")), + ("template_id", models.JSONField(db_index=True, help_text="模板 ID 列表", verbose_name="模板 ID 列表")), ("creator", models.CharField(default="", max_length=32, verbose_name="创建者")), ("create_at", models.DateTimeField(auto_now_add=True, verbose_name="创建时间")), ("update_at", models.DateTimeField(auto_now=True, verbose_name="更新时间")), diff --git a/gcloud/contrib/template_market/models.py b/gcloud/contrib/template_market/models.py index 31b4e6f84..1adc82ce1 100644 --- a/gcloud/contrib/template_market/models.py +++ b/gcloud/contrib/template_market/models.py @@ -16,9 +16,8 @@ class TemplateSharedRecord(models.Model): - market_record_id = models.CharField(_("共享实例 ID"), max_length=32, help_text="共享实例 ID", db_index=True) project_id = models.IntegerField(_("项目 ID"), default=-1, help_text="项目 ID") - templates = models.JSONField(_("模板 ID 列表"), help_text="模板 ID 列表") + template_id = models.JSONField(_("模板 ID 列表"), help_text="模板 ID 列表", db_index=True) creator = models.CharField(_("创建者"), max_length=32, default="") create_at = models.DateTimeField(_("创建时间"), auto_now_add=True) update_at = models.DateTimeField(verbose_name=_("更新时间"), auto_now=True) diff --git a/gcloud/contrib/template_market/permission.py b/gcloud/contrib/template_market/permission.py index e14739196..8ac657d34 100644 --- a/gcloud/contrib/template_market/permission.py +++ b/gcloud/contrib/template_market/permission.py @@ -29,7 +29,7 @@ def has_permission(self, request, view): template_id = int(serializer.validated_data["template_id"]) project_id = int(serializer.validated_data["project_id"]) - record = TemplateSharedRecord.objects.filter(project_id=project_id, templates__contains=[template_id]).first() + record = TemplateSharedRecord.objects.filter(project_id=project_id, template_id=template_id).first() if record is None: logging.warning("The specified template could not be found") return False @@ -37,7 +37,7 @@ def has_permission(self, request, view): return True -class SharedProcessTemplatePermission(permissions.BasePermission): +class SharedTemplateRecordPermission(permissions.BasePermission): def has_permission(self, request, view): if not settings.ENABLE_TEMPLATE_MARKET: return False @@ -47,13 +47,13 @@ def has_permission(self, request, view): serializer = view.serializer_class(data=request.data) serializer.is_valid(raise_exception=True) - template_id_list = serializer.validated_data["templates"] + template_id_list = serializer.validated_data["template_ids"] try: iam_multi_resource_auth_or_raise( username, IAMMeta.FLOW_EDIT_ACTION, template_id_list, "resources_list_for_flows" ) except MultiAuthFailedException: - logging.exception("You do not have permission to perform this operation") + logging.exception("Template permission verification failed") return False return True diff --git a/gcloud/contrib/template_market/serializers.py b/gcloud/contrib/template_market/serializers.py index 82248119f..ebcab5571 100644 --- a/gcloud/contrib/template_market/serializers.py +++ b/gcloud/contrib/template_market/serializers.py @@ -21,6 +21,7 @@ class TemplatePreviewSerializer(serializers.Serializer): pipeline_tree = serializers.SerializerMethodField(read_only=True, help_text="pipeline_tree") def get_pipeline_tree(self, obj): + # todo 节点信息防护 return json.dumps(obj.pipeline_tree) @@ -29,9 +30,9 @@ class TemplateProjectBaseSerializer(serializers.Serializer): project_id = serializers.CharField(required=True, help_text="项目id") -class TemplateSharedRecordSerializer(serializers.ModelSerializer): +class TemplateSharedRecordSerializer(serializers.Serializer): project_id = serializers.CharField(required=True, max_length=32, help_text="项目id") - templates = serializers.ListField(required=True, help_text="关联的模板列表") + template_ids = serializers.ListField(required=True, help_text="关联的模板列表") creator = serializers.CharField(required=True, max_length=32, help_text="创建者") extra_info = serializers.JSONField(required=False, allow_null=True, help_text="额外信息") name = serializers.CharField(required=True, help_text="共享名称") @@ -42,30 +43,36 @@ class TemplateSharedRecordSerializer(serializers.ModelSerializer): labels = serializers.ListField(child=serializers.IntegerField(), required=True, help_text="共享标签列表") usage_content = serializers.JSONField(required=True, help_text="使用说明") - class Meta: - model = TemplateSharedRecord - fields = [ - "project_id", - "templates", - "creator", - "extra_info", - "name", - "code", - "category", - "risk_level", - "usage_id", - "labels", - "usage_content", - ] + def create_shared_record(self, project_id, market_record_id, template_ids, creator): + for template_id in template_ids: + existing_record, created = TemplateSharedRecord.objects.get_or_create( + project_id=project_id, + template_id=template_id, + defaults={"creator": creator, "extra_info": {"market_record_ids": [market_record_id]}}, + ) + if not created: + market_ids = existing_record.extra_info.setdefault("market_record_ids", []) + if market_record_id not in market_ids: + market_ids.append(market_record_id) + existing_record.save() - def create(self, validated_data): - fields_to_remove = ["name", "code", "category", "risk_level", "usage_id", "labels", "usage_content"] - for field in fields_to_remove: - validated_data.pop(field, None) - return super().create(validated_data) + def update_shared_record(self, new_template_ids, market_record_id, project_id, creator): + market_record_id = int(market_record_id) - def update(self, instance, validated_data): - fields_to_remove = ["name", "code", "category", "risk_level", "usage_id", "labels", "usage_content"] - for field in fields_to_remove: - validated_data.pop(field, None) - return super().update(instance, validated_data) + existing_records = TemplateSharedRecord.objects.filter( + project_id=project_id, extra_info__market_record_ids__contains=[market_record_id] + ) + existing_template_ids = set(existing_records.values_list("template_id", flat=True)) + templates_to_remove = existing_template_ids - set(new_template_ids) + + for template_id in templates_to_remove: + current_template_record = existing_records.get(template_id=template_id) + current_market_ids = current_template_record.extra_info.get("market_record_ids", []) + if market_record_id in current_market_ids: + current_market_ids.remove(market_record_id) + current_template_record.extra_info["market_record_ids"] = current_market_ids + current_template_record.save() + + templates_to_add = set(new_template_ids) - existing_template_ids + if templates_to_add: + self.create_shared_record(project_id, market_record_id, list(templates_to_add), creator) diff --git a/gcloud/contrib/template_market/urls.py b/gcloud/contrib/template_market/urls.py index 4b9ddd09b..250be933f 100644 --- a/gcloud/contrib/template_market/urls.py +++ b/gcloud/contrib/template_market/urls.py @@ -13,12 +13,12 @@ from django.conf.urls import include, url from rest_framework.routers import DefaultRouter -from gcloud.contrib.template_market.viewsets import TemplatePreviewViewSet, SharedProcessTemplateViewSet +from gcloud.contrib.template_market.viewsets import TemplatePreviewViewSet, SharedTemplateRecordsViewSet template_market_router = DefaultRouter() template_market_router.register(r"template_preview", TemplatePreviewViewSet) -template_market_router.register(r"shared_process_templates", SharedProcessTemplateViewSet) +template_market_router.register(r"shared_templates_records", SharedTemplateRecordsViewSet) urlpatterns = [ url(r"^api/", include(template_market_router.urls)), diff --git a/gcloud/contrib/template_market/viewsets.py b/gcloud/contrib/template_market/viewsets.py index ba670d2cc..4bbf52132 100644 --- a/gcloud/contrib/template_market/viewsets.py +++ b/gcloud/contrib/template_market/viewsets.py @@ -18,6 +18,7 @@ from rest_framework import permissions from gcloud import err_code +from gcloud.conf import settings from drf_yasg.utils import swagger_auto_schema from gcloud.contrib.template_market.serializers import ( TemplateSharedRecordSerializer, @@ -27,7 +28,7 @@ from gcloud.contrib.template_market.models import TemplateSharedRecord from gcloud.taskflow3.models import TaskTemplate from gcloud.contrib.template_market.clients import MarketAPIClient -from gcloud.contrib.template_market.permission import TemplatePreviewPermission, SharedProcessTemplatePermission +from gcloud.contrib.template_market.permission import TemplatePreviewPermission, SharedTemplateRecordPermission class TemplatePreviewViewSet(viewsets.ViewSet): @@ -48,18 +49,16 @@ def retrieve(self, request, *args, **kwargs): return Response({"result": True, "data": serializer.data, "code": err_code.SUCCESS.code}) -class SharedProcessTemplateViewSet(viewsets.ViewSet): +class SharedTemplateRecordsViewSet(viewsets.ViewSet): queryset = TemplateSharedRecord.objects.all() serializer_class = TemplateSharedRecordSerializer - permission_classes = [permissions.IsAuthenticated, SharedProcessTemplatePermission] + permission_classes = [permissions.IsAuthenticated, SharedTemplateRecordPermission] - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.market_client = MarketAPIClient() + market_client = MarketAPIClient() def _build_template_data(self, serializer, **kwargs): - templates = TaskTemplate.objects.filter(id__in=serializer.validated_data["templates"], is_deleted=False) - template_id_list = [{"id": template.id, "name": template.name} for template in templates] + templates = TaskTemplate.objects.filter(id__in=serializer.validated_data["template_ids"], is_deleted=False) + template_info = [{"id": template.id, "name": template.name} for template in templates] data = { "name": serializer.validated_data["name"], "code": serializer.validated_data["code"], @@ -67,14 +66,14 @@ def _build_template_data(self, serializer, **kwargs): "risk_level": serializer.validated_data["risk_level"], "usage_id": serializer.validated_data["usage_id"], "labels": serializer.validated_data["labels"], - "source_system": "bk_sops", + "source_system": settings.APP_CODE, "project_code": serializer.validated_data["project_id"], - "templates": json.dumps(template_id_list), + "templates": json.dumps(template_info), "usage_content": serializer.validated_data["usage_content"], } - scene_shared_id = kwargs.get("scene_shared_id") - if scene_shared_id: - data["id"] = scene_shared_id + market_record_id = kwargs.get("market_record_id") + if market_record_id: + data["id"] = market_record_id return data def list(self, request, *args, **kwargs): @@ -106,14 +105,17 @@ def create(self, request, *args, **kwargs): "code": err_code.OPERATION_FAIL.code, } ) - serializer.validated_data["market_record_id"] = response_data["data"]["id"] - serializer.create(serializer.validated_data) + serializer.create_shared_record( + project_id=int(serializer.validated_data["project_id"]), + template_ids=serializer.validated_data["template_ids"], + market_record_id=response_data["data"]["id"], + creator=serializer.validated_data["creator"], + ) return Response({"result": True, "data": response_data, "code": err_code.SUCCESS.code}) @swagger_auto_schema(request_body=TemplateSharedRecordSerializer) def partial_update(self, request, *args, **kwargs): market_record_id = kwargs["pk"] - instance = self.queryset.get(market_record_id=market_record_id) serializer = self.serializer_class(data=request.data, partial=True) serializer.is_valid(raise_exception=True) @@ -127,5 +129,10 @@ def partial_update(self, request, *args, **kwargs): "code": err_code.OPERATION_FAIL.code, } ) - serializer.update(instance=instance, validated_data=serializer.validated_data) + serializer.update_shared_record( + project_id=int(serializer.validated_data["project_id"]), + new_template_ids=serializer.validated_data["template_ids"], + market_record_id=market_record_id, + creator=serializer.validated_data["creator"], + ) return Response({"result": True, "data": response_data, "code": err_code.SUCCESS.code})