diff --git a/rest_framework_docs/api_endpoint.py b/rest_framework_docs/api_endpoint.py index f37155c..8eb07c9 100644 --- a/rest_framework_docs/api_endpoint.py +++ b/rest_framework_docs/api_endpoint.py @@ -37,22 +37,27 @@ def __get_permissions_class__(self): def __get_serializer_fields__(self): fields = [] + serializer = None - if hasattr(self.callback.cls, 'serializer_class') and hasattr(self.callback.cls.serializer_class, 'get_fields'): + if hasattr(self.callback.cls, 'serializer_class'): serializer = self.callback.cls.serializer_class - if hasattr(serializer, 'get_fields'): - try: - fields = [{ - "name": key, - "type": str(field.__class__.__name__), - "required": field.required - } for key, field in serializer().get_fields().items()] - except KeyError as e: - self.errors = e - fields = [] - - # FIXME: - # Show more attibutes of `field`? + + elif hasattr(self.callback.cls, 'get_serializer_class'): + serializer = self.callback.cls.get_serializer_class(self.pattern.callback.cls()) + + if hasattr(serializer, 'get_fields'): + try: + fields = [{ + "name": key, + "type": str(field.__class__.__name__), + "required": field.required + } for key, field in serializer().get_fields().items()] + except KeyError as e: + self.errors = e + fields = [] + + # FIXME: + # Show more attibutes of `field`? return fields diff --git a/tests/tests.py b/tests/tests.py index afb58d0..57028b7 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -27,7 +27,7 @@ def test_index_view_with_endpoints(self): response = self.client.get(reverse('drfdocs')) self.assertEqual(response.status_code, 200) - self.assertEqual(len(response.context["endpoints"]), 10) + self.assertEqual(len(response.context["endpoints"]), 11) # Test the login view self.assertEqual(response.context["endpoints"][0].name_parent, "accounts") @@ -38,8 +38,16 @@ def test_index_view_with_endpoints(self): self.assertEqual(response.context["endpoints"][0].fields[0]["type"], "CharField") self.assertTrue(response.context["endpoints"][0].fields[0]["required"]) + self.assertEqual(response.context["endpoints"][1].name_parent, "accounts") + self.assertEqual(response.context["endpoints"][1].allowed_methods, ['POST', 'OPTIONS']) + self.assertEqual(response.context["endpoints"][1].path, "/accounts/login2/") + self.assertEqual(response.context["endpoints"][1].docstring, "A view that allows users to login providing their username and password. Without serializer_class") + self.assertEqual(len(response.context["endpoints"][1].fields), 2) + self.assertEqual(response.context["endpoints"][1].fields[0]["type"], "CharField") + self.assertTrue(response.context["endpoints"][1].fields[0]["required"]) + # The view "OrganisationErroredView" (organisations/(?P[\w-]+)/errored/) should contain an error. - self.assertEqual(str(response.context["endpoints"][8].errors), "'test_value'") + self.assertEqual(str(response.context["endpoints"][9].errors), "'test_value'") def test_index_search_with_endpoints(self): response = self.client.get("%s?search=reset-password" % reverse("drfdocs")) diff --git a/tests/urls.py b/tests/urls.py index b226620..092973c 100644 --- a/tests/urls.py +++ b/tests/urls.py @@ -6,6 +6,7 @@ accounts_urls = [ url(r'^login/$', views.LoginView.as_view(), name="login"), + url(r'^login2/$', views.LoginWithSerilaizerClassView.as_view(), name="login2"), url(r'^register/$', views.UserRegistrationView.as_view(), name="register"), url(r'^reset-password/$', view=views.PasswordResetView.as_view(), name="reset-password"), url(r'^reset-password/confirm/$', views.PasswordResetConfirmView.as_view(), name="reset-password-confirm"), diff --git a/tests/views.py b/tests/views.py index c6987d0..64e0bec 100644 --- a/tests/views.py +++ b/tests/views.py @@ -111,3 +111,24 @@ def delete(self, request, *args, **kwargs): class OrganisationErroredView(generics.ListAPIView): serializer_class = serializers.OrganisationErroredSerializer + + +class LoginWithSerilaizerClassView(APIView): + """ + A view that allows users to login providing their username and password. Without serializer_class + """ + + throttle_classes = () + permission_classes = () + parser_classes = (parsers.FormParser, parsers.MultiPartParser, parsers.JSONParser,) + renderer_classes = (renderers.JSONRenderer,) + + def post(self, request): + serializer = self.serializer_class(data=request.data) + serializer.is_valid(raise_exception=True) + user = serializer.validated_data['user'] + token, created = Token.objects.get_or_create(user=user) + return Response({'token': token.key}) + + def get_serializer_class(self): + return AuthTokenSerializer