From aecba04af66dbb2ff16c2a4f5c088cf6096872c5 Mon Sep 17 00:00:00 2001 From: Benjamin Cutler Date: Tue, 31 Dec 2024 16:07:08 -0700 Subject: [PATCH] add Prize to V2 api (#747) [#184870668] --- tests/apiv2/test_donations.py | 2 +- tests/apiv2/test_milestones.py | 2 +- tests/apiv2/test_prizes.py | 264 +++++++++++++++++++++++++++++++++ tests/randgen.py | 14 +- tests/test_event.py | 2 - tests/test_prize.py | 2 - tests/util.py | 74 +++++---- tracker/api/filters.py | 228 +++++++++++++++++++++------- tracker/api/messages.py | 8 +- tracker/api/permissions.py | 32 ++++ tracker/api/serializers.py | 72 ++++++++- tracker/api/urls.py | 9 +- tracker/api/util.py | 19 +++ tracker/api/views/__init__.py | 18 ++- tracker/api/views/prize.py | 25 ++++ tracker/models/prize.py | 71 ++++++++- 16 files changed, 727 insertions(+), 115 deletions(-) create mode 100644 tests/apiv2/test_prizes.py create mode 100644 tracker/api/util.py create mode 100644 tracker/api/views/prize.py diff --git a/tests/apiv2/test_donations.py b/tests/apiv2/test_donations.py index c60bb3677..dd645ad88 100644 --- a/tests/apiv2/test_donations.py +++ b/tests/apiv2/test_donations.py @@ -117,7 +117,7 @@ def test_unprocessed_returns_only_after_timestamp(self): self.event, count=2, state='pending', - time=date.replace(year=9999), + time=date.replace(year=9998), ) response = self.client.get( diff --git a/tests/apiv2/test_milestones.py b/tests/apiv2/test_milestones.py index 3d35cc269..db91c5ca9 100644 --- a/tests/apiv2/test_milestones.py +++ b/tests/apiv2/test_milestones.py @@ -173,7 +173,7 @@ def test_patch(self): with self.subTest('error cases'): self.patch_detail( self.public_milestone, - data={'event': self.event.id}, + data={'event': self.blank_event.id}, status_code=400, expected_error_codes=messages.EVENT_READ_ONLY_CODE, ) diff --git a/tests/apiv2/test_prizes.py b/tests/apiv2/test_prizes.py new file mode 100644 index 000000000..d69ad19e8 --- /dev/null +++ b/tests/apiv2/test_prizes.py @@ -0,0 +1,264 @@ +from datetime import timedelta + +from tests import randgen +from tests.util import APITestCase, today_noon +from tracker import models +from tracker.api import messages +from tracker.api.serializers import PrizeSerializer + + +class TestPrizes(APITestCase): + model_name = 'prize' + serializer_class = PrizeSerializer + + def setUp(self): + super().setUp() + self.runs = randgen.generate_runs( + self.rand, num_runs=3, event=self.event, ordered=True + ) + self.accepted_prize = randgen.generate_prize( + self.rand, start_run=self.runs[0], end_run=self.runs[1] + ) + self.accepted_prize.description = 'test long description' + self.accepted_prize.shortdescription = 'test short description' + self.accepted_prize.save() + self.pending_prize = randgen.generate_prize( + self.rand, event=self.event, state='PENDING' + ) + self.pending_prize.save() + self.denied_prize = randgen.generate_prize( + self.rand, event=self.event, state='DENIED' + ) + self.denied_prize.save() + self.flagged_prize = randgen.generate_prize( + self.rand, event=self.event, state='FLAGGED' + ) + self.flagged_prize.save() + self.locked_prize = randgen.generate_prize( + self.rand, event=self.locked_event, state='PENDING' + ) + self.locked_prize.save() + # TODO + # self.view_winner_user = User.objects.create(username='view_winner_user') + # self.view_winner_user.user_permissions.add(Permission.objects.get(codename='view_prizewinner')) + + def test_fetch(self): + with self.saveSnapshot(): + with self.subTest('public'): + data = self.get_list() + self.assertExactV2Models([self.accepted_prize], data) + + # TODO: more exhaustive? + data = self.get_list( + kwargs={'feed': 'current'}, data={'time': self.runs[0].starttime} + ) + self.assertExactV2Models([self.accepted_prize], data) + + data = self.get_list( + kwargs={'feed': 'current'}, data={'run': self.runs[0].pk} + ) + self.assertExactV2Models([self.accepted_prize], data) + + with self.subTest('searches'): + data = self.get_list(data={'q': self.accepted_prize.name}) + self.assertExactV2Models([self.accepted_prize], data) + + data = self.get_list(data={'q': self.accepted_prize.description}) + self.assertExactV2Models([self.accepted_prize], data) + + data = self.get_list( + data={'q': self.accepted_prize.shortdescription} + ) + self.assertExactV2Models([self.accepted_prize], data) + + data = self.get_list(data={'name': self.accepted_prize.name}) + self.assertExactV2Models([self.accepted_prize], data) + + data = self.get_list(data={'state': 'ACCEPTED'}) + self.assertExactV2Models([self.accepted_prize], data) + + data = self.get_detail(self.accepted_prize) + self.assertV2ModelPresent(self.accepted_prize, data) + + data = self.get_detail( + self.accepted_prize, kwargs={'event_pk': self.event.pk} + ) + self.assertV2ModelPresent(self.accepted_prize, data) + + with self.subTest('private'): + data = self.get_list(user=self.view_user, kwargs={'feed': 'all'}) + self.assertExactV2Models( + [ + self.accepted_prize, + self.flagged_prize, + self.denied_prize, + self.pending_prize, + self.locked_prize, + ], + data, + ) + + data = self.get_list( + user=self.view_user, + data={'state': ['PENDING', 'DENIED', 'FLAGGED']}, + kwargs={'event_pk': self.event.pk}, + ) + self.assertExactV2Models( + [self.flagged_prize, self.denied_prize, self.pending_prize], data + ) + + # TODO + # data = self.get_list(user=self.view_winner_user, data={'include_winners': ''}) + # self.assertExactV2Models([self.accepted_prize], data, serializer_kwargs={'include_winners': True}) + + with self.subTest('error cases'): + with self.subTest('private feeds'): + for feed in models.Prize.HIDDEN_FEEDS: + with self.subTest(feed): + self.get_list(user=None, kwargs={'feed': feed}, status_code=403) + + with self.subTest('wrong event detail'): + self.get_detail( + self.accepted_prize, + kwargs={'event_pk': self.blank_event.pk}, + status_code=404, + ) + + with self.subTest('private detail'): + for prize in [ + self.pending_prize, + self.denied_prize, + self.flagged_prize, + ]: + with self.subTest(prize.state): + self.get_detail(prize, user=None, status_code=404) + + with self.subTest('private states'): + for state in models.Prize.HIDDEN_STATES: + with self.subTest(state): + self.get_list(user=None, data={'state': state}, status_code=403) + + with self.subTest('combining feed and state'): + self.get_list( + data={'state': 'ACCEPTED'}, + kwargs={'feed': 'public'}, + status_code=400, + ) + + with self.subTest('combining feed and detail'): + self.get_detail( + self.accepted_prize, kwargs={'feed': 'public'}, status_code=404 + ) + + def test_create(self): + with self.saveSnapshot(), self.assertLogsChanges(4): + with self.subTest('minimal'): + data = self.post_new( + user=self.add_user, + data={'event': self.event.pk, 'name': 'Event Wide Prize'}, + ) + prize = models.Prize.objects.get(pk=data['id']) + serialized = PrizeSerializer(prize) + self.assertEqual(data, serialized.data) + self.assertEqual(prize.handler, self.add_user) + + data = self.post_new( + user=self.add_user, + data={ + 'event': self.event.pk, + 'name': 'Timed Prize', + 'starttime': today_noon, + 'endtime': today_noon + timedelta(hours=1), + }, + ) + serialized = PrizeSerializer(models.Prize.objects.get(pk=data['id'])) + self.assertEqual(data, serialized.data) + + data = self.post_new( + user=self.add_user, + data={ + 'event': self.event.pk, + 'name': 'Block Prize', + 'startrun': self.runs[0].pk, + 'endrun': self.runs[2].pk, + }, + ) + serialized = PrizeSerializer(models.Prize.objects.get(pk=data['id'])) + self.assertEqual(data, serialized.data) + + with self.subTest('full blown'): + data = self.post_new( + user=self.add_user, + kwargs={'event_pk': self.event.pk}, + data={ + 'name': 'Earthquake Pills', + 'startrun': self.runs[0].pk, + 'endrun': self.runs[1].pk, + 'description': 'Why wait? Make your own earthquakes - loads of fun!', + 'shortdescription': 'Make your own earthquakes!', + 'image': 'https://www.example.com/image.jpg', + 'altimage': 'https://www.example.com/thumbnail.jpg', + 'estimatedvalue': 10, + 'minimumbid': 25, + 'sumdonations': 1, + 'provider': 'Coyote', + 'creator': 'ACME', + 'creatorwebsite': 'https://www.acme.com/', + }, + ) + serialized = PrizeSerializer( + models.Prize.objects.get(pk=data['id']), event_pk=self.event.pk + ) + self.assertEqual(data, serialized.data) + + with self.subTest('error cases'): + self.post_new( + user=self.add_user, + data={'event': self.locked_event.pk}, + status_code=403, + expected_error_codes=messages.UNAUTHORIZED_LOCKED_EVENT_CODE, + ) + self.post_new( + user=self.view_user, + status_code=403, + expected_error_codes=messages.PERMISSION_DENIED_CODE, + ) + self.post_new( + user=None, + status_code=403, + expected_error_codes=messages.NOT_AUTHENTICATED_CODE, + ) + + def test_patch(self): + with self.saveSnapshot(), self.assertLogsChanges(1): + data = self.patch_detail( + self.pending_prize, user=self.add_user, data={'state': 'ACCEPTED'} + ) + self.assertEqual(self.pending_prize.state, 'ACCEPTED') + self.assertEqual(data, PrizeSerializer(self.pending_prize).data) + + with self.subTest('error cases'): + self.patch_detail( + self.accepted_prize, + data={'event': self.blank_event.pk}, + status_code=400, + expected_error_codes=messages.EVENT_READ_ONLY_CODE, + ) + self.patch_detail( + self.locked_prize, + user=self.add_user, + status_code=403, + expected_error_codes=messages.UNAUTHORIZED_LOCKED_EVENT_CODE, + ) + self.patch_detail( + self.accepted_prize, + user=self.view_user, + status_code=403, + expected_error_codes='permission_denied', + ) + self.patch_detail( + self.accepted_prize, + user=None, + status_code=403, + expected_error_codes=messages.NOT_AUTHENTICATED_CODE, + ) diff --git a/tests/randgen.py b/tests/randgen.py index 00c488878..6d571d0c3 100644 --- a/tests/randgen.py +++ b/tests/randgen.py @@ -211,7 +211,6 @@ def generate_prize( end_time=None, sum_donations=None, min_amount=Decimal('1.00'), - max_amount=Decimal('20.00'), random_draw=True, maxwinners=1, state='ACCEPTED', @@ -232,17 +231,8 @@ def generate_prize( prize.category = category else: prize.category = rand.choice([None] + list(PrizeCategory.objects.all())) - if true_false_or_random(rand, sum_donations): - prize.sumdonations = True - lo = random_amount(rand, min_amount=min_amount, max_amount=max_amount) - hi = random_amount(rand, min_amount=min_amount, max_amount=max_amount) - prize.minimumbid = min(lo, hi) - prize.maximumbid = max(lo, hi) - else: - prize.sumdonations = False - prize.minimumbid = prize.maximumbid = random_amount( - rand, min_amount=min_amount, max_amount=max_amount - ) + prize.sumdonations = true_false_or_random(rand, sum_donations) + prize.minimumbid = min_amount prize.randomdraw = random_draw if start_run: prize.event = start_run.event diff --git a/tests/test_event.py b/tests/test_event.py index dd2687197..c3d12d5a1 100644 --- a/tests/test_event.py +++ b/tests/test_event.py @@ -753,7 +753,6 @@ def test_event_prize_report(self): end_run=runs[0], sum_donations=False, min_amount=5, - max_amount=5, ) prize.save() donors = randgen.generate_donors(self.rand, 3) @@ -781,7 +780,6 @@ def test_event_prize_report(self): event=self.event, sum_donations=True, min_amount=50, - max_amount=50, ) grandPrize.save() # generate 2 for summation diff --git a/tests/test_prize.py b/tests/test_prize.py index d5c1b3da3..29b7294de 100644 --- a/tests/test_prize.py +++ b/tests/test_prize.py @@ -597,7 +597,6 @@ def test_decline_prize_single(self): sum_donations=False, random_draw=False, min_amount=amount, - max_amount=amount, maxwinners=1, ) targetPrize.save() @@ -880,7 +879,6 @@ def test_accept_deadline_offset(self): sum_donations=False, random_draw=False, min_amount=amount, - max_amount=amount, maxwinners=1, ) targetPrize.save() diff --git a/tests/util.py b/tests/util.py index 7c3a193e7..b477ab462 100644 --- a/tests/util.py +++ b/tests/util.py @@ -38,6 +38,8 @@ from tracker.api.pagination import TrackerPagination from tracker.compat import zoneinfo +_empty = object() + class PickledRandom(random.Random): # I live in hell @@ -261,6 +263,8 @@ def _get_viewname(self, model_name, action, **kwargs): viewname = f'tracker:api_v2:event-{model_name}-feed-{action}' else: viewname = f'tracker:api_v2:event-{model_name}-{action}' + elif 'feed' in kwargs: + viewname = f'tracker:api_v2:{model_name}-feed-{action}' else: viewname = f'tracker:api_v2:{model_name}-{action}' return viewname @@ -273,11 +277,11 @@ def get_detail( status_code=200, data=None, kwargs=None, - **other_kwargs, + user=_empty, ): kwargs = kwargs or {} - if 'user' in other_kwargs: - self.client.force_authenticate(user=other_kwargs['user']) + if user is not _empty: + self.client.force_authenticate(user=user) model_name = model_name or self.model_name assert model_name is not None lookup_kwargs = {**kwargs} @@ -309,11 +313,11 @@ def get_list( status_code=200, data=None, kwargs=None, - **other_kwargs, + user=_empty, ): kwargs = kwargs or {} - if 'user' in other_kwargs: - self.client.force_authenticate(user=other_kwargs['user']) + if user is not _empty: + self.client.force_authenticate(user=user) model_name = model_name or self.model_name assert model_name is not None url = reverse( @@ -343,15 +347,15 @@ def get_noun( data=None, kwargs=None, lookup_key=None, - **other_kwargs, + user=_empty, ): kwargs = kwargs or {} if lookup_key is None: lookup_key = self.lookup_key if obj is not None and lookup_key == 'pk': kwargs['pk'] = obj.pk - if 'user' in other_kwargs: - self.client.force_authenticate(user=other_kwargs['user']) + if user is not _empty: + self.client.force_authenticate(user=user) model_name = model_name or self.model_name assert model_name is not None url = reverse(self._get_viewname(model_name, noun, **kwargs), kwargs=kwargs) @@ -442,7 +446,7 @@ def post_new( data=None, kwargs=None, expected_error_codes=None, - **other_kwargs, + user=_empty, ): return self.post_noun( 'list', @@ -451,7 +455,7 @@ def post_new( data=data, kwargs=kwargs, expected_error_codes=expected_error_codes, - **other_kwargs, + user=user, ) def post_noun( @@ -463,12 +467,12 @@ def post_noun( data=None, kwargs=None, expected_error_codes=None, - **other_kwargs, + user=_empty, ): kwargs = kwargs or {} data = data or {} - if 'user' in other_kwargs: - self.client.force_authenticate(user=other_kwargs['user']) + if user is not _empty: + self.client.force_authenticate(user=user) model_name = model_name or self.model_name assert model_name is not None url = reverse(self._get_viewname(model_name, noun, **kwargs), kwargs=kwargs) @@ -493,11 +497,11 @@ def patch_detail( expected_error_codes=None, data=None, kwargs=None, - **other_kwargs, + user=_empty, ): kwargs = kwargs or {} - if 'user' in other_kwargs: - self.client.force_authenticate(user=other_kwargs['user']) + if user is not _empty: + self.client.force_authenticate(user=user) model_name = model_name or self.model_name assert model_name is not None url = reverse( @@ -847,12 +851,27 @@ def assertLogsChanges(self, number, action_flag=None): msg=f'Expected {number} change(s) logged, got {after - before}', ) + @contextlib.contextmanager + def subTest(self, msg=_empty, **params): + if msg is not _empty and msg: + num = self._snapshot_num + self._snapshot_num = 1 + self._messages.append(msg) + try: + with super().subTest(msg, **params): + yield + finally: + self._snapshot_num = num + self._messages.pop() + else: + with super().subTest(**params): + yield + @contextlib.contextmanager def saveSnapshot(self): # TODO: don't save 'empty' results by default? previous = getattr(self, '_save_snapshot', False) self._save_snapshot = True - self._last_subtest = None try: yield finally: @@ -900,15 +919,11 @@ def _snapshot(self, method, url, data): re.sub(r'^Test', '', self.__class__.__name__), re.sub(r'^test_', '', self._testMethodName).lower(), ] - subtest = self - while next_subtest := getattr(subtest, '_subtest', None): - subtest = next_subtest - if subtest._message == 'happy path': - continue - pieces.append(re.sub(r'\W', '_', subtest._message).lower()) - - if self._last_subtest is not subtest: - self._snapshot_num = 1 + pieces.extend( + re.sub(r'\W', '_', m).lower() + for m in self._messages + if m != 'happy path' + ) # obscure ids from url since they can drift depending on test order/results, remove leading tracker since it's redundant, and slugify everything else # FIXME: this doesn't quite work for Country since we don't use PK lookups in the urls @@ -922,7 +937,6 @@ def _snapshot(self, method, url, data): snapshot_name = '_'.join(p.strip('_') for p in pieces) self._snapshot_num += 1 - self._last_subtest = subtest basepath = os.path.join(os.path.dirname(__file__), 'snapshots') os.makedirs(basepath, exist_ok=True) @@ -934,6 +948,8 @@ def _snapshot(self, method, url, data): def setUp(self): super().setUp() self._save_snapshot = False + self._snapshot_num = 1 + self._messages = [] self.rand = random.Random() # depending on the environment this might not be pickleable, which makes random test failures extremely # hard to diagnose @@ -957,7 +973,7 @@ def setUp(self): name='Blank Event', ) self.event = models.Event.objects.create( - datetime=today_noon, targetamount=5, short='event', name='Test Event' + datetime=today_noon, targetamount=5, short='test', name='Test Event' ) self.anonymous_user = AnonymousUser() self.user = User.objects.create(username='test') diff --git a/tracker/api/filters.py b/tracker/api/filters.py index d04aa12b4..83f9d9628 100644 --- a/tracker/api/filters.py +++ b/tracker/api/filters.py @@ -1,5 +1,8 @@ import datetime +import itertools import logging +import operator +from functools import reduce from django.db.models import Q from django.http import Http404 @@ -8,68 +11,128 @@ from rest_framework.exceptions import NotFound, ParseError, PermissionDenied from tracker.api import messages -from tracker.models import Bid +from tracker.api.util import parse_time +from tracker.api.views.run import SpeedRunViewSet +from tracker.models import Bid, Prize logger = logging.getLogger(__name__) +empty = object() + class TrackerFilter(filters.BaseFilterBackend): - filter_params = {} + general_filter = [] + filter_lookup = [] + filter_keys = {} def filter_queryset(self, request, queryset, view): - if not view.detail: - filter_args = [] - filter_kwargs = {} - if 'id' in request.query_params: - filter_kwargs['id__in'] = request.query_params.getlist('id') - for param, filter_param in self.filter_params.items(): - if param in request.query_params: - if isinstance(filter_param, str): - if filter_param.endswith('__in'): - values = request.query_params.getlist(param) - if any( - not self.has_filter_permission(request, param, value) - for value in values - ): - raise PermissionDenied( - detail=messages.UNAUTHORIZED_FILTER_PARAM, - code=messages.UNAUTHORIZED_FILTER_PARAM_CODE, - ) - filter_kwargs[filter_param] = [ - self.normalize_value(param, value) for value in values - ] - else: - value = request.query_params[param] - if not self.has_filter_permission(request, param, value): - raise PermissionDenied( - detail=messages.UNAUTHORIZED_FILTER_PARAM, - code=messages.UNAUTHORIZED_FILTER_PARAM_CODE, - ) - filter_kwargs[filter_param] = self.normalize_value( - param, value - ) - elif isinstance(filter_param, Q): - filter_args.append(filter_param) - elif callable(filter_param): - filter_args.append(filter_param(request.query_params[param])) - try: - queryset = queryset.filter(*filter_args, **filter_kwargs) - except (ValueError, TypeError): + if view.detail or view.action == ['create']: + return queryset + + if 'q' in request.query_params: + if not self.general_filter: raise ParseError( - detail=messages.MALFORMED_SEARCH_PARAMETER, - code=messages.MALFORMED_SEARCH_PARAMETER_CODE, + detail=messages.NO_GENERAL_SEARCH, + code=messages.NO_GENERAL_SEARCH_CODE, ) + # TODO: recurse the keys like in search_filters + queryset = queryset.filter( + reduce( + operator.or_, + ( + Q(**{k + '__icontains': v}) + for (k, v) in itertools.product( + self.general_filter, request.query_params.getlist('q') + ) + ), + ) + ) + + filter_args = [] + filter_kwargs = {} + + if 'id' in request.query_params: + filter_kwargs['id__in'] = request.query_params.getlist('id') + + for param in self.filter_lookup: + if param in request.query_params: + if not self.has_filter_permission(request, param): + raise PermissionDenied( + detail=messages.UNAUTHORIZED_FIELD, + code=messages.UNAUTHORIZED_FIELD_CODE, + ) + value = request.query_params[param] + if not self.has_filter_permission(request, param, value): + raise PermissionDenied( + detail=messages.UNAUTHORIZED_FILTER_PARAM, + code=messages.UNAUTHORIZED_FILTER_PARAM_CODE, + ) + filter_kwargs[param] = self.normalize_value(param, value) + + for param, filter_param in self.filter_keys.items(): + if param in request.query_params: + if not self.has_filter_permission(request, param): + raise PermissionDenied( + detail=messages.UNAUTHORIZED_FIELD, + code=messages.UNAUTHORIZED_FIELD_CODE, + ) + if isinstance(filter_param, str): + if filter_param.endswith('__in'): + values = request.query_params.getlist(param) + if any( + not self.has_filter_permission(request, param, value) + for value in values + ): + raise PermissionDenied( + detail=messages.UNAUTHORIZED_FILTER_PARAM, + code=messages.UNAUTHORIZED_FILTER_PARAM_CODE, + ) + filter_kwargs[filter_param] = [ + self.normalize_value(param, value) for value in values + ] + else: + value = request.query_params[param] + if not self.has_filter_permission(request, param, value): + raise PermissionDenied( + detail=messages.UNAUTHORIZED_FILTER_PARAM, + code=messages.UNAUTHORIZED_FILTER_PARAM_CODE, + ) + filter_kwargs[filter_param] = self.normalize_value(param, value) + elif isinstance(filter_param, Q): + filter_args.append(filter_param) + elif callable(filter_param): + filter_args.append(filter_param(request.query_params[param])) + try: + queryset = queryset.filter(*filter_args, **filter_kwargs) + except (ValueError, TypeError): + raise ParseError( + detail=messages.MALFORMED_SEARCH_PARAMETER, + code=messages.MALFORMED_SEARCH_PARAMETER_CODE, + ) + return queryset def normalize_value(self, field, value): return value - def has_filter_permission(self, request, field, value): + def has_filter_permission(self, request, field, value=empty): return True +def check_feed(feed, view, query_params): + if feed is not None: + # feed makes no sense for detail views or when trying to explicitly filter by state, but for different reasons + if view.detail: + raise Http404 + if 'state' in query_params: + raise ParseError( + detail=_('Cannot search for state while using the feed endpoint.'), + code=messages.INVALID_SEARCH_PARAMETER_CODE, + ) + + class BidFilter(TrackerFilter): - filter_params = { + filter_keys = { 'name': 'name__icontains', 'state': 'state__in', 'run': 'speedrun__in', @@ -86,18 +149,26 @@ def normalize_value(self, field, value): return value.upper() return value - def has_filter_permission(self, request, field, value): + def has_filter_permission(self, request, field, value=empty): return ( field != 'state' + or value is empty or value in Bid.PUBLIC_STATES or request.user.has_perm('tracker.view_hidden_bid') + or request.user.has_perm('tracker.view_bid') ) def filter_queryset(self, request, queryset, view): feed = view.get_feed() query_params = request.query_params + + check_feed(feed, view, query_params) + + if view.detail or view.action == ['create']: + return queryset + if feed is None: - if not view.detail and 'state' not in query_params: + if 'state' not in query_params: queryset = queryset.public() elif feed == 'open': queryset = queryset.open() @@ -134,16 +205,6 @@ def filter_queryset(self, request, queryset, view): detail=messages.INVALID_FEED % feed, code=messages.INVALID_FEED_CODE ) - if feed is not None: - # feed makes no sense for detail views or when trying to explicitly filter by state, but for different reasons - if view.detail: - raise Http404 - if 'state' in query_params: - raise ParseError( - detail=_('Cannot search for state while using the feed endpoint.'), - code=messages.INVALID_SEARCH_PARAMETER_CODE, - ) - if view.action == 'tree': if feed == 'pending': raise NotFound( @@ -160,3 +221,58 @@ def filter_queryset(self, request, queryset, view): ) return super().filter_queryset(request, queryset, view) + + +class PrizeFilter(TrackerFilter): + general_filter = ['name', 'description', 'shortdescription'] + filter_lookup = ['event', 'category'] + filter_keys = { + 'name': 'name__icontains', + 'state': 'state__in', + } + + def normalize_value(self, field, value): + if field == 'state': + return value.upper() + return value + + def has_filter_permission(self, request, field, value=empty): + return ( + field != 'state' + or value is empty + or value in Prize.PUBLIC_STATES + or request.user.has_perm('tracker.view_prize') + ) + + def filter_queryset(self, request, queryset, view): + feed = view.get_feed() + query_params = request.query_params + + check_feed(feed, view, query_params) + + if view.detail or view.action == ['create']: + return queryset + + if feed is None or feed == 'public': + if 'state' not in query_params: + queryset = queryset.public() + elif feed == 'current': + if 'run' in request.query_params: + run = SpeedRunViewSet( + kwargs={'pk': request.query_params['run']}, request=request + ).get_object() + else: + run = None + queryset = queryset.current( + parse_time(query_params.get('time', None)), run=run + ) + elif feed == 'all': + pass # no change for 'all' + elif feed is not None: + if feed.upper() in Prize.ALL_FEEDS: + logger.warning(f'unhandled valid prize feed `{feed}`') + raise NotFound( + detail=messages.INVALID_FEED % feed, code=messages.INVALID_FEED_CODE + ) + + return super().filter_queryset(request, queryset, view) diff --git a/tracker/api/messages.py b/tracker/api/messages.py index 542407e25..68e5e4793 100644 --- a/tracker/api/messages.py +++ b/tracker/api/messages.py @@ -1,10 +1,12 @@ from django.utils.translation import gettext_lazy as _ -from rest_framework.exceptions import NotAuthenticated +from rest_framework.exceptions import NotAuthenticated, PermissionDenied GENERIC_NOT_FOUND = _( 'That resource does not exist or you do not have permission to view it.' ) +NO_GENERAL_SEARCH = _('That endpoint does not support `q` searches.') +NO_GENERAL_SEARCH_CODE = 'no_general_search' MALFORMED_SEARCH_PARAMETER = _('At least one search parameter was malformed.') MALFORMED_SEARCH_PARAMETER_SPECIFIC = _('`%s` parameter was malformed.') MALFORMED_SEARCH_PARAMETER_CODE = 'malformed_search_parameter' @@ -51,5 +53,9 @@ ANCHOR_FIELD_CODE = 'invalid_anchor_sibling' INVALID_ANCHOR = _('Specified anchor is not ordered.') INVALID_ANCHOR_CODE = 'invalid_anchor' +PERMISSION_DENIED = PermissionDenied.default_detail +PERMISSION_DENIED_CODE = PermissionDenied.default_code NOT_AUTHENTICATED = NotAuthenticated.default_detail NOT_AUTHENTICATED_CODE = NotAuthenticated.default_code +INVALID_TIMESTAMP = _('Provided timestamp could not be parsed.') +INVALID_TIMESTAMP_CODE = 'invalid_timestamp' diff --git a/tracker/api/permissions.py b/tracker/api/permissions.py index 45cc2b7f4..13466aa30 100644 --- a/tracker/api/permissions.py +++ b/tracker/api/permissions.py @@ -82,6 +82,38 @@ def has_object_permission(self, request: Request, view: t.Callable, obj: t.Any): ) +class PrizeFeedPermission(BasePermission): + PUBLIC_FEEDS = models.Prize.PUBLIC_FEEDS + message = messages.UNAUTHORIZED_FEED + code = messages.UNAUTHORIZED_FEED_CODE + + def has_permission(self, request: Request, view: t.Callable): + feed = view.get_feed() + return super().has_permission(request, view) and ( + feed is None + or feed in self.PUBLIC_FEEDS + or any( + request.user.has_perm(f'tracker.{p}') + for p in ('change_prize', 'view_prize') + ) + ) + + +class PrizeStatePermission(BasePermission): + PUBLIC_STATES = models.Prize.PUBLIC_STATES + message = messages.GENERIC_NOT_FOUND + code = messages.UNAUTHORIZED_OBJECT_CODE + + def has_object_permission(self, request: Request, view: t.Callable, obj: t.Any): + return super().has_object_permission(request, view, obj) and ( + obj.state in self.PUBLIC_STATES + or any( + request.user.has_perm(f'tracker.{p}') + for p in ('change_prize', 'view_prize') + ) + ) + + class DonationBidStatePermission(BasePermission): PUBLIC_STATES = models.Bid.PUBLIC_STATES message = messages.GENERIC_NOT_FOUND diff --git a/tracker/api/serializers.py b/tracker/api/serializers.py index 5b5cede7d..fee6024ea 100644 --- a/tracker/api/serializers.py +++ b/tracker/api/serializers.py @@ -17,6 +17,7 @@ from rest_framework.validators import UniqueTogetherValidator from tracker.api import messages +from tracker.models import Prize from tracker.models.bid import Bid, DonationBid from tracker.models.country import Country, CountryRegion from tracker.models.donation import Donation, Donor, Milestone @@ -94,6 +95,19 @@ def __init__(self, instance=None, exclude_from_clean=None, **kwargs): self.exclude_from_clean = exclude_from_clean or [] super().__init__(instance, **kwargs) + @property + def is_root(self): + return self.root is self or ( + isinstance(self.root, ListSerializer) and self.root.child is self + ) + + def get_fields(self): + fields = super().get_fields() + if not self.is_root: + for field in getattr(self.Meta, 'exclude_from_nested', []): + fields.pop(field, None) + return fields + def get_validators(self): validators = super().get_validators() # we do this ourselves and it causes weird issues elsewhere @@ -313,7 +327,7 @@ class Meta: ) def to_representation(self, instance): - if self.root == self or getattr(self.root, 'child', None) == self: + if self.is_root: return super().to_representation(instance) else: return instance.alpha3 @@ -333,7 +347,7 @@ class Meta: ) def to_representation(self, instance): - if self.root == self or getattr(self.root, 'child', None) == self: + if self.is_root: return super().to_representation(instance) else: return [instance.name, instance.country.alpha3] @@ -367,8 +381,8 @@ def get_event(self): def to_representation(self, instance): ret = super().to_representation(instance) - if self.event_pk and 'event' in ret: - del ret['event'] + if self.event_pk: + ret.pop('event', None) return ret def to_internal_value(self, data): @@ -402,6 +416,7 @@ class BidSerializer( SerializerWithPermissionsMixin, EventNestedSerializerMixin, TrackerModelSerializer ): type = ClassNameField() + event_move = True def __init__(self, *args, include_hidden=False, feed=None, tree=False, **kwargs): super().__init__(*args, **kwargs) @@ -799,6 +814,7 @@ class Meta: # TODO: almost assuredly a bug in DRF, see: https://github.com/encode/django-rest-framework/discussions/9538 'order': {'default': None, 'required': False} } + exclude_from_nested = ('event',) def __init__(self, *args, with_tech_notes=False, **kwargs): super().__init__(*args, **kwargs) @@ -996,3 +1012,51 @@ def to_representation(self, instance): if instance.visibility == 'ANON': value.pop('alias', None) return value + + +class PrizeSerializer( + SerializerWithPermissionsMixin, EventNestedSerializerMixin, TrackerModelSerializer +): + type = ClassNameField() + event = EventSerializer() + # TODO: when I figure out a better way to be selective about nested fields + # startrun = SpeedRunSerializer() + # endrun = SpeedRunSerializer() + + class Meta: + model = Prize + fields = ( + 'type', + 'id', + 'event', + 'name', + 'state', + 'startrun', + 'endrun', + 'starttime', + 'endtime', + 'start_draw_time', + 'end_draw_time', + 'description', + 'shortdescription', + 'image', + 'altimage', + 'imagefile', + 'estimatedvalue', + 'minimumbid', + 'sumdonations', + 'provider', + 'creator', + # 'creatoremail', TODO, maybe a privacy filter? how often does this get used? + 'creatorwebsite', + ) + + def validate(self, data): + # TODO: allow assigning other handlers, but figure out what those permissions need to look like first + if ( + 'request' in self.context + and 'view' in self.context + and self.context['view'].action == 'create' + ): + data['handler'] = self.context['request'].user + return super().validate(data) diff --git a/tracker/api/urls.py b/tracker/api/urls.py index 28df1c512..56a9283c8 100644 --- a/tracker/api/urls.py +++ b/tracker/api/urls.py @@ -13,6 +13,7 @@ interview, me, milestone, + prize, run, talent, ) @@ -23,16 +24,21 @@ def event_nested_route(path, viewset, *, basename=None, feed=False): if basename is None: basename = router.get_default_basename(viewset) - router.register(path, viewset, basename) if feed: router.register( r'events/(?P[^/.]+)/' + path + r'/feed_(?P\w+)', viewset, f'event-{basename}-feed', ) + router.register( + path + r'/feed_(?P\w+)', + viewset, + f'{basename}-feed', + ) router.register( r'events/(?P[^/.]+)/' + path, viewset, f'event-{basename}' ) + router.register(path, viewset, basename) # routers generate URLs based on the view sets, so that we don't need to do a bunch of stuff by hand @@ -43,6 +49,7 @@ def event_nested_route(path, viewset, *, basename=None, feed=False): event_nested_route(r'ads', ad.AdViewSet) event_nested_route(r'interviews', interview.InterviewViewSet) event_nested_route(r'milestones', milestone.MilestoneViewSet) +event_nested_route(r'prizes', prize.PrizeViewSet, feed=True) event_nested_route(r'donors', donors.DonorViewSet) router.register(r'donations', donations.DonationViewSet, basename='donations') router.register(r'me', me.MeViewSet, basename='me') diff --git a/tracker/api/util.py b/tracker/api/util.py new file mode 100644 index 000000000..ac4e72e31 --- /dev/null +++ b/tracker/api/util.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from datetime import datetime + +from rest_framework.exceptions import ParseError + +from tracker.api import messages + + +def parse_time(time: None | str | int | datetime) -> datetime: + """api helper to throw the correct exception""" + from tracker.util import parse_time + + try: + return parse_time(time) + except (TypeError, ValueError): + raise ParseError( + detail=messages.INVALID_TIMESTAMP, code=messages.INVALID_TIMESTAMP_CODE + ) diff --git a/tracker/api/views/__init__.py b/tracker/api/views/__init__.py index c892ad1c3..3a52633db 100644 --- a/tracker/api/views/__init__.py +++ b/tracker/api/views/__init__.py @@ -121,8 +121,6 @@ def get_serializer(self, *args, **kwargs): class EventNestedMixin: - allow_event_moves = True - def get_permissions(self): return super().get_permissions() + [EventLockedPermission()] @@ -141,7 +139,7 @@ def get_event_from_request(self): return EventViewSet( kwargs={'pk': event_pk, 'skip_annotations': True}, request=self.request ).get_object() - if event := self.request.data.get('event', None): + if not self.detail and (event := self.request.data.get('event', None)): with contextlib.suppress(TypeError, ValueError): return models.Event.objects.filter(pk=event).first() return None @@ -149,9 +147,19 @@ def get_event_from_request(self): def is_event_locked(self, obj=None): if self.detail and obj: event = obj.event + # happens if trying patch an object to another event in any way + if ( + other_event := self.request.data.get('event', None) + ) is not None and other_event != event.pk: + try: + other_event = models.Event.objects.get(pk=other_event) + except (TypeError, ValueError, models.Event.DoesNotExist): + pass # should be caught by validation later + else: + return event.locked or other_event.locked + return event.locked else: - event = self.get_event_from_request() - return event and event.locked + return (event := self.get_event_from_request()) is not None and event.locked def generic_404(exception_handler): diff --git a/tracker/api/views/prize.py b/tracker/api/views/prize.py new file mode 100644 index 000000000..08355cabc --- /dev/null +++ b/tracker/api/views/prize.py @@ -0,0 +1,25 @@ +from tracker.api.filters import PrizeFilter +from tracker.api.permissions import PrizeFeedPermission, PrizeStatePermission +from tracker.api.serializers import PrizeSerializer +from tracker.api.views import ( + EventNestedMixin, + TrackerFullViewSet, + WithSerializerPermissionsMixin, +) +from tracker.models import Prize + + +class PrizeViewSet( + WithSerializerPermissionsMixin, + EventNestedMixin, + TrackerFullViewSet, +): + queryset = Prize.objects.select_related( + 'event', 'startrun', 'endrun', 'prev_run', 'next_run' + ) + serializer_class = PrizeSerializer + permission_classes = [PrizeFeedPermission, PrizeStatePermission] + filter_backends = [PrizeFilter] + + def get_feed(self): + return self.kwargs.get('feed', None) diff --git a/tracker/models/prize.py b/tracker/models/prize.py index 286a4efe7..254fd85d9 100644 --- a/tracker/models/prize.py +++ b/tracker/models/prize.py @@ -27,13 +27,77 @@ USER_MODEL_NAME = getattr(settings, 'AUTH_USER_MODEL', User) +class PrizeQuerySet(models.QuerySet): + PUBLIC_FEEDS = ('public', 'current') + HIDDEN_FEEDS = ('to_draw', 'pending', 'all') + ALL_FEEDS = PUBLIC_FEEDS + HIDDEN_FEEDS + + def public(self): + return self.filter(state='ACCEPTED') + + def current(self, time=None, *, run=None): + # current implies 'public', since it should only list prizes that are + # available to donate for + if run is None: + time = util.parse_time(time) + else: + if run.order is None: + raise ValueError('provided Run is not ordered') + time = run.starttime + return self.public().filter( + Q(prizewinner__isnull=True) + & ( + Q(startrun__starttime__lte=time, endrun__endtime__gte=time) + | Q(starttime__lte=time, endtime__gte=time) + | Q( + startrun__isnull=True, + endrun__isnull=True, + starttime__isnull=True, + endtime__isnull=True, + ) + ) + ) + + def to_draw(self, time=None): + time = util.parse_time(time) + return self.filter( + ( + Q(prizewinner=None) + | ( + Q(prizewinner__pendingcount__gt=0) + & Q(prizewinner__acceptdeadline__lt=time) + ) + ) + & ( + Q(endrun__endtime__lte=time) + | Q(endtime__lte=time) + | (Q(endtime=None) & Q(endrun=None)) + ) + & ( + Q(event__prize_drawing_date=None) + | Q(event__prize_drawing_date__lte=time) + ), + state='ACCEPTED', + ) + + def pending(self): + return self.filter(state='PENDING') + + class PrizeManager(models.Manager): def get_by_natural_key(self, name, event): return self.get(name=name, event=Event.objects.get_by_natural_key(*event)) class Prize(models.Model): - objects = PrizeManager() + PUBLIC_FEEDS = PrizeQuerySet.PUBLIC_FEEDS + HIDDEN_FEEDS = PrizeQuerySet.HIDDEN_FEEDS + ALL_FEEDS = PrizeQuerySet.ALL_FEEDS + PUBLIC_STATES = ('ACCEPTED',) + HIDDEN_STATES = ('DENIED', 'PENDING', 'FLAGGED') + ALL_STATES = PUBLIC_STATES + HIDDEN_STATES + + objects = PrizeManager.from_queryset(PrizeQuerySet)() name = models.CharField(max_length=64) category = models.ForeignKey( 'PrizeCategory', on_delete=models.PROTECT, null=True, blank=True @@ -204,6 +268,10 @@ class Meta: ordering = ['event__datetime', 'startrun__starttime', 'starttime', 'name'] unique_together = ('name', 'event') + @property + def public(self): + return self.state == 'ACCEPTED' + def natural_key(self): return self.name, self.event.natural_key() @@ -421,6 +489,7 @@ def has_draw_time(self): def start_draw_time(self): if self.startrun and self.startrun.order: if self.prev_run: + # allow some slop into the previous run's setup time in case the run starts 'late' return self.prev_run.endtime - datetime.timedelta( milliseconds=self.prev_run.setup_time_ms )