From 9b5f995c1094b527b532dc012fbcd7c45ab388f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Diemer?= Date: Mon, 13 Jun 2016 20:54:32 +0200 Subject: [PATCH] add viewset support --- rest_framework_docs/api_docs.py | 5 ++-- rest_framework_docs/api_endpoint.py | 37 +++++++++++++++++++++++++++-- rest_framework_docs/views.py | 3 ++- tests/tests.py | 13 +++++++++- tests/urls.py | 8 ++++++- tests/views.py | 12 ++++++++++ 6 files changed, 71 insertions(+), 7 deletions(-) diff --git a/rest_framework_docs/api_docs.py b/rest_framework_docs/api_docs.py index c51c205..d22dd4c 100644 --- a/rest_framework_docs/api_docs.py +++ b/rest_framework_docs/api_docs.py @@ -8,8 +8,9 @@ class ApiDocumentation(object): - def __init__(self): + def __init__(self, drf_router=None): self.endpoints = [] + self.drf_router = drf_router try: root_urlconf = import_string(settings.ROOT_URLCONF) except ImportError: @@ -26,7 +27,7 @@ def get_all_view_names(self, urlpatterns, parent_pattern=None): parent_pattern = None if pattern._regex == "^" else pattern self.get_all_view_names(urlpatterns=pattern.url_patterns, parent_pattern=parent_pattern) elif isinstance(pattern, RegexURLPattern) and self._is_drf_view(pattern) and not self._is_format_endpoint(pattern): - api_endpoint = ApiEndpoint(pattern, parent_pattern) + api_endpoint = ApiEndpoint(pattern, parent_pattern, self.drf_router) self.endpoints.append(api_endpoint) def _is_drf_view(self, pattern): diff --git a/rest_framework_docs/api_endpoint.py b/rest_framework_docs/api_endpoint.py index 8eb07c9..f302aa8 100644 --- a/rest_framework_docs/api_endpoint.py +++ b/rest_framework_docs/api_endpoint.py @@ -6,7 +6,8 @@ class ApiEndpoint(object): - def __init__(self, pattern, parent_pattern=None): + def __init__(self, pattern, parent_pattern=None, drf_router=None): + self.drf_router = drf_router self.pattern = pattern self.callback = pattern.callback # self.name = pattern.name @@ -26,7 +27,39 @@ def __get_path__(self, parent_pattern): return simplify_regex(self.pattern.regex.pattern) def __get_allowed_methods__(self): - return [force_str(m).upper() for m in self.callback.cls.http_method_names if hasattr(self.callback.cls, m)] + + viewset_methods = [] + if self.drf_router: + for prefix, viewset, basename in self.drf_router.registry: + if self.callback.cls != viewset: + continue + + lookup = self.drf_router.get_lookup_regex(viewset) + routes = self.drf_router.get_routes(viewset) + + for route in routes: + + # Only actions which actually exist on the viewset will be bound + mapping = self.drf_router.get_method_map(viewset, route.mapping) + if not mapping: + continue + + # Build the url pattern + regex = route.url.format( + prefix=prefix, + lookup=lookup, + trailing_slash=self.drf_router.trailing_slash + ) + if self.pattern.regex.pattern == regex: + funcs, viewset_methods = zip( + *[(mapping[m], m.upper()) for m in self.callback.cls.http_method_names if m in mapping] + ) + viewset_methods = list(viewset_methods) + if len(set(funcs)) == 1: + self.docstring = inspect.getdoc(getattr(self.callback.cls, funcs[0])) + + view_methods = [force_str(m).upper() for m in self.callback.cls.http_method_names if hasattr(self.callback.cls, m)] + return viewset_methods + view_methods def __get_docstring__(self): return inspect.getdoc(self.callback) diff --git a/rest_framework_docs/views.py b/rest_framework_docs/views.py index 04074cc..50400d4 100644 --- a/rest_framework_docs/views.py +++ b/rest_framework_docs/views.py @@ -7,6 +7,7 @@ class DRFDocsView(TemplateView): template_name = "rest_framework_docs/home.html" + drf_router = None def get_context_data(self, **kwargs): settings = DRFSettings().settings @@ -14,7 +15,7 @@ def get_context_data(self, **kwargs): raise Http404("Django Rest Framework Docs are hidden. Check your settings.") context = super(DRFDocsView, self).get_context_data(**kwargs) - docs = ApiDocumentation() + docs = ApiDocumentation(drf_router=self.drf_router) endpoints = docs.get_endpoints() query = self.request.GET.get("search", "") diff --git a/tests/tests.py b/tests/tests.py index 57028b7..935510d 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -27,7 +27,7 @@ def test_index_view_with_endpoints(self): response = self.client.get(reverse('drfdocs')) self.assertEqual(response.status_code, 200) - self.assertEqual(len(response.context["endpoints"]), 11) + self.assertEqual(len(response.context["endpoints"]), 14) # Test the login view self.assertEqual(response.context["endpoints"][0].name_parent, "accounts") @@ -67,3 +67,14 @@ def test_index_view_docs_hidden(self): self.assertEqual(response.status_code, 404) self.assertEqual(response.reason_phrase.upper(), "NOT FOUND") + + def test_model_viewset(self): + response = self.client.get(reverse('drfdocs')) + + self.assertEqual(response.status_code, 200) + self.assertEqual(response.context["endpoints"][10].path, '/organisation-model-viewsets/') + self.assertEqual(response.context["endpoints"][11].path, '/organisation-model-viewsets//') + self.assertEqual(response.context["endpoints"][10].allowed_methods, ['GET', 'POST', 'OPTIONS']) + self.assertEqual(response.context["endpoints"][11].allowed_methods, ['GET', 'PUT', 'PATCH', 'DELETE', 'OPTIONS']) + self.assertEqual(response.context["endpoints"][12].allowed_methods, ['POST', 'OPTIONS']) + self.assertEqual(response.context["endpoints"][12].docstring, 'This is a test.') diff --git a/tests/urls.py b/tests/urls.py index 092973c..bc1c167 100644 --- a/tests/urls.py +++ b/tests/urls.py @@ -2,6 +2,8 @@ from django.conf.urls import include, url from django.contrib import admin +from rest_framework.routers import SimpleRouter +from rest_framework_docs.views import DRFDocsView from tests import views accounts_urls = [ @@ -23,13 +25,17 @@ url(r'^(?P[\w-]+)/errored/$', view=views.OrganisationErroredView.as_view(), name="errored") ] +router = SimpleRouter() +router.register('organisation-model-viewsets', views.TestModelViewSet, base_name='organisation') + urlpatterns = [ url(r'^admin/', include(admin.site.urls)), - url(r'^docs/', include('rest_framework_docs.urls')), + url(r'^docs/', DRFDocsView.as_view(drf_router=router), name='drfdocs'), # API url(r'^accounts/', view=include(accounts_urls, namespace='accounts')), url(r'^organisations/', view=include(organisations_urls, namespace='organisations')), + url(r'^', include(router.urls)), # Endpoints without parents/namespaces url(r'^another-login/$', views.LoginView.as_view(), name="login"), diff --git a/tests/views.py b/tests/views.py index 64e0bec..e786d5e 100644 --- a/tests/views.py +++ b/tests/views.py @@ -5,9 +5,11 @@ from rest_framework import parsers, renderers, generics, status from rest_framework.authtoken.models import Token from rest_framework.authtoken.serializers import AuthTokenSerializer +from rest_framework.decorators import detail_route from rest_framework.permissions import AllowAny from rest_framework.response import Response from rest_framework.views import APIView +from rest_framework.viewsets import ModelViewSet from tests.models import User, Organisation, Membership from tests import serializers @@ -132,3 +134,13 @@ def post(self, request): def get_serializer_class(self): return AuthTokenSerializer + + +class TestModelViewSet(ModelViewSet): + queryset = Organisation.objects.all() + serializer_class = serializers.OrganisationMembersSerializer + + @detail_route(methods=['post']) + def test_route(self, request): + """This is a test.""" + return Response()