From b267fc2a375912eb777c76eb062d2183297eb546 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=91=E4=BD=A9=E7=8F=8A=5Bshiisa=5D?= Date: Thu, 7 Sep 2023 20:19:47 +0800 Subject: [PATCH] Add unit tests --- .../data_source_organization/serializers.py | 2 +- .../test_data_source_organization.py | 300 ++++++++++++++---- src/bk-user/tests/biz/__init__.py | 10 + src/bk-user/tests/biz/conftest.py | 28 ++ .../biz/test_data_source_organization.py | 225 +++++++++++++ .../test_utils/data_source_organization.py | 60 +++- 6 files changed, 537 insertions(+), 88 deletions(-) create mode 100644 src/bk-user/tests/biz/__init__.py create mode 100644 src/bk-user/tests/biz/conftest.py create mode 100644 src/bk-user/tests/biz/test_data_source_organization.py diff --git a/src/bk-user/bkuser/apis/web/data_source_organization/serializers.py b/src/bk-user/bkuser/apis/web/data_source_organization/serializers.py index 5c6e184f8..17192f434 100644 --- a/src/bk-user/bkuser/apis/web/data_source_organization/serializers.py +++ b/src/bk-user/bkuser/apis/web/data_source_organization/serializers.py @@ -152,7 +152,7 @@ def get_leaders(self, obj: DataSourceUser) -> List[Dict]: class UserUpdateInputSLZ(serializers.Serializer): full_name = serializers.CharField(help_text="姓名") - email = serializers.CharField(help_text="邮箱") + email = serializers.EmailField(help_text="邮箱") phone_country_code = serializers.CharField(help_text="手机国际区号") phone = serializers.CharField(help_text="手机号") logo = serializers.CharField(help_text="用户 Logo", allow_blank=True, required=False, default="") diff --git a/src/bk-user/tests/apis/web/data_source_organization/test_data_source_organization.py b/src/bk-user/tests/apis/web/data_source_organization/test_data_source_organization.py index ff12bb348..bfd4e89db 100644 --- a/src/bk-user/tests/apis/web/data_source_organization/test_data_source_organization.py +++ b/src/bk-user/tests/apis/web/data_source_organization/test_data_source_organization.py @@ -8,14 +8,17 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ +import random +from typing import Any, Dict, List + import pytest from bkuser.apps.data_source.models import DataSourceDepartment, DataSourceUser from django.urls import reverse from rest_framework import status from tests.test_utils.data_source_organization import ( - create_data_source_department, - create_data_source_user, + create_data_source_departments, + create_data_source_users, generate_data_source_username, ) from tests.test_utils.helpers import generate_random_string @@ -24,116 +27,101 @@ @pytest.fixture() -def data_source_user(local_data_source) -> DataSourceUser: - return create_data_source_user(data_source_id=local_data_source.id) +def data_source_user_base_info() -> Dict[str, Any]: + return { + "username": generate_data_source_username(), + "full_name": generate_random_string(), + "email": "test@example.com", + "phone": "13000000000", + } + + +@pytest.fixture() +def data_source_departments(local_data_source) -> List[DataSourceDepartment]: + return create_data_source_departments(data_source=local_data_source) @pytest.fixture() -def data_source_department(local_data_source) -> DataSourceDepartment: - return create_data_source_department(data_source_id=local_data_source.id) +def data_source_users(local_data_source, data_source_departments) -> List[DataSourceUser]: + return create_data_source_users(data_source=local_data_source, departments=data_source_departments) class TestDataSourceUserCreateApi: - def test_create_local_data_source_user(self, api_client, local_data_source): + def test_create_local_data_source_user(self, api_client, local_data_source, data_source_user_base_info): resp = api_client.post( reverse("data_source_user.list_create", kwargs={"id": local_data_source.id}), - data={ - "username": generate_data_source_username(), - "full_name": generate_random_string(), - "email": "test@example.com", - "phone": "13000000000", - }, + data=data_source_user_base_info, ) assert resp.status_code == status.HTTP_200_OK - def test_create_without_username(self, api_client, local_data_source): + def test_create_without_username(self, api_client, local_data_source, data_source_user_base_info): + data_source_user_base_info.pop("username") resp = api_client.post( reverse("data_source_user.list_create", kwargs={"id": local_data_source.id}), - data={"full_name": generate_data_source_username(), "email": "test@example.com", "phone": "13000000000"}, + data=data_source_user_base_info, ) assert resp.status_code == status.HTTP_400_BAD_REQUEST assert "参数校验不通过: username: 该字段是必填项。" in resp.data["message"] - def test_create_without_fullname(self, api_client, local_data_source): + def test_create_without_fullname(self, api_client, local_data_source, data_source_user_base_info): + data_source_user_base_info.pop("full_name") resp = api_client.post( reverse("data_source_user.list_create", kwargs={"id": local_data_source.id}), - data={"username": generate_data_source_username(), "email": "test@example.com", "phone": "13000000000"}, + data=data_source_user_base_info, ) assert resp.status_code == status.HTTP_400_BAD_REQUEST assert "参数校验不通过: full_name: 该字段是必填项。" in resp.data["message"] - def test_create_without_email(self, api_client, local_data_source): + def test_create_without_email(self, api_client, local_data_source, data_source_user_base_info): + data_source_user_base_info.pop("email") resp = api_client.post( reverse("data_source_user.list_create", kwargs={"id": local_data_source.id}), - data={ - "full_name": generate_random_string(), - "username": generate_data_source_username(), - "phone": "13000000000", - }, + data=data_source_user_base_info, ) assert resp.status_code == status.HTTP_400_BAD_REQUEST assert "参数校验不通过: email: 该字段是必填项。" in resp.data["message"] - def test_create_without_phone(self, api_client, local_data_source): + def test_create_without_phone(self, api_client, local_data_source, data_source_user_base_info): + data_source_user_base_info.pop("phone") resp = api_client.post( reverse("data_source_user.list_create", kwargs={"id": local_data_source.id}), - data={ - "username": generate_data_source_username(), - "full_name": generate_random_string(), - "email": "test@example.com", - }, + data=data_source_user_base_info, ) assert resp.status_code == status.HTTP_400_BAD_REQUEST assert "参数校验不通过: phone: 该字段是必填项。" in resp.data["message"] - def test_create_with_invalid_username(self, api_client, local_data_source): + def test_create_with_invalid_username(self, api_client, local_data_source, data_source_user_base_info): + data_source_user_base_info["username"] = ".error_username" resp = api_client.post( reverse("data_source_user.list_create", kwargs={"id": local_data_source.id}), - data={ - "username": ".error_username", - "full_name": generate_random_string(), - "email": "test@example.com", - "phone": "13000000000", - }, + data=data_source_user_base_info, ) assert resp.status_code == status.HTTP_400_BAD_REQUEST assert "不符合 用户名 的命名规范" in resp.data["message"] - def test_create_with_invalid_email(self, api_client, local_data_source): + def test_create_with_invalid_email(self, api_client, local_data_source, data_source_user_base_info): + data_source_user_base_info["email"] = "test.com" resp = api_client.post( reverse("data_source_user.list_create", kwargs={"id": local_data_source.id}), - data={ - "username": generate_data_source_username(), - "full_name": generate_random_string(), - "email": "test.com", - "phone": "13000000000", - }, + data=data_source_user_base_info, ) assert resp.status_code == status.HTTP_400_BAD_REQUEST assert "参数校验不通过: email: 请输入合法的邮件地址。" in resp.data["message"] - def test_create_with_incorrect_length_phone(self, api_client, local_data_source): + def test_create_with_incorrect_length_phone(self, api_client, local_data_source, data_source_user_base_info): + data_source_user_base_info["phone"] = "130" resp = api_client.post( reverse("data_source_user.list_create", kwargs={"id": local_data_source.id}), - data={ - "username": generate_data_source_username(), - "full_name": generate_random_string(), - "email": "test@example.com", - "phone": "130", - }, + data=data_source_user_base_info, ) assert resp.status_code == status.HTTP_400_BAD_REQUEST assert "手机号码解析异常: 手机号 130 长度异常" in resp.data["message"] - def test_create_with_invalid_phone(self, api_client, local_data_source): + def test_create_with_invalid_phone(self, api_client, local_data_source, data_source_user_base_info): + data_source_user_base_info["phone"] = "aaaaaaaaaaa" resp = api_client.post( reverse("data_source_user.list_create", kwargs={"id": local_data_source.id}), - data={ - "username": generate_data_source_username(), - "full_name": generate_random_string(), - "email": "test@example.com", - "phone": "aaaaaaaaaaa", - }, + data=data_source_user_base_info, ) assert resp.status_code == status.HTTP_400_BAD_REQUEST assert "手机号码解析异常: 手机号 aaaaaaaaaaa 解析异常" in resp.data["message"] @@ -143,7 +131,7 @@ class TestDataSourceUserListApi: def test_list(self, api_client, local_data_source): resp = api_client.get(reverse("data_source_user.list_create", kwargs={"id": local_data_source.id})) - assert DataSourceUser.objects.filter(data_source_id=local_data_source.id).count() == resp.data["count"] + assert DataSourceUser.objects.filter(data_source=local_data_source).count() == resp.data["count"] for item in resp.data["results"]: data_source_user = DataSourceUser.objects.filter(id=item["id"]).first() @@ -154,7 +142,8 @@ def test_list(self, api_client, local_data_source): assert data_source_user.email == item["email"] assert data_source_user.departments == item["departments"] - def test_list_with_username(self, api_client, local_data_source, data_source_user): + def test_list_with_username(self, api_client, local_data_source, data_source_users): + data_source_user = random.choice(data_source_users) resp = api_client.get( reverse("data_source_user.list_create", kwargs={"id": local_data_source.id}), data={"username": data_source_user.username}, @@ -162,7 +151,6 @@ def test_list_with_username(self, api_client, local_data_source, data_source_use assert resp.data["count"] == 1 user = resp.data["results"][0] - assert user["username"] == data_source_user.username assert user["full_name"] == data_source_user.full_name assert user["phone"] == data_source_user.phone @@ -173,7 +161,7 @@ class TestDataSourceLeadersListApi: def test_list(self, api_client, local_data_source): resp = api_client.get(reverse("data_source_leader.list", kwargs={"id": local_data_source.id})) - assert DataSourceUser.objects.filter(data_source_id=local_data_source.id).count() == resp.data["count"] + assert DataSourceUser.objects.filter(data_source=local_data_source).count() == resp.data["count"] for item in resp.data["results"]: data_source_user = DataSourceUser.objects.filter(id=item["id"]).first() @@ -181,23 +169,22 @@ def test_list(self, api_client, local_data_source): assert data_source_user.id == item["id"] assert data_source_user.username == item["username"] - def test_list_with_keyword(self, api_client, local_data_source, data_source_user): + def test_list_with_keyword(self, api_client, local_data_source, data_source_users): + data_source_user = random.choice(data_source_users) resp = api_client.get( - reverse("data_source_user.list_create", kwargs={"id": local_data_source.id}), + reverse("data_source_leader.list", kwargs={"id": local_data_source.id}), data={"keyword": data_source_user.username}, ) - assert resp.data["count"] == 1 user = resp.data["results"][0] - - assert user["username"] == data_source_user.username + assert user["username"] in [data_source_user.username, data_source_user.full_name] class TestDataSourceDepartmentsListApi: def test_list(self, api_client, local_data_source): resp = api_client.get(reverse("data_source_department.list", kwargs={"id": local_data_source.id})) - assert DataSourceDepartment.objects.filter(data_source_id=local_data_source.id).count() == resp.data["count"] + assert DataSourceDepartment.objects.filter(data_source=local_data_source).count() == resp.data["count"] for item in resp.data["results"]: data_source_department = DataSourceDepartment.objects.filter(id=item["id"]).first() @@ -205,13 +192,186 @@ def test_list(self, api_client, local_data_source): assert data_source_department.id == item["id"] assert data_source_department.name == item["name"] - def test_list_with_keyword(self, api_client, local_data_source, data_source_department): + def test_list_with_keyword(self, api_client, local_data_source, data_source_departments): + data_source_department = random.choice(data_source_departments) resp = api_client.get( reverse("data_source_department.list", kwargs={"id": local_data_source.id}), - data={"keyword": data_source_department.name}, + data={"name": data_source_department.name}, ) assert resp.data["count"] == 1 department = resp.data["results"][0] - assert department["name"] == data_source_department.name + + +class TestDataSourceUserUpdateApi: + def test_update_local_data_source_user(self, api_client, data_source_users): + data_source_user = random.choice(data_source_users) + resp = api_client.put( + reverse("data_source_user.retrieve_update", kwargs={"id": data_source_user.id}), + data={ + "full_name": data_source_user.full_name, + "email": "test@update.com", + "phone_country_code": data_source_user.phone_country_code, + "phone": data_source_user.phone, + "leader_ids": [], + "department_ids": [], + }, + ) + assert resp.status_code == status.HTTP_204_NO_CONTENT + resp = api_client.get(reverse("data_source_user.retrieve_update", kwargs={"id": data_source_user.id})) + assert resp.data["email"] == "test@update.com" + + def test_update_local_data_source_user_without_full_name(self, api_client, data_source_users): + data_source_user = random.choice(data_source_users) + resp = api_client.put( + reverse("data_source_user.retrieve_update", kwargs={"id": data_source_user.id}), + data={ + "email": data_source_user.email, + "phone_country_code": data_source_user.phone_country_code, + "phone": data_source_user.phone, + "leader_ids": [], + "department_ids": [], + }, + ) + assert resp.status_code == status.HTTP_400_BAD_REQUEST + assert "参数校验不通过: full_name: 该字段是必填项。" in resp.data["message"] + + def test_update_local_data_source_user_without_email(self, api_client, data_source_users): + data_source_user = random.choice(data_source_users) + resp = api_client.put( + reverse("data_source_user.retrieve_update", kwargs={"id": data_source_user.id}), + data={ + "full_name": data_source_user.full_name, + "phone_country_code": data_source_user.phone_country_code, + "phone": data_source_user.phone, + "leader_ids": [], + "department_ids": [], + }, + ) + assert resp.status_code == status.HTTP_400_BAD_REQUEST + assert "参数校验不通过: email: 该字段是必填项。" in resp.data["message"] + + def test_update_local_data_source_user_without_phone_country_code(self, api_client, data_source_users): + data_source_user = random.choice(data_source_users) + resp = api_client.put( + reverse("data_source_user.retrieve_update", kwargs={"id": data_source_user.id}), + data={ + "full_name": data_source_user.full_name, + "email": data_source_user.email, + "phone": data_source_user.phone, + "leader_ids": [], + "department_ids": [], + }, + ) + assert resp.status_code == status.HTTP_400_BAD_REQUEST + assert "参数校验不通过: phone_country_code: 该字段是必填项。" in resp.data["message"] + + def test_update_local_data_source_user_without_phone(self, api_client, data_source_users): + data_source_user = random.choice(data_source_users) + resp = api_client.put( + reverse("data_source_user.retrieve_update", kwargs={"id": data_source_user.id}), + data={ + "full_name": data_source_user.full_name, + "email": data_source_user.email, + "phone_country_code": data_source_user.phone_country_code, + "leader_ids": [], + "department_ids": [], + }, + ) + assert resp.status_code == status.HTTP_400_BAD_REQUEST + assert "参数校验不通过: phone: 该字段是必填项。" in resp.data["message"] + + def test_update_local_data_source_user_without_leader_ids(self, api_client, data_source_users): + data_source_user = random.choice(data_source_users) + resp = api_client.put( + reverse("data_source_user.retrieve_update", kwargs={"id": data_source_user.id}), + data={ + "full_name": data_source_user.full_name, + "email": data_source_user.email, + "phone_country_code": data_source_user.phone_country_code, + "phone": data_source_user.phone, + "department_ids": [], + }, + ) + assert resp.status_code == status.HTTP_400_BAD_REQUEST + assert "参数校验不通过: leader_ids: 该字段是必填项。" in resp.data["message"] + + def test_update_local_data_source_user_without_department_ids(self, api_client, data_source_users): + data_source_user = random.choice(data_source_users) + resp = api_client.put( + reverse("data_source_user.retrieve_update", kwargs={"id": data_source_user.id}), + data={ + "full_name": data_source_user.full_name, + "email": data_source_user.email, + "phone_country_code": data_source_user.phone_country_code, + "phone": data_source_user.phone, + "leader_ids": [], + }, + ) + assert resp.status_code == status.HTTP_400_BAD_REQUEST + assert "参数校验不通过: department_ids: 该字段是必填项。" in resp.data["message"] + + def test_update_local_data_source_user_with_invalid_email(self, api_client, data_source_users): + data_source_user = random.choice(data_source_users) + + resp = api_client.put( + reverse("data_source_user.retrieve_update", kwargs={"id": data_source_user.id}), + data={ + "full_name": data_source_user.full_name, + "email": "test.com", + "phone_country_code": data_source_user.phone_country_code, + "phone": data_source_user.phone, + "leader_ids": [], + "department_ids": [], + }, + ) + assert resp.status_code == status.HTTP_400_BAD_REQUEST + assert "参数校验不通过: email: 请输入合法的邮件地址。" in resp.data["message"] + + def test_update_local_data_source_user_with_invalid_phone(self, api_client, data_source_users): + data_source_user = random.choice(data_source_users) + + resp = api_client.put( + reverse("data_source_user.retrieve_update", kwargs={"id": data_source_user.id}), + data={ + "full_name": data_source_user.full_name, + "email": data_source_user.email, + "phone_country_code": data_source_user.phone_country_code, + "phone": "aaaaaaaaaaa", + "leader_ids": [], + "department_ids": [], + }, + ) + assert resp.status_code == status.HTTP_400_BAD_REQUEST + assert "手机号码解析异常: 手机号 aaaaaaaaaaa 解析异常" in resp.data["message"] + + def test_update_local_data_source_user_with_incorrect_length_phone(self, api_client, data_source_users): + data_source_user = random.choice(data_source_users) + + resp = api_client.put( + reverse("data_source_user.retrieve_update", kwargs={"id": data_source_user.id}), + data={ + "full_name": data_source_user.full_name, + "email": data_source_user.email, + "phone_country_code": data_source_user.phone_country_code, + "phone": "130", + "leader_ids": [], + "department_ids": [], + }, + ) + assert resp.status_code == status.HTTP_400_BAD_REQUEST + assert "手机号码解析异常: 手机号 130 长度异常" in resp.data["message"] + + +class TestDataSourceUserRetrieveApi: + def test_retrieve(self, api_client, data_source_users): + data_source_user = random.choice(data_source_users) + resp = api_client.get(reverse("data_source_user.retrieve_update", kwargs={"id": data_source_user.id})) + + assert resp.data["username"] == data_source_user.username + assert resp.data["full_name"] == data_source_user.full_name + assert resp.data["email"] == data_source_user.email + assert resp.data["phone_country_code"] == data_source_user.phone_country_code + assert resp.data["phone"] == data_source_user.phone + assert resp.data["logo"] == data_source_user.logo diff --git a/src/bk-user/tests/biz/__init__.py b/src/bk-user/tests/biz/__init__.py new file mode 100644 index 000000000..1060b7bf4 --- /dev/null +++ b/src/bk-user/tests/biz/__init__.py @@ -0,0 +1,10 @@ +# -*- coding: utf-8 -*- +""" +TencentBlueKing is pleased to support the open source community by making 蓝鲸智云-用户管理(Bk-User) available. +Copyright (C) 2017-2021 THL A29 Limited, a Tencent company. All rights reserved. +Licensed under the MIT License (the "License"); you may not use this file except in compliance with the License. +You may obtain a copy of the License at http://opensource.org/licenses/MIT +Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +specific language governing permissions and limitations under the License. +""" diff --git a/src/bk-user/tests/biz/conftest.py b/src/bk-user/tests/biz/conftest.py new file mode 100644 index 000000000..a8168b31b --- /dev/null +++ b/src/bk-user/tests/biz/conftest.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- +""" +TencentBlueKing is pleased to support the open source community by making 蓝鲸智云-用户管理(Bk-User) available. +Copyright (C) 2017-2021 THL A29 Limited, a Tencent company. All rights reserved. +Licensed under the MIT License (the "License"); you may not use this file except in compliance with the License. +You may obtain a copy of the License at http://opensource.org/licenses/MIT +Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +specific language governing permissions and limitations under the License. +""" +import pytest +from bkuser.apps.data_source.constants import DataSourcePluginEnum +from bkuser.apps.data_source.models import DataSource + +from tests.test_utils.helpers import generate_random_string + + +@pytest.fixture() +def local_data_source(default_tenant: str) -> DataSource: + """ + 生成测试数据源 + """ + + local_data_source, _ = DataSource.objects.get_or_create( + name=generate_random_string(), + defaults={"owner_tenant_id": default_tenant, "plugin_id": DataSourcePluginEnum.LOCAL.value}, + ) + return local_data_source diff --git a/src/bk-user/tests/biz/test_data_source_organization.py b/src/bk-user/tests/biz/test_data_source_organization.py new file mode 100644 index 000000000..9d2ebeef5 --- /dev/null +++ b/src/bk-user/tests/biz/test_data_source_organization.py @@ -0,0 +1,225 @@ +# -*- coding: utf-8 -*- +""" +TencentBlueKing is pleased to support the open source community by making 蓝鲸智云-用户管理(Bk-User) available. +Copyright (C) 2017-2021 THL A29 Limited, a Tencent company. All rights reserved. +Licensed under the MIT License (the "License"); you may not use this file except in compliance with the License. +You may obtain a copy of the License at http://opensource.org/licenses/MIT +Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +specific language governing permissions and limitations under the License. +""" +import random +from typing import List + +import pytest +from bkuser.apps.data_source.models import ( + DataSourceDepartment, + DataSourceDepartmentUserRelation, + DataSourceUser, + DataSourceUserLeaderRelation, +) +from bkuser.biz.data_source_organization import ( + DataSourceOrganizationHandler, + DataSourceUserBaseInfo, + DataSourceUserDepartmentInfo, + DataSourceUserEditableBaseInfo, + DataSourceUserLeaderInfo, + DataSourceUserRelationInfo, +) +from django.db import transaction + +from tests.test_utils.data_source_organization import ( + create_data_source_departments, + create_data_source_users, + generate_data_source_username, +) +from tests.test_utils.helpers import generate_random_string + +pytestmark = pytest.mark.django_db + + +@pytest.fixture() +def base_user_info() -> DataSourceUserBaseInfo: + return DataSourceUserBaseInfo( + username=generate_data_source_username(), + full_name=generate_random_string(), + email="test@example.com", + phone="13000000000", + phone_country_code="86", + ) + + +@pytest.fixture() +def editable_base_user_info() -> DataSourceUserEditableBaseInfo: + return DataSourceUserEditableBaseInfo( + full_name=generate_random_string(), + email="test@example.com", + phone="13000000000", + phone_country_code="86", + logo="", + ) + + +@pytest.fixture() +def relation_info() -> DataSourceUserRelationInfo: + return DataSourceUserRelationInfo(department_ids=[11, 22], leader_ids=[33, 44]) + + +@pytest.fixture() +def data_source_departments(local_data_source) -> List[DataSourceDepartment]: + return create_data_source_departments(data_source=local_data_source) + + +@pytest.fixture() +def data_source_users(local_data_source, data_source_departments) -> List[DataSourceUser]: + return create_data_source_users(data_source=local_data_source, departments=data_source_departments) + + +class TestDataSourceOrganizationHandler: + def test_create_user(self, local_data_source, base_user_info, relation_info): + # 创建用户 + with transaction.atomic(): + user_id = DataSourceOrganizationHandler.create_user(local_data_source, base_user_info, relation_info) + + # 验证用户是否创建成功 + user = DataSourceUser.objects.get(id=user_id) + assert user.data_source == local_data_source + assert user.username == base_user_info.username + + # 验证用户-部门关系是否创建成功 + department_ids = relation_info.department_ids + relations = DataSourceDepartmentUserRelation.objects.filter(user=user) + assert set(relations.values_list("department_id", flat=True)) == set(department_ids) + + # 验证用户-上级关系是否创建成功 + leader_ids = relation_info.leader_ids + relations = DataSourceUserLeaderRelation.objects.filter(user=user) + assert set(relations.values_list("leader_id", flat=True)) == set(leader_ids) + + def test_update_user_department_relations(self, data_source_users): + # 创建用户 + user = random.choice(data_source_users) + + # 创建用户-部门关系 + department_ids = [1, 2, 3] + relations = [ + DataSourceDepartmentUserRelation(department_id=department_id, user=user) + for department_id in department_ids + ] + DataSourceDepartmentUserRelation.objects.bulk_create(relations) + + # 更新用户-部门关系 + new_department_ids = [2, 3, 4] + DataSourceOrganizationHandler.update_user_department_relations(user, new_department_ids) + + # 验证用户-部门关系是否更新成功 + relations = DataSourceDepartmentUserRelation.objects.filter(user=user) + assert set(relations.values_list("department_id", flat=True)) == set(new_department_ids) + + def test_update_user_leader_relations(self, data_source_users): + user = random.choice(data_source_users) + # 创建用户-上级关系 + leader_ids = [1, 2, 3] + relations = [DataSourceUserLeaderRelation(leader_id=leader_id, user=user) for leader_id in leader_ids] + DataSourceUserLeaderRelation.objects.bulk_create(relations) + + # 更新用户-上级关系 + new_leader_ids = [2, 3, 4] + DataSourceOrganizationHandler.update_user_leader_relations(user, new_leader_ids) + + # 验证用户-上级关系是否更新成功 + relations = DataSourceUserLeaderRelation.objects.filter(user=user) + assert set(relations.values_list("leader_id", flat=True)) == set(new_leader_ids) + + def test_update_user(self, data_source_users, editable_base_user_info, relation_info): + user = random.choice(data_source_users) + + # 更新用户 + with transaction.atomic(): + DataSourceOrganizationHandler.update_user(user, editable_base_user_info, relation_info) + + # 验证用户基础信息是否更新成功 + assert user.full_name == editable_base_user_info.full_name + assert user.email == editable_base_user_info.email + assert user.phone == editable_base_user_info.phone + assert user.phone_country_code == editable_base_user_info.phone_country_code + assert user.logo == editable_base_user_info.logo + + # 验证用户-部门关系是否更新成功 + department_ids = relation_info.department_ids + relations = DataSourceDepartmentUserRelation.objects.filter(user=user) + assert set(relations.values_list("department_id", flat=True)) == set(department_ids) + + # 验证用户-上级关系是否更新成功 + leader_ids = relation_info.leader_ids + relations = DataSourceUserLeaderRelation.objects.filter(user=user) + assert set(relations.values_list("leader_id", flat=True)) == set(leader_ids) + + def test_list_department_info_by_id(self, data_source_departments): + # 获取部门信息 + department_ids = [dept.id for dept in data_source_departments] + department_infos = DataSourceOrganizationHandler.list_department_info_by_id(department_ids) + + # 验证部门信息是否正确 + assert len(department_infos) == len(department_ids) + for department_info in department_infos: + assert department_info.id in department_ids + assert department_info.name == DataSourceDepartment.objects.get(id=department_info.id).name + + def test_get_user_department_ids_map(self, data_source_users): + # 获取用户-部门id映射 + user_ids = [user.id for user in data_source_users] + user_department_ids_map = DataSourceOrganizationHandler.get_user_department_ids_map(user_ids) + + # 验证用户-部门id映射是否正确 + assert set(user_department_ids_map.keys()) == set(user_ids) + for user_id, department_ids in user_department_ids_map.items(): + assert set(department_ids) == set( + DataSourceDepartmentUserRelation.objects.filter(user_id=user_id).values_list( + "department_id", flat=True + ) + ) + + def test_get_user_departments_map_by_user_id(self, data_source_users): + # 获取用户-所有归属部门信息 + user_ids = [user.id for user in data_source_users] + user_departments_map = DataSourceOrganizationHandler.get_user_departments_map_by_user_id(user_ids) + + # 验证用户-所有归属部门信息是否正确 + assert set(user_departments_map.keys()) == set(user_ids) + for user_id, department_infos in user_departments_map.items(): + assert list(department_infos) == [ + DataSourceUserDepartmentInfo(id=dept["department_id"], name=dept["department__name"]) + for dept in DataSourceDepartmentUserRelation.objects.filter(user_id=user_id).values( + "department_id", "department__name" + ) + ] + + def test_get_user_leader_ids_map(self, data_source_users): + # 获取用户-所有上级id关系映射 + # 由于首个用户为其余用户的上级,因此需要将该用户剔除 + user_ids = [user.id for user in data_source_users][1:] + user_leader_ids_map = DataSourceOrganizationHandler.get_user_leader_ids_map(user_ids) + + # 验证用户-所有上级id关系映射是否正确 + assert set(user_leader_ids_map.keys()) == set(user_ids) + for user_id, leader_ids in user_leader_ids_map.items(): + assert set(leader_ids) == set( + DataSourceUserLeaderRelation.objects.filter(user_id=user_id).values_list("leader_id", flat=True) + ) + + def test_get_user_leaders_map_by_user_id(self, data_source_users): + # 获取用户-所有上级信息数据 + users = data_source_users[1:] + user_ids = [user.id for user in users] + user_leaders_map = DataSourceOrganizationHandler.get_user_leaders_map_by_user_id(user_ids) + # 验证用户-所有上级信息数据是否正确 + assert set(user_leaders_map.keys()) == set(user_ids) + + for user_id, leader_infos in user_leaders_map.items(): + assert list(leader_infos) == [ + DataSourceUserLeaderInfo(id=leader["leader_id"], username=leader["leader__username"]) + for leader in DataSourceUserLeaderRelation.objects.filter(user_id=user_id).values( + "leader_id", "leader__username" + ) + ] diff --git a/src/bk-user/tests/test_utils/data_source_organization.py b/src/bk-user/tests/test_utils/data_source_organization.py index 53cb11c4e..9fdc5e51d 100644 --- a/src/bk-user/tests/test_utils/data_source_organization.py +++ b/src/bk-user/tests/test_utils/data_source_organization.py @@ -11,8 +11,14 @@ import random import re import string +from typing import List -from bkuser.apps.data_source.models import DataSourceDepartment, DataSourceUser, DataSourceUserLeaderRelation +from bkuser.apps.data_source.models import ( + DataSourceDepartment, + DataSourceDepartmentUserRelation, + DataSourceUser, + DataSourceUserLeaderRelation, +) from tests.test_utils.helpers import generate_random_string @@ -26,24 +32,44 @@ def generate_data_source_username(): return username -def create_data_source_user(data_source_id) -> DataSourceUser: - return DataSourceUser.objects.create( - full_name=generate_random_string(), - username=generate_random_string(), - phone="13000000000", - data_source_id=data_source_id, - ) +def create_data_source_departments(data_source) -> List[DataSourceDepartment]: + """ + 创建数据源部门 + """ + departments = [DataSourceDepartment(data_source=data_source, name=generate_random_string()) for _ in range(10)] + DataSourceDepartment.objects.bulk_create(departments) + return list(DataSourceDepartment.objects.filter(data_source=data_source)) -def create_data_source_department(data_source_id) -> DataSourceDepartment: - return DataSourceDepartment.objects.create( - name=generate_random_string(), data_source_id=data_source_id - ) +def create_data_source_users(data_source, departments) -> List[DataSourceUser]: + """ + 创建数据源用户,关联首个用户为上级,随机关联部门 + """ + users = [ + DataSourceUser( + full_name=generate_random_string(), + username=generate_random_string(), + email=f"{generate_random_string()}@qq.com", + phone="13123456789", + data_source=data_source, + ) + for _ in range(10) + ] + DataSourceUser.objects.bulk_create(users) + data_source_users = list(DataSourceUser.objects.filter(data_source=data_source)) + # 添加用户-上级关系 + user_leader_relations = [ + DataSourceUserLeaderRelation(user=data_source_user, leader=data_source_users[0]) + for data_source_user in data_source_users[1:] + ] + DataSourceUserLeaderRelation.objects.bulk_create(user_leader_relations) + # 添加部门-用户关系 + user_department_relations = [ + DataSourceDepartmentUserRelation(user=data_source_user, department=random.choice(departments)) + for data_source_user in data_source_users + ] + DataSourceDepartmentUserRelation.objects.bulk_create(user_department_relations) - -def create_data_source_user_leader(user) -> DataSourceUser: - leader = create_data_source_user(data_source_id=user.date_source_id) - DataSourceUserLeaderRelation.objects.create(user=user, leader=leader) - return leader + return data_source_users