Source code for rest_assured.testcases

from django.db.models import Manager
from django.core.exceptions import ObjectDoesNotExist
from django.utils import six
from rest_framework import status
from rest_framework.reverse import reverse
from rest_framework.test import APITestCase
from django.utils.six import text_type


[docs]class BaseRESTAPITestCase(APITestCase): """Base test case class for testing REST API endpoints.""" #: *required*: Base route name of the API endpoints to test. base_name = None #: *required*: The factory class to use for creating the main object to test against. factory_class = None #: Suffix for list endpoint view names. Defaults to ``'-list'``. LIST_SUFFIX = '-list' #: Suffix for detail endpoint view names. Defaults to ``'-detail'``. DETAIL_SUFFIX = '-detail' #: The field to use for DB and route lookups. Defaults to ``'pk'``. lookup_field = 'pk' #: User factory to use in case you need user authentication for testing. Defaults to ``None``. user_factory = None #: The main test subject. object = None #: The user instance created if the ``user_factory`` is set and used. Defaults to ``None``. user = None
[docs] def get_factory_class(self): """Return the factory class for generating the main object (or model instance) of this test case. By default this gets the ``factory_class`` attribute of this class. :returns: Factory class used for creating the mock objects. """ return getattr(self, 'factory_class')
[docs] def get_object(self, factory): """Create and return the object (or model instance) of this test case. By default this calls the ``create()`` method of the factory class, assuming a Django Model or a factory_boy's Factory. :param factory: The factory class used for creating :returns: The main object of this test case. """ return factory.create()
[docs] def setUp(self): """Generates the main object and user instance if needed. The user instance will be created only if the ``user_factory`` attribute is set to the factory class. If there is an available user instance, that user will be force authenticated. """ # create and force authenticate user user_factory = getattr(self, 'user_factory') if user_factory: self.user = user_factory.create() self.client.force_authenticate(self.user) # create the object self.object = self.get_object(self.get_factory_class())
[docs]class ListAPITestCaseMixin(object): """Adds a list view test to the test case.""" #: When using pagination set this attribute to the name of the property in the response data that holds the result set. Defaults to ``None``. pagination_results_field = None
[docs] def get_list_url(self): """Return the list endpoint url. :returns: The url of list endpoint. """ return reverse(self.base_name + self.LIST_SUFFIX)
[docs] def get_list_response(self, **kwargs): """Send the list request and return the response. :param kwargs: Extra arguments that are passed to the client's ``get()`` call. :returns: The response object. """ return self.client.get(self.get_list_url(), **kwargs)
[docs] def test_list(self, **kwargs): """Send request to the list view endpoint, verify and return the response. Checks for a 200 status code and that there is a ``results`` property in the ``response.data``. You can extend it for more extensive checks. .. admonition:: example .. code:: python class LanguageRESTAPITestCase(ListAPITestCaseMixin, BaseRESTAPITestCase): def test_list(self, **kwargs): response = super(LanguageRESTAPITestCase, self).test_list(**kwargs) results = response.data.get('results') self.assertEqual(results[0].get('code'), self.object.code) :param kwargs: Extra arguments that are passed to the client's ``get()`` call. :returns: The view's response. """ response = self.get_list_response(**kwargs) self.assertEqual(response.status_code, status.HTTP_200_OK, response.data) results = response.data if self.pagination_results_field: self.assertIn(self.pagination_results_field, response.data) results = results[self.pagination_results_field] self.assertTrue(len(results) >= 1) return response
[docs]class DetailAPITestCaseMixin(object): """Adds a detail view test to the test case.""" # A list of attribute names to check equality between the main object and the response data. # Defaults to ``['id']``. # You can also use a tuple of a string and a callable, that takes the object and returns an attribute's value. attributes_to_check = ['id']
[docs] def get_detail_url(self): """Return the detail endpoint url. :returns: The url of detail endpoint. """ object_id = getattr(self.object, self.lookup_field) return reverse(self.base_name + self.DETAIL_SUFFIX, args=[text_type(object_id)])
[docs] def get_detail_response(self, **kwargs): """Send the detail request and return the response. :param kwargs: Extra arguments that are passed to the client's ``get()`` call. :returns: The response object. """ return self.client.get(self.get_detail_url(), **kwargs)
[docs] def test_detail(self, **kwargs): """Send request to the detail view endpoint, verify and return the response. Checks for a 200 status code and that there is an ``id`` property in the ``response.data`` and that it equals the main object's id. You can extend it for more extensive checks. .. admonition:: example .. code:: python class LanguageRESTAPITestCase(DetailAPITestCaseMixin, BaseRESTAPITestCase): def test_list(self, **kwargs): response = super(LanguageRESTAPITestCase, self).test_list(**kwargs) self.assertEqual(response.data.get('code'), self.object.code) Using a callable in ``attributes_to_check``: .. admonition:: example .. code:: python class TaggedFoodRESTAPITestCase(DetailAPITestCaseMixin, BaseRESTAPITestCase): attributes_to_check = ['name', ('similar', lambda obj: obj.tags.similar_objects())] :param kwargs: Extra arguments that are passed to the client's ``get()`` call. :returns: The view's response. """ response = self.get_detail_response(**kwargs) self.assertEqual(response.status_code, status.HTTP_200_OK, response.data) self._check_attributes(response.data) return response
def _check_attributes(self, data): for attr in self.attributes_to_check: if isinstance(attr, (tuple, list, set)): value = text_type(attr[1](self.object)) attr = attr[0] else: value = text_type(getattr(self.object, attr)) self.assertEqual(value, text_type(data[attr]), attr)
[docs]class CreateAPITestCaseMixin(object): """Adds a create view test to the test case.""" #: *required*: Dictionary of data to use as the POST request's body. create_data = None #: The name of the field in the response data for looking up the created object in DB. response_lookup_field = 'id'
[docs] def get_create_data(self): """Return the data used for the create request. By default gets the ``create_data`` attribute of this class. :returns: The data dictionary. """ return getattr(self, 'create_data')
[docs] def get_create_url(self): """Return the create endpoint url. :returns: The url of create endpoint. """ return reverse(self._get_create_name())
[docs] def get_create_response(self, data=None, **kwargs): """Send the create request and return the response. :param data: A dictionary of the data to use for the create request. :param kwargs: Extra arguments that are passed to the client's ``post()`` call. :returns: The response object. """ if data is None: data = self.get_create_data() return self.client.post(self.get_create_url(), data or {}, **kwargs)
[docs] def get_lookup_from_response(self, data): """Return value for looking up the created object in DB. :Note: The created object will be looked up using the ``lookup_field`` attribute as key, which defaults to ``pk``. :param data: A dictionary of the response data to lookup the field in. :returns: The value for looking up the """ return data.get(self.response_lookup_field)
[docs] def test_create(self, data=None, **kwargs): """Send request to the create view endpoint, verify and return the response. Also verifies that the object actually exists in the database. :param data: A dictionary of the data to use for the create request. :param kwargs: Extra arguments that are passed to the client's ``post()`` call. :returns: A tuple ``response, created`` of the view's response the created instance. """ response = self.get_create_response(data, **kwargs) self.assertEqual(response.status_code, status.HTTP_201_CREATED, getattr(response, 'data', response)) # another sanity check: # getting the instance from database simply to see that it's found and does not raise any exception created = self.object.__class__.objects.get( **{self.lookup_field: self.get_lookup_from_response(response.data)}) return response, created
def _get_create_name(self): if hasattr(self, 'create_name'): view_name = self.create_name else: view_name = self.base_name + self.LIST_SUFFIX return view_name
[docs]class DestroyAPITestCaseMixin(object): """Adds a destroy view test to the test case."""
[docs] def get_destroy_url(self): """Return the destroy endpoint url. :returns: The url of destroy endpoint. """ self.object_id = getattr(self.object, self.lookup_field) return reverse(self._get_destroy_name(), args=(self.object_id,))
[docs] def get_destroy_response(self, **kwargs): """Send the destroy request and return the response. :param kwargs: Extra arguments that are passed to the client's ``delete()`` call. :returns: The view's response. """ return self.client.delete(self.get_destroy_url(), **kwargs)
[docs] def test_destroy(self, **kwargs): """Send request to the destroy view endpoint, verify and return the response. Also verifies the object does not exist anymore in the database. :param kwargs: Extra arguments that are passed to the client's ``delete()`` call. :returns: The view's response. """ response = self.get_destroy_response(**kwargs) self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT, response.data) # Another sanity check: # see that the instance is removed from the database. self.assertRaises(ObjectDoesNotExist, self.object.__class__.objects.get, **{self.lookup_field: self.object_id}) return response
def _get_destroy_name(self): if hasattr(self, 'destroy_name'): view_name = self.destroy_name else: view_name = self.base_name + self.DETAIL_SUFFIX return view_name
[docs]class UpdateAPITestCaseMixin(object): """Adds an update view test to the test case.""" #: Whether to send a PATCH request instead of PUT. Defaults to ``True``. use_patch = True #: *required*: Dictionary of data to use as the update request's body. update_data = None #: Dictionary mapping attributes to values to check against the updated instance in the database. #: Defaults to ``update_data``. update_results = None #: The name of the field in the response data for looking up the created object in DB. relationship_lookup_field = 'id'
[docs] def get_update_url(self): """Return the update endpoint url. :returns: The url of update endpoint. """ self.object_id = getattr(self.object, self.lookup_field) return reverse(self._get_update_name(), args=(self.object_id,))
[docs] def get_update_response(self, data=None, results=None, use_patch=None, **kwargs): """Send the update request and return the response. :param data: Data dictionary for the update request. :param results: Dictionary mapping instance properties to expected values. :param kwargs: Extra arguments that are passed to the client's ``put()`` or ``patch()`` call. :returns: The response object. """ if data is None: data = self.get_update_data() self.__data = data if results is None: results = self.get_update_results(data) self.__results = results args = [self.get_update_url(), data] if use_patch is None: use_patch = self.use_patch return self.client.patch(*args, **kwargs) if use_patch else self.client.put(*args, **kwargs)
[docs] def get_update_data(self): """Return the data used for the update request. By default gets the ``update_data`` attribute of this class. :returns: Data dictionary for the update request. """ return getattr(self, 'update_data')
[docs] def get_update_results(self, data=None): """Return a dictionary of the expected results of the instance. By default gets the ``update_results`` attribute of this class. If that isn't set defaults to the data. :param data: The update request's data dictionary. :returns: Dictionary mapping instance properties to expected values. """ return getattr(self, 'update_results', data)
[docs] def get_relationship_value(self, related_obj, key): """Return a value representing a relation to a related model instance. By default gets the ``relationship_lookup_field`` attribute of this class which defaults to ``id``, and converts it to a ``string``. :param related_obj: The related model instance to convert to a value. :param key: A ``string`` representing the name of the relation, or the key on the updated object. :returns: Value representing the relation to assert against. """ return text_type(getattr(related_obj, getattr(self, 'relationship_lookup_field')))
[docs] def test_update(self, data=None, results=None, use_patch=None, **kwargs): """Send request to the update view endpoint, verify and return the response. :param data: Data dictionary for the update request. :param results: Dictionary mapping instance properties to expected values. :param kwargs: Extra arguments that are passed to the client's ``put()`` or ``patch()`` call. :returns: A tuple ``response, updated`` of the view's response the updated instance. """ response = self.get_update_response(data, results, use_patch, **kwargs) self.assertEqual(response.status_code, status.HTTP_200_OK, response.data) # getting a fresh copy of the object from DB updated = self.object.__class__.objects.get(**{self.lookup_field: self.object_id}) # Sanity check: # check that the copy in the database was updated as expected. self._update_check_db(updated, data, results) return response, updated
def _get_update_name(self): if hasattr(self, 'update_name'): view_name = self.update_name else: view_name = self.base_name + self.DETAIL_SUFFIX return view_name def _update_check_db(self, obj, data=None, results=None): if data is None: data = self.__data if results is None: results = self.__results or {} for key, value in six.iteritems(data): # check if ``obj`` is a dict to allow overriding ``_update_check_db()`` # and perform checks on a serialized object if isinstance(obj, dict): attribute = obj.get(key) if isinstance(attribute, list): self.assertListEqual(attribute, value, key) continue else: # check for foreign key if hasattr(obj, '%s_id' % key): related = getattr(obj, key) attribute = self.get_relationship_value(related, key) else: attribute = getattr(obj, key) # Handle case of a ManyToMany relation if isinstance(attribute, Manager): items = {self.get_relationship_value(item, key) for item in attribute.all()} self.assertTrue(set(value).issubset(items), key) continue self.assertEqual(attribute, results.get(key, value), key)
[docs]class ReadRESTAPITestCaseMixin(ListAPITestCaseMixin, DetailAPITestCaseMixin): """Adds the read CRUD operations tests to the test case. Includes: :class:`ListAPITestCaseMixin`, :class:`DetailAPITestCaseMixin`. """ pass
[docs]class WriteRESTAPITestCaseMixin(CreateAPITestCaseMixin, UpdateAPITestCaseMixin, DestroyAPITestCaseMixin): """Adds the write CRUD operations tests to the test case. Includes: :class:`CreateAPITestCaseMixin`, :class:`UpdateAPITestCaseMixin`, :class:`DestroyAPITestCaseMixin`. """ pass
[docs]class ReadWriteRESTAPITestCaseMixin(ReadRESTAPITestCaseMixin, WriteRESTAPITestCaseMixin): """A complete API test case that covers all successful CRUD operation requests. Includes: :class:`ReadRESTAPITestCaseMixin`, :class:`WriteRESTAPITestCaseMixin`. """ pass