From 1f0cd1bf5b35dca096bd4b08e6969cf6ba252c28 Mon Sep 17 00:00:00 2001
From: Sergei Kliuikov <onegreyonewhite@mail.ru>
Date: Tue, 19 Nov 2024 17:00:32 +1000
Subject: [PATCH] Release 5.11.12

---
 doc/locale/ru/LC_MESSAGES/backend.po | 46 ++++++++++++++++--------
 pyproject.toml                       |  5 ++-
 setup.py                             |  1 +
 test_src/test_proj/tests.py          |  8 ++---
 vstutils/__init__.py                 |  2 +-
 vstutils/api/auth.py                 | 11 +++---
 vstutils/api/endpoint.py             | 16 ++++-----
 vstutils/api/permissions.py          |  2 +-
 vstutils/oauth2/authentication.py    |  8 ++++-
 vstutils/settings.py                 |  1 +
 vstutils/tests.py                    | 52 +++++++++++++++++++++++++++-
 11 files changed, 116 insertions(+), 36 deletions(-)

diff --git a/doc/locale/ru/LC_MESSAGES/backend.po b/doc/locale/ru/LC_MESSAGES/backend.po
index a2161ffa..2795f975 100644
--- a/doc/locale/ru/LC_MESSAGES/backend.po
+++ b/doc/locale/ru/LC_MESSAGES/backend.po
@@ -7,7 +7,7 @@ msgid ""
 msgstr ""
 "Project-Id-Version: VST Utils 5.0.4\n"
 "Report-Msgid-Bugs-To: \n"
-"POT-Creation-Date: 2024-11-05 03:35+0000\n"
+"POT-Creation-Date: 2024-11-19 06:27+0000\n"
 "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
 "Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
 "Language-Team: LANGUAGE <LL@li.org>\n"
@@ -2589,32 +2589,32 @@ msgid ""
 " the link points to an existing resource in the interface to avoid 404 "
 "errors."
 msgstr ""
-"**link** *(необязательно)*: URL для другой страницы. Если "
-"указан, будет отображаться текст как ссылка. Если не указан, будет "
-"отображаться просто текст. Значение должно быть совместимым с "
-"`параметром метода push Vue Router <https://router.vuejs.org/api/interfaces/Router.html#push>`_. "
+"**link** *(необязательно)*: URL для другой страницы. Если указан, будет "
+"отображаться текст как ссылка. Если не указан, будет отображаться просто "
+"текст. Значение должно быть совместимым с `параметром метода push Vue "
+"Router <https://router.vuejs.org/api/interfaces/Router.html#push>`_. "
 "Убедитесь, что ссылка указывает на существующий ресурс в интерфейсе для "
 "избежания ошибок 404."
 
-#: of vstutils.api.fields.RouterLinkField:11
+#: of vstutils.api.fields.RouterLinkField:14
 msgid ""
 "**label**: The text to display. This is required whether or not a link is"
 " provided."
 msgstr "**label**: Текст для отображения. Обязательное поле."
 
-#: of vstutils.api.fields.RouterLinkField:14
+#: of vstutils.api.fields.RouterLinkField:17
 msgid "For simpler use cases, you might consider using :class:`.FkField`."
 msgstr "Для простых случаев использования см. :class:`.FkField`."
 
-#: of vstutils.api.fields.RouterLinkField:16
+#: of vstutils.api.fields.RouterLinkField:19
 msgid "**Examples:**"
 msgstr "**Примеры:**"
 
-#: of vstutils.api.fields.RouterLinkField:18
+#: of vstutils.api.fields.RouterLinkField:21
 msgid "*Using a model class method:*"
 msgstr "*Использование метода класса модели:*"
 
-#: of vstutils.api.fields.RouterLinkField:42
+#: of vstutils.api.fields.RouterLinkField:45
 msgid ""
 "In this example, the ``get_link`` method in the ``Author`` model returns "
 "a dictionary containing the ``link`` and ``label``. The "
@@ -2626,11 +2626,11 @@ msgstr ""
 "метод для отображения имени автора как ссылку на страницу с "
 "подробностями."
 
-#: of vstutils.api.fields.RouterLinkField:45
+#: of vstutils.api.fields.RouterLinkField:50
 msgid "*Using a custom field class:*"
 msgstr "*Использование пользовательского класса поля:*"
 
-#: of vstutils.api.fields.RouterLinkField:70
+#: of vstutils.api.fields.RouterLinkField:75
 msgid ""
 "In this example, we create a custom field ``AuthorLinkField`` by "
 "subclassing ``RouterLinkField``. We override the ``to_representation`` "
@@ -2645,7 +2645,7 @@ msgstr ""
 "используется в вьюсете для отображения имени автора как кликабельной "
 "ссылки."
 
-#: of vstutils.api.fields.RouterLinkField:75
+#: of vstutils.api.fields.RouterLinkField:81
 msgid ""
 "The field is read-only and is intended to display dynamic links based on "
 "the instance data."
@@ -2653,7 +2653,7 @@ msgstr ""
 "Поле является только для чтения и предназначено для отображения "
 "динамических ссылок на основе данных экземпляра."
 
-#: of vstutils.api.fields.RouterLinkField:76
+#: of vstutils.api.fields.RouterLinkField:82
 msgid ""
 "If the ``link`` key is omitted or ``None``, the field will display the "
 "``label`` as plain text instead of a link."
@@ -2661,7 +2661,7 @@ msgstr ""
 "Если ключ ``link`` отсутствует или имеет значение ``None``, поле "
 "отображает текст как обычный текст вместо ссылки."
 
-#: of vstutils.api.fields.RouterLinkField:79
+#: of vstutils.api.fields.RouterLinkField:86
 msgid ""
 "Always ensure that the ``link`` provided points to a valid route within "
 "your application to prevent users from encountering 404 errors."
@@ -5252,6 +5252,18 @@ msgstr ""
 "Делает транзакционный bulk-запрос и проверяет код состояния (200 по "
 "умолчанию)"
 
+#: ../../docstring of vstutils.tests.BaseTestCase.client_token_app_id:1
+msgid "oAuth2 client id"
+msgstr "ID клиента для тестов в oAuth2"
+
+#: ../../docstring of vstutils.tests.BaseTestCase.client_token_grant_type:1
+msgid "oAuth2 grant type"
+msgstr "Тип авторизации в oAuth2"
+
+#: ../../docstring of vstutils.tests.BaseTestCase.client_token_scopes:1
+msgid "oAuth2 scopes"
+msgstr ""
+
 #: of vstutils.tests.BaseTestCase.details_test:1
 msgid ""
 "Test for get details of model. If you setup additional named arguments, "
@@ -5493,6 +5505,10 @@ msgstr ""
 msgid "Simple function which returns uuid1 string."
 msgstr "Простая функция, возвращающая строку uuid1."
 
+#: of vstutils.oauth2.authorization_server.AuthorizationServer:1
+msgid "oAuth2 server class"
+msgstr "Класс авторизации для oAuth2"
+
 #: ../../docstring of vstutils.tests.BaseTestCase.std_codes:1
 msgid ""
 "Default http status codes for different http methods. Uses in "
diff --git a/pyproject.toml b/pyproject.toml
index ebb4f3f9..e854036f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -121,6 +121,9 @@ exclude_lines = [
 ]
 
 [tool.bandit]
+exclude_dirs = [
+    'vstutils/tests.py',
+]
 skips = [
     "B403",
     "B404",
@@ -132,7 +135,7 @@ skips = [
 ]
 
 [tool.mypy]
-python_version = 3.8
+python_version = "3.10"
 #strict = true
 allow_redefinition = true
 check_untyped_defs = true
diff --git a/setup.py b/setup.py
index eb397bbc..225c286f 100644
--- a/setup.py
+++ b/setup.py
@@ -23,6 +23,7 @@
         'vstutils.api.endpoint',
         'vstutils.api.validators',
         'vstutils.api.actions',
+        'vstutils.oauth2.authentication',
         'vstutils.models.base',
         'vstutils.models.queryset',
         'vstutils.models.cent_notify',
diff --git a/test_src/test_proj/tests.py b/test_src/test_proj/tests.py
index 808ffc03..b75923af 100644
--- a/test_src/test_proj/tests.py
+++ b/test_src/test_proj/tests.py
@@ -936,7 +936,7 @@ def test_users_api(self):
             password2='12345'
         )
         user_get_request = {"method": "get", "path": ['user', 'profile']}
-        self.client.force_login(self.user)
+        self._login()
         results = self.bulk([
             {"method": "post", "path": ['user', 'profile', 'change_password'], "data": i}
             for i in (invalid_old_password, not_identical_passwords, update_password)
@@ -2525,7 +2525,6 @@ def test_simple_queries(self):
         self.assertEqual('get', response[0]['method'])
         self.assertEqual('/api/v1/user/1/', response[0]['path'])
         self.assertEqual(200, response[0]['status'])
-        # self.assertEqual('v1', response[0]['version'])
 
         expected_user = user_from_db_to_user_from_api_detail(user1)
         actual_user = response[0]['data']
@@ -3861,12 +3860,13 @@ class FieldChoices(BaseEnum):
 
     @override_settings(SESSION_ENGINE='django.contrib.sessions.backends.db')
     def test_hierarchy(self):
+        # self.client_oauth_session = False
         Host.objects.all().delete()
         HostGroup.objects.all().delete()
         bulk_data = list(self.objects_bulk_data)
         results = self.bulk(bulk_data)
-        for result in results:
-            self.assertEqual(result['status'], 201, result)
+        for num, result in enumerate(results):
+            self.assertEqual(result['status'], 201, f'Attempt: {num}, {result}')
             del result
         self._check_subhost(results[0]['data']['id'], name='a')
         self._check_subhost(
diff --git a/vstutils/__init__.py b/vstutils/__init__.py
index a35b8372..4d1e00eb 100644
--- a/vstutils/__init__.py
+++ b/vstutils/__init__.py
@@ -1,2 +1,2 @@
 # pylint: disable=django-not-available
-__version__: str = '5.11.11'
+__version__: str = '5.11.12'
diff --git a/vstutils/api/auth.py b/vstutils/api/auth.py
index 8869cced..d8391bd7 100644
--- a/vstutils/api/auth.py
+++ b/vstutils/api/auth.py
@@ -2,7 +2,7 @@
 from copy import deepcopy
 
 import pyotp
-from django.contrib.auth import get_user_model, update_session_auth_hash
+from django.contrib.auth import get_user_model, HASH_SESSION_KEY
 from django.contrib.auth.hashers import make_password
 from django.contrib.auth.password_validation import validate_password
 from django.contrib.auth.models import AbstractUser
@@ -131,6 +131,7 @@ class ChangePasswordSerializer(BaseSerializer):
     password = fields.PasswordField(required=True, label='New password')
     password2 = fields.PasswordField(required=True, label='Confirm new password')
 
+    @transaction.atomic
     def update(self, instance, validated_data):
         if not instance.check_password(validated_data['old_password']):
             raise exceptions.AuthenticationFailed()
@@ -138,9 +139,9 @@ def update(self, instance, validated_data):
             raise exceptions.ValidationError(
                 translate("New passwords values are not equal.")
             )
-        validate_password(validated_data['password'])
+        validate_password(validated_data['password'], user=instance)
         instance.set_password(validated_data['password'])
-        instance.save()
+        instance.save(update_fields=['password'])
         return instance
 
     def to_representation(self, instance):
@@ -307,7 +308,9 @@ def change_password(self, request: drf_request.Request, *args, **kwargs):
         serializer = self.get_serializer(user, data=request.data)
         serializer.is_valid(raise_exception=True)
         serializer.save()
-        update_session_auth_hash(request, user)
+        if hasattr(user, "get_session_auth_hash") and request.user == user:
+            request.session[HASH_SESSION_KEY] = user.get_session_auth_hash()
+            request.session.save()
         return responses.HTTP_201_CREATED(serializer.data)
 
     @deco.action(['get', 'put'], detail=True, permission_classes=(ChangePasswordPermission,))
diff --git a/vstutils/api/endpoint.py b/vstutils/api/endpoint.py
index e7dd4a30..2bc605a9 100644
--- a/vstutils/api/endpoint.py
+++ b/vstutils/api/endpoint.py
@@ -374,16 +374,16 @@ def original_environ_data(self, request: BulkRequestType, *args) -> _t.Dict:
             value = get_environ(env_var, None)
             if value:
                 kwargs[env_var] = str(value)
+
         if request.user.is_authenticated:
-            if isinstance(request.successful_authenticator, SessionAuthentication):
-                kwargs['HTTP_COOKIE'] = str(request.META.get('HTTP_COOKIE'))
-            elif isinstance(request.successful_authenticator, (
-                BasicAuthentication,
-                TokenAuthentication,
-                JWTBearerTokenAuthentication,
-            )):
-                kwargs['HTTP_AUTHORIZATION'] = str(request.META.get('HTTP_AUTHORIZATION'))
             kwargs['user'] = request.user
+
+        if cookies := get_environ('HTTP_COOKIE'):
+            kwargs['HTTP_COOKIE'] = str(cookies)
+
+        if auth_header := get_environ('HTTP_AUTHORIZATION'):
+            kwargs['HTTP_AUTHORIZATION'] = str(auth_header)
+
         kwargs['language'] = getattr(request, 'language', None)
         kwargs['session'] = getattr(request, 'session', None)
         kwargs['notificator'] = getattr(request, 'notificator', None)
diff --git a/vstutils/api/permissions.py b/vstutils/api/permissions.py
index 33ae15a1..1062abe7 100644
--- a/vstutils/api/permissions.py
+++ b/vstutils/api/permissions.py
@@ -26,7 +26,7 @@ def has_permission(self, request, view):
 class SuperUserPermission(IsAuthenticatedOpenApiRequest):
 
     def has_permission(self, request, view):
-        if request.user.is_staff or request.method in permissions.SAFE_METHODS:
+        if request.user.is_staff or request.user.is_superuser or request.method in permissions.SAFE_METHODS:
             # pylint: disable=bad-super-call
             return super(IsAuthenticatedOpenApiRequest, self).has_permission(request, view)
         with raise_context():
diff --git a/vstutils/oauth2/authentication.py b/vstutils/oauth2/authentication.py
index 43ac5b2b..c1469585 100644
--- a/vstutils/oauth2/authentication.py
+++ b/vstutils/oauth2/authentication.py
@@ -45,8 +45,14 @@ def _get_request_token(request: "Request"):
         raise AuthenticationFailed() from exc
 
 
+def _get_session_store():  # nocv
+    # We have to mock this method in tests
+    # because performance preferred
+    return SESSION_STORE
+
+
 def get_session(session_key):
-    session = SESSION_STORE(session_key)
+    session = _get_session_store()(session_key)
     session._from_jwt = True  # pylint: disable=protected-access
     return session
 
diff --git a/vstutils/settings.py b/vstutils/settings.py
index 96696a8e..72df570f 100644
--- a/vstutils/settings.py
+++ b/vstutils/settings.py
@@ -1656,6 +1656,7 @@ class OauthServerClientConfig(_t.TypedDict):
     for storage_name in filter('staticfiles'.__ne__, STORAGES):
         STORAGES[storage_name] = {"BACKEND": 'django.core.files.storage.InMemoryStorage'}
     CENTRIFUGO_CLIENT_KWARGS = {}
+    OAUTH_SERVER_TOKEN_EXPIRES_IN = 60 * 60
     try:
         __import__('pysqlite3')
         sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')  # nocv
diff --git a/vstutils/tests.py b/vstutils/tests.py
index ad102640..41fef8d3 100644
--- a/vstutils/tests.py
+++ b/vstutils/tests.py
@@ -5,14 +5,18 @@
 import os  # noqa: F401
 import uuid
 import warnings
+from time import time
+from importlib import import_module
 from unittest.mock import patch, Mock
 import json  # noqa: F401
 
 import ormsgpack
+from authlib.jose import jwt
 from django.apps import apps
 from django.http import StreamingHttpResponse
 from django.db import transaction, models as django_models
 from django.core.exceptions import BadRequest
+from django.contrib.sessions.backends.base import SessionBase
 from django.conf import settings
 from django.test import TestCase, override_settings  # noqa: F401
 from django.contrib.auth import get_user_model
@@ -21,12 +25,16 @@
 
 from .utils import raise_context_decorator_with_default
 from .api.renderers import ORJSONRenderer
+from .oauth2.jwk import jwk_set
 
 User = get_user_model()
 
 BulkDataType = _t.Union[_t.List[_t.Dict[_t.Text, _t.Any]], str, bytes, bytearray]
 ApiResultType = _t.Union[BulkDataType, _t.Dict, _t.Sequence[BulkDataType]]
 
+patched_get_session = patch("vstutils.oauth2.authentication._get_session_store").start()
+patched_get_session.side_effect = lambda: import_module(settings.SESSION_ENGINE).SessionStore
+
 
 class BaseTestCase(TestCase):
     """
@@ -34,6 +42,20 @@ class BaseTestCase(TestCase):
     """
     server_name = 'vstutilstestserver'
 
+    #: oAuth2 server class
+    server_class = import_string(settings.OAUTH_SERVER_CLASS)
+
+    #: oAuth2 client id
+    client_token_app_id = 'simple-client-id'
+
+    #: oAuth2 grant type
+    client_token_grant_type = 'password'
+
+    #: oAuth2 scopes
+    client_token_scopes = 'openid read write'
+
+    client_oauth_session = True
+
     #: Attribute with default project models module.
     models = None
 
@@ -93,15 +115,43 @@ def _create_user(self, is_super_user=True, **kwargs):
         user.data = {'username': username, 'password': password}
         return user
 
+    def get_oauth2_server(self):
+        return self.server_class()
+
+    def generate_token_for_session(self, session: SessionBase):
+        oauth_server = self.get_oauth2_server()
+        client = oauth_server.query_client(self.client_token_app_id)
+        payload = {
+            'iss': settings.OAUTH_SERVER_ISSUER,
+            'aud': client.get_client_id(),
+            'client_id': client.get_client_id(),
+            'jti': str(session.session_key),
+            'sub': str(session.get('user_id', None)),
+            'scope': self.client_token_scopes,
+            'exp': int(time()) + 3600,
+            'iat': int(time()),
+        }
+        header = {
+            'alg': settings.OAUTH_SERVER_JWT_ALG,
+            'typ': 'at+jwt'
+        }
+        return jwt.encode(header, payload, key=jwk_set).decode('utf-8')
+
     def _login(self):
         client = self.client
         client.force_login(self.user)
-        # TODO: Make OAuth2 auth
+        # Make OAuth2 auth over session auth
+        if self.client_oauth_session:
+            client.session.save()
+            client.defaults.setdefault('Sec-Fetch-Site', 'same-origin')
+            client.defaults['HTTP_AUTHORIZATION'] = f'Bearer {self.generate_token_for_session(client.session)}'
         return client
 
     def _logout(self, client):
         saved_cookies = client.cookies
         client.logout()
+        client.defaults.pop('Sec-Fetch-Site', None)
+        client.defaults.pop('HTTP_AUTHORIZATION', None)
         client.cookies = saved_cookies
 
     def _check_update(self, url, data, **fields):