diff --git a/tournaments/frontend/forms.py b/tournaments/frontend/forms.py index 89e51dc..683f22a 100644 --- a/tournaments/frontend/forms.py +++ b/tournaments/frontend/forms.py @@ -54,13 +54,10 @@ def clean_definition(self): return definition def create_tournament(self, request): - tournament = models.Tournament.load( + return models.Tournament.load( definition = self.cleaned_data['definition'], name = self.cleaned_data['name'], creator = request.user) - tournament.definition = self.data['definition'] - tournament.save() - return tournament class UpdateTournamentForm(CreateTournamentForm): diff --git a/tournaments/frontend/tests.py b/tournaments/frontend/tests.py index b4d0c4e..a24c9e9 100644 --- a/tournaments/frontend/tests.py +++ b/tournaments/frontend/tests.py @@ -531,3 +531,59 @@ def test(self): self.assertEqual(response.status_code, 200) self.assertIs(response.resolver_match.func.view_class, views.UpdateTournamentView) self.assertFalse(self.user1 in self.user1_tournament.participants) + + +class CloneTournamentViewTests(TestCase): + + def setUp(self): + self.user1 = models.User.objects.create(username = 'test1') + self.user2 = models.User.objects.create(username = 'test2') + self.client.force_login(self.user1) + + self.user1_tournament = models.Tournament.load(definition = test_tournament1_yml, name = 'Test1', creator = self.user1, published = True) + self.user2_tournament = models.Tournament.load(definition = test_tournament1_yml, name = 'Test2', creator = self.user2, published = True) + + def test_unauthenticated(self): + self.client.logout() + + response = self.client.get(reverse('clone-tournament', kwargs = dict(pk = self.user1_tournament.id)), follow = True) + self.assertEqual(response.status_code, 200) + self.assertIs(response.resolver_match.func.view_class, LoginView) + + def test_not_found(self): + response = self.client.get(reverse('clone-tournament', kwargs = dict(pk = 0))) + self.assertEqual(response.status_code, 404) + + def test_foreign(self): + response = self.client.get(reverse('clone-tournament', kwargs = dict(pk = self.user2_tournament.id)), follow = True) + clone = models.Tournament.objects.get(name = 'Test2 (Copy)') + self.assertEqual(response.status_code, 200) + self.assertIs(response.resolver_match.func.view_class, views.UpdateTournamentView) + self.assertEqual(clone.definition, self.user1_tournament.definition) + self.assertEqual(clone.creator.id, self.user1.id) + + def test_drafted(self): + self.user1_tournament.published = False + self.user1_tournament.save() + response = self.client.get(reverse('clone-tournament', kwargs = dict(pk = self.user1_tournament.id)), follow = True) + clone = models.Tournament.objects.get(name = 'Test1 (Copy)') + self.assertEqual(response.status_code, 200) + self.assertIs(response.resolver_match.func.view_class, views.UpdateTournamentView) + self.assertEqual(clone.definition, self.user1_tournament.definition) + self.assertEqual(clone.creator.id, self.user1.id) + + def test_active(self): + start_tournament(self.user1_tournament) + response = self.client.get(reverse('clone-tournament', kwargs = dict(pk = self.user1_tournament.id)), follow = True) + clone = models.Tournament.objects.get(name = 'Test1 (Copy)') + self.assertEqual(response.status_code, 200) + self.assertIs(response.resolver_match.func.view_class, views.UpdateTournamentView) + self.assertEqual(clone.definition, self.user1_tournament.definition) + self.assertEqual(clone.creator.id, self.user1.id) + + def test(self): + response = self.client.get(reverse('clone-tournament', kwargs = dict(pk = self.user1_tournament.id)), follow = True) + clone = models.Tournament.objects.get(name = 'Test1 (Copy)') + self.assertEqual(response.status_code, 200) + self.assertIs(response.resolver_match.func.view_class, views.UpdateTournamentView) + self.assertEqual(clone.definition, self.user1_tournament.definition) diff --git a/tournaments/frontend/views.py b/tournaments/frontend/views.py index cc01230..404b526 100644 --- a/tournaments/frontend/views.py +++ b/tournaments/frontend/views.py @@ -385,7 +385,5 @@ def get(self, request, *args, **kwargs): definition = self.object.definition, name = self.object.name + ' (Copy)', creator = request.user) - tournament.definition = self.object.definition - tournament.save() request.session['alert'] = dict(status = 'success', text = f'A copy of the tournament "{ self.object.name }" has been created (see below).') return redirect('update-tournament', pk = tournament.id) diff --git a/tournaments/tournaments/models.py b/tournaments/tournaments/models.py index 5685b7e..8f328dd 100644 --- a/tournaments/tournaments/models.py +++ b/tournaments/tournaments/models.py @@ -23,6 +23,7 @@ class Tournament(models.Model): @staticmethod def load(definition, name, **kwargs): + definition_str = definition if isinstance(definition, str): import yaml definition = yaml.safe_load(definition) @@ -31,7 +32,7 @@ def load(definition, name, **kwargs): if len(definition['podium']) == 0: raise ValidationError('No podium definition given.') - tournament = Tournament.objects.create(name = name, podium_spec = definition['podium'], **kwargs) + tournament = Tournament.objects.create(name = name, podium_spec = definition['podium'], definition = definition_str, **kwargs) for stage in definition['stages']: stage = {key.replace('-', '_'): value for key, value in stage.items()}