diff --git a/app/recipe/tests/test_recipe_api.py b/app/recipe/tests/test_recipe_api.py index 95ce357..dc47db2 100644 --- a/app/recipe/tests/test_recipe_api.py +++ b/app/recipe/tests/test_recipe_api.py @@ -38,6 +38,13 @@ def create_recipe(user, **params): recipe = Recipe.objects.create(user=user, **defaults) return recipe +# helper function to create a user + + +def create_user(**params): + """Create and return a sample user.""" + return get_user_model().objects.create_user(**params) + class PublicRecipeAPITests(TestCase): """Test unauthenticated recipe API requests.""" @@ -57,10 +64,8 @@ class PrivateRecipeApiTests(TestCase): def setUp(self): self.client = APIClient() - self.user = get_user_model().objects.create_user( - 'user@example.com', - 'password123', - ) + self.user = create_user( + email='user@example.com', password='password123') self.client.force_authenticate(self.user) def test_retrieve_recipes(self): @@ -79,10 +84,7 @@ def test_retrieve_recipes(self): def test_recipes_limited_to_user(self): """Test list of recipes is limited to authenticated user.""" - user2 = get_user_model().objects.create_user( - 'other@example.com', - 'password123', - ) + user2 = create_user(email='other@example.com', password='testpass123') create_recipe(user=user2) create_recipe(user=self.user) @@ -119,3 +121,62 @@ def test_create_recipe(self): for key, value in payload.items(): self.assertEqual(getattr(recipe, key), value) self.assertEqual(recipe.user, self.user) + + def test_partial_update(self): + """Test partial update of recipe.""" + original_link = 'https://example.com/recipe.pdf' + recipe = create_recipe( + user=self.user, + title='Sample recipe title', + link=original_link, + ) + + payload = {'title': 'New recipe title'} + url = detail_url(recipe.id) + res = self.client.patch(url, payload) + + self.assertEqual(res.status_code, status.HTTP_200_OK) + recipe.refresh_from_db() + self.assertEqual(recipe.title, payload['title']) + self.assertEqual(recipe.link, original_link) + self.assertEqual(recipe.user, self.user) + + def test_full_update(self): + """Test full update of recipe.""" + recipe = create_recipe( + user=self.user, + title='Sample recipe title', + description='Sample description', + link='https://example.com/recipe.pdf', + ) + + payload = { + 'title': 'Updated recipe title', + 'time_minutes': 13, + 'price': Decimal('17.99'), + 'description': 'Updated description', + 'link': 'https://example.com/recipe2.pdf', + } + url = detail_url(recipe.id) + res = self.client.put(url, payload) + + self.assertEqual(res.status_code, status.HTTP_200_OK) + recipe.refresh_from_db() + for key, value in payload.items(): + self.assertEqual(getattr(recipe, key), value) + self.assertEqual(recipe.user, self.user) + + def test_update_user_returns_error(self): + """Test changing the recipe user results in an error""" + new_user = create_user( + email='user2@example.com', + password='testpass123' + ) + recipe = create_recipe(user=self.user) + + payload = {'user': new_user.id} + url = detail_url(recipe.id) + self.client.patch(url, payload) + + recipe.refresh_from_db() + self.assertEqual(recipe.user, self.user) diff --git a/app/recipe/views.py b/app/recipe/views.py index 1cef0ae..08dea64 100644 --- a/app/recipe/views.py +++ b/app/recipe/views.py @@ -26,3 +26,7 @@ def get_serializer_class(self): if self.action == 'list': return serializers.RecipeSerializer return self.serializer_class + + def perform_create(self, serializer): + """Create a new recipe.""" + serializer.save(user=self.request.user)