From abd73b9fc13f8305fea98295c7fd592c2db7818b Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Thu, 26 Jul 2018 17:25:14 +1200 Subject: [PATCH 01/72] Django Channels v2 --- README.md | 81 +++++- README.rst | 180 +++++++++++++- examples/django_channels2/Pipfile | 14 ++ examples/django_channels2/Pipfile.lock | 234 ++++++++++++++++++ .../django_channels2/__init__.py | 0 .../django_channels2/routing.py | 9 + .../django_channels2/schema.py | 24 ++ .../django_channels2/settings.py | 121 +++++++++ .../django_channels2/django_channels2/urls.py | 8 + .../django_channels2/django_channels2/wsgi.py | 16 ++ examples/django_channels2/manage.py | 15 ++ graphql_ws/django/__init__.py | 0 graphql_ws/django/graphql_channels.py | 129 ++++++++++ .../django/templates/graphene/graphiql.html | 135 ++++++++++ 14 files changed, 963 insertions(+), 3 deletions(-) create mode 100644 examples/django_channels2/Pipfile create mode 100644 examples/django_channels2/Pipfile.lock create mode 100644 examples/django_channels2/django_channels2/__init__.py create mode 100644 examples/django_channels2/django_channels2/routing.py create mode 100644 examples/django_channels2/django_channels2/schema.py create mode 100644 examples/django_channels2/django_channels2/settings.py create mode 100644 examples/django_channels2/django_channels2/urls.py create mode 100644 examples/django_channels2/django_channels2/wsgi.py create mode 100755 examples/django_channels2/manage.py create mode 100644 graphql_ws/django/__init__.py create mode 100644 graphql_ws/django/graphql_channels.py create mode 100644 graphql_ws/django/templates/graphene/graphiql.html diff --git a/README.md b/README.md index 338bb9b..a45a1aa 100644 --- a/README.md +++ b/README.md @@ -132,7 +132,6 @@ You can see a full example here: https://github.com/graphql-python/graphql-ws/tr ### Django Channels - First `pip install channels` and it to your django apps Then add the following to your settings.py @@ -202,4 +201,82 @@ from graphql_ws.django_channels import GraphQLSubscriptionConsumer channel_routing = [ route_class(GraphQLSubscriptionConsumer, path=r"^/subscriptions"), ] -``` \ No newline at end of file +``` + +### Django Channels 2 + +Set up with Django Channels just takes three steps: + +1. Install the apps +2. Set up schema +3. Set up channels Router + + +First `pip install channels` and it to your `INSTALLED_APPS`. If you want +graphiQL, install `graphql_ws.django` app before `graphene_django` to serve a +graphiql template that will work with websockets: + +```python +INSTALLED_APPS = [ + "channels", + "graphql_ws.django", + "graphene_django", + # ... +] +``` + + +Next, set up your graphql schema: + +```python +import graphene +from rx import Observable + + +class Query(graphene.ObjectType): + hello = graphene.String() + + def resolve_hello(self, info, **kwargs): + return "world" + + +class Subscription(graphene.ObjectType): + + count_seconds = graphene.Int(up_to=graphene.Int()) + + def resolve_count_seconds(root, info, up_to=5): + return ( + Observable.interval(1000) + .map(lambda i: "{0}".format(i)) + .take_while(lambda i: int(i) <= up_to) + ) + + +schema = graphene.Schema(query=Query, subscription=Subscription) +``` + +...and point to your schema in Django settings +```python +GRAPHENE = { + 'SCHEMA': 'yourproject.schema' +} +``` + + +Finally, configure channels routing (it'll be served from `/subscriptions`): + +```python +from channels.routing import ProtocolTypeRouter, URLRouter +from graphql_ws.django.graphql_channels import ( + websocket_urlpatterns as graphql_urlpatterns +) + +application = ProtocolTypeRouter({"websocket": URLRouter(graphql_urlpatterns)}) +``` + +...and point to the application in Django settings +```python +ASGI_APPLICATION = 'yourproject.schema' +``` + +Run `./manage.py runserver` and go to `http://localhost:8000/graphql` to test! diff --git a/README.rst b/README.rst index 6a909a1..7dc6f3c 100644 --- a/README.rst +++ b/README.rst @@ -5,7 +5,9 @@ Websocket server for GraphQL subscriptions. Currently supports: \* `aiohttp `__ \* -`Gevent `__ +`Gevent `__ \* +Sanic (uses `websockets `__ +library) Installation instructions ========================= @@ -44,6 +46,29 @@ For setting up, just plug into your aiohttp server. web.run_app(app, port=8000) +Sanic +~~~~~ + +Works with any framework that uses the websockets library for it's +websocket implementation. For this example, plug in your Sanic server. + +.. code:: python + + from graphql_ws.websockets_lib import WsLibSubscriptionServer + + + app = Sanic(__name__) + + subscription_server = WsLibSubscriptionServer(schema) + + @app.websocket('/subscriptions', subprotocols=['graphql-ws']) + async def subscriptions(request, ws): + await subscription_server.handle(ws) + return ws + + + app.run(host="0.0.0.0", port=8000) + And then, plug into a subscribable schema: .. code:: python @@ -111,3 +136,156 @@ And then, plug into a subscribable schema: You can see a full example here: https://github.com/graphql-python/graphql-ws/tree/master/examples/flask\_gevent + +Django Channels +~~~~~~~~~~~~~~~ + +First ``pip install channels`` and it to your django apps + +Then add the following to your settings.py + +.. code:: python + + CHANNELS_WS_PROTOCOLS = ["graphql-ws", ] + CHANNEL_LAYERS = { + "default": { + "BACKEND": "asgiref.inmemory.ChannelLayer", + "ROUTING": "django_subscriptions.urls.channel_routing", + }, + + } + +Setup your graphql schema + +.. code:: python + + import graphene + from rx import Observable + + + class Query(graphene.ObjectType): + hello = graphene.String() + + def resolve_hello(self, info, **kwargs): + return 'world' + + class Subscription(graphene.ObjectType): + + count_seconds = graphene.Int(up_to=graphene.Int()) + + + def resolve_count_seconds( + root, + info, + up_to=5 + ): + return Observable.interval(1000)\ + .map(lambda i: "{0}".format(i))\ + .take_while(lambda i: int(i) <= up_to) + + + + schema = graphene.Schema( + query=Query, + subscription=Subscription + ) + +Setup your schema in settings.py + +.. code:: python + + GRAPHENE = { + 'SCHEMA': 'path.to.schema' + } + +and finally add the channel routes + +.. code:: python + + from channels.routing import route_class + from graphql_ws.django_channels import GraphQLSubscriptionConsumer + + channel_routing = [ + route_class(GraphQLSubscriptionConsumer, path=r"^/subscriptions"), + ] + +Django Channels 2 +~~~~~~~~~~~~~~~~~ + +Set up with Django Channels just takes three steps: + +1. Install the apps +2. Set up schema +3. Set up channels Router + +First ``pip install channels`` and it to your ``INSTALLED_APPS``. If you +want graphiQL, install ``graphql_ws.django`` app before +``graphene_django`` to serve a graphiql template that will work with +websockets: + +.. code:: python + + INSTALLED_APPS = [ + "channels", + "graphql_ws.django", + "graphene_django", + # ... + ] + +Next, set up your graphql schema: + +.. code:: python + + import graphene + from rx import Observable + + + class Query(graphene.ObjectType): + hello = graphene.String() + + def resolve_hello(self, info, **kwargs): + return "world" + + + class Subscription(graphene.ObjectType): + + count_seconds = graphene.Int(up_to=graphene.Int()) + + def resolve_count_seconds(root, info, up_to=5): + return ( + Observable.interval(1000) + .map(lambda i: "{0}".format(i)) + .take_while(lambda i: int(i) <= up_to) + ) + + + schema = graphene.Schema(query=Query, subscription=Subscription) + +...and point to your schema in Django settings + +.. code:: python + + GRAPHENE = { + 'SCHEMA': 'yourproject.schema' + } + +Finally, configure channels routing (it'll be served from +``/subscriptions``): + +.. code:: python + + from channels.routing import ProtocolTypeRouter, URLRouter + from graphql_ws.django.graphql_channels import ( + websocket_urlpatterns as graphql_urlpatterns + ) + + application = ProtocolTypeRouter({"websocket": URLRouter(graphql_urlpatterns)}) + +...and point to the application in Django settings + +.. code:: python + + ASGI_APPLICATION = 'yourproject.schema' + +Run ``./manage.py runserver`` and go to +``http://localhost:8000/graphql`` to test! diff --git a/examples/django_channels2/Pipfile b/examples/django_channels2/Pipfile new file mode 100644 index 0000000..025e468 --- /dev/null +++ b/examples/django_channels2/Pipfile @@ -0,0 +1,14 @@ +[[source]] +url = "https://pypi.org/simple" +verify_ssl = true +name = "pypi" + +[dev-packages] + +[packages] +graphql-ws = {path = "./../..", editable = true} +channels = "*" +graphene-django = "*" + +[requires] +python_version = "3.6" diff --git a/examples/django_channels2/Pipfile.lock b/examples/django_channels2/Pipfile.lock new file mode 100644 index 0000000..1c5b40c --- /dev/null +++ b/examples/django_channels2/Pipfile.lock @@ -0,0 +1,234 @@ +{ + "_meta": { + "hash": { + "sha256": "e13b3565bfb7dfc05c521ba15ff30a2e1dc76fbee4f629b216d61117dcbd686c" + }, + "pipfile-spec": 6, + "requires": { + "python_version": "3.6" + }, + "sources": [ + { + "name": "pypi", + "url": "https://pypi.org/simple", + "verify_ssl": true + } + ] + }, + "default": { + "aniso8601": { + "hashes": [ + "sha256:7849749cf00ae0680ad2bdfe4419c7a662bef19c03691a19e008c8b9a5267802", + "sha256:94f90871fcd314a458a3d4eca1c84448efbd200e86f55fe4c733c7a40149ef50" + ], + "version": "==3.0.2" + }, + "asgiref": { + "hashes": [ + "sha256:9b05dcd41a6a89ca8c6e7f7e4089c3f3e76b5af60aebb81ae6d455ad81989c97", + "sha256:b21dc4c43d7aba5a844f4c48b8f49d56277bc34937fd9f9cb93ec97fde7e3082" + ], + "markers": "python_version >= '3.5.3'", + "version": "==2.3.2" + }, + "async-timeout": { + "hashes": [ + "sha256:474d4bc64cee20603e225eb1ece15e248962958b45a3648a9f5cc29e827a610c", + "sha256:b3c0ddc416736619bd4a95ca31de8da6920c3b9a140c64dbef2b2fa7bf521287" + ], + "markers": "python_version >= '3.5.3'", + "version": "==3.0.0" + }, + "attrs": { + "hashes": [ + "sha256:4b90b09eeeb9b88c35bc642cbac057e45a5fd85367b985bd2809c62b7b939265", + "sha256:e0d0eb91441a3b53dab4d9b743eafc1ac44476296a2053b6ca3af0b139faf87b" + ], + "version": "==18.1.0" + }, + "autobahn": { + "hashes": [ + "sha256:2f41bfc512ec482044fa8cfa74182118dedd87e03b3494472d9ff1b5a1e27d24", + "sha256:83e050f0e0783646dbf6da60fe837f4f825c241080d2b5f080002ae6885b036f" + ], + "version": "==18.6.1" + }, + "automat": { + "hashes": [ + "sha256:cbd78b83fa2d81fe2a4d23d258e1661dd7493c9a50ee2f1a5b2cac61c1793b0e", + "sha256:fdccab66b68498af9ecfa1fa43693abe546014dd25cf28543cbe9d1334916a58" + ], + "version": "==0.7.0" + }, + "channels": { + "hashes": [ + "sha256:173441ccf2ac3a93c3b4f86135db301cbe95be58f5815c1e071f2e7154c192a2", + "sha256:3c308108161596ddaa1b9e9f0ed9568a34ee4ebefaa33bc9cc4e941561363add" + ], + "index": "pypi", + "version": "==2.1.2" + }, + "constantly": { + "hashes": [ + "sha256:586372eb92059873e29eba4f9dec8381541b4d3834660707faf8ba59146dfc35", + "sha256:dd2fa9d6b1a51a83f0d7dd76293d734046aa176e384bf6e33b7e44880eb37c5d" + ], + "version": "==15.1.0" + }, + "daphne": { + "hashes": [ + "sha256:bc49584532b2d52116f9a99af2d45d92092de93ccf2fc36a433eb7155d48b2a3", + "sha256:da19b36605cc64d1e3a888a95ff90495dcbb75a25cfee173606a7d86112ebbca" + ], + "version": "==2.2.1" + }, + "django": { + "hashes": [ + "sha256:97886b8a13bbc33bfeba2ff133035d3eca014e2309dff2b6da0bdfc0b8656613", + "sha256:e900b73beee8977c7b887d90c6c57d68af10066b9dac898e1eaf0f82313de334" + ], + "version": "==2.0.7" + }, + "graphene": { + "hashes": [ + "sha256:b8ec446d17fa68721636eaad3d6adc1a378cb6323e219814c8f98c9928fc9642", + "sha256:faa26573b598b22ffd274e2fd7a4c52efa405dcca96e01a62239482246248aa3" + ], + "version": "==2.1.3" + }, + "graphene-django": { + "hashes": [ + "sha256:6abc3ec4f1dcbd91faeb3ce772b428e431807b8ec474f9dae918cff74bf7f6b1", + "sha256:b336eecbf03e6fa12a53288d22015c7035727ffaa8fdd89c93fd41d9b942dd91" + ], + "index": "pypi", + "version": "==2.1.0" + }, + "graphql-core": { + "hashes": [ + "sha256:889e869be5574d02af77baf1f30b5db9ca2959f1c9f5be7b2863ead5a3ec6181", + "sha256:9462e22e32c7f03b667373ec0a84d95fba10e8ce2ead08f29fbddc63b671b0c1" + ], + "version": "==2.1" + }, + "graphql-relay": { + "hashes": [ + "sha256:2716b7245d97091af21abf096fabafac576905096d21ba7118fba722596f65db" + ], + "version": "==0.4.5" + }, + "graphql-ws": { + "editable": true, + "path": "./../.." + }, + "hyperlink": { + "hashes": [ + "sha256:98da4218a56b448c7ec7d2655cb339af1f7d751cf541469bb4fc28c4a4245b34", + "sha256:f01b4ff744f14bc5d0a22a6b9f1525ab7d6312cb0ff967f59414bbac52f0a306" + ], + "version": "==18.0.0" + }, + "idna": { + "hashes": [ + "sha256:156a6814fb5ac1fc6850fb002e0852d56c0c8d2531923a51032d1b70760e186e", + "sha256:684a38a6f903c1d71d6d5fac066b58d7768af4de2b832e426ec79c30daa94a16" + ], + "version": "==2.7" + }, + "incremental": { + "hashes": [ + "sha256:717e12246dddf231a349175f48d74d93e2897244939173b01974ab6661406b9f", + "sha256:7b751696aaf36eebfab537e458929e194460051ccad279c72b755a167eebd4b3" + ], + "version": "==17.5.0" + }, + "iso8601": { + "hashes": [ + "sha256:210e0134677cc0d02f6028087fee1df1e1d76d372ee1db0bf30bf66c5c1c89a3", + "sha256:49c4b20e1f38aa5cf109ddcd39647ac419f928512c869dc01d5c7098eddede82", + "sha256:bbbae5fb4a7abfe71d4688fd64bff70b91bbd74ef6a99d964bab18f7fdf286dd" + ], + "version": "==0.1.12" + }, + "promise": { + "hashes": [ + "sha256:0bca4ed933e3d50e3d18fb54fc1432fa84b0564838cd093e824abcd718ab9304", + "sha256:95506bac89df7a495e0b8c813fd782dd1ae590decb52f95248e316c6659ca49b" + ], + "version": "==2.1" + }, + "pyhamcrest": { + "hashes": [ + "sha256:6b672c02fdf7470df9674ab82263841ce8333fb143f32f021f6cb26f0e512420", + "sha256:8ffaa0a53da57e89de14ced7185ac746227a8894dbd5a3c718bf05ddbd1d56cd" + ], + "version": "==1.9.0" + }, + "pytz": { + "hashes": [ + "sha256:a061aa0a9e06881eb8b3b2b43f05b9439d6583c206d0a6c340ff72a7b6669053", + "sha256:ffb9ef1de172603304d9d2819af6f5ece76f2e85ec10692a524dd876e72bf277" + ], + "version": "==2018.5" + }, + "rx": { + "hashes": [ + "sha256:13a1d8d9e252625c173dc795471e614eadfe1cf40ffc684e08b8fff0d9748c23", + "sha256:7357592bc7e881a95e0c2013b73326f704953301ab551fbc8133a6fadab84105" + ], + "version": "==1.6.1" + }, + "singledispatch": { + "hashes": [ + "sha256:5b06af87df13818d14f08a028e42f566640aef80805c3b50c5056b086e3c2b9c", + "sha256:833b46966687b3de7f438c761ac475213e53b306740f1abfaa86e1d1aae56aa8" + ], + "version": "==3.4.0.3" + }, + "six": { + "hashes": [ + "sha256:70e8a77beed4562e7f14fe23a786b54f6296e34344c23bc42f07b15018ff98e9", + "sha256:832dc0e10feb1aa2c68dcc57dbb658f1c7e65b9b61af69048abc87a2db00a0eb" + ], + "version": "==1.11.0" + }, + "twisted": { + "hashes": [ + "sha256:5de7b79b26aee64efe63319bb8f037af88c21287d1502b39706c818065b3d5a4", + "sha256:95ae985716e8107816d8d0df249d558dbaabb677987cc2ace45272c166b267e4" + ], + "version": "==18.7.0" + }, + "txaio": { + "hashes": [ + "sha256:4797f9f6a9866fe887c96abc0110a226dd5744c894dc3630870542597ad30853", + "sha256:c25acd6c2ef7005a0cd50fa2b65deac409be2f3886e2fcd04f99fae827b179e4" + ], + "version": "==2.10.0" + }, + "typing": { + "hashes": [ + "sha256:3a887b021a77b292e151afb75323dea88a7bc1b3dfa92176cff8e44c8b68bddf", + "sha256:b2c689d54e1144bbcfd191b0832980a21c2dbcf7b5ff7a66248a60c90e951eb8", + "sha256:d400a9344254803a2368533e4533a4200d21eb7b6b729c173bc38201a74db3f2" + ], + "version": "==3.6.4" + }, + "zope.interface": { + "hashes": [ + "sha256:21506674d30c009271fe68a242d330c83b1b9d76d62d03d87e1e9528c61beea6", + "sha256:3d184aff0756c44fff7de69eb4cd5b5311b6f452d4de28cb08343b3f21993763", + "sha256:467d364b24cb398f76ad5e90398d71b9325eb4232be9e8a50d6a3b3c7a1c8789", + "sha256:57c38470d9f57e37afb460c399eb254e7193ac7fb8042bd09bdc001981a9c74c", + "sha256:9ada83f4384bbb12dedc152bcdd46a3ac9f5f7720d43ac3ce3e8e8b91d733c10", + "sha256:a1daf9c5120f3cc6f2b5fef8e1d2a3fb7bbbb20ed4bfdc25bc8364bc62dcf54b", + "sha256:e6b77ae84f2b8502d99a7855fa33334a1eb6159de45626905cb3e454c023f339", + "sha256:e881ef610ff48aece2f4ee2af03d2db1a146dc7c705561bd6089b2356f61641f", + "sha256:f41037260deaacb875db250021fe883bf536bf6414a4fd25b25059b02e31b120" + ], + "markers": "python_version != '3.3.*' and python_version != '3.1.*' and python_version != '3.0.*' and python_version != '3.2.*' and python_version >= '2.7'", + "version": "==4.5.0" + } + }, + "develop": {} +} diff --git a/examples/django_channels2/django_channels2/__init__.py b/examples/django_channels2/django_channels2/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/django_channels2/django_channels2/routing.py b/examples/django_channels2/django_channels2/routing.py new file mode 100644 index 0000000..75effe2 --- /dev/null +++ b/examples/django_channels2/django_channels2/routing.py @@ -0,0 +1,9 @@ +from channels.routing import ProtocolTypeRouter, URLRouter +from channels.auth import AuthMiddlewareStack +from graphql_ws.django.graphql_channels import ( + websocket_urlpatterns as graphql_urlpatterns +) + +application = ProtocolTypeRouter( + {"websocket": AuthMiddlewareStack(URLRouter(graphql_urlpatterns))} +) diff --git a/examples/django_channels2/django_channels2/schema.py b/examples/django_channels2/django_channels2/schema.py new file mode 100644 index 0000000..db6893c --- /dev/null +++ b/examples/django_channels2/django_channels2/schema.py @@ -0,0 +1,24 @@ +import graphene +from rx import Observable + + +class Query(graphene.ObjectType): + hello = graphene.String() + + def resolve_hello(self, info, **kwargs): + return "world" + + +class Subscription(graphene.ObjectType): + + count_seconds = graphene.Int(up_to=graphene.Int()) + + def resolve_count_seconds(root, info, up_to=5): + return ( + Observable.interval(1000) + .map(lambda i: "{0}".format(i)) + .take_while(lambda i: int(i) <= up_to) + ) + + +schema = graphene.Schema(query=Query, subscription=Subscription) diff --git a/examples/django_channels2/django_channels2/settings.py b/examples/django_channels2/django_channels2/settings.py new file mode 100644 index 0000000..0fc8224 --- /dev/null +++ b/examples/django_channels2/django_channels2/settings.py @@ -0,0 +1,121 @@ +""" +Django settings for django_channels2 project. + +Generated by 'django-admin startproject' using Django 2.0.7. + +For more information on this file, see +https://docs.djangoproject.com/en/2.0/topics/settings/ + +For the full list of settings and their values, see +https://docs.djangoproject.com/en/2.0/ref/settings/ +""" + +import os + +# Build paths inside the project like this: os.path.join(BASE_DIR, ...) +BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + + +# Quick-start development settings - unsuitable for production +# See https://docs.djangoproject.com/en/2.0/howto/deployment/checklist/ + +# SECURITY WARNING: keep the secret key used in production secret! +SECRET_KEY = "0%1c709jhmggqhk&=tci06iy+%jedfxpcoai69jd8wjzm+k2f0" + +# SECURITY WARNING: don't run with debug turned on in production! +DEBUG = True + +ALLOWED_HOSTS = [] + + +# Application definition + +INSTALLED_APPS = [ + "channels", + "graphql_ws.django", + "graphene_django", + "django.contrib.admin", + "django.contrib.auth", + "django.contrib.contenttypes", + "django.contrib.sessions", + "django.contrib.messages", + "django.contrib.staticfiles", +] + +MIDDLEWARE = [ + "django.middleware.security.SecurityMiddleware", + "django.contrib.sessions.middleware.SessionMiddleware", + "django.middleware.common.CommonMiddleware", + "django.middleware.csrf.CsrfViewMiddleware", + "django.contrib.auth.middleware.AuthenticationMiddleware", + "django.contrib.messages.middleware.MessageMiddleware", + "django.middleware.clickjacking.XFrameOptionsMiddleware", +] + +ROOT_URLCONF = "django_channels2.urls" + +TEMPLATES = [ + { + "BACKEND": "django.template.backends.django.DjangoTemplates", + "DIRS": [], + "APP_DIRS": True, + "OPTIONS": { + "context_processors": [ + "django.template.context_processors.debug", + "django.template.context_processors.request", + "django.contrib.auth.context_processors.auth", + "django.contrib.messages.context_processors.messages", + ] + }, + } +] + +WSGI_APPLICATION = "django_channels2.wsgi.application" +ASGI_APPLICATION = "django_channels2.routing.application" + + +# Database +# https://docs.djangoproject.com/en/2.0/ref/settings/#databases + +DATABASES = { + "default": { + "ENGINE": "django.db.backends.sqlite3", + "NAME": os.path.join(BASE_DIR, "db.sqlite3"), + } +} + + +# Password validation +# https://docs.djangoproject.com/en/2.0/ref/settings/#auth-password-validators + +AUTH_PASSWORD_VALIDATORS = [ + { + "NAME": "django.contrib.auth.password_validation.UserAttributeSimilarityValidator" + }, + {"NAME": "django.contrib.auth.password_validation.MinimumLengthValidator"}, + {"NAME": "django.contrib.auth.password_validation.CommonPasswordValidator"}, + {"NAME": "django.contrib.auth.password_validation.NumericPasswordValidator"}, +] + + +# Internationalization +# https://docs.djangoproject.com/en/2.0/topics/i18n/ + +LANGUAGE_CODE = "en-us" + +TIME_ZONE = "UTC" + +USE_I18N = True + +USE_L10N = True + +USE_TZ = True + + +# Static files (CSS, JavaScript, Images) +# https://docs.djangoproject.com/en/2.0/howto/static-files/ + +STATIC_URL = "/static/" + + +GRAPHENE = {"MIDDLEWARE": [], "SCHEMA": "django_channels2.schema.schema"} diff --git a/examples/django_channels2/django_channels2/urls.py b/examples/django_channels2/django_channels2/urls.py new file mode 100644 index 0000000..addcdfe --- /dev/null +++ b/examples/django_channels2/django_channels2/urls.py @@ -0,0 +1,8 @@ +from django.conf.urls import url +from django.contrib import admin +from graphene_django.views import GraphQLView + +urlpatterns = [ + url(r"^admin/", admin.site.urls), + url(r"^graphql/?$", GraphQLView.as_view(graphiql=True)), +] diff --git a/examples/django_channels2/django_channels2/wsgi.py b/examples/django_channels2/django_channels2/wsgi.py new file mode 100644 index 0000000..49b09bf --- /dev/null +++ b/examples/django_channels2/django_channels2/wsgi.py @@ -0,0 +1,16 @@ +""" +WSGI config for django_channels2 project. + +It exposes the WSGI callable as a module-level variable named ``application``. + +For more information on this file, see +https://docs.djangoproject.com/en/2.0/howto/deployment/wsgi/ +""" + +import os + +from django.core.wsgi import get_wsgi_application + +os.environ.setdefault("DJANGO_SETTINGS_MODULE", "django_channels2.settings") + +application = get_wsgi_application() diff --git a/examples/django_channels2/manage.py b/examples/django_channels2/manage.py new file mode 100755 index 0000000..1b65bb4 --- /dev/null +++ b/examples/django_channels2/manage.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python +import os +import sys + +if __name__ == "__main__": + os.environ.setdefault("DJANGO_SETTINGS_MODULE", "django_channels2.settings") + try: + from django.core.management import execute_from_command_line + except ImportError as exc: + raise ImportError( + "Couldn't import Django. Are you sure it's installed and " + "available on your PYTHONPATH environment variable? Did you " + "forget to activate a virtual environment?" + ) from exc + execute_from_command_line(sys.argv) diff --git a/graphql_ws/django/__init__.py b/graphql_ws/django/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/graphql_ws/django/graphql_channels.py b/graphql_ws/django/graphql_channels.py new file mode 100644 index 0000000..dec6799 --- /dev/null +++ b/graphql_ws/django/graphql_channels.py @@ -0,0 +1,129 @@ +from asgiref.sync import async_to_sync +from channels.generic.websocket import AsyncJsonWebsocketConsumer +from graphene_django.settings import graphene_settings +from graphql.execution.executors.asyncio import AsyncioExecutor +from django.urls import path +from rx import Observer, Observable +from ..base import BaseConnectionContext, BaseSubscriptionServer +from ..constants import GQL_CONNECTION_ACK, GQL_CONNECTION_ERROR, WS_PROTOCOL + + +class ChannelsConnectionContext(BaseConnectionContext): + async def send(self, data): + await self.ws.send_json(data) + + async def close(self, code): + await self.ws.close(code=code) + + +class ChannelsSubscriptionServer(BaseSubscriptionServer): + def get_graphql_params(self, connection_context, payload): + payload["context"] = connection_context.request_context + params = super(ChannelsSubscriptionServer, self).get_graphql_params( + connection_context, payload + ) + return dict(params, return_promise=True, executor=AsyncioExecutor()) + + async def handle(self, ws, request_context=None): + connection_context = ChannelsConnectionContext(ws, request_context) + await self.on_open(connection_context) + return connection_context + + async def send_message( + self, connection_context, op_id=None, op_type=None, payload=None + ): + message = {} + if op_id is not None: + message["id"] = op_id + if op_type is not None: + message["type"] = op_type + if payload is not None: + message["payload"] = payload + + assert message, "You need to send at least one thing" + return await connection_context.send(message) + + async def on_open(self, connection_context): + pass + + async def on_connect(self, connection_context, payload): + pass + + async def on_connection_init(self, connection_context, op_id, payload): + try: + await self.on_connect(connection_context, payload) + await self.send_message(connection_context, op_type=GQL_CONNECTION_ACK) + except Exception as e: + await self.send_error(connection_context, op_id, e, GQL_CONNECTION_ERROR) + await connection_context.close(1011) + + async def on_start(self, connection_context, op_id, params): + try: + execution_result = await self.execute( + connection_context.request_context, params + ) + assert isinstance( + execution_result, Observable + ), "A subscription must return an observable" + execution_result.subscribe( + SubscriptionObserver( + connection_context, + op_id, + async_to_sync(self.send_execution_result), + async_to_sync(self.send_error), + async_to_sync(self.on_close), + ) + ) + except Exception as e: + self.send_error(connection_context, op_id, str(e)) + + async def on_close(self, connection_context): + remove_operations = list(connection_context.operations.keys()) + for op_id in remove_operations: + self.unsubscribe(connection_context, op_id) + + async def on_stop(self, connection_context, op_id): + self.unsubscribe(connection_context, op_id) + + +subscription_server = ChannelsSubscriptionServer(schema=graphene_settings.SCHEMA) + + +class GraphQLSubscriptionConsumer(AsyncJsonWebsocketConsumer): + async def connect(self): + self.connection_context = None + if WS_PROTOCOL in self.scope["subprotocols"]: + self.connection_context = await subscription_server.handle(self, self.scope) + await self.accept(subprotocol=WS_PROTOCOL) + else: + await self.close() + + async def disconnect(self, code): + if self.connection_context: + await subscription_server.on_close(self.connection_context) + + async def receive_json(self, content): + await subscription_server.on_message(self.connection_context, content) + + +class SubscriptionObserver(Observer): + def __init__( + self, connection_context, op_id, send_execution_result, send_error, on_close + ): + self.connection_context = connection_context + self.op_id = op_id + self.send_execution_result = send_execution_result + self.send_error = send_error + self.on_close = on_close + + def on_next(self, value): + self.send_execution_result(self.connection_context, self.op_id, value) + + def on_completed(self): + self.on_close(self.connection_context) + + def on_error(self, error): + self.send_error(self.connection_context, self.op_id, error) + + +websocket_urlpatterns = [path("subscriptions", GraphQLSubscriptionConsumer)] diff --git a/graphql_ws/django/templates/graphene/graphiql.html b/graphql_ws/django/templates/graphene/graphiql.html new file mode 100644 index 0000000..dce2683 --- /dev/null +++ b/graphql_ws/django/templates/graphene/graphiql.html @@ -0,0 +1,135 @@ + + + + + + + + + + + + + + + + + From 86762d3e1ff71252df132be0b66b119b13a32d4a Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Fri, 27 Jul 2018 11:15:23 +1200 Subject: [PATCH 02/72] Make the django app more uniquely named --- graphql_ws/django/__init__.py | 1 + graphql_ws/django/apps.py | 6 ++++++ 2 files changed, 7 insertions(+) create mode 100644 graphql_ws/django/apps.py diff --git a/graphql_ws/django/__init__.py b/graphql_ws/django/__init__.py index e69de29..d08b0f3 100644 --- a/graphql_ws/django/__init__.py +++ b/graphql_ws/django/__init__.py @@ -0,0 +1 @@ +default_app_config = "graphql_ws.django.apps.GraphQLChannelsApp" diff --git a/graphql_ws/django/apps.py b/graphql_ws/django/apps.py new file mode 100644 index 0000000..eb65bb2 --- /dev/null +++ b/graphql_ws/django/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class GraphQLChannelsApp(AppConfig): + name = "graphql_ws.django" + label = "graphql_channels" From d365c72505e8cf13365a50da517c7cb1d8780224 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Fri, 27 Jul 2018 11:19:26 +1200 Subject: [PATCH 03/72] Split the channels 2 app into different modules --- graphql_ws/django/consumers.py | 20 ++++++ graphql_ws/django/routing.py | 19 ++++++ .../{graphql_channels.py => subscriptions.py} | 64 ++++++------------- 3 files changed, 60 insertions(+), 43 deletions(-) create mode 100644 graphql_ws/django/consumers.py create mode 100644 graphql_ws/django/routing.py rename graphql_ws/django/{graphql_channels.py => subscriptions.py} (82%) diff --git a/graphql_ws/django/consumers.py b/graphql_ws/django/consumers.py new file mode 100644 index 0000000..4b203df --- /dev/null +++ b/graphql_ws/django/consumers.py @@ -0,0 +1,20 @@ +from channels.generic.websocket import AsyncJsonWebsocketConsumer +from ..constants import WS_PROTOCOL +from .subscriptions import subscription_server + + +class GraphQLSubscriptionConsumer(AsyncJsonWebsocketConsumer): + async def connect(self): + self.connection_context = None + if WS_PROTOCOL in self.scope["subprotocols"]: + self.connection_context = await subscription_server.handle(self, self.scope) + await self.accept(subprotocol=WS_PROTOCOL) + else: + await self.close() + + async def disconnect(self, code): + if self.connection_context: + await subscription_server.on_close(self.connection_context) + + async def receive_json(self, content): + await subscription_server.on_message(self.connection_context, content) diff --git a/graphql_ws/django/routing.py b/graphql_ws/django/routing.py new file mode 100644 index 0000000..d30955d --- /dev/null +++ b/graphql_ws/django/routing.py @@ -0,0 +1,19 @@ +from channels.routing import ProtocolTypeRouter, URLRouter +from django.apps import apps +from django.urls import path +from .consumers import GraphQLSubscriptionConsumer + +if apps.is_installed("django.contrib.auth"): + from channels.auth import AuthMiddlewareStack +else: + AuthMiddlewareStack = None + + +websocket_urlpatterns = [path("subscriptions", GraphQLSubscriptionConsumer)] + +application = ProtocolTypeRouter({"websocket": URLRouter(websocket_urlpatterns)}) + +if AuthMiddlewareStack: + auth_application = ProtocolTypeRouter( + {"websocket": AuthMiddlewareStack(URLRouter(websocket_urlpatterns))} + ) diff --git a/graphql_ws/django/graphql_channels.py b/graphql_ws/django/subscriptions.py similarity index 82% rename from graphql_ws/django/graphql_channels.py rename to graphql_ws/django/subscriptions.py index dec6799..cabea51 100644 --- a/graphql_ws/django/graphql_channels.py +++ b/graphql_ws/django/subscriptions.py @@ -1,11 +1,29 @@ from asgiref.sync import async_to_sync -from channels.generic.websocket import AsyncJsonWebsocketConsumer from graphene_django.settings import graphene_settings from graphql.execution.executors.asyncio import AsyncioExecutor -from django.urls import path from rx import Observer, Observable from ..base import BaseConnectionContext, BaseSubscriptionServer -from ..constants import GQL_CONNECTION_ACK, GQL_CONNECTION_ERROR, WS_PROTOCOL +from ..constants import GQL_CONNECTION_ACK, GQL_CONNECTION_ERROR + + +class SubscriptionObserver(Observer): + def __init__( + self, connection_context, op_id, send_execution_result, send_error, on_close + ): + self.connection_context = connection_context + self.op_id = op_id + self.send_execution_result = send_execution_result + self.send_error = send_error + self.on_close = on_close + + def on_next(self, value): + self.send_execution_result(self.connection_context, self.op_id, value) + + def on_completed(self): + self.on_close(self.connection_context) + + def on_error(self, error): + self.send_error(self.connection_context, self.op_id, error) class ChannelsConnectionContext(BaseConnectionContext): @@ -87,43 +105,3 @@ async def on_stop(self, connection_context, op_id): subscription_server = ChannelsSubscriptionServer(schema=graphene_settings.SCHEMA) - - -class GraphQLSubscriptionConsumer(AsyncJsonWebsocketConsumer): - async def connect(self): - self.connection_context = None - if WS_PROTOCOL in self.scope["subprotocols"]: - self.connection_context = await subscription_server.handle(self, self.scope) - await self.accept(subprotocol=WS_PROTOCOL) - else: - await self.close() - - async def disconnect(self, code): - if self.connection_context: - await subscription_server.on_close(self.connection_context) - - async def receive_json(self, content): - await subscription_server.on_message(self.connection_context, content) - - -class SubscriptionObserver(Observer): - def __init__( - self, connection_context, op_id, send_execution_result, send_error, on_close - ): - self.connection_context = connection_context - self.op_id = op_id - self.send_execution_result = send_execution_result - self.send_error = send_error - self.on_close = on_close - - def on_next(self, value): - self.send_execution_result(self.connection_context, self.op_id, value) - - def on_completed(self): - self.on_close(self.connection_context) - - def on_error(self, error): - self.send_error(self.connection_context, self.op_id, error) - - -websocket_urlpatterns = [path("subscriptions", GraphQLSubscriptionConsumer)] From c9cda692aebe352093396517ac926087fd64b355 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Fri, 27 Jul 2018 11:20:02 +1200 Subject: [PATCH 04/72] Update and simplify the channels2 example --- README.md | 26 ++--- README.rst | 29 +++--- examples/django_channels2/Pipfile | 2 +- examples/django_channels2/Pipfile.lock | 2 +- .../django_channels2/routing.py | 9 -- .../django_channels2/schema.py | 1 - .../django_channels2/settings.py | 98 +------------------ .../django_channels2/django_channels2/urls.py | 8 +- .../django_channels2/django_channels2/wsgi.py | 16 --- 9 files changed, 28 insertions(+), 163 deletions(-) delete mode 100644 examples/django_channels2/django_channels2/routing.py delete mode 100644 examples/django_channels2/django_channels2/wsgi.py diff --git a/README.md b/README.md index a45a1aa..e19a53b 100644 --- a/README.md +++ b/README.md @@ -208,13 +208,13 @@ channel_routing = [ Set up with Django Channels just takes three steps: 1. Install the apps -2. Set up schema -3. Set up channels Router +2. Set up your schema +3. Configure the channels router application First `pip install channels` and it to your `INSTALLED_APPS`. If you want -graphiQL, install `graphql_ws.django` app before `graphene_django` to serve a -graphiql template that will work with websockets: +graphiQL, install the `graphql_ws.django` app before `graphene_django` to serve +a graphiQL template that will work with websockets: ```python INSTALLED_APPS = [ @@ -263,20 +263,14 @@ GRAPHENE = { ``` -Finally, configure channels routing (it'll be served from `/subscriptions`): +Finally, you can set up channels routing yourself (maybe using +`graphql_ws.django.routing.websocket_urlpatterns` in your `URLRouter`), or you +can just use one of the preset channels applications: ```python -from channels.routing import ProtocolTypeRouter, URLRouter -from graphql_ws.django.graphql_channels import ( - websocket_urlpatterns as graphql_urlpatterns -) - -application = ProtocolTypeRouter({"websocket": URLRouter(graphql_urlpatterns)}) -``` - -...and point to the application in Django settings -```python -ASGI_APPLICATION = 'yourproject.schema' +ASGI_APPLICATION = 'graphql_ws.django.routing.application' +# or +ASGI_APPLICATION = 'graphql_ws.django.routing.auth_application' ``` Run `./manage.py runserver` and go to `http://localhost:8000/graphql` to test! diff --git a/README.rst b/README.rst index 7dc6f3c..828aa8d 100644 --- a/README.rst +++ b/README.rst @@ -215,12 +215,12 @@ Django Channels 2 Set up with Django Channels just takes three steps: 1. Install the apps -2. Set up schema -3. Set up channels Router +2. Set up your schema +3. Configure the channels router application First ``pip install channels`` and it to your ``INSTALLED_APPS``. If you -want graphiQL, install ``graphql_ws.django`` app before -``graphene_django`` to serve a graphiql template that will work with +want graphiQL, install the ``graphql_ws.django`` app before +``graphene_django`` to serve a graphiQL template that will work with websockets: .. code:: python @@ -269,23 +269,16 @@ Next, set up your graphql schema: 'SCHEMA': 'yourproject.schema' } -Finally, configure channels routing (it'll be served from -``/subscriptions``): +Finally, you can set up channels routing yourself (maybe using +``graphql_ws.django.routing.websocket_urlpatterns`` in your +``URLRouter``), or you can just use one of the preset channels +applications: .. code:: python - from channels.routing import ProtocolTypeRouter, URLRouter - from graphql_ws.django.graphql_channels import ( - websocket_urlpatterns as graphql_urlpatterns - ) - - application = ProtocolTypeRouter({"websocket": URLRouter(graphql_urlpatterns)}) - -...and point to the application in Django settings - -.. code:: python - - ASGI_APPLICATION = 'yourproject.schema' + ASGI_APPLICATION = 'graphql_ws.django.routing.application' + # or + ASGI_APPLICATION = 'graphql_ws.django.routing.auth_application' Run ``./manage.py runserver`` and go to ``http://localhost:8000/graphql`` to test! diff --git a/examples/django_channels2/Pipfile b/examples/django_channels2/Pipfile index 025e468..ba4a0a4 100644 --- a/examples/django_channels2/Pipfile +++ b/examples/django_channels2/Pipfile @@ -7,7 +7,7 @@ name = "pypi" [packages] graphql-ws = {path = "./../..", editable = true} -channels = "*" +channels = "==2.*" graphene-django = "*" [requires] diff --git a/examples/django_channels2/Pipfile.lock b/examples/django_channels2/Pipfile.lock index 1c5b40c..6d65d10 100644 --- a/examples/django_channels2/Pipfile.lock +++ b/examples/django_channels2/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "e13b3565bfb7dfc05c521ba15ff30a2e1dc76fbee4f629b216d61117dcbd686c" + "sha256": "75a0ce53afdb6d8ea231f82ff73a9de06bad3a6dba8263f658d14abe9f6cf9f9" }, "pipfile-spec": 6, "requires": { diff --git a/examples/django_channels2/django_channels2/routing.py b/examples/django_channels2/django_channels2/routing.py deleted file mode 100644 index 75effe2..0000000 --- a/examples/django_channels2/django_channels2/routing.py +++ /dev/null @@ -1,9 +0,0 @@ -from channels.routing import ProtocolTypeRouter, URLRouter -from channels.auth import AuthMiddlewareStack -from graphql_ws.django.graphql_channels import ( - websocket_urlpatterns as graphql_urlpatterns -) - -application = ProtocolTypeRouter( - {"websocket": AuthMiddlewareStack(URLRouter(graphql_urlpatterns))} -) diff --git a/examples/django_channels2/django_channels2/schema.py b/examples/django_channels2/django_channels2/schema.py index db6893c..2e87181 100644 --- a/examples/django_channels2/django_channels2/schema.py +++ b/examples/django_channels2/django_channels2/schema.py @@ -10,7 +10,6 @@ def resolve_hello(self, info, **kwargs): class Subscription(graphene.ObjectType): - count_seconds = graphene.Int(up_to=graphene.Int()) def resolve_count_seconds(root, info, up_to=5): diff --git a/examples/django_channels2/django_channels2/settings.py b/examples/django_channels2/django_channels2/settings.py index 0fc8224..0931a27 100644 --- a/examples/django_channels2/django_channels2/settings.py +++ b/examples/django_channels2/django_channels2/settings.py @@ -1,58 +1,11 @@ """ Django settings for django_channels2 project. - -Generated by 'django-admin startproject' using Django 2.0.7. - -For more information on this file, see -https://docs.djangoproject.com/en/2.0/topics/settings/ - -For the full list of settings and their values, see -https://docs.djangoproject.com/en/2.0/ref/settings/ """ - -import os - -# Build paths inside the project like this: os.path.join(BASE_DIR, ...) -BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) - - -# Quick-start development settings - unsuitable for production -# See https://docs.djangoproject.com/en/2.0/howto/deployment/checklist/ - -# SECURITY WARNING: keep the secret key used in production secret! SECRET_KEY = "0%1c709jhmggqhk&=tci06iy+%jedfxpcoai69jd8wjzm+k2f0" - -# SECURITY WARNING: don't run with debug turned on in production! DEBUG = True -ALLOWED_HOSTS = [] - -# Application definition - -INSTALLED_APPS = [ - "channels", - "graphql_ws.django", - "graphene_django", - "django.contrib.admin", - "django.contrib.auth", - "django.contrib.contenttypes", - "django.contrib.sessions", - "django.contrib.messages", - "django.contrib.staticfiles", -] - -MIDDLEWARE = [ - "django.middleware.security.SecurityMiddleware", - "django.contrib.sessions.middleware.SessionMiddleware", - "django.middleware.common.CommonMiddleware", - "django.middleware.csrf.CsrfViewMiddleware", - "django.contrib.auth.middleware.AuthenticationMiddleware", - "django.contrib.messages.middleware.MessageMiddleware", - "django.middleware.clickjacking.XFrameOptionsMiddleware", -] - -ROOT_URLCONF = "django_channels2.urls" +INSTALLED_APPS = ["channels", "graphql_ws.django", "graphene_django"] TEMPLATES = [ { @@ -63,59 +16,14 @@ "context_processors": [ "django.template.context_processors.debug", "django.template.context_processors.request", - "django.contrib.auth.context_processors.auth", - "django.contrib.messages.context_processors.messages", ] }, } ] -WSGI_APPLICATION = "django_channels2.wsgi.application" -ASGI_APPLICATION = "django_channels2.routing.application" - - -# Database -# https://docs.djangoproject.com/en/2.0/ref/settings/#databases -DATABASES = { - "default": { - "ENGINE": "django.db.backends.sqlite3", - "NAME": os.path.join(BASE_DIR, "db.sqlite3"), - } -} - - -# Password validation -# https://docs.djangoproject.com/en/2.0/ref/settings/#auth-password-validators - -AUTH_PASSWORD_VALIDATORS = [ - { - "NAME": "django.contrib.auth.password_validation.UserAttributeSimilarityValidator" - }, - {"NAME": "django.contrib.auth.password_validation.MinimumLengthValidator"}, - {"NAME": "django.contrib.auth.password_validation.CommonPasswordValidator"}, - {"NAME": "django.contrib.auth.password_validation.NumericPasswordValidator"}, -] - - -# Internationalization -# https://docs.djangoproject.com/en/2.0/topics/i18n/ - -LANGUAGE_CODE = "en-us" - -TIME_ZONE = "UTC" - -USE_I18N = True - -USE_L10N = True - -USE_TZ = True - - -# Static files (CSS, JavaScript, Images) -# https://docs.djangoproject.com/en/2.0/howto/static-files/ - -STATIC_URL = "/static/" +ROOT_URLCONF = "django_channels2.urls" +ASGI_APPLICATION = "graphql_ws.django.routing.application" GRAPHENE = {"MIDDLEWARE": [], "SCHEMA": "django_channels2.schema.schema"} diff --git a/examples/django_channels2/django_channels2/urls.py b/examples/django_channels2/django_channels2/urls.py index addcdfe..f4470a6 100644 --- a/examples/django_channels2/django_channels2/urls.py +++ b/examples/django_channels2/django_channels2/urls.py @@ -1,8 +1,4 @@ -from django.conf.urls import url -from django.contrib import admin +from django.urls import path from graphene_django.views import GraphQLView -urlpatterns = [ - url(r"^admin/", admin.site.urls), - url(r"^graphql/?$", GraphQLView.as_view(graphiql=True)), -] +urlpatterns = [path("graphql", GraphQLView.as_view(graphiql=True))] diff --git a/examples/django_channels2/django_channels2/wsgi.py b/examples/django_channels2/django_channels2/wsgi.py deleted file mode 100644 index 49b09bf..0000000 --- a/examples/django_channels2/django_channels2/wsgi.py +++ /dev/null @@ -1,16 +0,0 @@ -""" -WSGI config for django_channels2 project. - -It exposes the WSGI callable as a module-level variable named ``application``. - -For more information on this file, see -https://docs.djangoproject.com/en/2.0/howto/deployment/wsgi/ -""" - -import os - -from django.core.wsgi import get_wsgi_application - -os.environ.setdefault("DJANGO_SETTINGS_MODULE", "django_channels2.settings") - -application = get_wsgi_application() From 2ce92acde64bd6bd31459cc8d42453c63ae89b7a Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Sat, 28 Jul 2018 22:05:33 +1200 Subject: [PATCH 05/72] Support async generator responses in django channels --- .../django_channels2/schema.py | 34 ++++++++++++++- .../django_channels2/settings.py | 1 + graphql_ws/django/subscriptions.py | 42 ++++++++++--------- 3 files changed, 56 insertions(+), 21 deletions(-) diff --git a/examples/django_channels2/django_channels2/schema.py b/examples/django_channels2/django_channels2/schema.py index 2e87181..66e15ee 100644 --- a/examples/django_channels2/django_channels2/schema.py +++ b/examples/django_channels2/django_channels2/schema.py @@ -1,5 +1,9 @@ import graphene from rx import Observable +from channels.layers import get_channel_layer +from asgiref.sync import async_to_sync + +channel_layer = get_channel_layer() class Query(graphene.ObjectType): @@ -9,15 +13,41 @@ def resolve_hello(self, info, **kwargs): return "world" +class TestMessageMutation(graphene.Mutation): + class Arguments: + input_text = graphene.String() + + output_text = graphene.String() + + def mutate(self, info, input_text): + async_to_sync(channel_layer.group_send)("new_message", {"data": input_text}) + return TestMessageMutation(output_text=input_text) + + +class Mutations(graphene.ObjectType): + test_message = TestMessageMutation.Field() + + class Subscription(graphene.ObjectType): count_seconds = graphene.Int(up_to=graphene.Int()) + new_message = graphene.String() - def resolve_count_seconds(root, info, up_to=5): + def resolve_count_seconds(self, info, up_to=5): return ( Observable.interval(1000) .map(lambda i: "{0}".format(i)) .take_while(lambda i: int(i) <= up_to) ) + async def resolve_new_message(self, info): + channel_name = await channel_layer.new_channel() + await channel_layer.group_add("new_message", channel_name) + try: + while True: + message = await channel_layer.receive(channel_name) + yield message["data"] + finally: + await channel_layer.group_discard("new_message", channel_name) + -schema = graphene.Schema(query=Query, subscription=Subscription) +schema = graphene.Schema(query=Query, mutation=Mutations, subscription=Subscription) diff --git a/examples/django_channels2/django_channels2/settings.py b/examples/django_channels2/django_channels2/settings.py index 0931a27..a635965 100644 --- a/examples/django_channels2/django_channels2/settings.py +++ b/examples/django_channels2/django_channels2/settings.py @@ -26,4 +26,5 @@ ASGI_APPLICATION = "graphql_ws.django.routing.application" +CHANNEL_LAYERS = {"default": {"BACKEND": "channels.layers.InMemoryChannelLayer"}} GRAPHENE = {"MIDDLEWARE": [], "SCHEMA": "django_channels2.schema.schema"} diff --git a/graphql_ws/django/subscriptions.py b/graphql_ws/django/subscriptions.py index cabea51..d638d2f 100644 --- a/graphql_ws/django/subscriptions.py +++ b/graphql_ws/django/subscriptions.py @@ -1,9 +1,12 @@ -from asgiref.sync import async_to_sync +from inspect import isawaitable from graphene_django.settings import graphene_settings from graphql.execution.executors.asyncio import AsyncioExecutor -from rx import Observer, Observable +from rx import Observer from ..base import BaseConnectionContext, BaseSubscriptionServer -from ..constants import GQL_CONNECTION_ACK, GQL_CONNECTION_ERROR +from ..constants import GQL_CONNECTION_ACK, GQL_CONNECTION_ERROR, GQL_COMPLETE +from ..observable_aiter import setup_observable_extension + +setup_observable_extension() class SubscriptionObserver(Observer): @@ -76,24 +79,25 @@ async def on_connection_init(self, connection_context, op_id, payload): await connection_context.close(1011) async def on_start(self, connection_context, op_id, params): - try: - execution_result = await self.execute( - connection_context.request_context, params + execution_result = self.execute(connection_context.request_context, params) + + if isawaitable(execution_result): + execution_result = await execution_result + + if not hasattr(execution_result, "__aiter__"): + await self.send_execution_result( + connection_context, op_id, execution_result ) - assert isinstance( - execution_result, Observable - ), "A subscription must return an observable" - execution_result.subscribe( - SubscriptionObserver( - connection_context, - op_id, - async_to_sync(self.send_execution_result), - async_to_sync(self.send_error), - async_to_sync(self.on_close), + else: + iterator = await execution_result.__aiter__() + connection_context.register_operation(op_id, iterator) + async for single_result in iterator: + if not connection_context.has_operation(op_id): + break + await self.send_execution_result( + connection_context, op_id, single_result ) - ) - except Exception as e: - self.send_error(connection_context, op_id, str(e)) + await self.send_message(connection_context, op_id, GQL_COMPLETE) async def on_close(self, connection_context): remove_operations = list(connection_context.operations.keys()) From d46f6bde0d846f506ea62611d47620569edf034d Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Sun, 12 Aug 2018 21:57:57 +1200 Subject: [PATCH 06/72] Ensure graphql_ws submodules are installed --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 99844fc..14c2335 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ author="Syrus Akbary", author_email='me@syrusakbary.com', url='https://github.com/graphql-python/graphql-ws', - packages=find_packages(include=['graphql_ws']), + packages=find_packages(include=['graphql_ws', 'graphql_ws.*']), include_package_data=True, install_requires=requirements, license="MIT license", From 5829ff27dca1df4c86615a67b73d3aa400c71f94 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Sun, 12 Aug 2018 21:58:31 +1200 Subject: [PATCH 07/72] Ensure the graphiql template is added in the distribution Also add the (unexpectedly missing) examples to the distribution --- MANIFEST.in | 3 +++ 1 file changed, 3 insertions(+) diff --git a/MANIFEST.in b/MANIFEST.in index 965b2dd..16c6f35 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -4,6 +4,9 @@ include HISTORY.rst include LICENSE include README.rst +graft graphql_ws/django/templates +graft examples +prune examples/django_channels2/.cache recursive-include tests * recursive-exclude * __pycache__ recursive-exclude * *.py[co] From bc16ee557bf47d5e574e97bd8555a441ff564ca5 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Tue, 21 Aug 2018 12:03:11 +1200 Subject: [PATCH 08/72] Remove some unused observer code --- graphql_ws/django/subscriptions.py | 22 +--------------------- 1 file changed, 1 insertion(+), 21 deletions(-) diff --git a/graphql_ws/django/subscriptions.py b/graphql_ws/django/subscriptions.py index d638d2f..26c07fb 100644 --- a/graphql_ws/django/subscriptions.py +++ b/graphql_ws/django/subscriptions.py @@ -1,7 +1,6 @@ from inspect import isawaitable from graphene_django.settings import graphene_settings from graphql.execution.executors.asyncio import AsyncioExecutor -from rx import Observer from ..base import BaseConnectionContext, BaseSubscriptionServer from ..constants import GQL_CONNECTION_ACK, GQL_CONNECTION_ERROR, GQL_COMPLETE from ..observable_aiter import setup_observable_extension @@ -9,27 +8,8 @@ setup_observable_extension() -class SubscriptionObserver(Observer): - def __init__( - self, connection_context, op_id, send_execution_result, send_error, on_close - ): - self.connection_context = connection_context - self.op_id = op_id - self.send_execution_result = send_execution_result - self.send_error = send_error - self.on_close = on_close - - def on_next(self, value): - self.send_execution_result(self.connection_context, self.op_id, value) - - def on_completed(self): - self.on_close(self.connection_context) - - def on_error(self, error): - self.send_error(self.connection_context, self.op_id, error) - - class ChannelsConnectionContext(BaseConnectionContext): + async def send(self, data): await self.ws.send_json(data) From 5ddcb1bc21bf132ac7e82b24c0eeb0b85852da06 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Tue, 21 Aug 2018 12:03:38 +1200 Subject: [PATCH 09/72] Execute iterable operations as a separate task Fixes future operation requests that were being blocked --- graphql_ws/django/consumers.py | 3 ++- graphql_ws/django/subscriptions.py | 24 +++++++++++++++--------- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/graphql_ws/django/consumers.py b/graphql_ws/django/consumers.py index 4b203df..59f3595 100644 --- a/graphql_ws/django/consumers.py +++ b/graphql_ws/django/consumers.py @@ -7,7 +7,8 @@ class GraphQLSubscriptionConsumer(AsyncJsonWebsocketConsumer): async def connect(self): self.connection_context = None if WS_PROTOCOL in self.scope["subprotocols"]: - self.connection_context = await subscription_server.handle(self, self.scope) + self.connection_context = await subscription_server.handle( + ws=self, request_context=self.scope) await self.accept(subprotocol=WS_PROTOCOL) else: await self.close() diff --git a/graphql_ws/django/subscriptions.py b/graphql_ws/django/subscriptions.py index 26c07fb..dd26b87 100644 --- a/graphql_ws/django/subscriptions.py +++ b/graphql_ws/django/subscriptions.py @@ -1,3 +1,4 @@ +from asyncio import ensure_future from inspect import isawaitable from graphene_django.settings import graphene_settings from graphql.execution.executors.asyncio import AsyncioExecutor @@ -68,16 +69,21 @@ async def on_start(self, connection_context, op_id, params): await self.send_execution_result( connection_context, op_id, execution_result ) - else: - iterator = await execution_result.__aiter__() - connection_context.register_operation(op_id, iterator) - async for single_result in iterator: - if not connection_context.has_operation(op_id): - break - await self.send_execution_result( - connection_context, op_id, single_result - ) await self.send_message(connection_context, op_id, GQL_COMPLETE) + return + + iterator = await execution_result.__aiter__() + ensure_future(self.run_op(connection_context, op_id, iterator)) + + async def run_op(self, connection_context, op_id, iterator): + connection_context.register_operation(op_id, iterator) + async for single_result in iterator: + if not connection_context.has_operation(op_id): + break + await self.send_execution_result( + connection_context, op_id, single_result + ) + await self.send_message(connection_context, op_id, GQL_COMPLETE) async def on_close(self, connection_context): remove_operations = list(connection_context.operations.keys()) From dc478beb1b495328bb220640e1762eed7083b5d2 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Mon, 27 Aug 2018 13:16:10 +1200 Subject: [PATCH 10/72] Add session_application helper to django routing --- graphql_ws/django/routing.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/graphql_ws/django/routing.py b/graphql_ws/django/routing.py index d30955d..15a1356 100644 --- a/graphql_ws/django/routing.py +++ b/graphql_ws/django/routing.py @@ -1,4 +1,5 @@ from channels.routing import ProtocolTypeRouter, URLRouter +from channels.sessions import SessionMiddlewareStack from django.apps import apps from django.urls import path from .consumers import GraphQLSubscriptionConsumer @@ -13,6 +14,10 @@ application = ProtocolTypeRouter({"websocket": URLRouter(websocket_urlpatterns)}) +session_application = ProtocolTypeRouter( + {"websocket": SessionMiddlewareStack(URLRouter(websocket_urlpatterns))} +) + if AuthMiddlewareStack: auth_application = ProtocolTypeRouter( {"websocket": AuthMiddlewareStack(URLRouter(websocket_urlpatterns))} From f0a2727d1a9afda21829acdb0f60ee14ab03837c Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Mon, 27 Aug 2018 13:48:29 +1200 Subject: [PATCH 11/72] Behave correctly by cancelling async tasks --- graphql_ws/django/subscriptions.py | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/graphql_ws/django/subscriptions.py b/graphql_ws/django/subscriptions.py index dd26b87..69df107 100644 --- a/graphql_ws/django/subscriptions.py +++ b/graphql_ws/django/subscriptions.py @@ -1,4 +1,4 @@ -from asyncio import ensure_future +import asyncio from inspect import isawaitable from graphene_django.settings import graphene_settings from graphql.execution.executors.asyncio import AsyncioExecutor @@ -10,7 +10,6 @@ class ChannelsConnectionContext(BaseConnectionContext): - async def send(self, data): await self.ws.send_json(data) @@ -73,25 +72,38 @@ async def on_start(self, connection_context, op_id, params): return iterator = await execution_result.__aiter__() - ensure_future(self.run_op(connection_context, op_id, iterator)) + task = asyncio.ensure_future(self.run_op(connection_context, op_id, iterator)) + connection_context.register_operation(op_id, task) async def run_op(self, connection_context, op_id, iterator): - connection_context.register_operation(op_id, iterator) async for single_result in iterator: if not connection_context.has_operation(op_id): break - await self.send_execution_result( - connection_context, op_id, single_result - ) + await self.send_execution_result(connection_context, op_id, single_result) await self.send_message(connection_context, op_id, GQL_COMPLETE) async def on_close(self, connection_context): remove_operations = list(connection_context.operations.keys()) + cancelled_tasks = [] for op_id in remove_operations: - self.unsubscribe(connection_context, op_id) + task = await self.unsubscribe(connection_context, op_id) + if task: + cancelled_tasks.append(task) + # Wait around for all the tasks to actually cancel. + await asyncio.gather(*cancelled_tasks, return_exceptions=True) async def on_stop(self, connection_context, op_id): - self.unsubscribe(connection_context, op_id) + task = await self.unsubscribe(connection_context, op_id) + await asyncio.gather(task, return_exceptions=True) + + async def unsubscribe(self, connection_context, op_id): + op = None + if connection_context.has_operation(op_id): + op = connection_context.get_operation(op_id) + op.cancel() + connection_context.remove_operation(op_id) + self.on_operation_complete(connection_context, op_id) + return op subscription_server = ChannelsSubscriptionServer(schema=graphene_settings.SCHEMA) From fe91dbb0e18224f4e50f4215440bf35bec109739 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Fri, 31 Aug 2018 10:24:35 +1200 Subject: [PATCH 12/72] Use lower level asyncio.wait, abstract the on_complete command --- graphql_ws/django/subscriptions.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/graphql_ws/django/subscriptions.py b/graphql_ws/django/subscriptions.py index 69df107..3ad544e 100644 --- a/graphql_ws/django/subscriptions.py +++ b/graphql_ws/django/subscriptions.py @@ -68,7 +68,7 @@ async def on_start(self, connection_context, op_id, params): await self.send_execution_result( connection_context, op_id, execution_result ) - await self.send_message(connection_context, op_id, GQL_COMPLETE) + await self.on_operation_complete(connection_context, op_id) return iterator = await execution_result.__aiter__() @@ -80,7 +80,7 @@ async def run_op(self, connection_context, op_id, iterator): if not connection_context.has_operation(op_id): break await self.send_execution_result(connection_context, op_id, single_result) - await self.send_message(connection_context, op_id, GQL_COMPLETE) + await self.on_operation_complete(connection_context, op_id) async def on_close(self, connection_context): remove_operations = list(connection_context.operations.keys()) @@ -90,11 +90,11 @@ async def on_close(self, connection_context): if task: cancelled_tasks.append(task) # Wait around for all the tasks to actually cancel. - await asyncio.gather(*cancelled_tasks, return_exceptions=True) + await asyncio.wait(cancelled_tasks) async def on_stop(self, connection_context, op_id): task = await self.unsubscribe(connection_context, op_id) - await asyncio.gather(task, return_exceptions=True) + await asyncio.wait([task]) async def unsubscribe(self, connection_context, op_id): op = None @@ -102,8 +102,11 @@ async def unsubscribe(self, connection_context, op_id): op = connection_context.get_operation(op_id) op.cancel() connection_context.remove_operation(op_id) - self.on_operation_complete(connection_context, op_id) + await self.on_operation_complete(connection_context, op_id) return op + async def on_operation_complete(self, connection_context, op_id): + await self.send_message(connection_context, op_id, GQL_COMPLETE) + subscription_server = ChannelsSubscriptionServer(schema=graphene_settings.SCHEMA) From 187ee7848f1350a7fe906ee995a73c57992e80e9 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Fri, 31 Aug 2018 17:16:18 +1200 Subject: [PATCH 13/72] Optimize unsubscribing to ops --- graphql_ws/django/subscriptions.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/graphql_ws/django/subscriptions.py b/graphql_ws/django/subscriptions.py index 3ad544e..a78e349 100644 --- a/graphql_ws/django/subscriptions.py +++ b/graphql_ws/django/subscriptions.py @@ -83,18 +83,20 @@ async def run_op(self, connection_context, op_id, iterator): await self.on_operation_complete(connection_context, op_id) async def on_close(self, connection_context): - remove_operations = list(connection_context.operations.keys()) - cancelled_tasks = [] - for op_id in remove_operations: - task = await self.unsubscribe(connection_context, op_id) - if task: - cancelled_tasks.append(task) + # Unsubscribe from all the connection's current operations in parallel. + unsubscribes = [ + self.unsubscribe(connection_context, op_id) + for op_id in connection_context.operations + ] + cancelled_tasks = [task for task in await asyncio.gather(*unsubscribes) if task] # Wait around for all the tasks to actually cancel. - await asyncio.wait(cancelled_tasks) + if cancelled_tasks: + await asyncio.wait(cancelled_tasks) async def on_stop(self, connection_context, op_id): task = await self.unsubscribe(connection_context, op_id) - await asyncio.wait([task]) + if task: + await asyncio.wait([task]) async def unsubscribe(self, connection_context, op_id): op = None From 99bc3a8576b84f1a745cf43c5fa8eebee8b6da80 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Fri, 31 Aug 2018 17:30:58 +1200 Subject: [PATCH 14/72] Cleaner iterable op running --- graphql_ws/django/subscriptions.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/graphql_ws/django/subscriptions.py b/graphql_ws/django/subscriptions.py index a78e349..4094675 100644 --- a/graphql_ws/django/subscriptions.py +++ b/graphql_ws/django/subscriptions.py @@ -71,12 +71,13 @@ async def on_start(self, connection_context, op_id, params): await self.on_operation_complete(connection_context, op_id) return - iterator = await execution_result.__aiter__() - task = asyncio.ensure_future(self.run_op(connection_context, op_id, iterator)) + task = asyncio.ensure_future( + self.run_op(connection_context, op_id, execution_result) + ) connection_context.register_operation(op_id, task) - async def run_op(self, connection_context, op_id, iterator): - async for single_result in iterator: + async def run_op(self, connection_context, op_id, aiterable): + async for single_result in aiterable: if not connection_context.has_operation(op_id): break await self.send_execution_result(connection_context, op_id, single_result) From 7879f324b6c8d09d6d50efede7066fdfc029c90e Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Tue, 4 Sep 2018 15:46:28 +1200 Subject: [PATCH 15/72] Simplify the django async futures Promise observers are already futures, the only thing that needs to be a future is the on_message call from the receive_json Django consumer --- graphql_ws/django/consumers.py | 26 ++++++++++++++++-- graphql_ws/django/subscriptions.py | 42 +++++++++--------------------- 2 files changed, 37 insertions(+), 31 deletions(-) diff --git a/graphql_ws/django/consumers.py b/graphql_ws/django/consumers.py index 59f3595..9a67006 100644 --- a/graphql_ws/django/consumers.py +++ b/graphql_ws/django/consumers.py @@ -1,14 +1,30 @@ +import asyncio +import json + from channels.generic.websocket import AsyncJsonWebsocketConsumer +from promise import Promise + from ..constants import WS_PROTOCOL from .subscriptions import subscription_server +class JSONPromiseEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, Promise): + return o.value + return super(JSONPromiseEncoder, self).default(o) + + +json_promise_encoder = JSONPromiseEncoder() + + class GraphQLSubscriptionConsumer(AsyncJsonWebsocketConsumer): async def connect(self): self.connection_context = None if WS_PROTOCOL in self.scope["subprotocols"]: self.connection_context = await subscription_server.handle( - ws=self, request_context=self.scope) + ws=self, request_context=self.scope + ) await self.accept(subprotocol=WS_PROTOCOL) else: await self.close() @@ -18,4 +34,10 @@ async def disconnect(self, code): await subscription_server.on_close(self.connection_context) async def receive_json(self, content): - await subscription_server.on_message(self.connection_context, content) + asyncio.ensure_future( + subscription_server.on_message(self.connection_context, content) + ) + + @classmethod + async def encode_json(cls, content): + return json_promise_encoder.encode(content) diff --git a/graphql_ws/django/subscriptions.py b/graphql_ws/django/subscriptions.py index 4094675..96883c4 100644 --- a/graphql_ws/django/subscriptions.py +++ b/graphql_ws/django/subscriptions.py @@ -1,4 +1,3 @@ -import asyncio from inspect import isawaitable from graphene_django.settings import graphene_settings from graphql.execution.executors.asyncio import AsyncioExecutor @@ -64,49 +63,34 @@ async def on_start(self, connection_context, op_id, params): if isawaitable(execution_result): execution_result = await execution_result - if not hasattr(execution_result, "__aiter__"): + if hasattr(execution_result, "__aiter__"): + iterator = await execution_result.__aiter__() + connection_context.register_operation(op_id, iterator) + async for single_result in iterator: + if not connection_context.has_operation(op_id): + break + await self.send_execution_result( + connection_context, op_id, single_result + ) + else: await self.send_execution_result( connection_context, op_id, execution_result ) - await self.on_operation_complete(connection_context, op_id) - return - - task = asyncio.ensure_future( - self.run_op(connection_context, op_id, execution_result) - ) - connection_context.register_operation(op_id, task) - - async def run_op(self, connection_context, op_id, aiterable): - async for single_result in aiterable: - if not connection_context.has_operation(op_id): - break - await self.send_execution_result(connection_context, op_id, single_result) await self.on_operation_complete(connection_context, op_id) async def on_close(self, connection_context): - # Unsubscribe from all the connection's current operations in parallel. - unsubscribes = [ + for op_id in connection_context.operations: self.unsubscribe(connection_context, op_id) - for op_id in connection_context.operations - ] - cancelled_tasks = [task for task in await asyncio.gather(*unsubscribes) if task] - # Wait around for all the tasks to actually cancel. - if cancelled_tasks: - await asyncio.wait(cancelled_tasks) async def on_stop(self, connection_context, op_id): - task = await self.unsubscribe(connection_context, op_id) - if task: - await asyncio.wait([task]) + await self.unsubscribe(connection_context, op_id) async def unsubscribe(self, connection_context, op_id): - op = None if connection_context.has_operation(op_id): op = connection_context.get_operation(op_id) - op.cancel() + op.dispose() connection_context.remove_operation(op_id) await self.on_operation_complete(connection_context, op_id) - return op async def on_operation_complete(self, connection_context, op_id): await self.send_message(connection_context, op_id, GQL_COMPLETE) From 1eb443ea0fca65395903a2b99fe32966285cc54b Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Tue, 20 Nov 2018 15:57:38 +1300 Subject: [PATCH 16/72] When closing a channels subscription, correctly await for the unsubscribes --- graphql_ws/django/subscriptions.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/graphql_ws/django/subscriptions.py b/graphql_ws/django/subscriptions.py index 96883c4..fa4910b 100644 --- a/graphql_ws/django/subscriptions.py +++ b/graphql_ws/django/subscriptions.py @@ -1,3 +1,4 @@ +import asyncio from inspect import isawaitable from graphene_django.settings import graphene_settings from graphql.execution.executors.asyncio import AsyncioExecutor @@ -79,8 +80,12 @@ async def on_start(self, connection_context, op_id, params): await self.on_operation_complete(connection_context, op_id) async def on_close(self, connection_context): - for op_id in connection_context.operations: + unsubscribes = [ self.unsubscribe(connection_context, op_id) + for op_id in connection_context.operations + ] + if unsubscribes: + await asyncio.wait(unsubscribes) async def on_stop(self, connection_context, op_id): await self.unsubscribe(connection_context, op_id) From 1eb83a5a30e303b69f8975d4480d25ce5a7f4d46 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Thu, 10 Jan 2019 09:49:45 +1300 Subject: [PATCH 17/72] Await pending promises in the payload --- graphql_ws/django/consumers.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/graphql_ws/django/consumers.py b/graphql_ws/django/consumers.py index 9a67006..2cd3888 100644 --- a/graphql_ws/django/consumers.py +++ b/graphql_ws/django/consumers.py @@ -9,15 +9,18 @@ class JSONPromiseEncoder(json.JSONEncoder): + def encode(self, *args, **kwargs): + self.pending_promises = [] + return super(JSONPromiseEncoder, self).encode(*args, **kwargs) + def default(self, o): if isinstance(o, Promise): + if o.is_pending: + self.pending_promises.append(o) return o.value return super(JSONPromiseEncoder, self).default(o) -json_promise_encoder = JSONPromiseEncoder() - - class GraphQLSubscriptionConsumer(AsyncJsonWebsocketConsumer): async def connect(self): self.connection_context = None @@ -40,4 +43,10 @@ async def receive_json(self, content): @classmethod async def encode_json(cls, content): - return json_promise_encoder.encode(content) + json_promise_encoder = JSONPromiseEncoder() + e = json_promise_encoder.encode(content) + while json_promise_encoder.pending_promises: + # Wait for pending promises to complete, then try encoding again. + await asyncio.wait(json_promise_encoder.pending_promises) + e = json_promise_encoder.encode(content) + return e From a99c1cfa709c75a372af5e866cf07d497de8256e Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Fri, 29 Mar 2019 13:17:59 +1300 Subject: [PATCH 18/72] Update pipfile.lock versions for django channels2 example --- examples/django_channels2/Pipfile.lock | 237 ++++++++++++++++--------- 1 file changed, 153 insertions(+), 84 deletions(-) diff --git a/examples/django_channels2/Pipfile.lock b/examples/django_channels2/Pipfile.lock index 6d65d10..e33bf6c 100644 --- a/examples/django_channels2/Pipfile.lock +++ b/examples/django_channels2/Pipfile.lock @@ -18,40 +18,38 @@ "default": { "aniso8601": { "hashes": [ - "sha256:7849749cf00ae0680ad2bdfe4419c7a662bef19c03691a19e008c8b9a5267802", - "sha256:94f90871fcd314a458a3d4eca1c84448efbd200e86f55fe4c733c7a40149ef50" + "sha256:b8a6a9b24611fc50cf2d9b45d371bfdc4fd0581d1cc52254f5502130a776d4af", + "sha256:bb167645c79f7a438f9dfab6161af9bed75508c645b1f07d1158240841d22673" ], - "version": "==3.0.2" + "version": "==6.0.0" }, "asgiref": { "hashes": [ - "sha256:9b05dcd41a6a89ca8c6e7f7e4089c3f3e76b5af60aebb81ae6d455ad81989c97", - "sha256:b21dc4c43d7aba5a844f4c48b8f49d56277bc34937fd9f9cb93ec97fde7e3082" + "sha256:865b7ccce5a6e815607b08d9059fe9c058cd75c77f896f5e0b74ff6c1ba81818", + "sha256:b718a9d35e204a96e2456c2271b0ef12e36124c363b3a8fd1d626744f23192aa" ], - "markers": "python_version >= '3.5.3'", - "version": "==2.3.2" + "version": "==3.1.4" }, - "async-timeout": { + "asn1crypto": { "hashes": [ - "sha256:474d4bc64cee20603e225eb1ece15e248962958b45a3648a9f5cc29e827a610c", - "sha256:b3c0ddc416736619bd4a95ca31de8da6920c3b9a140c64dbef2b2fa7bf521287" + "sha256:2f1adbb7546ed199e3c90ef23ec95c5cf3585bac7d11fb7eb562a3fe89c64e87", + "sha256:9d5c20441baf0cb60a4ac34cc447c6c189024b6b4c6cd7877034f4965c464e49" ], - "markers": "python_version >= '3.5.3'", - "version": "==3.0.0" + "version": "==0.24.0" }, "attrs": { "hashes": [ - "sha256:4b90b09eeeb9b88c35bc642cbac057e45a5fd85367b985bd2809c62b7b939265", - "sha256:e0d0eb91441a3b53dab4d9b743eafc1ac44476296a2053b6ca3af0b139faf87b" + "sha256:69c0dbf2ed392de1cb5ec704444b08a5ef81680a61cb899dc08127123af36a79", + "sha256:f0b870f674851ecbfbbbd364d6b5cbdff9dcedbc7f3f5e18a6891057f21fe399" ], - "version": "==18.1.0" + "version": "==19.1.0" }, "autobahn": { "hashes": [ - "sha256:2f41bfc512ec482044fa8cfa74182118dedd87e03b3494472d9ff1b5a1e27d24", - "sha256:83e050f0e0783646dbf6da60fe837f4f825c241080d2b5f080002ae6885b036f" + "sha256:70f0cfb8005b5429df5709acf5d66a8eba00669765547029371648dffd4a0470", + "sha256:89f94a1535673b1655df28ef91e96b7f34faea76da04a5e56441c9ac779a2f9f" ], - "version": "==18.6.1" + "version": "==19.7.1" }, "automat": { "hashes": [ @@ -60,13 +58,46 @@ ], "version": "==0.7.0" }, + "cffi": { + "hashes": [ + "sha256:041c81822e9f84b1d9c401182e174996f0bae9991f33725d059b771744290774", + "sha256:046ef9a22f5d3eed06334d01b1e836977eeef500d9b78e9ef693f9380ad0b83d", + "sha256:066bc4c7895c91812eff46f4b1c285220947d4aa46fa0a2651ff85f2afae9c90", + "sha256:066c7ff148ae33040c01058662d6752fd73fbc8e64787229ea8498c7d7f4041b", + "sha256:2444d0c61f03dcd26dbf7600cf64354376ee579acad77aef459e34efcb438c63", + "sha256:300832850b8f7967e278870c5d51e3819b9aad8f0a2c8dbe39ab11f119237f45", + "sha256:34c77afe85b6b9e967bd8154e3855e847b70ca42043db6ad17f26899a3df1b25", + "sha256:46de5fa00f7ac09f020729148ff632819649b3e05a007d286242c4882f7b1dc3", + "sha256:4aa8ee7ba27c472d429b980c51e714a24f47ca296d53f4d7868075b175866f4b", + "sha256:4d0004eb4351e35ed950c14c11e734182591465a33e960a4ab5e8d4f04d72647", + "sha256:4e3d3f31a1e202b0f5a35ba3bc4eb41e2fc2b11c1eff38b362de710bcffb5016", + "sha256:50bec6d35e6b1aaeb17f7c4e2b9374ebf95a8975d57863546fa83e8d31bdb8c4", + "sha256:55cad9a6df1e2a1d62063f79d0881a414a906a6962bc160ac968cc03ed3efcfb", + "sha256:5662ad4e4e84f1eaa8efce5da695c5d2e229c563f9d5ce5b0113f71321bcf753", + "sha256:59b4dc008f98fc6ee2bb4fd7fc786a8d70000d058c2bbe2698275bc53a8d3fa7", + "sha256:73e1ffefe05e4ccd7bcea61af76f36077b914f92b76f95ccf00b0c1b9186f3f9", + "sha256:a1f0fd46eba2d71ce1589f7e50a9e2ffaeb739fb2c11e8192aa2b45d5f6cc41f", + "sha256:a2e85dc204556657661051ff4bab75a84e968669765c8a2cd425918699c3d0e8", + "sha256:a5457d47dfff24882a21492e5815f891c0ca35fefae8aa742c6c263dac16ef1f", + "sha256:a8dccd61d52a8dae4a825cdbb7735da530179fea472903eb871a5513b5abbfdc", + "sha256:ae61af521ed676cf16ae94f30fe202781a38d7178b6b4ab622e4eec8cefaff42", + "sha256:b012a5edb48288f77a63dba0840c92d0504aa215612da4541b7b42d849bc83a3", + "sha256:d2c5cfa536227f57f97c92ac30c8109688ace8fa4ac086d19d0af47d134e2909", + "sha256:d42b5796e20aacc9d15e66befb7a345454eef794fdb0737d1af593447c6c8f45", + "sha256:dee54f5d30d775f525894d67b1495625dd9322945e7fee00731952e0368ff42d", + "sha256:e070535507bd6aa07124258171be2ee8dfc19119c28ca94c9dfb7efd23564512", + "sha256:e1ff2748c84d97b065cc95429814cdba39bcbd77c9c85c89344b317dc0d9cbff", + "sha256:ed851c75d1e0e043cbf5ca9a8e1b13c4c90f3fbd863dacb01c0808e2b5204201" + ], + "version": "==1.12.3" + }, "channels": { "hashes": [ - "sha256:173441ccf2ac3a93c3b4f86135db301cbe95be58f5815c1e071f2e7154c192a2", - "sha256:3c308108161596ddaa1b9e9f0ed9568a34ee4ebefaa33bc9cc4e941561363add" + "sha256:9191a85800673b790d1d74666fb7676f430600b71b662581e97dd69c9aedd29a", + "sha256:af7cdba9efb3f55b939917d1b15defb5d40259936013e60660e5e9aff98db4c5" ], "index": "pypi", - "version": "==2.1.2" + "version": "==2.2.0" }, "constantly": { "hashes": [ @@ -75,47 +106,70 @@ ], "version": "==15.1.0" }, + "cryptography": { + "hashes": [ + "sha256:24b61e5fcb506424d3ec4e18bca995833839bf13c59fc43e530e488f28d46b8c", + "sha256:25dd1581a183e9e7a806fe0543f485103232f940fcfc301db65e630512cce643", + "sha256:3452bba7c21c69f2df772762be0066c7ed5dc65df494a1d53a58b683a83e1216", + "sha256:41a0be220dd1ed9e998f5891948306eb8c812b512dc398e5a01846d855050799", + "sha256:5751d8a11b956fbfa314f6553d186b94aa70fdb03d8a4d4f1c82dcacf0cbe28a", + "sha256:5f61c7d749048fa6e3322258b4263463bfccefecb0dd731b6561cb617a1d9bb9", + "sha256:72e24c521fa2106f19623a3851e9f89ddfdeb9ac63871c7643790f872a305dfc", + "sha256:7b97ae6ef5cba2e3bb14256625423413d5ce8d1abb91d4f29b6d1a081da765f8", + "sha256:961e886d8a3590fd2c723cf07be14e2a91cf53c25f02435c04d39e90780e3b53", + "sha256:96d8473848e984184b6728e2c9d391482008646276c3ff084a1bd89e15ff53a1", + "sha256:ae536da50c7ad1e002c3eee101871d93abdc90d9c5f651818450a0d3af718609", + "sha256:b0db0cecf396033abb4a93c95d1602f268b3a68bb0a9cc06a7cff587bb9a7292", + "sha256:cfee9164954c186b191b91d4193989ca994703b2fff406f71cf454a2d3c7327e", + "sha256:e6347742ac8f35ded4a46ff835c60e68c22a536a8ae5c4422966d06946b6d4c6", + "sha256:f27d93f0139a3c056172ebb5d4f9056e770fdf0206c2f422ff2ebbad142e09ed", + "sha256:f57b76e46a58b63d1c6375017f4564a28f19a5ca912691fd2e4261b3414b618d" + ], + "version": "==2.7" + }, "daphne": { "hashes": [ - "sha256:bc49584532b2d52116f9a99af2d45d92092de93ccf2fc36a433eb7155d48b2a3", - "sha256:da19b36605cc64d1e3a888a95ff90495dcbb75a25cfee173606a7d86112ebbca" + "sha256:2329b7a74b5559f7ea012879c10ba945c3a53df7d8d2b5932a904e3b4c9abcc2", + "sha256:3cae286a995ae5b127d7de84916f0480cb5be19f81125b6a150b8326250dadd5" ], - "version": "==2.2.1" + "version": "==2.3.0" }, "django": { "hashes": [ - "sha256:97886b8a13bbc33bfeba2ff133035d3eca014e2309dff2b6da0bdfc0b8656613", - "sha256:e900b73beee8977c7b887d90c6c57d68af10066b9dac898e1eaf0f82313de334" + "sha256:4d23f61b26892bac785f07401bc38cbf8fa4cec993f400e9cd9ddf28fd51c0ea", + "sha256:6e974d4b57e3b29e4882b244d40171d6a75202ab8d2402b8e8adbd182e25cf0c" ], - "version": "==2.0.7" + "version": "==2.2.3" }, "graphene": { "hashes": [ - "sha256:b8ec446d17fa68721636eaad3d6adc1a378cb6323e219814c8f98c9928fc9642", - "sha256:faa26573b598b22ffd274e2fd7a4c52efa405dcca96e01a62239482246248aa3" + "sha256:77d61618132ccd084c343e64c22d806cee18dce73cc86e0f427378dbdeeac287", + "sha256:acf808d50d053b94f7958414d511489a9e490a7f9563b9be80f6875fc5723d2a" ], - "version": "==2.1.3" + "version": "==2.1.7" }, "graphene-django": { "hashes": [ - "sha256:6abc3ec4f1dcbd91faeb3ce772b428e431807b8ec474f9dae918cff74bf7f6b1", - "sha256:b336eecbf03e6fa12a53288d22015c7035727ffaa8fdd89c93fd41d9b942dd91" + "sha256:3101e8a8353c6b13f33261f5b0437deb3d3614d1c44b2d56932b158e3660c0cd", + "sha256:5714c5dd1200800ddc12d0782b0d82db70aedf387575e5b57ee2cdee4f25c681" ], "index": "pypi", - "version": "==2.1.0" + "version": "==2.4.0" }, "graphql-core": { "hashes": [ - "sha256:889e869be5574d02af77baf1f30b5db9ca2959f1c9f5be7b2863ead5a3ec6181", - "sha256:9462e22e32c7f03b667373ec0a84d95fba10e8ce2ead08f29fbddc63b671b0c1" + "sha256:1488f2a5c2272dc9ba66e3042a6d1c30cea0db4c80bd1e911c6791ad6187d91b", + "sha256:da64c472d720da4537a2e8de8ba859210b62841bd47a9be65ca35177f62fe0e4" ], - "version": "==2.1" + "version": "==2.2.1" }, "graphql-relay": { "hashes": [ - "sha256:2716b7245d97091af21abf096fabafac576905096d21ba7118fba722596f65db" + "sha256:0e94201af4089e1f81f07d7bd8f84799768e39d70fa1ea16d1df505b46cc6335", + "sha256:75aa0758971e252964cb94068a4decd472d2a8295229f02189e3cbca1f10dbb5", + "sha256:7fa74661246e826ef939ee92e768f698df167a7617361ab399901eaebf80dce6" ], - "version": "==0.4.5" + "version": "==2.0.0" }, "graphql-ws": { "editable": true, @@ -123,17 +177,17 @@ }, "hyperlink": { "hashes": [ - "sha256:98da4218a56b448c7ec7d2655cb339af1f7d751cf541469bb4fc28c4a4245b34", - "sha256:f01b4ff744f14bc5d0a22a6b9f1525ab7d6312cb0ff967f59414bbac52f0a306" + "sha256:4288e34705da077fada1111a24a0aa08bb1e76699c9ce49876af722441845654", + "sha256:ab4a308feb039b04f855a020a6eda3b18ca5a68e6d8f8c899cbe9e653721d04f" ], - "version": "==18.0.0" + "version": "==19.0.0" }, "idna": { "hashes": [ - "sha256:156a6814fb5ac1fc6850fb002e0852d56c0c8d2531923a51032d1b70760e186e", - "sha256:684a38a6f903c1d71d6d5fac066b58d7768af4de2b832e426ec79c30daa94a16" + "sha256:c357b3f628cf53ae2c4c05627ecc484553142ca23264e593d327bcde5e9c3407", + "sha256:ea8b7f6188e6fa117537c3df7da9fc686d485087abf6ac197f9c46432f7e4a3c" ], - "version": "==2.7" + "version": "==2.8" }, "incremental": { "hashes": [ @@ -142,20 +196,18 @@ ], "version": "==17.5.0" }, - "iso8601": { + "promise": { "hashes": [ - "sha256:210e0134677cc0d02f6028087fee1df1e1d76d372ee1db0bf30bf66c5c1c89a3", - "sha256:49c4b20e1f38aa5cf109ddcd39647ac419f928512c869dc01d5c7098eddede82", - "sha256:bbbae5fb4a7abfe71d4688fd64bff70b91bbd74ef6a99d964bab18f7fdf286dd" + "sha256:2ebbfc10b7abf6354403ed785fe4f04b9dfd421eb1a474ac8d187022228332af", + "sha256:348f5f6c3edd4fd47c9cd65aed03ac1b31136d375aa63871a57d3e444c85655c" ], - "version": "==0.1.12" + "version": "==2.2.1" }, - "promise": { + "pycparser": { "hashes": [ - "sha256:0bca4ed933e3d50e3d18fb54fc1432fa84b0564838cd093e824abcd718ab9304", - "sha256:95506bac89df7a495e0b8c813fd782dd1ae590decb52f95248e316c6659ca49b" + "sha256:a988718abfad80b6b157acce7bf130a30876d27603738ac39f140993246b25b3" ], - "version": "==2.1" + "version": "==2.19" }, "pyhamcrest": { "hashes": [ @@ -166,10 +218,10 @@ }, "pytz": { "hashes": [ - "sha256:a061aa0a9e06881eb8b3b2b43f05b9439d6583c206d0a6c340ff72a7b6669053", - "sha256:ffb9ef1de172603304d9d2819af6f5ece76f2e85ec10692a524dd876e72bf277" + "sha256:303879e36b721603cc54604edcac9d20401bdbe31e1e4fdee5b9f98d5d31dfda", + "sha256:d747dd3d23d77ef44c6a3526e274af6efeb0a6f1afd5a69ba4d5be4098c8e141" ], - "version": "==2018.5" + "version": "==2019.1" }, "rx": { "hashes": [ @@ -187,47 +239,64 @@ }, "six": { "hashes": [ - "sha256:70e8a77beed4562e7f14fe23a786b54f6296e34344c23bc42f07b15018ff98e9", - "sha256:832dc0e10feb1aa2c68dcc57dbb658f1c7e65b9b61af69048abc87a2db00a0eb" + "sha256:3350809f0555b11f552448330d0b52d5f24c91a322ea4a15ef22629740f3761c", + "sha256:d16a0141ec1a18405cd4ce8b4613101da75da0e9a7aec5bdd4fa804d0e0eba73" ], - "version": "==1.11.0" + "version": "==1.12.0" }, - "twisted": { + "sqlparse": { "hashes": [ - "sha256:5de7b79b26aee64efe63319bb8f037af88c21287d1502b39706c818065b3d5a4", - "sha256:95ae985716e8107816d8d0df249d558dbaabb677987cc2ace45272c166b267e4" + "sha256:40afe6b8d4b1117e7dff5504d7a8ce07d9a1b15aeeade8a2d10f130a834f8177", + "sha256:7c3dca29c022744e95b547e867cee89f4fce4373f3549ccd8797d8eb52cdb873" ], - "version": "==18.7.0" + "version": "==0.3.0" }, - "txaio": { + "twisted": { "hashes": [ - "sha256:4797f9f6a9866fe887c96abc0110a226dd5744c894dc3630870542597ad30853", - "sha256:c25acd6c2ef7005a0cd50fa2b65deac409be2f3886e2fcd04f99fae827b179e4" + "sha256:fa2c04c2d68a9be7fc3975ba4947f653a57a656776f24be58ff0fe4b9aaf3e52" ], - "version": "==2.10.0" + "version": "==19.2.1" }, - "typing": { + "txaio": { "hashes": [ - "sha256:3a887b021a77b292e151afb75323dea88a7bc1b3dfa92176cff8e44c8b68bddf", - "sha256:b2c689d54e1144bbcfd191b0832980a21c2dbcf7b5ff7a66248a60c90e951eb8", - "sha256:d400a9344254803a2368533e4533a4200d21eb7b6b729c173bc38201a74db3f2" + "sha256:67e360ac73b12c52058219bb5f8b3ed4105d2636707a36a7cdafb56fe06db7fe", + "sha256:b6b235d432cc58ffe111b43e337db71a5caa5d3eaa88f0eacf60b431c7626ef5" ], - "version": "==3.6.4" + "version": "==18.8.1" }, "zope.interface": { "hashes": [ - "sha256:21506674d30c009271fe68a242d330c83b1b9d76d62d03d87e1e9528c61beea6", - "sha256:3d184aff0756c44fff7de69eb4cd5b5311b6f452d4de28cb08343b3f21993763", - "sha256:467d364b24cb398f76ad5e90398d71b9325eb4232be9e8a50d6a3b3c7a1c8789", - "sha256:57c38470d9f57e37afb460c399eb254e7193ac7fb8042bd09bdc001981a9c74c", - "sha256:9ada83f4384bbb12dedc152bcdd46a3ac9f5f7720d43ac3ce3e8e8b91d733c10", - "sha256:a1daf9c5120f3cc6f2b5fef8e1d2a3fb7bbbb20ed4bfdc25bc8364bc62dcf54b", - "sha256:e6b77ae84f2b8502d99a7855fa33334a1eb6159de45626905cb3e454c023f339", - "sha256:e881ef610ff48aece2f4ee2af03d2db1a146dc7c705561bd6089b2356f61641f", - "sha256:f41037260deaacb875db250021fe883bf536bf6414a4fd25b25059b02e31b120" - ], - "markers": "python_version != '3.3.*' and python_version != '3.1.*' and python_version != '3.0.*' and python_version != '3.2.*' and python_version >= '2.7'", - "version": "==4.5.0" + "sha256:086707e0f413ff8800d9c4bc26e174f7ee4c9c8b0302fbad68d083071822316c", + "sha256:1157b1ec2a1f5bf45668421e3955c60c610e31913cc695b407a574efdbae1f7b", + "sha256:11ebddf765bff3bbe8dbce10c86884d87f90ed66ee410a7e6c392086e2c63d02", + "sha256:14b242d53f6f35c2d07aa2c0e13ccb710392bcd203e1b82a1828d216f6f6b11f", + "sha256:1b3d0dcabc7c90b470e59e38a9acaa361be43b3a6ea644c0063951964717f0e5", + "sha256:20a12ab46a7e72b89ce0671e7d7a6c3c1ca2c2766ac98112f78c5bddaa6e4375", + "sha256:298f82c0ab1b182bd1f34f347ea97dde0fffb9ecf850ecf7f8904b8442a07487", + "sha256:2f6175722da6f23dbfc76c26c241b67b020e1e83ec7fe93c9e5d3dd18667ada2", + "sha256:3b877de633a0f6d81b600624ff9137312d8b1d0f517064dfc39999352ab659f0", + "sha256:4265681e77f5ac5bac0905812b828c9fe1ce80c6f3e3f8574acfb5643aeabc5b", + "sha256:550695c4e7313555549aa1cdb978dc9413d61307531f123558e438871a883d63", + "sha256:5f4d42baed3a14c290a078e2696c5f565501abde1b2f3f1a1c0a94fbf6fbcc39", + "sha256:62dd71dbed8cc6a18379700701d959307823b3b2451bdc018594c48956ace745", + "sha256:7040547e5b882349c0a2cc9b50674b1745db551f330746af434aad4f09fba2cc", + "sha256:7e099fde2cce8b29434684f82977db4e24f0efa8b0508179fce1602d103296a2", + "sha256:7e5c9a5012b2b33e87980cee7d1c82412b2ebabcb5862d53413ba1a2cfde23aa", + "sha256:81295629128f929e73be4ccfdd943a0906e5fe3cdb0d43ff1e5144d16fbb52b1", + "sha256:95cc574b0b83b85be9917d37cd2fad0ce5a0d21b024e1a5804d044aabea636fc", + "sha256:968d5c5702da15c5bf8e4a6e4b67a4d92164e334e9c0b6acf080106678230b98", + "sha256:9e998ba87df77a85c7bed53240a7257afe51a07ee6bc3445a0bf841886da0b97", + "sha256:a0c39e2535a7e9c195af956610dba5a1073071d2d85e9d2e5d789463f63e52ab", + "sha256:a15e75d284178afe529a536b0e8b28b7e107ef39626a7809b4ee64ff3abc9127", + "sha256:a6a6ff82f5f9b9702478035d8f6fb6903885653bff7ec3a1e011edc9b1a7168d", + "sha256:b639f72b95389620c1f881d94739c614d385406ab1d6926a9ffe1c8abbea23fe", + "sha256:bad44274b151d46619a7567010f7cde23a908c6faa84b97598fd2f474a0c6891", + "sha256:bbcef00d09a30948756c5968863316c949d9cedbc7aabac5e8f0ffbdb632e5f1", + "sha256:d788a3999014ddf416f2dc454efa4a5dbeda657c6aba031cf363741273804c6b", + "sha256:eed88ae03e1ef3a75a0e96a55a99d7937ed03e53d0cffc2451c208db445a2966", + "sha256:f99451f3a579e73b5dd58b1b08d1179791d49084371d9a47baad3b22417f0317" + ], + "version": "==4.6.0" } }, "develop": {} From 2e1ddfdb38c969bef3a61d2dddf8f663fa518614 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Mon, 22 Jul 2019 11:05:16 +1200 Subject: [PATCH 19/72] Get rid of rx.Observable django channels example, async is simpler --- .../django_channels2/django_channels2/schema.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/examples/django_channels2/django_channels2/schema.py b/examples/django_channels2/django_channels2/schema.py index 66e15ee..158ba13 100644 --- a/examples/django_channels2/django_channels2/schema.py +++ b/examples/django_channels2/django_channels2/schema.py @@ -1,7 +1,8 @@ +import asyncio + import graphene -from rx import Observable -from channels.layers import get_channel_layer from asgiref.sync import async_to_sync +from channels.layers import get_channel_layer channel_layer = get_channel_layer() @@ -32,12 +33,12 @@ class Subscription(graphene.ObjectType): count_seconds = graphene.Int(up_to=graphene.Int()) new_message = graphene.String() - def resolve_count_seconds(self, info, up_to=5): - return ( - Observable.interval(1000) - .map(lambda i: "{0}".format(i)) - .take_while(lambda i: int(i) <= up_to) - ) + async def resolve_count_seconds(self, info, up_to=5): + i = 1 + while i <= up_to: + yield str(i) + await asyncio.sleep(1) + i += 1 async def resolve_new_message(self, info): channel_name = await channel_layer.new_channel() From 8d6fe9360c3c531209c4b62eb85be20ddd13c5dc Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Mon, 22 Jul 2019 13:03:38 +1200 Subject: [PATCH 20/72] Use common middleware in channels2 example --- examples/django_channels2/django_channels2/settings.py | 3 +++ examples/django_channels2/django_channels2/urls.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/django_channels2/django_channels2/settings.py b/examples/django_channels2/django_channels2/settings.py index a635965..6c7b22b 100644 --- a/examples/django_channels2/django_channels2/settings.py +++ b/examples/django_channels2/django_channels2/settings.py @@ -21,6 +21,9 @@ } ] +MIDDLEWARE = [ + 'django.middleware.common.CommonMiddleware', +] ROOT_URLCONF = "django_channels2.urls" ASGI_APPLICATION = "graphql_ws.django.routing.application" diff --git a/examples/django_channels2/django_channels2/urls.py b/examples/django_channels2/django_channels2/urls.py index f4470a6..d0e41c4 100644 --- a/examples/django_channels2/django_channels2/urls.py +++ b/examples/django_channels2/django_channels2/urls.py @@ -1,4 +1,4 @@ from django.urls import path from graphene_django.views import GraphQLView -urlpatterns = [path("graphql", GraphQLView.as_view(graphiql=True))] +urlpatterns = [path("graphql/", GraphQLView.as_view(graphiql=True))] From 7e969a290452eb7b2fa5efb216b6ad0ef41b1f83 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Wed, 24 Jul 2019 10:26:12 +1200 Subject: [PATCH 21/72] Change the graphql-core requirement to avoid 3.0.0a0 install in pip --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index bee9ce6..a44dc06 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,7 @@ history = history_file.read() requirements = [ - "graphql-core>=2.0<3", + "graphql-core==2.*", # TODO: put package requirements here ] From b1b15e0db5e050b64f013ab3ca36437116ea6f86 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Tue, 7 Jan 2020 10:10:18 +1300 Subject: [PATCH 22/72] Safer django socket sending Sometimes the connection gets closed just before a message is sent This will avoid the ConnectionClosedOK: code = 1001 exception --- graphql_ws/django/subscriptions.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/graphql_ws/django/subscriptions.py b/graphql_ws/django/subscriptions.py index fa4910b..8473a74 100644 --- a/graphql_ws/django/subscriptions.py +++ b/graphql_ws/django/subscriptions.py @@ -11,6 +11,8 @@ class ChannelsConnectionContext(BaseConnectionContext): async def send(self, data): + if self.closed: + return await self.ws.send_json(data) async def close(self, code): From d7689b7d5460146e04253af2b0a16dd8eb52f7a1 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Tue, 7 Jan 2020 10:29:27 +1300 Subject: [PATCH 23/72] Add a django connection_context.closed prop --- graphql_ws/django/subscriptions.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/graphql_ws/django/subscriptions.py b/graphql_ws/django/subscriptions.py index 8473a74..9b216b8 100644 --- a/graphql_ws/django/subscriptions.py +++ b/graphql_ws/django/subscriptions.py @@ -15,6 +15,10 @@ async def send(self, data): return await self.ws.send_json(data) + @property + def closed(self): + return self.ws.closed + async def close(self, code): await self.ws.close(code=code) From faea9cf168cc43ce4b98de5d8e560d9fe8beb52f Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Wed, 8 Jan 2020 09:53:42 +1300 Subject: [PATCH 24/72] Django Channels connection closed fix Channels can't tell if the socket is closed since different engines can be used, so explicitly set our own closed prop to work with. --- graphql_ws/django/consumers.py | 1 + graphql_ws/django/subscriptions.py | 8 ++++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/graphql_ws/django/consumers.py b/graphql_ws/django/consumers.py index 2cd3888..77ccb49 100644 --- a/graphql_ws/django/consumers.py +++ b/graphql_ws/django/consumers.py @@ -34,6 +34,7 @@ async def connect(self): async def disconnect(self, code): if self.connection_context: + self.connection_context.closed = True await subscription_server.on_close(self.connection_context) async def receive_json(self, content): diff --git a/graphql_ws/django/subscriptions.py b/graphql_ws/django/subscriptions.py index 9b216b8..e16a574 100644 --- a/graphql_ws/django/subscriptions.py +++ b/graphql_ws/django/subscriptions.py @@ -10,15 +10,15 @@ class ChannelsConnectionContext(BaseConnectionContext): + def __init__(self, *args, **kwargs): + super(ChannelsConnectionContext, self).__init__(*args, **kwargs) + self.closed = False + async def send(self, data): if self.closed: return await self.ws.send_json(data) - @property - def closed(self): - return self.ws.closed - async def close(self, code): await self.ws.close(code=code) From 184367b4660f085a7d8534550acef3a8d9b255d3 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Wed, 8 Jan 2020 11:43:36 +1300 Subject: [PATCH 25/72] Fix a property collision in django connectioncontext --- graphql_ws/django/consumers.py | 2 +- graphql_ws/django/subscriptions.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/graphql_ws/django/consumers.py b/graphql_ws/django/consumers.py index 77ccb49..7b56233 100644 --- a/graphql_ws/django/consumers.py +++ b/graphql_ws/django/consumers.py @@ -34,7 +34,7 @@ async def connect(self): async def disconnect(self, code): if self.connection_context: - self.connection_context.closed = True + self.connection_context.socket_closed = True await subscription_server.on_close(self.connection_context) async def receive_json(self, content): diff --git a/graphql_ws/django/subscriptions.py b/graphql_ws/django/subscriptions.py index e16a574..7a4f4dc 100644 --- a/graphql_ws/django/subscriptions.py +++ b/graphql_ws/django/subscriptions.py @@ -12,13 +12,17 @@ class ChannelsConnectionContext(BaseConnectionContext): def __init__(self, *args, **kwargs): super(ChannelsConnectionContext, self).__init__(*args, **kwargs) - self.closed = False + self.socket_closed = False async def send(self, data): if self.closed: return await self.ws.send_json(data) + @property + def closed(self): + return self.socket_closed + async def close(self, code): await self.ws.close(code=code) From f805f3e11a3419d3ed55bd5b32e9d0aac5c2efc8 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Mon, 4 May 2020 16:27:21 +1200 Subject: [PATCH 26/72] Ensure the Django subscription consumer cleans up any dangling tasks when disconnecting --- graphql_ws/django/consumers.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/graphql_ws/django/consumers.py b/graphql_ws/django/consumers.py index 7b56233..2a449fd 100644 --- a/graphql_ws/django/consumers.py +++ b/graphql_ws/django/consumers.py @@ -22,6 +22,10 @@ def default(self, o): class GraphQLSubscriptionConsumer(AsyncJsonWebsocketConsumer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.futures = [] + async def connect(self): self.connection_context = None if WS_PROTOCOL in self.scope["subprotocols"]: @@ -33,14 +37,22 @@ async def connect(self): await self.close() async def disconnect(self, code): + for future in self.futures: + # Ensure any running message tasks are cancelled. + future.cancel() if self.connection_context: self.connection_context.socket_closed = True - await subscription_server.on_close(self.connection_context) + close_future = subscription_server.on_close(self.connection_context) + await asyncio.gather(close_future, *self.futures) async def receive_json(self, content): - asyncio.ensure_future( - subscription_server.on_message(self.connection_context, content) + self.futures.append( + asyncio.ensure_future( + subscription_server.on_message(self.connection_context, content) + ) ) + # Clean up any completed futures. + self.futures = [future for future in self.futures if not future.done()] @classmethod async def encode_json(cls, content): From c00066babc8de09af521e3ffcbf360c74808e6b4 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Mon, 18 May 2020 02:54:17 +1200 Subject: [PATCH 27/72] Split base classes into sync and async Deduplicate code --- graphql_ws/__init__.py | 3 - graphql_ws/aiohttp.py | 93 +++----------------- graphql_ws/base.py | 150 ++++++++++++++------------------- graphql_ws/base_async.py | 118 ++++++++++++++++++++++++++ graphql_ws/base_sync.py | 88 +++++++++++++++++++ graphql_ws/django_channels.py | 97 ++------------------- graphql_ws/gevent.py | 79 ++--------------- graphql_ws/observable_aiter.py | 47 +---------- graphql_ws/websockets_lib.py | 93 +++----------------- 9 files changed, 312 insertions(+), 456 deletions(-) create mode 100644 graphql_ws/base_async.py create mode 100644 graphql_ws/base_sync.py diff --git a/graphql_ws/__init__.py b/graphql_ws/__init__.py index 44c7dc3..793831a 100644 --- a/graphql_ws/__init__.py +++ b/graphql_ws/__init__.py @@ -5,6 +5,3 @@ __author__ = """Syrus Akbary""" __email__ = 'me@syrusakbary.com' __version__ = '0.3.1' - - -from .base import BaseConnectionContext, BaseSubscriptionServer # noqa: F401 diff --git a/graphql_ws/aiohttp.py b/graphql_ws/aiohttp.py index 363ca67..49e0a5e 100644 --- a/graphql_ws/aiohttp.py +++ b/graphql_ws/aiohttp.py @@ -1,23 +1,13 @@ -from inspect import isawaitable -from asyncio import ensure_future, wait, shield +import json +from asyncio import ensure_future, shield from aiohttp import WSMsgType -from graphql.execution.executors.asyncio import AsyncioExecutor -from .base import ( - ConnectionClosedException, BaseConnectionContext, BaseSubscriptionServer) -from .observable_aiter import setup_observable_extension +from .base import ConnectionClosedException +from .base_async import BaseAsyncConnectionContext, BaseAsyncSubscriptionServer -from .constants import ( - GQL_CONNECTION_ACK, - GQL_CONNECTION_ERROR, - GQL_COMPLETE -) -setup_observable_extension() - - -class AiohttpConnectionContext(BaseConnectionContext): +class AiohttpConnectionContext(BaseAsyncConnectionContext): async def receive(self): msg = await self.ws.receive() if msg.type == WSMsgType.TEXT: @@ -32,7 +22,7 @@ async def receive(self): async def send(self, data): if self.closed: return - await self.ws.send_str(data) + await self.ws.send_str(json.dumps(data)) @property def closed(self): @@ -42,21 +32,10 @@ async def close(self, code): await self.ws.close(code=code) -class AiohttpSubscriptionServer(BaseSubscriptionServer): - def __init__(self, schema, keep_alive=True, loop=None): - self.loop = loop - super().__init__(schema, keep_alive) - - def get_graphql_params(self, *args, **kwargs): - params = super(AiohttpSubscriptionServer, - self).get_graphql_params(*args, **kwargs) - return dict(params, return_promise=True, - executor=AsyncioExecutor(loop=self.loop)) - +class AiohttpSubscriptionServer(BaseAsyncSubscriptionServer): async def _handle(self, ws, request_context=None): connection_context = AiohttpConnectionContext(ws, request_context) await self.on_open(connection_context) - pending = set() while True: try: if connection_context.closed: @@ -64,59 +43,13 @@ async def _handle(self, ws, request_context=None): message = await connection_context.receive() except ConnectionClosedException: break - finally: - if pending: - (_, pending) = await wait(pending, timeout=0, loop=self.loop) - task = ensure_future( - self.on_message(connection_context, message), loop=self.loop) - pending.add(task) - - self.on_close(connection_context) - for task in pending: - task.cancel() + connection_context.remember_task( + ensure_future( + self.on_message(connection_context, message), loop=self.loop + ) + ) + await self.on_close(connection_context) async def handle(self, ws, request_context=None): await shield(self._handle(ws, request_context), loop=self.loop) - - async def on_open(self, connection_context): - pass - - def on_close(self, connection_context): - remove_operations = list(connection_context.operations.keys()) - for op_id in remove_operations: - self.unsubscribe(connection_context, op_id) - - async def on_connect(self, connection_context, payload): - pass - - async def on_connection_init(self, connection_context, op_id, payload): - try: - await self.on_connect(connection_context, payload) - await self.send_message(connection_context, op_type=GQL_CONNECTION_ACK) - except Exception as e: - await self.send_error(connection_context, op_id, e, GQL_CONNECTION_ERROR) - await connection_context.close(1011) - - async def on_start(self, connection_context, op_id, params): - execution_result = self.execute( - connection_context.request_context, params) - - if isawaitable(execution_result): - execution_result = await execution_result - - if not hasattr(execution_result, '__aiter__'): - await self.send_execution_result( - connection_context, op_id, execution_result) - else: - iterator = await execution_result.__aiter__() - connection_context.register_operation(op_id, iterator) - async for single_result in iterator: - if not connection_context.has_operation(op_id): - break - await self.send_execution_result( - connection_context, op_id, single_result) - await self.send_message(connection_context, op_id, GQL_COMPLETE) - - async def on_stop(self, connection_context, op_id): - self.unsubscribe(connection_context, op_id) diff --git a/graphql_ws/base.py b/graphql_ws/base.py index f3aa1e7..d146419 100644 --- a/graphql_ws/base.py +++ b/graphql_ws/base.py @@ -1,16 +1,16 @@ import json from collections import OrderedDict -from graphql import graphql, format_error +from graphql import format_error from .constants import ( + GQL_CONNECTION_ERROR, GQL_CONNECTION_INIT, GQL_CONNECTION_TERMINATE, + GQL_DATA, + GQL_ERROR, GQL_START, GQL_STOP, - GQL_ERROR, - GQL_CONNECTION_ERROR, - GQL_DATA ) @@ -51,33 +51,16 @@ def close(self, code): class BaseSubscriptionServer(object): + graphql_executor = None def __init__(self, schema, keep_alive=True): self.schema = schema self.keep_alive = keep_alive - def get_graphql_params(self, connection_context, payload): - return { - 'request_string': payload.get('query'), - 'variable_values': payload.get('variables'), - 'operation_name': payload.get('operationName'), - 'context_value': payload.get('context'), - } - - def build_message(self, id, op_type, payload): - message = {} - if id is not None: - message['id'] = id - if op_type is not None: - message['type'] = op_type - if payload is not None: - message['payload'] = payload - return message - def process_message(self, connection_context, parsed_message): - op_id = parsed_message.get('id') - op_type = parsed_message.get('type') - payload = parsed_message.get('payload') + op_id = parsed_message.get("id") + op_type = parsed_message.get("type") + payload = parsed_message.get("payload") if op_type == GQL_CONNECTION_INIT: return self.on_connection_init(connection_context, op_id, payload) @@ -92,7 +75,8 @@ def process_message(self, connection_context, parsed_message): if not isinstance(params, dict): error = Exception( "Invalid params returned from get_graphql_params!" - " Return values must be a dict.") + " Return values must be a dict." + ) return self.send_error(connection_context, op_id, error) # If we already have a subscription with this id, unsubscribe from @@ -100,14 +84,54 @@ def process_message(self, connection_context, parsed_message): if connection_context.has_operation(op_id): self.unsubscribe(connection_context, op_id) + params = self.get_graphql_params(connection_context, payload) return self.on_start(connection_context, op_id, params) elif op_type == GQL_STOP: return self.on_stop(connection_context, op_id) else: - return self.send_error(connection_context, op_id, Exception( - "Invalid message type: {}.".format(op_type))) + return self.send_error( + connection_context, + op_id, + Exception("Invalid message type: {}.".format(op_type)), + ) + + def on_connection_init(self, connection_context, op_id, payload): + raise NotImplementedError("on_connection_init method not implemented") + + def on_connection_terminate(self, connection_context, op_id): + return connection_context.close(1011) + + def get_graphql_params(self, connection_context, payload): + return { + "request_string": payload.get("query"), + "variable_values": payload.get("variables"), + "operation_name": payload.get("operationName"), + "context_value": payload.get("context"), + "executor": self.graphql_executor(), + } + + def on_open(self, connection_context): + raise NotImplementedError("on_open method not implemented") + + def on_stop(self, connection_context, op_id): + raise NotImplementedError("on_stop method not implemented") + + def send_message(self, connection_context, op_id=None, op_type=None, payload=None): + message = self.build_message(op_id, op_type, payload) + assert message, "You need to send at least one thing" + return connection_context.send(message) + + def build_message(self, id, op_type, payload): + message = {} + if id is not None: + message["id"] = id + if op_type is not None: + message["type"] = op_type + if payload is not None: + message["payload"] = payload + return message def send_execution_result(self, connection_context, op_id, execution_result): result = self.execution_result_to_dict(execution_result) @@ -116,86 +140,34 @@ def send_execution_result(self, connection_context, op_id, execution_result): def execution_result_to_dict(self, execution_result): result = OrderedDict() if execution_result.data: - result['data'] = execution_result.data + result["data"] = execution_result.data if execution_result.errors: - result['errors'] = [format_error(error) - for error in execution_result.errors] + result["errors"] = [ + format_error(error) for error in execution_result.errors + ] return result - def send_message(self, connection_context, op_id=None, op_type=None, payload=None): - message = self.build_message(op_id, op_type, payload) - assert message, "You need to send at least one thing" - json_message = json.dumps(message) - return connection_context.send(json_message) - def send_error(self, connection_context, op_id, error, error_type=None): if error_type is None: error_type = GQL_ERROR assert error_type in [GQL_CONNECTION_ERROR, GQL_ERROR], ( - 'error_type should be one of the allowed error messages' - ' GQL_CONNECTION_ERROR or GQL_ERROR' - ) - - error_payload = { - 'message': str(error) - } - - return self.send_message( - connection_context, - op_id, - error_type, - error_payload + "error_type should be one of the allowed error messages" + " GQL_CONNECTION_ERROR or GQL_ERROR" ) - def unsubscribe(self, connection_context, op_id): - if connection_context.has_operation(op_id): - # Close async iterator - connection_context.get_operation(op_id).dispose() - # Close operation - connection_context.remove_operation(op_id) - self.on_operation_complete(connection_context, op_id) + error_payload = {"message": str(error)} - def on_operation_complete(self, connection_context, op_id): - pass - - def on_connection_terminate(self, connection_context, op_id): - return connection_context.close(1011) - - def execute(self, request_context, params): - return graphql( - self.schema, **dict(params, allow_subscriptions=True)) - - def handle(self, ws, request_context=None): - raise NotImplementedError("handle method not implemented") + return self.send_message(connection_context, op_id, error_type, error_payload) def on_message(self, connection_context, message): try: if not isinstance(message, dict): parsed_message = json.loads(message) - assert isinstance( - parsed_message, dict), "Payload must be an object." + assert isinstance(parsed_message, dict), "Payload must be an object." else: parsed_message = message except Exception as e: return self.send_error(connection_context, None, e) return self.process_message(connection_context, parsed_message) - - def on_open(self, connection_context): - raise NotImplementedError("on_open method not implemented") - - def on_connect(self, connection_context, payload): - raise NotImplementedError("on_connect method not implemented") - - def on_close(self, connection_context): - raise NotImplementedError("on_close method not implemented") - - def on_connection_init(self, connection_context, op_id, payload): - raise NotImplementedError("on_connection_init method not implemented") - - def on_stop(self, connection_context, op_id): - raise NotImplementedError("on_stop method not implemented") - - def on_start(self, connection_context, op_id, params): - raise NotImplementedError("on_start method not implemented") diff --git a/graphql_ws/base_async.py b/graphql_ws/base_async.py new file mode 100644 index 0000000..d067a8d --- /dev/null +++ b/graphql_ws/base_async.py @@ -0,0 +1,118 @@ +import asyncio +from abc import ABC, abstractmethod +from inspect import isawaitable +from weakref import WeakSet + +from graphql.execution.executors.asyncio import AsyncioExecutor + +from graphql_ws import base + +from .constants import GQL_COMPLETE, GQL_CONNECTION_ACK, GQL_CONNECTION_ERROR +from .observable_aiter import setup_observable_extension + +setup_observable_extension() + + +class BaseAsyncConnectionContext(base.BaseConnectionContext, ABC): + def __init__(self, ws, request_context=None): + super().__init__(ws, request_context=request_context) + self.pending_tasks = WeakSet() + + @abstractmethod + async def receive(self): + raise NotImplementedError("receive method not implemented") + + @abstractmethod + async def send(self, data): + ... + + @property + @abstractmethod + def closed(self): + ... + + @abstractmethod + async def close(self, code): + ... + + def remember_task(self, task): + self.pending_tasks.add(asyncio.ensure_future(task)) + # Clear completed tasks + self.pending_tasks -= WeakSet( + task for task in self.pending_tasks if task.done() + ) + + +class BaseAsyncSubscriptionServer(base.BaseSubscriptionServer, ABC): + graphql_executor = AsyncioExecutor + + def __init__(self, schema, keep_alive=True, loop=None): + self.loop = loop + super().__init__(schema, keep_alive) + + @abstractmethod + async def handle(self, ws, request_context=None): + ... + + def process_message(self, connection_context, parsed_message): + task = asyncio.ensure_future( + super().process_message(connection_context, parsed_message) + ) + connection_context.pending.add(task) + return task + + async def send_message(self, *args, **kwargs): + await super().send_message(*args, **kwargs) + + async def on_open(self, connection_context): + pass + + async def on_connect(self, connection_context, payload): + pass + + async def on_connection_init(self, connection_context, op_id, payload): + try: + await self.on_connect(connection_context, payload) + await self.send_message(connection_context, op_type=GQL_CONNECTION_ACK) + except Exception as e: + await self.send_error(connection_context, op_id, e, GQL_CONNECTION_ERROR) + await connection_context.close(1011) + + async def on_start(self, connection_context, op_id, params): + execution_result = self.execute(connection_context.request_context, params) + + if isawaitable(execution_result): + execution_result = await execution_result + + if hasattr(execution_result, "__aiter__"): + iterator = await execution_result.__aiter__() + connection_context.register_operation(op_id, iterator) + async for single_result in iterator: + if not connection_context.has_operation(op_id): + break + await self.send_execution_result( + connection_context, op_id, single_result + ) + else: + await self.send_execution_result( + connection_context, op_id, execution_result + ) + await self.send_message(connection_context, op_id, GQL_COMPLETE) + await self.on_operation_complete(connection_context, op_id) + + async def on_close(self, connection_context): + awaitables = tuple( + self.unsubscribe(connection_context, op_id) + for op_id in connection_context.operations + ) + tuple(task.cancel() for task in connection_context.pending_tasks) + if awaitables: + await asyncio.gather(*awaitables, loop=self.loop) + + async def on_stop(self, connection_context, op_id): + await self.unsubscribe(connection_context, op_id) + + async def unsubscribe(self, connection_context, op_id): + await super().unsubscribe(connection_context, op_id) + + async def on_operation_complete(self, connection_context, op_id): + pass diff --git a/graphql_ws/base_sync.py b/graphql_ws/base_sync.py new file mode 100644 index 0000000..b7cb412 --- /dev/null +++ b/graphql_ws/base_sync.py @@ -0,0 +1,88 @@ +from graphql import graphql +from graphql.execution.executors.sync import SyncExecutor +from rx import Observable, Observer + +from .base import BaseSubscriptionServer +from .constants import GQL_CONNECTION_ACK, GQL_CONNECTION_ERROR + + +class BaseSyncSubscriptionServer(BaseSubscriptionServer): + graphql_executor = SyncExecutor + + def unsubscribe(self, connection_context, op_id): + if connection_context.has_operation(op_id): + # Close async iterator + connection_context.get_operation(op_id).dispose() + # Close operation + connection_context.remove_operation(op_id) + self.on_operation_complete(connection_context, op_id) + + def on_operation_complete(self, connection_context, op_id): + pass + + def execute(self, request_context, params): + return graphql(self.schema, **dict(params, allow_subscriptions=True)) + + def handle(self, ws, request_context=None): + raise NotImplementedError("handle method not implemented") + + def on_open(self, connection_context): + pass + + def on_connect(self, connection_context, payload): + pass + + def on_close(self, connection_context): + remove_operations = list(connection_context.operations.keys()) + for op_id in remove_operations: + self.unsubscribe(connection_context, op_id) + + def on_connection_init(self, connection_context, op_id, payload): + try: + self.on_connect(connection_context, payload) + self.send_message(connection_context, op_type=GQL_CONNECTION_ACK) + + except Exception as e: + self.send_error(connection_context, op_id, e, GQL_CONNECTION_ERROR) + connection_context.close(1011) + + def on_stop(self, connection_context, op_id): + self.unsubscribe(connection_context, op_id) + + def on_start(self, connection_context, op_id, params): + try: + execution_result = self.execute(connection_context.request_context, params) + assert isinstance( + execution_result, Observable + ), "A subscription must return an observable" + execution_result.subscribe( + SubscriptionObserver( + connection_context, + op_id, + self.send_execution_result, + self.send_error, + self.on_close, + ) + ) + except Exception as e: + self.send_error(connection_context, op_id, str(e)) + + +class SubscriptionObserver(Observer): + def __init__( + self, connection_context, op_id, send_execution_result, send_error, on_close + ): + self.connection_context = connection_context + self.op_id = op_id + self.send_execution_result = send_execution_result + self.send_error = send_error + self.on_close = on_close + + def on_next(self, value): + self.send_execution_result(self.connection_context, self.op_id, value) + + def on_completed(self): + self.on_close(self.connection_context) + + def on_error(self, error): + self.send_error(self.connection_context, self.op_id, error) diff --git a/graphql_ws/django_channels.py b/graphql_ws/django_channels.py index 61a7247..fbee47b 100644 --- a/graphql_ws/django_channels.py +++ b/graphql_ws/django_channels.py @@ -1,94 +1,30 @@ import json -from rx import Observer, Observable -from graphql.execution.executors.sync import SyncExecutor - from channels.generic.websockets import JsonWebsocketConsumer from graphene_django.settings import graphene_settings -from .base import BaseConnectionContext, BaseSubscriptionServer -from .constants import GQL_CONNECTION_ACK, GQL_CONNECTION_ERROR +from .base import BaseConnectionContext +from .base_sync import BaseSyncSubscriptionServer class DjangoChannelConnectionContext(BaseConnectionContext): - def __init__(self, message, request_context=None): self.message = message self.operations = {} self.request_context = request_context def send(self, data): - self.message.reply_channel.send(data) + self.message.reply_channel.send({"text": json.dumps(data)}) def close(self, reason): - data = { - 'close': True, - 'text': reason - } + data = {"close": True, "text": reason} self.message.reply_channel.send(data) -class DjangoChannelSubscriptionServer(BaseSubscriptionServer): - - def get_graphql_params(self, *args, **kwargs): - params = super(DjangoChannelSubscriptionServer, - self).get_graphql_params(*args, **kwargs) - return dict(params, executor=SyncExecutor()) - +class DjangoChannelSubscriptionServer(BaseSyncSubscriptionServer): def handle(self, message, connection_context): self.on_message(connection_context, message) - def send_message(self, connection_context, op_id=None, op_type=None, payload=None): - message = {} - if op_id is not None: - message['id'] = op_id - if op_type is not None: - message['type'] = op_type - if payload is not None: - message['payload'] = payload - - assert message, "You need to send at least one thing" - return connection_context.send({'text': json.dumps(message)}) - - def on_open(self, connection_context): - pass - - def on_connect(self, connection_context, payload): - pass - - def on_connection_init(self, connection_context, op_id, payload): - try: - self.on_connect(connection_context, payload) - self.send_message(connection_context, op_type=GQL_CONNECTION_ACK) - - except Exception as e: - self.send_error(connection_context, op_id, e, GQL_CONNECTION_ERROR) - connection_context.close(1011) - - def on_start(self, connection_context, op_id, params): - try: - execution_result = self.execute( - connection_context.request_context, params) - assert isinstance(execution_result, Observable), \ - "A subscription must return an observable" - execution_result.subscribe(SubscriptionObserver( - connection_context, - op_id, - self.send_execution_result, - self.send_error, - self.on_close - )) - except Exception as e: - self.send_error(connection_context, op_id, str(e)) - - def on_close(self, connection_context): - remove_operations = list(connection_context.operations.keys()) - for op_id in remove_operations: - self.unsubscribe(connection_context, op_id) - - def on_stop(self, connection_context, op_id): - self.unsubscribe(connection_context, op_id) - class GraphQLSubscriptionConsumer(JsonWebsocketConsumer): http_user_and_session = True @@ -104,26 +40,7 @@ def receive(self, content, **_kwargs): """ self.connection_context = DjangoChannelConnectionContext(self.message) self.subscription_server = DjangoChannelSubscriptionServer( - graphene_settings.SCHEMA) + graphene_settings.SCHEMA + ) self.subscription_server.on_open(self.connection_context) self.subscription_server.handle(content, self.connection_context) - - -class SubscriptionObserver(Observer): - - def __init__(self, connection_context, op_id, - send_execution_result, send_error, on_close): - self.connection_context = connection_context - self.op_id = op_id - self.send_execution_result = send_execution_result - self.send_error = send_error - self.on_close = on_close - - def on_next(self, value): - self.send_execution_result(self.connection_context, self.op_id, value) - - def on_completed(self): - self.on_close(self.connection_context) - - def on_error(self, error): - self.send_error(self.connection_context, self.op_id, error) diff --git a/graphql_ws/gevent.py b/graphql_ws/gevent.py index aadbe64..b7d6849 100644 --- a/graphql_ws/gevent.py +++ b/graphql_ws/gevent.py @@ -1,15 +1,15 @@ from __future__ import absolute_import -from rx import Observer, Observable -from graphql.execution.executors.sync import SyncExecutor +import json from .base import ( - ConnectionClosedException, BaseConnectionContext, BaseSubscriptionServer) -from .constants import GQL_CONNECTION_ACK, GQL_CONNECTION_ERROR + BaseConnectionContext, + ConnectionClosedException, +) +from .base_sync import BaseSyncSubscriptionServer class GeventConnectionContext(BaseConnectionContext): - def receive(self): msg = self.ws.receive() return msg @@ -17,7 +17,7 @@ def receive(self): def send(self, data): if self.closed: return - self.ws.send(data) + self.ws.send(json.dumps(data)) @property def closed(self): @@ -27,13 +27,7 @@ def close(self, code): self.ws.close(code) -class GeventSubscriptionServer(BaseSubscriptionServer): - - def get_graphql_params(self, *args, **kwargs): - params = super(GeventSubscriptionServer, - self).get_graphql_params(*args, **kwargs) - return dict(params, executor=SyncExecutor()) - +class GeventSubscriptionServer(BaseSyncSubscriptionServer): def handle(self, ws, request_context=None): connection_context = GeventConnectionContext(ws, request_context) self.on_open(connection_context) @@ -46,62 +40,3 @@ def handle(self, ws, request_context=None): self.on_close(connection_context) return self.on_message(connection_context, message) - - def on_open(self, connection_context): - pass - - def on_connect(self, connection_context, payload): - pass - - def on_close(self, connection_context): - remove_operations = list(connection_context.operations.keys()) - for op_id in remove_operations: - self.unsubscribe(connection_context, op_id) - - def on_connection_init(self, connection_context, op_id, payload): - try: - self.on_connect(connection_context, payload) - self.send_message(connection_context, op_type=GQL_CONNECTION_ACK) - - except Exception as e: - self.send_error(connection_context, op_id, e, GQL_CONNECTION_ERROR) - connection_context.close(1011) - - def on_start(self, connection_context, op_id, params): - try: - execution_result = self.execute( - connection_context.request_context, params) - assert isinstance(execution_result, Observable), \ - "A subscription must return an observable" - execution_result.subscribe(SubscriptionObserver( - connection_context, - op_id, - self.send_execution_result, - self.send_error, - self.on_close - )) - except Exception as e: - self.send_error(connection_context, op_id, str(e)) - - def on_stop(self, connection_context, op_id): - self.unsubscribe(connection_context, op_id) - - -class SubscriptionObserver(Observer): - - def __init__(self, connection_context, op_id, - send_execution_result, send_error, on_close): - self.connection_context = connection_context - self.op_id = op_id - self.send_execution_result = send_execution_result - self.send_error = send_error - self.on_close = on_close - - def on_next(self, value): - self.send_execution_result(self.connection_context, self.op_id, value) - - def on_completed(self): - self.on_close(self.connection_context) - - def on_error(self, error): - self.send_error(self.connection_context, self.op_id, error) diff --git a/graphql_ws/observable_aiter.py b/graphql_ws/observable_aiter.py index 0bd1a59..424d95f 100644 --- a/graphql_ws/observable_aiter.py +++ b/graphql_ws/observable_aiter.py @@ -1,7 +1,7 @@ from asyncio import Future -from rx.internal import extensionmethod from rx.core import Observable +from rx.internal import extensionmethod async def __aiter__(self): @@ -13,15 +13,11 @@ def __init__(self): self.future = Future() self.disposable = source.materialize().subscribe(self.on_next) - # self.disposed = False def __aiter__(self): return self def dispose(self): - # self.future.cancel() - # self.disposed = True - # self.future.set_exception(StopAsyncIteration) self.disposable.dispose() def feeder(self): @@ -30,11 +26,11 @@ def feeder(self): notification = self.notifications.pop(0) kind = notification.kind - if kind == 'N': + if kind == "N": self.future.set_result(notification.value) - if kind == 'E': + if kind == "E": self.future.set_exception(notification.exception) - if kind == 'C': + if kind == "C": self.future.set_exception(StopAsyncIteration) def on_next(self, notification): @@ -42,8 +38,6 @@ def on_next(self, notification): self.feeder() async def __anext__(self): - # if self.disposed: - # raise StopAsyncIteration self.feeder() value = await self.future @@ -53,38 +47,5 @@ async def __anext__(self): return AIterator() -# def __aiter__(self, sentinel=None): -# loop = get_event_loop() -# future = [Future()] -# notifications = [] - -# def feeder(): -# if not len(notifications) or future[0].done(): -# return -# notification = notifications.pop(0) -# if notification.kind == "E": -# future[0].set_exception(notification.exception) -# elif notification.kind == "C": -# future[0].set_exception(StopIteration(sentinel)) -# else: -# future[0].set_result(notification.value) - -# def on_next(value): -# """Takes on_next values and appends them to the notification queue""" -# notifications.append(value) -# loop.call_soon(feeder) - -# self.materialize().subscribe(on_next) - -# @asyncio.coroutine -# def gen(): -# """Generator producing futures""" -# loop.call_soon(feeder) -# future[0] = Future() -# return future[0] - -# return gen - - def setup_observable_extension(): extensionmethod(Observable)(__aiter__) diff --git a/graphql_ws/websockets_lib.py b/graphql_ws/websockets_lib.py index 7e78d5d..93ad76f 100644 --- a/graphql_ws/websockets_lib.py +++ b/graphql_ws/websockets_lib.py @@ -1,19 +1,13 @@ -from inspect import isawaitable -from asyncio import ensure_future, wait, shield -from websockets import ConnectionClosed -from graphql.execution.executors.asyncio import AsyncioExecutor - -from .base import ( - ConnectionClosedException, BaseConnectionContext, BaseSubscriptionServer) -from .observable_aiter import setup_observable_extension +import json +from asyncio import ensure_future, shield -from .constants import ( - GQL_CONNECTION_ACK, GQL_CONNECTION_ERROR, GQL_COMPLETE) +from websockets import ConnectionClosed -setup_observable_extension() +from .base import ConnectionClosedException +from .base_async import BaseAsyncConnectionContext, BaseAsyncSubscriptionServer -class WsLibConnectionContext(BaseConnectionContext): +class WsLibConnectionContext(BaseAsyncConnectionContext): async def receive(self): try: msg = await self.ws.recv() @@ -24,7 +18,7 @@ async def receive(self): async def send(self, data): if self.closed: return - await self.ws.send(data) + await self.ws.send(json.dumps(data)) @property def closed(self): @@ -34,21 +28,10 @@ async def close(self, code): await self.ws.close(code) -class WsLibSubscriptionServer(BaseSubscriptionServer): - def __init__(self, schema, keep_alive=True, loop=None): - self.loop = loop - super().__init__(schema, keep_alive) - - def get_graphql_params(self, *args, **kwargs): - params = super(WsLibSubscriptionServer, - self).get_graphql_params(*args, **kwargs) - return dict(params, return_promise=True, - executor=AsyncioExecutor(loop=self.loop)) - +class WsLibSubscriptionServer(BaseAsyncSubscriptionServer): async def _handle(self, ws, request_context): connection_context = WsLibConnectionContext(ws, request_context) await self.on_open(connection_context) - pending = set() while True: try: if connection_context.closed: @@ -56,61 +39,13 @@ async def _handle(self, ws, request_context): message = await connection_context.receive() except ConnectionClosedException: break - finally: - if pending: - (_, pending) = await wait(pending, timeout=0, loop=self.loop) - - task = ensure_future( - self.on_message(connection_context, message), loop=self.loop) - pending.add(task) - self.on_close(connection_context) - for task in pending: - task.cancel() + connection_context.remember_task( + ensure_future( + self.on_message(connection_context, message), loop=self.loop + ) + ) + await self.on_close(connection_context) async def handle(self, ws, request_context=None): await shield(self._handle(ws, request_context), loop=self.loop) - - async def on_open(self, connection_context): - pass - - def on_close(self, connection_context): - remove_operations = list(connection_context.operations.keys()) - for op_id in remove_operations: - self.unsubscribe(connection_context, op_id) - - async def on_connect(self, connection_context, payload): - pass - - async def on_connection_init(self, connection_context, op_id, payload): - try: - await self.on_connect(connection_context, payload) - await self.send_message( - connection_context, op_type=GQL_CONNECTION_ACK) - except Exception as e: - await self.send_error( - connection_context, op_id, e, GQL_CONNECTION_ERROR) - await connection_context.close(1011) - - async def on_start(self, connection_context, op_id, params): - execution_result = self.execute( - connection_context.request_context, params) - - if isawaitable(execution_result): - execution_result = await execution_result - - if not hasattr(execution_result, '__aiter__'): - await self.send_execution_result( - connection_context, op_id, execution_result) - else: - iterator = await execution_result.__aiter__() - connection_context.register_operation(op_id, iterator) - async for single_result in iterator: - if not connection_context.has_operation(op_id): - break - await self.send_execution_result( - connection_context, op_id, single_result) - await self.send_message(connection_context, op_id, GQL_COMPLETE) - - async def on_stop(self, connection_context, op_id): - self.unsubscribe(connection_context, op_id) From b6951ed404aa7bf4f8c79e0c4a1c20a3aca435f9 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Mon, 18 May 2020 02:54:24 +1200 Subject: [PATCH 28/72] Fix tests to match deduplication changes --- tests/test_aiohttp.py | 2 +- tests/test_django_channels.py | 2 +- tests/test_gevent.py | 4 ++-- tests/test_graphql_ws.py | 35 +++++++---------------------------- 4 files changed, 11 insertions(+), 32 deletions(-) diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index f20ca15..88a48d1 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -55,7 +55,7 @@ async def test_receive_closed(self, mock_ws): async def test_send(self, mock_ws): connection_context = AiohttpConnectionContext(ws=mock_ws) await connection_context.send("test") - mock_ws.send_str.assert_called_with("test") + mock_ws.send_str.assert_called_with('"test"') async def test_send_closed(self, mock_ws): mock_ws.closed = True diff --git a/tests/test_django_channels.py b/tests/test_django_channels.py index e7b054c..51ef6ae 100644 --- a/tests/test_django_channels.py +++ b/tests/test_django_channels.py @@ -14,7 +14,7 @@ def test_send(self): msg = mock.Mock() connection_context = DjangoChannelConnectionContext(message=msg) connection_context.send("test") - msg.reply_channel.send.assert_called_with("test") + msg.reply_channel.send.assert_called_with({'text': '"test"'}) def test_close(self): msg = mock.Mock() diff --git a/tests/test_gevent.py b/tests/test_gevent.py index f766c5a..a734970 100644 --- a/tests/test_gevent.py +++ b/tests/test_gevent.py @@ -17,8 +17,8 @@ def test_send(self): ws = mock.Mock() ws.closed = False connection_context = GeventConnectionContext(ws=ws) - connection_context.send("test") - ws.send.assert_called_with("test") + connection_context.send({"text": "test"}) + ws.send.assert_called_with('{"text": "test"}') def test_send_closed(self): ws = mock.Mock() diff --git a/tests/test_graphql_ws.py b/tests/test_graphql_ws.py index 3ba1120..65cbf91 100644 --- a/tests/test_graphql_ws.py +++ b/tests/test_graphql_ws.py @@ -5,8 +5,9 @@ import mock import pytest +from graphql.execution.executors.sync import SyncExecutor -from graphql_ws import base, constants +from graphql_ws import base, base_sync, constants @pytest.fixture @@ -18,7 +19,7 @@ def cc(): @pytest.fixture def ss(): - return base.BaseSubscriptionServer(schema=None) + return base_sync.BaseSyncSubscriptionServer(schema=None) class TestConnectionContextOperation: @@ -137,7 +138,9 @@ def test_get_graphql_params(ss, cc): "operationName": "query", "context": "ctx", } - assert ss.get_graphql_params(cc, payload) == { + params = ss.get_graphql_params(cc, payload) + assert isinstance(params.pop("executor"), SyncExecutor) + assert params == { "request_string": "req", "variable_values": "vars", "operation_name": "query", @@ -189,34 +192,10 @@ def test_send_message(ss, cc): cc.send = mock.Mock() cc.send.return_value = "returned" assert "returned" == ss.send_message(cc) - cc.send.assert_called_with('{"mess": "age"}') + cc.send.assert_called_with({"mess": "age"}) class TestSSNotImplemented: def test_handle(self, ss): with pytest.raises(NotImplementedError): ss.handle(ws=None, request_context=None) - - def test_on_open(self, ss): - with pytest.raises(NotImplementedError): - ss.on_open(connection_context=None) - - def test_on_connect(self, ss): - with pytest.raises(NotImplementedError): - ss.on_connect(connection_context=None, payload=None) - - def test_on_close(self, ss): - with pytest.raises(NotImplementedError): - ss.on_close(connection_context=None) - - def test_on_connection_init(self, ss): - with pytest.raises(NotImplementedError): - ss.on_connection_init(connection_context=None, op_id=None, payload=None) - - def test_on_stop(self, ss): - with pytest.raises(NotImplementedError): - ss.on_stop(connection_context=None, op_id=None) - - def test_on_start(self, ss): - with pytest.raises(NotImplementedError): - ss.on_start(connection_context=None, op_id=None, params=None) From 4fe4736896b88e7ee3652941a32d3becbeed455c Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Mon, 1 Jun 2020 23:48:10 +1200 Subject: [PATCH 29/72] Add some base tests --- setup.cfg | 5 ++++ tests/test_base.py | 62 ++++++++++++++++++++++++++++++++++++++++++++++ tox.ini | 1 + 3 files changed, 68 insertions(+) create mode 100644 tests/test_base.py diff --git a/setup.cfg b/setup.cfg index df50b23..b921bca 100644 --- a/setup.cfg +++ b/setup.cfg @@ -90,3 +90,8 @@ ignore = W503 [coverage:run] omit = .tox/* + +[coverage:report] +exclude_lines = + pragma: no cover + @abstract \ No newline at end of file diff --git a/tests/test_base.py b/tests/test_base.py new file mode 100644 index 0000000..2e78459 --- /dev/null +++ b/tests/test_base.py @@ -0,0 +1,62 @@ +try: + from unittest import mock +except ImportError: + import mock + +import json + +import pytest + +from graphql_ws import base + + +def test_not_implemented(): + server = base.BaseSubscriptionServer(schema=None) + with pytest.raises(NotImplementedError): + server.on_connection_init(connection_context=None, op_id=1, payload={}) + with pytest.raises(NotImplementedError): + server.on_open(connection_context=None) + with pytest.raises(NotImplementedError): + server.on_stop(connection_context=None, op_id=1) + + +def test_terminate(): + server = base.BaseSubscriptionServer(schema=None) + + context = mock.Mock() + server.on_connection_terminate(connection_context=context, op_id=1) + context.close.assert_called_with(1011) + + +def test_send_error(): + server = base.BaseSubscriptionServer(schema=None) + context = mock.Mock() + server.send_error(connection_context=context, op_id=1, error="test error") + context.send.assert_called_with( + {"id": 1, "type": "error", "payload": {"message": "test error"}} + ) + + +def test_message(): + server = base.BaseSubscriptionServer(schema=None) + server.process_message = mock.Mock() + context = mock.Mock() + msg = {"id": 1, "type": base.GQL_CONNECTION_INIT, "payload": ""} + server.on_message(context, msg) + server.process_message.assert_called_with(context, msg) + + +def test_message_str(): + server = base.BaseSubscriptionServer(schema=None) + server.process_message = mock.Mock() + context = mock.Mock() + msg = {"id": 1, "type": base.GQL_CONNECTION_INIT, "payload": ""} + server.on_message(context, json.dumps(msg)) + server.process_message.assert_called_with(context, msg) + + +def test_message_invalid(): + server = base.BaseSubscriptionServer(schema=None) + server.send_error = mock.Mock() + server.on_message(connection_context=None, message="'not-json") + assert server.send_error.called diff --git a/tox.ini b/tox.ini index 6de6deb..42d13b4 100644 --- a/tox.ini +++ b/tox.ini @@ -31,5 +31,6 @@ skip_install = true deps = coverage commands = coverage html + coverage xml coverage report --include="tests/*" --fail-under=100 -m coverage report --omit="tests/*" # --fail-under=90 -m \ No newline at end of file From f7ef09fef63b9a5408496988fe5292a78d6a5475 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Sat, 6 Jun 2020 17:08:24 +1200 Subject: [PATCH 30/72] Add base async tests --- tests/conftest.py | 4 +-- tests/test_base_async.py | 59 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 2 deletions(-) create mode 100644 tests/test_base_async.py diff --git a/tests/conftest.py b/tests/conftest.py index e551557..fa905b4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,6 @@ if sys.version_info > (3,): collect_ignore = ["test_django_channels.py"] if sys.version_info < (3, 6): - collect_ignore.append('test_gevent.py') + collect_ignore.append("test_gevent.py") else: - collect_ignore = ["test_aiohttp.py"] + collect_ignore = ["test_aiohttp.py", "test_base_async.py"] diff --git a/tests/test_base_async.py b/tests/test_base_async.py new file mode 100644 index 0000000..902acc7 --- /dev/null +++ b/tests/test_base_async.py @@ -0,0 +1,59 @@ +from unittest import mock + +import json + +import pytest + +from graphql_ws import base + + +def test_not_implemented(): + server = base.BaseSubscriptionServer(schema=None) + with pytest.raises(NotImplementedError): + server.on_connection_init(connection_context=None, op_id=1, payload={}) + with pytest.raises(NotImplementedError): + server.on_open(connection_context=None) + with pytest.raises(NotImplementedError): + server.on_stop(connection_context=None, op_id=1) + + +def test_terminate(): + server = base.BaseSubscriptionServer(schema=None) + + context = mock.Mock() + server.on_connection_terminate(connection_context=context, op_id=1) + context.close.assert_called_with(1011) + + +def test_send_error(): + server = base.BaseSubscriptionServer(schema=None) + context = mock.Mock() + server.send_error(connection_context=context, op_id=1, error="test error") + context.send.assert_called_with( + {"id": 1, "type": "error", "payload": {"message": "test error"}} + ) + + +def test_message(): + server = base.BaseSubscriptionServer(schema=None) + server.process_message = mock.Mock() + context = mock.Mock() + msg = {"id": 1, "type": base.GQL_CONNECTION_INIT, "payload": ""} + server.on_message(context, msg) + server.process_message.assert_called_with(context, msg) + + +def test_message_str(): + server = base.BaseSubscriptionServer(schema=None) + server.process_message = mock.Mock() + context = mock.Mock() + msg = {"id": 1, "type": base.GQL_CONNECTION_INIT, "payload": ""} + server.on_message(context, json.dumps(msg)) + server.process_message.assert_called_with(context, msg) + + +def test_message_invalid(): + server = base.BaseSubscriptionServer(schema=None) + server.send_error = mock.Mock() + server.on_message(connection_context=None, message="'not-json") + assert server.send_error.called From 944d94982abbe6588273b3cca9c186a701149f33 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Sat, 6 Jun 2020 17:08:34 +1200 Subject: [PATCH 31/72] Add django_channels tests --- tests/django_routing.py | 6 +++++ tests/test_django_channels.py | 46 +++++++++++++++++++++++++++++++++-- 2 files changed, 50 insertions(+), 2 deletions(-) create mode 100644 tests/django_routing.py diff --git a/tests/django_routing.py b/tests/django_routing.py new file mode 100644 index 0000000..9d01766 --- /dev/null +++ b/tests/django_routing.py @@ -0,0 +1,6 @@ +from channels.routing import route +from graphql_ws.django_channels import GraphQLSubscriptionConsumer + +channel_routing = [ + route("websocket.receive", GraphQLSubscriptionConsumer), +] diff --git a/tests/test_django_channels.py b/tests/test_django_channels.py index 51ef6ae..137d541 100644 --- a/tests/test_django_channels.py +++ b/tests/test_django_channels.py @@ -1,11 +1,35 @@ +from __future__ import unicode_literals + +import json + +import django import mock +from channels import Channel +from channels.test import ChannelTestCase, Client from django.conf import settings +from django.core.management import call_command -settings.configure() # noqa +settings.configure( + CHANNEL_LAYERS={ + "default": { + "BACKEND": "asgiref.inmemory.ChannelLayer", + "ROUTING": "tests.django_routing.channel_routing", + }, + }, + INSTALLED_APPS=[ + "django.contrib.sessions", + "django.contrib.contenttypes", + "django.contrib.auth", + ], + DATABASES={"default": {"ENGINE": "django.db.backends.sqlite3", "NAME": ":memory:"}}, +) +django.setup() +from graphql_ws.constants import GQL_CONNECTION_ACK, GQL_CONNECTION_INIT from graphql_ws.django_channels import ( DjangoChannelConnectionContext, DjangoChannelSubscriptionServer, + GraphQLSubscriptionConsumer, ) @@ -14,7 +38,7 @@ def test_send(self): msg = mock.Mock() connection_context = DjangoChannelConnectionContext(message=msg) connection_context.send("test") - msg.reply_channel.send.assert_called_with({'text': '"test"'}) + msg.reply_channel.send.assert_called_with({"text": '"test"'}) def test_close(self): msg = mock.Mock() @@ -25,3 +49,21 @@ def test_close(self): def test_subscription_server_smoke(): DjangoChannelSubscriptionServer(schema=None) + + +class TestConsumer(ChannelTestCase): + def test_connect(self): + call_command("migrate") + Channel("websocket.receive").send( + { + "path": "/graphql", + "order": 0, + "reply_channel": "websocket.receive", + "text": json.dumps({"type": GQL_CONNECTION_INIT, "id": 1}), + } + ) + message = self.get_next_message("websocket.receive", require=True) + GraphQLSubscriptionConsumer(message) + result = self.get_next_message("websocket.receive", require=True) + result_content = json.loads(result.content["text"]) + assert result_content == {"type": GQL_CONNECTION_ACK} From e4b3d9f9b4c5c0d94a8717ac4694d454bf31aeb3 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Tue, 19 May 2020 15:39:11 +1200 Subject: [PATCH 32/72] Remove a redundant check for an internal detail It'll still cause an exception on .execute() if somehow a third party subscription server did the wrong thing anyway --- graphql_ws/base.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/graphql_ws/base.py b/graphql_ws/base.py index d146419..ee82dec 100644 --- a/graphql_ws/base.py +++ b/graphql_ws/base.py @@ -71,14 +71,6 @@ def process_message(self, connection_context, parsed_message): elif op_type == GQL_START: assert isinstance(payload, dict), "The payload must be a dict" - params = self.get_graphql_params(connection_context, payload) - if not isinstance(params, dict): - error = Exception( - "Invalid params returned from get_graphql_params!" - " Return values must be a dict." - ) - return self.send_error(connection_context, op_id, error) - # If we already have a subscription with this id, unsubscribe from # it first if connection_context.has_operation(op_id): From 7b21f0fe5235c548dca27ce7a718dbab8f85a1ab Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Tue, 19 May 2020 15:52:49 +1200 Subject: [PATCH 33/72] Move execute back to base --- graphql_ws/base.py | 9 +++++++-- graphql_ws/base_async.py | 2 +- graphql_ws/base_sync.py | 6 +----- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/graphql_ws/base.py b/graphql_ws/base.py index ee82dec..30bd766 100644 --- a/graphql_ws/base.py +++ b/graphql_ws/base.py @@ -1,7 +1,7 @@ import json from collections import OrderedDict -from graphql import format_error +from graphql import format_error, graphql from .constants import ( GQL_CONNECTION_ERROR, @@ -57,6 +57,9 @@ def __init__(self, schema, keep_alive=True): self.schema = schema self.keep_alive = keep_alive + def execute(self, params): + return graphql(self.schema, **dict(params, allow_subscriptions=True)) + def process_message(self, connection_context, parsed_message): op_id = parsed_message.get("id") op_type = parsed_message.get("type") @@ -96,11 +99,13 @@ def on_connection_terminate(self, connection_context, op_id): return connection_context.close(1011) def get_graphql_params(self, connection_context, payload): + context = payload.get('context') or {} + context.setdefault('request_context', connection_context.request_context) return { "request_string": payload.get("query"), "variable_values": payload.get("variables"), "operation_name": payload.get("operationName"), - "context_value": payload.get("context"), + "context_value": context, "executor": self.graphql_executor(), } diff --git a/graphql_ws/base_async.py b/graphql_ws/base_async.py index d067a8d..3252196 100644 --- a/graphql_ws/base_async.py +++ b/graphql_ws/base_async.py @@ -79,7 +79,7 @@ async def on_connection_init(self, connection_context, op_id, payload): await connection_context.close(1011) async def on_start(self, connection_context, op_id, params): - execution_result = self.execute(connection_context.request_context, params) + execution_result = self.execute(params) if isawaitable(execution_result): execution_result = await execution_result diff --git a/graphql_ws/base_sync.py b/graphql_ws/base_sync.py index b7cb412..70bdbfc 100644 --- a/graphql_ws/base_sync.py +++ b/graphql_ws/base_sync.py @@ -1,4 +1,3 @@ -from graphql import graphql from graphql.execution.executors.sync import SyncExecutor from rx import Observable, Observer @@ -20,9 +19,6 @@ def unsubscribe(self, connection_context, op_id): def on_operation_complete(self, connection_context, op_id): pass - def execute(self, request_context, params): - return graphql(self.schema, **dict(params, allow_subscriptions=True)) - def handle(self, ws, request_context=None): raise NotImplementedError("handle method not implemented") @@ -51,7 +47,7 @@ def on_stop(self, connection_context, op_id): def on_start(self, connection_context, op_id, params): try: - execution_result = self.execute(connection_context.request_context, params) + execution_result = self.execute(params) assert isinstance( execution_result, Observable ), "A subscription must return an observable" From 75cad357b9b29ac247efda975bfd2f8a167e7c31 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Fri, 26 Jun 2020 09:46:47 +1200 Subject: [PATCH 34/72] Move operation unsubscription to BaseSubscriptionServer --- graphql_ws/base.py | 8 ++++++++ graphql_ws/base_async.py | 3 --- graphql_ws/base_sync.py | 12 +++--------- 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/graphql_ws/base.py b/graphql_ws/base.py index 30bd766..9fb931d 100644 --- a/graphql_ws/base.py +++ b/graphql_ws/base.py @@ -168,3 +168,11 @@ def on_message(self, connection_context, message): return self.send_error(connection_context, None, e) return self.process_message(connection_context, parsed_message) + + def unsubscribe(self, connection_context, op_id): + if connection_context.has_operation(op_id): + # Close async iterator + connection_context.get_operation(op_id).dispose() + # Close operation + connection_context.remove_operation(op_id) + self.on_operation_complete(connection_context, op_id) diff --git a/graphql_ws/base_async.py b/graphql_ws/base_async.py index 3252196..95d2f2b 100644 --- a/graphql_ws/base_async.py +++ b/graphql_ws/base_async.py @@ -111,8 +111,5 @@ async def on_close(self, connection_context): async def on_stop(self, connection_context, op_id): await self.unsubscribe(connection_context, op_id) - async def unsubscribe(self, connection_context, op_id): - await super().unsubscribe(connection_context, op_id) - async def on_operation_complete(self, connection_context, op_id): pass diff --git a/graphql_ws/base_sync.py b/graphql_ws/base_sync.py index 70bdbfc..56b4d42 100644 --- a/graphql_ws/base_sync.py +++ b/graphql_ws/base_sync.py @@ -8,14 +8,6 @@ class BaseSyncSubscriptionServer(BaseSubscriptionServer): graphql_executor = SyncExecutor - def unsubscribe(self, connection_context, op_id): - if connection_context.has_operation(op_id): - # Close async iterator - connection_context.get_operation(op_id).dispose() - # Close operation - connection_context.remove_operation(op_id) - self.on_operation_complete(connection_context, op_id) - def on_operation_complete(self, connection_context, op_id): pass @@ -51,7 +43,7 @@ def on_start(self, connection_context, op_id, params): assert isinstance( execution_result, Observable ), "A subscription must return an observable" - execution_result.subscribe( + disposable = execution_result.subscribe( SubscriptionObserver( connection_context, op_id, @@ -60,6 +52,8 @@ def on_start(self, connection_context, op_id, params): self.on_close, ) ) + connection_context.register_operation(op_id, disposable) + except Exception as e: self.send_error(connection_context, op_id, str(e)) From e904b1513b8ca47b91c7d03ea92fc5fdf85c99a7 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Fri, 26 Jun 2020 09:59:07 +1200 Subject: [PATCH 35/72] Black the example code --- examples/aiohttp/app.py | 21 +++-- examples/aiohttp/schema.py | 4 +- examples/aiohttp/template.py | 15 +-- .../django_subscriptions/asgi.py | 2 +- .../django_subscriptions/schema.py | 15 +-- .../django_subscriptions/settings.py | 91 +++++++++---------- .../django_subscriptions/template.py | 15 +-- .../django_subscriptions/urls.py | 15 +-- examples/flask_gevent/app.py | 14 +-- examples/flask_gevent/schema.py | 12 ++- examples/flask_gevent/template.py | 15 +-- examples/websockets_lib/app.py | 14 +-- examples/websockets_lib/schema.py | 4 +- examples/websockets_lib/template.py | 15 +-- 14 files changed, 128 insertions(+), 124 deletions(-) diff --git a/examples/aiohttp/app.py b/examples/aiohttp/app.py index 56dcaff..336a0c6 100644 --- a/examples/aiohttp/app.py +++ b/examples/aiohttp/app.py @@ -10,24 +10,25 @@ async def graphql_view(request): payload = await request.json() - response = await schema.execute(payload.get('query', ''), return_promise=True) + response = await schema.execute(payload.get("query", ""), return_promise=True) data = {} if response.errors: - data['errors'] = [format_error(e) for e in response.errors] + data["errors"] = [format_error(e) for e in response.errors] if response.data: - data['data'] = response.data + data["data"] = response.data jsondata = json.dumps(data,) - return web.Response(text=jsondata, headers={'Content-Type': 'application/json'}) + return web.Response(text=jsondata, headers={"Content-Type": "application/json"}) async def graphiql_view(request): - return web.Response(text=render_graphiql(), headers={'Content-Type': 'text/html'}) + return web.Response(text=render_graphiql(), headers={"Content-Type": "text/html"}) + subscription_server = AiohttpSubscriptionServer(schema) async def subscriptions(request): - ws = web.WebSocketResponse(protocols=('graphql-ws',)) + ws = web.WebSocketResponse(protocols=("graphql-ws",)) await ws.prepare(request) await subscription_server.handle(ws) @@ -35,9 +36,9 @@ async def subscriptions(request): app = web.Application() -app.router.add_get('/subscriptions', subscriptions) -app.router.add_get('/graphiql', graphiql_view) -app.router.add_get('/graphql', graphql_view) -app.router.add_post('/graphql', graphql_view) +app.router.add_get("/subscriptions", subscriptions) +app.router.add_get("/graphiql", graphiql_view) +app.router.add_get("/graphql", graphql_view) +app.router.add_post("/graphql", graphql_view) web.run_app(app, port=8000) diff --git a/examples/aiohttp/schema.py b/examples/aiohttp/schema.py index 3c23d00..ae107c7 100644 --- a/examples/aiohttp/schema.py +++ b/examples/aiohttp/schema.py @@ -20,14 +20,14 @@ async def resolve_count_seconds(root, info, up_to=5): for i in range(up_to): print("YIELD SECOND", i) yield i - await asyncio.sleep(1.) + await asyncio.sleep(1.0) yield up_to async def resolve_random_int(root, info): i = 0 while True: yield RandomType(seconds=i, random_int=random.randint(0, 500)) - await asyncio.sleep(1.) + await asyncio.sleep(1.0) i += 1 diff --git a/examples/aiohttp/template.py b/examples/aiohttp/template.py index 0b74e96..709f7cf 100644 --- a/examples/aiohttp/template.py +++ b/examples/aiohttp/template.py @@ -1,9 +1,9 @@ - from string import Template def render_graphiql(): - return Template(''' + return Template( + """ @@ -116,10 +116,11 @@ def render_graphiql(): ); -''').substitute( - GRAPHIQL_VERSION='0.10.2', - SUBSCRIPTIONS_TRANSPORT_VERSION='0.7.0', - subscriptionsEndpoint='ws://localhost:8000/subscriptions', +""" + ).substitute( + GRAPHIQL_VERSION="0.10.2", + SUBSCRIPTIONS_TRANSPORT_VERSION="0.7.0", + subscriptionsEndpoint="ws://localhost:8000/subscriptions", # subscriptionsEndpoint='ws://localhost:5000/', - endpointURL='/graphql', + endpointURL="/graphql", ) diff --git a/examples/django_subscriptions/django_subscriptions/asgi.py b/examples/django_subscriptions/django_subscriptions/asgi.py index e6edd7d..35b4d4d 100644 --- a/examples/django_subscriptions/django_subscriptions/asgi.py +++ b/examples/django_subscriptions/django_subscriptions/asgi.py @@ -3,4 +3,4 @@ os.environ.setdefault("DJANGO_SETTINGS_MODULE", "django_subscriptions.settings") -channel_layer = get_channel_layer() \ No newline at end of file +channel_layer = get_channel_layer() diff --git a/examples/django_subscriptions/django_subscriptions/schema.py b/examples/django_subscriptions/django_subscriptions/schema.py index b55d76e..db6893c 100644 --- a/examples/django_subscriptions/django_subscriptions/schema.py +++ b/examples/django_subscriptions/django_subscriptions/schema.py @@ -6,18 +6,19 @@ class Query(graphene.ObjectType): hello = graphene.String() def resolve_hello(self, info, **kwargs): - return 'world' + return "world" + class Subscription(graphene.ObjectType): count_seconds = graphene.Int(up_to=graphene.Int()) - def resolve_count_seconds(root, info, up_to=5): - return Observable.interval(1000)\ - .map(lambda i: "{0}".format(i))\ - .take_while(lambda i: int(i) <= up_to) - + return ( + Observable.interval(1000) + .map(lambda i: "{0}".format(i)) + .take_while(lambda i: int(i) <= up_to) + ) -schema = graphene.Schema(query=Query, subscription=Subscription) \ No newline at end of file +schema = graphene.Schema(query=Query, subscription=Subscription) diff --git a/examples/django_subscriptions/django_subscriptions/settings.py b/examples/django_subscriptions/django_subscriptions/settings.py index 45d0471..62cac69 100644 --- a/examples/django_subscriptions/django_subscriptions/settings.py +++ b/examples/django_subscriptions/django_subscriptions/settings.py @@ -20,7 +20,7 @@ # See https://docs.djangoproject.com/en/1.11/howto/deployment/checklist/ # SECURITY WARNING: keep the secret key used in production secret! -SECRET_KEY = 'fa#kz8m$l6)4(np9+-j_-z!voa090mah!s9^4jp=kj!^nwdq^c' +SECRET_KEY = "fa#kz8m$l6)4(np9+-j_-z!voa090mah!s9^4jp=kj!^nwdq^c" # SECURITY WARNING: don't run with debug turned on in production! DEBUG = True @@ -31,53 +31,53 @@ # Application definition INSTALLED_APPS = [ - 'django.contrib.admin', - 'django.contrib.auth', - 'django.contrib.contenttypes', - 'django.contrib.sessions', - 'django.contrib.messages', - 'django.contrib.staticfiles', - 'channels', + "django.contrib.admin", + "django.contrib.auth", + "django.contrib.contenttypes", + "django.contrib.sessions", + "django.contrib.messages", + "django.contrib.staticfiles", + "channels", ] MIDDLEWARE = [ - 'django.middleware.security.SecurityMiddleware', - 'django.contrib.sessions.middleware.SessionMiddleware', - 'django.middleware.common.CommonMiddleware', - 'django.middleware.csrf.CsrfViewMiddleware', - 'django.contrib.auth.middleware.AuthenticationMiddleware', - 'django.contrib.messages.middleware.MessageMiddleware', - 'django.middleware.clickjacking.XFrameOptionsMiddleware', + "django.middleware.security.SecurityMiddleware", + "django.contrib.sessions.middleware.SessionMiddleware", + "django.middleware.common.CommonMiddleware", + "django.middleware.csrf.CsrfViewMiddleware", + "django.contrib.auth.middleware.AuthenticationMiddleware", + "django.contrib.messages.middleware.MessageMiddleware", + "django.middleware.clickjacking.XFrameOptionsMiddleware", ] -ROOT_URLCONF = 'django_subscriptions.urls' +ROOT_URLCONF = "django_subscriptions.urls" TEMPLATES = [ { - 'BACKEND': 'django.template.backends.django.DjangoTemplates', - 'DIRS': [], - 'APP_DIRS': True, - 'OPTIONS': { - 'context_processors': [ - 'django.template.context_processors.debug', - 'django.template.context_processors.request', - 'django.contrib.auth.context_processors.auth', - 'django.contrib.messages.context_processors.messages', + "BACKEND": "django.template.backends.django.DjangoTemplates", + "DIRS": [], + "APP_DIRS": True, + "OPTIONS": { + "context_processors": [ + "django.template.context_processors.debug", + "django.template.context_processors.request", + "django.contrib.auth.context_processors.auth", + "django.contrib.messages.context_processors.messages", ], }, }, ] -WSGI_APPLICATION = 'django_subscriptions.wsgi.application' +WSGI_APPLICATION = "django_subscriptions.wsgi.application" # Database # https://docs.djangoproject.com/en/1.11/ref/settings/#databases DATABASES = { - 'default': { - 'ENGINE': 'django.db.backends.sqlite3', - 'NAME': os.path.join(BASE_DIR, 'db.sqlite3'), + "default": { + "ENGINE": "django.db.backends.sqlite3", + "NAME": os.path.join(BASE_DIR, "db.sqlite3"), } } @@ -87,26 +87,20 @@ AUTH_PASSWORD_VALIDATORS = [ { - 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator', - }, - { - 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator', - }, - { - 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator', - }, - { - 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator', + "NAME": "django.contrib.auth.password_validation.UserAttributeSimilarityValidator", }, + {"NAME": "django.contrib.auth.password_validation.MinimumLengthValidator"}, + {"NAME": "django.contrib.auth.password_validation.CommonPasswordValidator"}, + {"NAME": "django.contrib.auth.password_validation.NumericPasswordValidator"}, ] # Internationalization # https://docs.djangoproject.com/en/1.11/topics/i18n/ -LANGUAGE_CODE = 'en-us' +LANGUAGE_CODE = "en-us" -TIME_ZONE = 'UTC' +TIME_ZONE = "UTC" USE_I18N = True @@ -118,20 +112,17 @@ # Static files (CSS, JavaScript, Images) # https://docs.djangoproject.com/en/1.11/howto/static-files/ -STATIC_URL = '/static/' -CHANNELS_WS_PROTOCOLS = ["graphql-ws", ] +STATIC_URL = "/static/" +CHANNELS_WS_PROTOCOLS = [ + "graphql-ws", +] CHANNEL_LAYERS = { "default": { "BACKEND": "asgi_redis.RedisChannelLayer", - "CONFIG": { - "hosts": [("localhost", 6379)], - }, + "CONFIG": {"hosts": [("localhost", 6379)]}, "ROUTING": "django_subscriptions.urls.channel_routing", }, - } -GRAPHENE = { - 'SCHEMA': 'django_subscriptions.schema.schema' -} \ No newline at end of file +GRAPHENE = {"SCHEMA": "django_subscriptions.schema.schema"} diff --git a/examples/django_subscriptions/django_subscriptions/template.py b/examples/django_subscriptions/django_subscriptions/template.py index b067ae5..738d9e7 100644 --- a/examples/django_subscriptions/django_subscriptions/template.py +++ b/examples/django_subscriptions/django_subscriptions/template.py @@ -1,9 +1,9 @@ - from string import Template def render_graphiql(): - return Template(''' + return Template( + """ @@ -116,10 +116,11 @@ def render_graphiql(): ); -''').substitute( - GRAPHIQL_VERSION='0.11.10', - SUBSCRIPTIONS_TRANSPORT_VERSION='0.7.0', - subscriptionsEndpoint='ws://localhost:8000/subscriptions', +""" + ).substitute( + GRAPHIQL_VERSION="0.11.10", + SUBSCRIPTIONS_TRANSPORT_VERSION="0.7.0", + subscriptionsEndpoint="ws://localhost:8000/subscriptions", # subscriptionsEndpoint='ws://localhost:5000/', - endpointURL='/graphql', + endpointURL="/graphql", ) diff --git a/examples/django_subscriptions/django_subscriptions/urls.py b/examples/django_subscriptions/django_subscriptions/urls.py index 3848d22..caf790d 100644 --- a/examples/django_subscriptions/django_subscriptions/urls.py +++ b/examples/django_subscriptions/django_subscriptions/urls.py @@ -21,20 +21,21 @@ from graphene_django.views import GraphQLView from django.views.decorators.csrf import csrf_exempt +from channels.routing import route_class +from graphql_ws.django_channels import GraphQLSubscriptionConsumer + def graphiql(request): response = HttpResponse(content=render_graphiql()) return response + urlpatterns = [ - url(r'^admin/', admin.site.urls), - url(r'^graphiql/', graphiql), - url(r'^graphql', csrf_exempt(GraphQLView.as_view(graphiql=True))) + url(r"^admin/", admin.site.urls), + url(r"^graphiql/", graphiql), + url(r"^graphql", csrf_exempt(GraphQLView.as_view(graphiql=True))), ] -from channels.routing import route_class -from graphql_ws.django_channels import GraphQLSubscriptionConsumer - channel_routing = [ route_class(GraphQLSubscriptionConsumer, path=r"^/subscriptions"), -] \ No newline at end of file +] diff --git a/examples/flask_gevent/app.py b/examples/flask_gevent/app.py index dbb0cca..efd145b 100644 --- a/examples/flask_gevent/app.py +++ b/examples/flask_gevent/app.py @@ -1,5 +1,3 @@ -import json - from flask import Flask, make_response from flask_graphql import GraphQLView from flask_sockets import Sockets @@ -14,19 +12,20 @@ sockets = Sockets(app) -@app.route('/graphiql') +@app.route("/graphiql") def graphql_view(): return make_response(render_graphiql()) app.add_url_rule( - '/graphql', view_func=GraphQLView.as_view('graphql', schema=schema, graphiql=False)) + "/graphql", view_func=GraphQLView.as_view("graphql", schema=schema, graphiql=False) +) subscription_server = GeventSubscriptionServer(schema) -app.app_protocol = lambda environ_path_info: 'graphql-ws' +app.app_protocol = lambda environ_path_info: "graphql-ws" -@sockets.route('/subscriptions') +@sockets.route("/subscriptions") def echo_socket(ws): subscription_server.handle(ws) return [] @@ -35,5 +34,6 @@ def echo_socket(ws): if __name__ == "__main__": from gevent import pywsgi from geventwebsocket.handler import WebSocketHandler - server = pywsgi.WSGIServer(('', 5000), app, handler_class=WebSocketHandler) + + server = pywsgi.WSGIServer(("", 5000), app, handler_class=WebSocketHandler) server.serve_forever() diff --git a/examples/flask_gevent/schema.py b/examples/flask_gevent/schema.py index 6e6298c..eb48050 100644 --- a/examples/flask_gevent/schema.py +++ b/examples/flask_gevent/schema.py @@ -19,12 +19,16 @@ class Subscription(graphene.ObjectType): random_int = graphene.Field(RandomType) def resolve_count_seconds(root, info, up_to=5): - return Observable.interval(1000)\ - .map(lambda i: "{0}".format(i))\ - .take_while(lambda i: int(i) <= up_to) + return ( + Observable.interval(1000) + .map(lambda i: "{0}".format(i)) + .take_while(lambda i: int(i) <= up_to) + ) def resolve_random_int(root, info): - return Observable.interval(1000).map(lambda i: RandomType(seconds=i, random_int=random.randint(0, 500))) + return Observable.interval(1000).map( + lambda i: RandomType(seconds=i, random_int=random.randint(0, 500)) + ) schema = graphene.Schema(query=Query, subscription=Subscription) diff --git a/examples/flask_gevent/template.py b/examples/flask_gevent/template.py index 41f52e1..ea0438c 100644 --- a/examples/flask_gevent/template.py +++ b/examples/flask_gevent/template.py @@ -1,9 +1,9 @@ - from string import Template def render_graphiql(): - return Template(''' + return Template( + """ @@ -116,10 +116,11 @@ def render_graphiql(): ); -''').substitute( - GRAPHIQL_VERSION='0.12.0', - SUBSCRIPTIONS_TRANSPORT_VERSION='0.7.0', - subscriptionsEndpoint='ws://localhost:5000/subscriptions', +""" + ).substitute( + GRAPHIQL_VERSION="0.12.0", + SUBSCRIPTIONS_TRANSPORT_VERSION="0.7.0", + subscriptionsEndpoint="ws://localhost:5000/subscriptions", # subscriptionsEndpoint='ws://localhost:5000/', - endpointURL='/graphql', + endpointURL="/graphql", ) diff --git a/examples/websockets_lib/app.py b/examples/websockets_lib/app.py index 0de6988..7638f3d 100644 --- a/examples/websockets_lib/app.py +++ b/examples/websockets_lib/app.py @@ -8,21 +8,23 @@ app = Sanic(__name__) -@app.listener('before_server_start') +@app.listener("before_server_start") def init_graphql(app, loop): - app.add_route(GraphQLView.as_view(schema=schema, - executor=AsyncioExecutor(loop=loop)), - '/graphql') + app.add_route( + GraphQLView.as_view(schema=schema, executor=AsyncioExecutor(loop=loop)), + "/graphql", + ) -@app.route('/graphiql') +@app.route("/graphiql") async def graphiql_view(request): return response.html(render_graphiql()) + subscription_server = WsLibSubscriptionServer(schema) -@app.websocket('/subscriptions', subprotocols=['graphql-ws']) +@app.websocket("/subscriptions", subprotocols=["graphql-ws"]) async def subscriptions(request, ws): await subscription_server.handle(ws) return ws diff --git a/examples/websockets_lib/schema.py b/examples/websockets_lib/schema.py index 3c23d00..ae107c7 100644 --- a/examples/websockets_lib/schema.py +++ b/examples/websockets_lib/schema.py @@ -20,14 +20,14 @@ async def resolve_count_seconds(root, info, up_to=5): for i in range(up_to): print("YIELD SECOND", i) yield i - await asyncio.sleep(1.) + await asyncio.sleep(1.0) yield up_to async def resolve_random_int(root, info): i = 0 while True: yield RandomType(seconds=i, random_int=random.randint(0, 500)) - await asyncio.sleep(1.) + await asyncio.sleep(1.0) i += 1 diff --git a/examples/websockets_lib/template.py b/examples/websockets_lib/template.py index 03587bb..8f007b9 100644 --- a/examples/websockets_lib/template.py +++ b/examples/websockets_lib/template.py @@ -1,9 +1,9 @@ - from string import Template def render_graphiql(): - return Template(''' + return Template( + """ @@ -116,9 +116,10 @@ def render_graphiql(): ); -''').substitute( - GRAPHIQL_VERSION='0.10.2', - SUBSCRIPTIONS_TRANSPORT_VERSION='0.7.0', - subscriptionsEndpoint='ws://localhost:8000/subscriptions', - endpointURL='/graphql', +""" + ).substitute( + GRAPHIQL_VERSION="0.10.2", + SUBSCRIPTIONS_TRANSPORT_VERSION="0.7.0", + subscriptionsEndpoint="ws://localhost:8000/subscriptions", + endpointURL="/graphql", ) From 9499ae9154f682d055cfebe9c7f7cbc9e4359e3e Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Fri, 26 Jun 2020 09:59:21 +1200 Subject: [PATCH 36/72] Black the modules in graphql_ws root --- graphql_ws/__init__.py | 4 ++-- graphql_ws/base.py | 4 ++-- graphql_ws/constants.py | 22 +++++++++++----------- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/graphql_ws/__init__.py b/graphql_ws/__init__.py index 793831a..0ffa258 100644 --- a/graphql_ws/__init__.py +++ b/graphql_ws/__init__.py @@ -3,5 +3,5 @@ """Top-level package for GraphQL WS.""" __author__ = """Syrus Akbary""" -__email__ = 'me@syrusakbary.com' -__version__ = '0.3.1' +__email__ = "me@syrusakbary.com" +__version__ = "0.3.1" diff --git a/graphql_ws/base.py b/graphql_ws/base.py index 9fb931d..0a2577e 100644 --- a/graphql_ws/base.py +++ b/graphql_ws/base.py @@ -99,8 +99,8 @@ def on_connection_terminate(self, connection_context, op_id): return connection_context.close(1011) def get_graphql_params(self, connection_context, payload): - context = payload.get('context') or {} - context.setdefault('request_context', connection_context.request_context) + context = payload.get("context") or {} + context.setdefault("request_context", connection_context.request_context) return { "request_string": payload.get("query"), "variable_values": payload.get("variables"), diff --git a/graphql_ws/constants.py b/graphql_ws/constants.py index 4f9d2f1..8b57a60 100644 --- a/graphql_ws/constants.py +++ b/graphql_ws/constants.py @@ -1,15 +1,15 @@ -GRAPHQL_WS = 'graphql-ws' +GRAPHQL_WS = "graphql-ws" WS_PROTOCOL = GRAPHQL_WS -GQL_CONNECTION_INIT = 'connection_init' # Client -> Server -GQL_CONNECTION_ACK = 'connection_ack' # Server -> Client -GQL_CONNECTION_ERROR = 'connection_error' # Server -> Client +GQL_CONNECTION_INIT = "connection_init" # Client -> Server +GQL_CONNECTION_ACK = "connection_ack" # Server -> Client +GQL_CONNECTION_ERROR = "connection_error" # Server -> Client # NOTE: This one here don't follow the standard due to connection optimization -GQL_CONNECTION_TERMINATE = 'connection_terminate' # Client -> Server -GQL_CONNECTION_KEEP_ALIVE = 'ka' # Server -> Client -GQL_START = 'start' # Client -> Server -GQL_DATA = 'data' # Server -> Client -GQL_ERROR = 'error' # Server -> Client -GQL_COMPLETE = 'complete' # Server -> Client -GQL_STOP = 'stop' # Client -> Server +GQL_CONNECTION_TERMINATE = "connection_terminate" # Client -> Server +GQL_CONNECTION_KEEP_ALIVE = "ka" # Server -> Client +GQL_START = "start" # Client -> Server +GQL_DATA = "data" # Server -> Client +GQL_ERROR = "error" # Server -> Client +GQL_COMPLETE = "complete" # Server -> Client +GQL_STOP = "stop" # Client -> Server From 738f447cf1d71191c8cd8a6f2865a54dd05cbe72 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Mon, 29 Jun 2020 09:34:11 +1200 Subject: [PATCH 37/72] Skip flake8 false positives and remove unneeded import --- tests/test_django_channels.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_django_channels.py b/tests/test_django_channels.py index 137d541..0552c7b 100644 --- a/tests/test_django_channels.py +++ b/tests/test_django_channels.py @@ -5,7 +5,7 @@ import django import mock from channels import Channel -from channels.test import ChannelTestCase, Client +from channels.test import ChannelTestCase from django.conf import settings from django.core.management import call_command @@ -25,8 +25,8 @@ ) django.setup() -from graphql_ws.constants import GQL_CONNECTION_ACK, GQL_CONNECTION_INIT -from graphql_ws.django_channels import ( +from graphql_ws.constants import GQL_CONNECTION_ACK, GQL_CONNECTION_INIT # noqa: E402 +from graphql_ws.django_channels import ( # noqa: E402 DjangoChannelConnectionContext, DjangoChannelSubscriptionServer, GraphQLSubscriptionConsumer, From f641e584f0f32c2733ec986c033634b974c85f78 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Mon, 29 Jun 2020 10:03:30 +1200 Subject: [PATCH 38/72] Update contributing doc --- CONTRIBUTING.rst | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 01d606e..a2315ad 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -68,7 +68,7 @@ Ready to contribute? Here's how to set up `graphql_ws` for local development. $ mkvirtualenv graphql_ws $ cd graphql_ws/ - $ python setup.py develop + $ pip install -e .[dev] 4. Create a branch for local development:: @@ -79,11 +79,8 @@ Ready to contribute? Here's how to set up `graphql_ws` for local development. 5. When you're done making changes, check that your changes pass flake8 and the tests, including testing other Python versions with tox:: $ flake8 graphql_ws tests - $ python setup.py test or py.test $ tox - To get flake8 and tox, just pip install them into your virtualenv. - 6. Commit your changes and push your branch to GitHub:: $ git add . @@ -101,14 +98,6 @@ Before you submit a pull request, check that it meets these guidelines: 2. If the pull request adds functionality, the docs should be updated. Put your new functionality into a function with a docstring, and add the feature to the list in README.rst. -3. The pull request should work for Python 2.6, 2.7, 3.3, 3.4 and 3.5, and for PyPy. Check +3. The pull request should work for Python 2.7, 3.5, 3.6, 3.7 and 3.8. Check https://travis-ci.org/graphql-python/graphql_ws/pull_requests and make sure that the tests pass for all supported Python versions. - -Tips ----- - -To run a subset of tests:: - -$ py.test tests.test_graphql_ws - From c007f58e2e3d5082aedf6a9ec1fac612dc9edf2d Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Mon, 29 Jun 2020 10:03:42 +1200 Subject: [PATCH 39/72] Correctly test a bad graphql parameter --- tests/test_graphql_ws.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_graphql_ws.py b/tests/test_graphql_ws.py index 65cbf91..4a7b845 100644 --- a/tests/test_graphql_ws.py +++ b/tests/test_graphql_ws.py @@ -111,7 +111,7 @@ def test_start_bad_graphql_params(self, ss, cc): ss.unsubscribe = mock.Mock() ss.on_start = mock.Mock() ss.process_message( - cc, {"id": "1", "type": constants.GQL_START, "payload": {"a": "b"}} + cc, {"id": "1", "type": None, "payload": {"a": "b"}} ) assert ss.send_error.called assert ss.send_error.call_args[0][:2] == (cc, "1") @@ -136,7 +136,7 @@ def test_get_graphql_params(ss, cc): "query": "req", "variables": "vars", "operationName": "query", - "context": "ctx", + "context": {}, } params = ss.get_graphql_params(cc, payload) assert isinstance(params.pop("executor"), SyncExecutor) @@ -144,7 +144,7 @@ def test_get_graphql_params(ss, cc): "request_string": "req", "variable_values": "vars", "operation_name": "query", - "context_value": "ctx", + "context_value": {'request_context': None}, } From bb4f1be10587b08aa85630e72018fd350410e12f Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Mon, 29 Jun 2020 11:24:38 +1200 Subject: [PATCH 40/72] Make removing an operation from context fail silently --- graphql_ws/base.py | 5 ++++- tests/test_base.py | 12 ++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/graphql_ws/base.py b/graphql_ws/base.py index 0a2577e..db4f675 100644 --- a/graphql_ws/base.py +++ b/graphql_ws/base.py @@ -34,7 +34,10 @@ def get_operation(self, op_id): return self.operations[op_id] def remove_operation(self, op_id): - del self.operations[op_id] + try: + del self.operations[op_id] + except KeyError: + pass def receive(self): raise NotImplementedError("receive method not implemented") diff --git a/tests/test_base.py b/tests/test_base.py index 2e78459..80de021 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -60,3 +60,15 @@ def test_message_invalid(): server.send_error = mock.Mock() server.on_message(connection_context=None, message="'not-json") assert server.send_error.called + + +def test_context_operations(): + ws = mock.Mock() + context = base.BaseConnectionContext(ws) + assert not context.has_operation(1) + context.register_operation(1, None) + assert context.has_operation(1) + context.remove_operation(1) + assert not context.has_operation(1) + # Removing a non-existant operation fails silently. + context.remove_operation(999) From a4eef790606d7607724ae156fbdd1f7126e17126 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Mon, 29 Jun 2020 11:30:42 +1200 Subject: [PATCH 41/72] Make async methods send an error if an operation raises an exception Also remove iteratable operations from the context when they complete --- graphql_ws/base_async.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/graphql_ws/base_async.py b/graphql_ws/base_async.py index 95d2f2b..29dfb08 100644 --- a/graphql_ws/base_async.py +++ b/graphql_ws/base_async.py @@ -87,16 +87,23 @@ async def on_start(self, connection_context, op_id, params): if hasattr(execution_result, "__aiter__"): iterator = await execution_result.__aiter__() connection_context.register_operation(op_id, iterator) - async for single_result in iterator: - if not connection_context.has_operation(op_id): - break + try: + async for single_result in iterator: + if not connection_context.has_operation(op_id): + break + await self.send_execution_result( + connection_context, op_id, single_result + ) + except Exception as e: + await self.send_error(connection_context, op_id, e) + connection_context.remove_operation(op_id) + else: + try: await self.send_execution_result( - connection_context, op_id, single_result + connection_context, op_id, execution_result ) - else: - await self.send_execution_result( - connection_context, op_id, execution_result - ) + except Exception as e: + await self.send_error(connection_context, op_id, e) await self.send_message(connection_context, op_id, GQL_COMPLETE) await self.on_operation_complete(connection_context, op_id) From 3e670e60cd7f0c8e55248dc7cbca09e12ed3be84 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Mon, 29 Jun 2020 11:31:13 +1200 Subject: [PATCH 42/72] Send completion messages when the sync observer completes / errors out. --- graphql_ws/base_sync.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/graphql_ws/base_sync.py b/graphql_ws/base_sync.py index 56b4d42..0f15c01 100644 --- a/graphql_ws/base_sync.py +++ b/graphql_ws/base_sync.py @@ -2,7 +2,7 @@ from rx import Observable, Observer from .base import BaseSubscriptionServer -from .constants import GQL_CONNECTION_ACK, GQL_CONNECTION_ERROR +from .constants import GQL_COMPLETE, GQL_CONNECTION_ACK, GQL_CONNECTION_ERROR class BaseSyncSubscriptionServer(BaseSubscriptionServer): @@ -49,30 +49,33 @@ def on_start(self, connection_context, op_id, params): op_id, self.send_execution_result, self.send_error, - self.on_close, + self.send_message, ) ) connection_context.register_operation(op_id, disposable) except Exception as e: - self.send_error(connection_context, op_id, str(e)) + self.send_error(connection_context, op_id, e) + self.send_message(connection_context, op_id, GQL_COMPLETE) class SubscriptionObserver(Observer): def __init__( - self, connection_context, op_id, send_execution_result, send_error, on_close + self, connection_context, op_id, send_execution_result, send_error, send_message ): self.connection_context = connection_context self.op_id = op_id self.send_execution_result = send_execution_result self.send_error = send_error - self.on_close = on_close + self.send_message = send_message def on_next(self, value): self.send_execution_result(self.connection_context, self.op_id, value) def on_completed(self): - self.on_close(self.connection_context) + self.send_message(self.connection_context, self.op_id, GQL_COMPLETE) + self.connection_context.remove_operation(self.op_id) def on_error(self, error): self.send_error(self.connection_context, self.op_id, error) + self.on_completed() From 650db340831f35b21d5c4402edf0533391656e01 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Mon, 29 Jun 2020 11:46:35 +1200 Subject: [PATCH 43/72] Cody tidy --- graphql_ws/base_sync.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphql_ws/base_sync.py b/graphql_ws/base_sync.py index 0f15c01..a6d2efb 100644 --- a/graphql_ws/base_sync.py +++ b/graphql_ws/base_sync.py @@ -21,7 +21,7 @@ def on_connect(self, connection_context, payload): pass def on_close(self, connection_context): - remove_operations = list(connection_context.operations.keys()) + remove_operations = list(connection_context.operations) for op_id in remove_operations: self.unsubscribe(connection_context, op_id) From a8c2f33bea6fdc7ca134abfa0a1a34ac2fe94319 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Mon, 29 Jun 2020 12:29:33 +1200 Subject: [PATCH 44/72] Abstract ensuring async task is a future --- graphql_ws/aiohttp.py | 6 ++---- graphql_ws/base_async.py | 4 ++-- graphql_ws/websockets_lib.py | 6 ++---- 3 files changed, 6 insertions(+), 10 deletions(-) diff --git a/graphql_ws/aiohttp.py b/graphql_ws/aiohttp.py index 49e0a5e..d2162f2 100644 --- a/graphql_ws/aiohttp.py +++ b/graphql_ws/aiohttp.py @@ -1,5 +1,5 @@ import json -from asyncio import ensure_future, shield +from asyncio import shield from aiohttp import WSMsgType @@ -45,9 +45,7 @@ async def _handle(self, ws, request_context=None): break connection_context.remember_task( - ensure_future( - self.on_message(connection_context, message), loop=self.loop - ) + self.on_message(connection_context, message), loop=self.loop ) await self.on_close(connection_context) diff --git a/graphql_ws/base_async.py b/graphql_ws/base_async.py index 29dfb08..8cdf31d 100644 --- a/graphql_ws/base_async.py +++ b/graphql_ws/base_async.py @@ -35,8 +35,8 @@ def closed(self): async def close(self, code): ... - def remember_task(self, task): - self.pending_tasks.add(asyncio.ensure_future(task)) + def remember_task(self, task, loop=None): + self.pending_tasks.add(asyncio.ensure_future(task, loop=loop)) # Clear completed tasks self.pending_tasks -= WeakSet( task for task in self.pending_tasks if task.done() diff --git a/graphql_ws/websockets_lib.py b/graphql_ws/websockets_lib.py index 93ad76f..4d753a5 100644 --- a/graphql_ws/websockets_lib.py +++ b/graphql_ws/websockets_lib.py @@ -1,5 +1,5 @@ import json -from asyncio import ensure_future, shield +from asyncio import shield from websockets import ConnectionClosed @@ -41,9 +41,7 @@ async def _handle(self, ws, request_context): break connection_context.remember_task( - ensure_future( - self.on_message(connection_context, message), loop=self.loop - ) + self.on_message(connection_context, message), loop=self.loop ) await self.on_close(connection_context) From 8d32f4b67fde158c6f6f2284dbe9854a777b470e Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Mon, 29 Jun 2020 14:46:03 +1200 Subject: [PATCH 45/72] Tidy up django_channels (1) backend and example --- .../django_subscriptions/settings.py | 3 +- .../django_subscriptions/requirements.txt | 4 +++ graphql_ws/django_channels.py | 29 ++++++++++--------- 3 files changed, 20 insertions(+), 16 deletions(-) create mode 100644 examples/django_subscriptions/requirements.txt diff --git a/examples/django_subscriptions/django_subscriptions/settings.py b/examples/django_subscriptions/django_subscriptions/settings.py index 62cac69..7bb3f24 100644 --- a/examples/django_subscriptions/django_subscriptions/settings.py +++ b/examples/django_subscriptions/django_subscriptions/settings.py @@ -118,8 +118,7 @@ ] CHANNEL_LAYERS = { "default": { - "BACKEND": "asgi_redis.RedisChannelLayer", - "CONFIG": {"hosts": [("localhost", 6379)]}, + "BACKEND": "asgiref.inmemory.ChannelLayer", "ROUTING": "django_subscriptions.urls.channel_routing", }, } diff --git a/examples/django_subscriptions/requirements.txt b/examples/django_subscriptions/requirements.txt new file mode 100644 index 0000000..557e99f --- /dev/null +++ b/examples/django_subscriptions/requirements.txt @@ -0,0 +1,4 @@ +-e ../.. +django<2 +channels<2 +graphene_django<3 \ No newline at end of file diff --git a/graphql_ws/django_channels.py b/graphql_ws/django_channels.py index fbee47b..ddba58d 100644 --- a/graphql_ws/django_channels.py +++ b/graphql_ws/django_channels.py @@ -8,17 +8,18 @@ class DjangoChannelConnectionContext(BaseConnectionContext): - def __init__(self, message, request_context=None): - self.message = message - self.operations = {} - self.request_context = request_context + def __init__(self, message): + super(DjangoChannelConnectionContext, self).__init__( + message.reply_channel, + request_context={"user": message.user, "session": message.http_session}, + ) def send(self, data): - self.message.reply_channel.send({"text": json.dumps(data)}) + self.ws.send({"text": json.dumps(data)}) def close(self, reason): data = {"close": True, "text": reason} - self.message.reply_channel.send(data) + self.ws.send(data) class DjangoChannelSubscriptionServer(BaseSyncSubscriptionServer): @@ -26,21 +27,21 @@ def handle(self, message, connection_context): self.on_message(connection_context, message) +subscription_server = DjangoChannelSubscriptionServer(graphene_settings.SCHEMA) + + class GraphQLSubscriptionConsumer(JsonWebsocketConsumer): http_user_and_session = True strict_ordering = True - def connect(self, message, **_kwargs): + def connect(self, message, **kwargs): message.reply_channel.send({"accept": True}) - def receive(self, content, **_kwargs): + def receive(self, content, **kwargs): """ Called when a message is received with either text or bytes filled out. """ - self.connection_context = DjangoChannelConnectionContext(self.message) - self.subscription_server = DjangoChannelSubscriptionServer( - graphene_settings.SCHEMA - ) - self.subscription_server.on_open(self.connection_context) - self.subscription_server.handle(content, self.connection_context) + context = DjangoChannelConnectionContext(self.message) + subscription_server.on_open(context) + subscription_server.handle(content, context) From 7797c29d2ce155988c10a6c61f0b1ce30c866f31 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Mon, 29 Jun 2020 15:26:45 +1200 Subject: [PATCH 46/72] Update readme --- README.rst | 199 ++++++++++++++++++++++++++--------------------------- setup.cfg | 2 +- 2 files changed, 100 insertions(+), 101 deletions(-) diff --git a/README.rst b/README.rst index 90ee500..0a871f0 100644 --- a/README.rst +++ b/README.rst @@ -1,14 +1,23 @@ +========== GraphQL WS ========== -Websocket server for GraphQL subscriptions. +Websocket backend for GraphQL subscriptions. + +Supports the following application servers: + +Python 3 application servers, using asyncio: + + * `aiohttp`_ + * `websockets compatible servers`_ such as Sanic + (via `websockets `__ library) -Currently supports: +Python 2 application servers: + + * `Gevent compatible servers`_ such as Flask + * `Django v1.x`_ + (via `channels v1.x `__) -* `aiohttp `__ -* `Gevent `__ -* Sanic (uses `websockets `__ - library) Installation instructions ========================= @@ -19,21 +28,54 @@ For instaling graphql-ws, just run this command in your shell pip install graphql-ws + Examples --------- +======== + +Python 3 servers +---------------- + +Create a subscribable schema like this: + +.. code:: python + + import asyncio + import graphene + + + class Query(graphene.ObjectType): + hello = graphene.String() + + @static_method + def resolve_hello(obj, info, **kwargs): + return "world" + + + class Subscription(graphene.ObjectType): + count_seconds = graphene.Float(up_to=graphene.Int()) + + async def resolve_count_seconds(root, info, up_to): + for i in range(up_to): + yield i + await asyncio.sleep(1.) + yield up_to + + + schema = graphene.Schema(query=Query, subscription=Subscription) aiohttp ~~~~~~~ -For setting up, just plug into your aiohttp server. +Then just plug into your aiohttp server. .. code:: python from graphql_ws.aiohttp import AiohttpSubscriptionServer - + from .schema import schema subscription_server = AiohttpSubscriptionServer(schema) + async def subscriptions(request): ws = web.WebSocketResponse(protocols=('graphql-ws',)) await ws.prepare(request) @@ -47,21 +89,26 @@ For setting up, just plug into your aiohttp server. web.run_app(app, port=8000) -Sanic -~~~~~ +You can see a full example here: +https://github.com/graphql-python/graphql-ws/tree/master/examples/aiohttp + -Works with any framework that uses the websockets library for it’s -websocket implementation. For this example, plug in your Sanic server. +websockets compatible servers +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Works with any framework that uses the websockets library for its websocket +implementation. For this example, plug in your Sanic server. .. code:: python from graphql_ws.websockets_lib import WsLibSubscriptionServer - + from . import schema app = Sanic(__name__) subscription_server = WsLibSubscriptionServer(schema) + @app.websocket('/subscriptions', subprotocols=['graphql-ws']) async def subscriptions(request, ws): await subscription_server.handle(ws) @@ -70,80 +117,73 @@ websocket implementation. For this example, plug in your Sanic server. app.run(host="0.0.0.0", port=8000) -And then, plug into a subscribable schema: + +Python 2 servers +----------------- + +Create a subscribable schema like this: .. code:: python - import asyncio import graphene + from rx import Observable class Query(graphene.ObjectType): - base = graphene.String() + hello = graphene.String() + + @static_method + def resolve_hello(obj, info, **kwargs): + return "world" class Subscription(graphene.ObjectType): count_seconds = graphene.Float(up_to=graphene.Int()) - async def resolve_count_seconds(root, info, up_to): - for i in range(up_to): - yield i - await asyncio.sleep(1.) - yield up_to + async def resolve_count_seconds(root, info, up_to=5): + return Observable.interval(1000)\ + .map(lambda i: "{0}".format(i))\ + .take_while(lambda i: int(i) <= up_to) schema = graphene.Schema(query=Query, subscription=Subscription) -You can see a full example here: -https://github.com/graphql-python/graphql-ws/tree/master/examples/aiohttp - -Gevent -~~~~~~ +Gevent compatible servers +~~~~~~~~~~~~~~~~~~~~~~~~~ -For setting up, just plug into your Gevent server. +Then just plug into your Gevent server, for example, Flask: .. code:: python + from flask_sockets import Sockets + from graphql_ws.gevent import GeventSubscriptionServer + from schema import schema + subscription_server = GeventSubscriptionServer(schema) app.app_protocol = lambda environ_path_info: 'graphql-ws' + @sockets.route('/subscriptions') def echo_socket(ws): subscription_server.handle(ws) return [] -And then, plug into a subscribable schema: - -.. code:: python - - import graphene - from rx import Observable - - - class Query(graphene.ObjectType): - base = graphene.String() - - - class Subscription(graphene.ObjectType): - count_seconds = graphene.Float(up_to=graphene.Int()) - - async def resolve_count_seconds(root, info, up_to=5): - return Observable.interval(1000)\ - .map(lambda i: "{0}".format(i))\ - .take_while(lambda i: int(i) <= up_to) - - - schema = graphene.Schema(query=Query, subscription=Subscription) - You can see a full example here: https://github.com/graphql-python/graphql-ws/tree/master/examples/flask_gevent -Django Channels -~~~~~~~~~~~~~~~ +Django v1.x +~~~~~~~~~~~ -First ``pip install channels`` and it to your django apps +For Django v1.x and Django Channels v1.x, setup your schema in ``settings.py`` -Then add the following to your settings.py +.. code:: python + + GRAPHENE = { + 'SCHEMA': 'yourproject.schema.schema' + } + +Then ``pip install "channels<1"`` and it to your django apps, adding the +following to your ``settings.py`` .. code:: python @@ -153,53 +193,9 @@ Then add the following to your settings.py "BACKEND": "asgiref.inmemory.ChannelLayer", "ROUTING": "django_subscriptions.urls.channel_routing", }, - } -Setup your graphql schema - -.. code:: python - - import graphene - from rx import Observable - - - class Query(graphene.ObjectType): - hello = graphene.String() - - def resolve_hello(self, info, **kwargs): - return 'world' - - class Subscription(graphene.ObjectType): - - count_seconds = graphene.Int(up_to=graphene.Int()) - - - def resolve_count_seconds( - root, - info, - up_to=5 - ): - return Observable.interval(1000)\ - .map(lambda i: "{0}".format(i))\ - .take_while(lambda i: int(i) <= up_to) - - - - schema = graphene.Schema( - query=Query, - subscription=Subscription - ) - -Setup your schema in settings.py - -.. code:: python - - GRAPHENE = { - 'SCHEMA': 'path.to.schema' - } - -and finally add the channel routes +And finally add the channel routes .. code:: python @@ -209,3 +205,6 @@ and finally add the channel routes channel_routing = [ route_class(GraphQLSubscriptionConsumer, path=r"^/subscriptions"), ] + +You can see a full example here: +https://github.com/graphql-python/graphql-ws/tree/master/examples/django_subscriptions diff --git a/setup.cfg b/setup.cfg index b921bca..1e7ea2a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,7 +2,7 @@ [metadata] name = graphql-ws version = 0.3.1 -description = Websocket server for GraphQL subscriptions +description = Websocket backend for GraphQL subscriptions long_description = file: README.rst, CHANGES.rst author = Syrus Akbary author_email = me@syrusakbary.com From d85df833273e5521561045b1b8bb64d584e175f2 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Mon, 29 Jun 2020 15:59:30 +1200 Subject: [PATCH 47/72] Use new abstracted base code for django channels 2 --- graphql_ws/django/consumers.py | 17 ++---- graphql_ws/django/subscriptions.py | 87 ++---------------------------- 2 files changed, 6 insertions(+), 98 deletions(-) diff --git a/graphql_ws/django/consumers.py b/graphql_ws/django/consumers.py index 2a449fd..3373576 100644 --- a/graphql_ws/django/consumers.py +++ b/graphql_ws/django/consumers.py @@ -22,9 +22,6 @@ def default(self, o): class GraphQLSubscriptionConsumer(AsyncJsonWebsocketConsumer): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.futures = [] async def connect(self): self.connection_context = None @@ -37,22 +34,14 @@ async def connect(self): await self.close() async def disconnect(self, code): - for future in self.futures: - # Ensure any running message tasks are cancelled. - future.cancel() if self.connection_context: self.connection_context.socket_closed = True - close_future = subscription_server.on_close(self.connection_context) - await asyncio.gather(close_future, *self.futures) + await subscription_server.on_close(self.connection_context) async def receive_json(self, content): - self.futures.append( - asyncio.ensure_future( - subscription_server.on_message(self.connection_context, content) - ) + self.connection_context.remember_task( + subscription_server.on_message(self.connection_context, content) ) - # Clean up any completed futures. - self.futures = [future for future in self.futures if not future.done()] @classmethod async def encode_json(cls, content): diff --git a/graphql_ws/django/subscriptions.py b/graphql_ws/django/subscriptions.py index 7a4f4dc..0cca653 100644 --- a/graphql_ws/django/subscriptions.py +++ b/graphql_ws/django/subscriptions.py @@ -1,15 +1,11 @@ -import asyncio -from inspect import isawaitable from graphene_django.settings import graphene_settings -from graphql.execution.executors.asyncio import AsyncioExecutor -from ..base import BaseConnectionContext, BaseSubscriptionServer -from ..constants import GQL_CONNECTION_ACK, GQL_CONNECTION_ERROR, GQL_COMPLETE +from ..base_async import BaseAsyncConnectionContext, BaseAsyncSubscriptionServer from ..observable_aiter import setup_observable_extension setup_observable_extension() -class ChannelsConnectionContext(BaseConnectionContext): +class ChannelsConnectionContext(BaseAsyncConnectionContext): def __init__(self, *args, **kwargs): super(ChannelsConnectionContext, self).__init__(*args, **kwargs) self.socket_closed = False @@ -27,88 +23,11 @@ async def close(self, code): await self.ws.close(code=code) -class ChannelsSubscriptionServer(BaseSubscriptionServer): - def get_graphql_params(self, connection_context, payload): - payload["context"] = connection_context.request_context - params = super(ChannelsSubscriptionServer, self).get_graphql_params( - connection_context, payload - ) - return dict(params, return_promise=True, executor=AsyncioExecutor()) - +class ChannelsSubscriptionServer(BaseAsyncSubscriptionServer): async def handle(self, ws, request_context=None): connection_context = ChannelsConnectionContext(ws, request_context) await self.on_open(connection_context) return connection_context - async def send_message( - self, connection_context, op_id=None, op_type=None, payload=None - ): - message = {} - if op_id is not None: - message["id"] = op_id - if op_type is not None: - message["type"] = op_type - if payload is not None: - message["payload"] = payload - - assert message, "You need to send at least one thing" - return await connection_context.send(message) - - async def on_open(self, connection_context): - pass - - async def on_connect(self, connection_context, payload): - pass - - async def on_connection_init(self, connection_context, op_id, payload): - try: - await self.on_connect(connection_context, payload) - await self.send_message(connection_context, op_type=GQL_CONNECTION_ACK) - except Exception as e: - await self.send_error(connection_context, op_id, e, GQL_CONNECTION_ERROR) - await connection_context.close(1011) - - async def on_start(self, connection_context, op_id, params): - execution_result = self.execute(connection_context.request_context, params) - - if isawaitable(execution_result): - execution_result = await execution_result - - if hasattr(execution_result, "__aiter__"): - iterator = await execution_result.__aiter__() - connection_context.register_operation(op_id, iterator) - async for single_result in iterator: - if not connection_context.has_operation(op_id): - break - await self.send_execution_result( - connection_context, op_id, single_result - ) - else: - await self.send_execution_result( - connection_context, op_id, execution_result - ) - await self.on_operation_complete(connection_context, op_id) - - async def on_close(self, connection_context): - unsubscribes = [ - self.unsubscribe(connection_context, op_id) - for op_id in connection_context.operations - ] - if unsubscribes: - await asyncio.wait(unsubscribes) - - async def on_stop(self, connection_context, op_id): - await self.unsubscribe(connection_context, op_id) - - async def unsubscribe(self, connection_context, op_id): - if connection_context.has_operation(op_id): - op = connection_context.get_operation(op_id) - op.dispose() - connection_context.remove_operation(op_id) - await self.on_operation_complete(connection_context, op_id) - - async def on_operation_complete(self, connection_context, op_id): - await self.send_message(connection_context, op_id, GQL_COMPLETE) - subscription_server = ChannelsSubscriptionServer(schema=graphene_settings.SCHEMA) From 5ed4f1d5f3a947ea5f707a0635677a6334865446 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Wed, 1 Jul 2020 09:28:16 +1200 Subject: [PATCH 48/72] Fix a readme typo --- README.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.rst b/README.rst index 0a871f0..fb968b6 100644 --- a/README.rst +++ b/README.rst @@ -46,7 +46,7 @@ Create a subscribable schema like this: class Query(graphene.ObjectType): hello = graphene.String() - @static_method + @staticmethod def resolve_hello(obj, info, **kwargs): return "world" @@ -132,7 +132,7 @@ Create a subscribable schema like this: class Query(graphene.ObjectType): hello = graphene.String() - @static_method + @staticmethod def resolve_hello(obj, info, **kwargs): return "world" From a3197d0b2bc0140ebd4c17469aaa730a72e6872b Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Wed, 1 Jul 2020 16:34:33 +1200 Subject: [PATCH 49/72] Recursively resolve Promises, fix async tests --- graphql_ws/base_async.py | 55 ++++++++++++++++++++- tests/test_base_async.py | 102 +++++++++++++++++++++++++++------------ 2 files changed, 124 insertions(+), 33 deletions(-) diff --git a/graphql_ws/base_async.py b/graphql_ws/base_async.py index 8cdf31d..af9e4e4 100644 --- a/graphql_ws/base_async.py +++ b/graphql_ws/base_async.py @@ -1,9 +1,12 @@ import asyncio +import inspect from abc import ABC, abstractmethod -from inspect import isawaitable +from types import CoroutineType, GeneratorType +from typing import Any, Union, List, Dict from weakref import WeakSet from graphql.execution.executors.asyncio import AsyncioExecutor +from promise import Promise from graphql_ws import base @@ -11,6 +14,49 @@ from .observable_aiter import setup_observable_extension setup_observable_extension() +CO_ITERABLE_COROUTINE = inspect.CO_ITERABLE_COROUTINE + + +# Copied from graphql-core v3.1.0 (graphql/pyutils/is_awaitable.py) +def is_awaitable(value: Any) -> bool: + """Return true if object can be passed to an ``await`` expression. + Instead of testing if the object is an instance of abc.Awaitable, it checks + the existence of an `__await__` attribute. This is much faster. + """ + return ( + # check for coroutine objects + isinstance(value, CoroutineType) + # check for old-style generator based coroutine objects + or isinstance(value, GeneratorType) + and bool(value.gi_code.co_flags & CO_ITERABLE_COROUTINE) + # check for other awaitables (e.g. futures) + or hasattr(value, "__await__") + ) + + +async def resolve( + data: Any, _container: Union[List, Dict] = None, _key: Union[str, int] = None +) -> None: + """ + Recursively wait on any awaitable children of a data element and resolve any + Promises. + """ + if is_awaitable(data): + data = await data + if isinstance(data, Promise): + data = data.value # type: Any + if _container is not None: + _container[_key] = data + if isinstance(data, dict): + items = data.items() + elif isinstance(data, list): + items = enumerate(data) + else: + items = None + if items is not None: + children = [resolve(child, _container=data, _key=key) for key, child in items] + if children: + await asyncio.wait(children) class BaseAsyncConnectionContext(base.BaseConnectionContext, ABC): @@ -81,7 +127,7 @@ async def on_connection_init(self, connection_context, op_id, payload): async def on_start(self, connection_context, op_id, params): execution_result = self.execute(params) - if isawaitable(execution_result): + if is_awaitable(execution_result): execution_result = await execution_result if hasattr(execution_result, "__aiter__"): @@ -120,3 +166,8 @@ async def on_stop(self, connection_context, op_id): async def on_operation_complete(self, connection_context, op_id): pass + + async def send_execution_result(self, connection_context, op_id, execution_result): + # Resolve any pending promises + await resolve(execution_result.data) + await super().send_execution_result(connection_context, op_id, execution_result) diff --git a/tests/test_base_async.py b/tests/test_base_async.py index 902acc7..d341c18 100644 --- a/tests/test_base_async.py +++ b/tests/test_base_async.py @@ -1,59 +1,99 @@ from unittest import mock import json +import promise import pytest -from graphql_ws import base +from graphql_ws import base, base_async +pytestmark = pytest.mark.asyncio -def test_not_implemented(): - server = base.BaseSubscriptionServer(schema=None) - with pytest.raises(NotImplementedError): - server.on_connection_init(connection_context=None, op_id=1, payload={}) - with pytest.raises(NotImplementedError): - server.on_open(connection_context=None) - with pytest.raises(NotImplementedError): - server.on_stop(connection_context=None, op_id=1) +class AsyncMock(mock.MagicMock): + async def __call__(self, *args, **kwargs): + return super().__call__(*args, **kwargs) -def test_terminate(): - server = base.BaseSubscriptionServer(schema=None) - context = mock.Mock() - server.on_connection_terminate(connection_context=context, op_id=1) +class TestServer(base_async.BaseAsyncSubscriptionServer): + def handle(self, *args, **kwargs): + pass + + +@pytest.fixture +def server(): + return TestServer(schema=None) + + +async def test_terminate(server: TestServer): + context = AsyncMock() + await server.on_connection_terminate(connection_context=context, op_id=1) context.close.assert_called_with(1011) -def test_send_error(): - server = base.BaseSubscriptionServer(schema=None) - context = mock.Mock() - server.send_error(connection_context=context, op_id=1, error="test error") +async def test_send_error(server: TestServer): + context = AsyncMock() + await server.send_error(connection_context=context, op_id=1, error="test error") context.send.assert_called_with( {"id": 1, "type": "error", "payload": {"message": "test error"}} ) -def test_message(): - server = base.BaseSubscriptionServer(schema=None) - server.process_message = mock.Mock() - context = mock.Mock() +async def test_message(server): + server.process_message = AsyncMock() + context = AsyncMock() msg = {"id": 1, "type": base.GQL_CONNECTION_INIT, "payload": ""} - server.on_message(context, msg) + await server.on_message(context, msg) server.process_message.assert_called_with(context, msg) -def test_message_str(): - server = base.BaseSubscriptionServer(schema=None) - server.process_message = mock.Mock() - context = mock.Mock() +async def test_message_str(server): + server.process_message = AsyncMock() + context = AsyncMock() msg = {"id": 1, "type": base.GQL_CONNECTION_INIT, "payload": ""} - server.on_message(context, json.dumps(msg)) + await server.on_message(context, json.dumps(msg)) server.process_message.assert_called_with(context, msg) -def test_message_invalid(): - server = base.BaseSubscriptionServer(schema=None) - server.send_error = mock.Mock() - server.on_message(connection_context=None, message="'not-json") +async def test_message_invalid(server): + server.send_error = AsyncMock() + await server.on_message(connection_context=None, message="'not-json") assert server.send_error.called + + +async def test_resolver(server): + server.send_message = AsyncMock() + result = mock.Mock() + result.data = {"test": [1, 2]} + result.errors = None + await server.send_execution_result( + connection_context=None, op_id=1, execution_result=result + ) + assert server.send_message.called + + +@pytest.mark.asyncio +async def test_resolver_with_promise(server): + server.send_message = AsyncMock() + result = mock.Mock() + result.data = {"test": [1, promise.Promise(lambda resolve, reject: resolve(2))]} + result.errors = None + await server.send_execution_result( + connection_context=None, op_id=1, execution_result=result + ) + assert server.send_message.called + assert result.data == {'test': [1, 2]} + + +async def test_resolver_with_nested_promise(server): + server.send_message = AsyncMock() + result = mock.Mock() + inner = promise.Promise(lambda resolve, reject: resolve(2)) + outer = promise.Promise(lambda resolve, reject: resolve({'in': inner})) + result.data = {"test": [1, outer]} + result.errors = None + await server.send_execution_result( + connection_context=None, op_id=1, execution_result=result + ) + assert server.send_message.called + assert result.data == {'test': [1, {'in': 2}]} From b6da149ac3e4a346ccf52b66c80923583a65bac5 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Wed, 1 Jul 2020 16:35:41 +1200 Subject: [PATCH 50/72] Simplify django consumer now promises are resolved in the base --- graphql_ws/django/consumers.py | 23 +---------------------- 1 file changed, 1 insertion(+), 22 deletions(-) diff --git a/graphql_ws/django/consumers.py b/graphql_ws/django/consumers.py index 3373576..11c7d68 100644 --- a/graphql_ws/django/consumers.py +++ b/graphql_ws/django/consumers.py @@ -1,26 +1,11 @@ -import asyncio import json from channels.generic.websocket import AsyncJsonWebsocketConsumer -from promise import Promise from ..constants import WS_PROTOCOL from .subscriptions import subscription_server -class JSONPromiseEncoder(json.JSONEncoder): - def encode(self, *args, **kwargs): - self.pending_promises = [] - return super(JSONPromiseEncoder, self).encode(*args, **kwargs) - - def default(self, o): - if isinstance(o, Promise): - if o.is_pending: - self.pending_promises.append(o) - return o.value - return super(JSONPromiseEncoder, self).default(o) - - class GraphQLSubscriptionConsumer(AsyncJsonWebsocketConsumer): async def connect(self): @@ -45,10 +30,4 @@ async def receive_json(self, content): @classmethod async def encode_json(cls, content): - json_promise_encoder = JSONPromiseEncoder() - e = json_promise_encoder.encode(content) - while json_promise_encoder.pending_promises: - # Wait for pending promises to complete, then try encoding again. - await asyncio.wait(json_promise_encoder.pending_promises) - e = json_promise_encoder.encode(content) - return e + return json.dumps(content) From 84d5d1749ba69b9bc72a0d9100e697a5839e817c Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Wed, 29 Jul 2020 12:18:22 +1200 Subject: [PATCH 51/72] Ignore cancellederror when closing connections --- graphql_ws/base_async.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/graphql_ws/base_async.py b/graphql_ws/base_async.py index af9e4e4..c4353d7 100644 --- a/graphql_ws/base_async.py +++ b/graphql_ws/base_async.py @@ -159,7 +159,10 @@ async def on_close(self, connection_context): for op_id in connection_context.operations ) + tuple(task.cancel() for task in connection_context.pending_tasks) if awaitables: - await asyncio.gather(*awaitables, loop=self.loop) + try: + await asyncio.gather(*awaitables, loop=self.loop) + except asyncio.CancelledError: + pass async def on_stop(self, connection_context, op_id): await self.unsubscribe(connection_context, op_id) From a9e63beaae25aab9840c7d90c36669c0412c8d53 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Wed, 29 Jul 2020 15:04:55 +1200 Subject: [PATCH 52/72] Add the required receive method to ChannelsConnectionContext --- graphql_ws/django/subscriptions.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/graphql_ws/django/subscriptions.py b/graphql_ws/django/subscriptions.py index 0cca653..086445f 100644 --- a/graphql_ws/django/subscriptions.py +++ b/graphql_ws/django/subscriptions.py @@ -22,6 +22,12 @@ def closed(self): async def close(self, code): await self.ws.close(code=code) + async def receive(self, code): + """ + Unused, as the django consumer handles receiving messages and passes + them straight to ChannelsSubscriptionServer.on_message. + """ + class ChannelsSubscriptionServer(BaseAsyncSubscriptionServer): async def handle(self, ws, request_context=None): From 7bfc59094f94dc428752600402ef6d31aaa838bd Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Wed, 29 Jul 2020 15:41:46 +1200 Subject: [PATCH 53/72] Fix async processing messages --- graphql_ws/aiohttp.py | 4 +--- graphql_ws/base_async.py | 8 ++++---- graphql_ws/websockets_lib.py | 4 +--- 3 files changed, 6 insertions(+), 10 deletions(-) diff --git a/graphql_ws/aiohttp.py b/graphql_ws/aiohttp.py index d2162f2..baf8837 100644 --- a/graphql_ws/aiohttp.py +++ b/graphql_ws/aiohttp.py @@ -44,9 +44,7 @@ async def _handle(self, ws, request_context=None): except ConnectionClosedException: break - connection_context.remember_task( - self.on_message(connection_context, message), loop=self.loop - ) + self.on_message(connection_context, message) await self.on_close(connection_context) async def handle(self, ws, request_context=None): diff --git a/graphql_ws/base_async.py b/graphql_ws/base_async.py index c4353d7..d02cc29 100644 --- a/graphql_ws/base_async.py +++ b/graphql_ws/base_async.py @@ -81,8 +81,8 @@ def closed(self): async def close(self, code): ... - def remember_task(self, task, loop=None): - self.pending_tasks.add(asyncio.ensure_future(task, loop=loop)) + def remember_task(self, task): + self.pending_tasks.add(task) # Clear completed tasks self.pending_tasks -= WeakSet( task for task in self.pending_tasks if task.done() @@ -102,9 +102,9 @@ async def handle(self, ws, request_context=None): def process_message(self, connection_context, parsed_message): task = asyncio.ensure_future( - super().process_message(connection_context, parsed_message) + super().process_message(connection_context, parsed_message), loop=self.loop ) - connection_context.pending.add(task) + connection_context.remember_task(task) return task async def send_message(self, *args, **kwargs): diff --git a/graphql_ws/websockets_lib.py b/graphql_ws/websockets_lib.py index 4d753a5..c0adc67 100644 --- a/graphql_ws/websockets_lib.py +++ b/graphql_ws/websockets_lib.py @@ -40,9 +40,7 @@ async def _handle(self, ws, request_context): except ConnectionClosedException: break - connection_context.remember_task( - self.on_message(connection_context, message), loop=self.loop - ) + self.on_message(connection_context, message) await self.on_close(connection_context) async def handle(self, ws, request_context=None): From 0cbcdef0b022dc5094c3776155bfb3f5d1bde6a6 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Wed, 29 Jul 2020 15:38:22 +1200 Subject: [PATCH 54/72] Simpler receive_json --- graphql_ws/django/consumers.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/graphql_ws/django/consumers.py b/graphql_ws/django/consumers.py index 11c7d68..b1c64d1 100644 --- a/graphql_ws/django/consumers.py +++ b/graphql_ws/django/consumers.py @@ -24,9 +24,7 @@ async def disconnect(self, code): await subscription_server.on_close(self.connection_context) async def receive_json(self, content): - self.connection_context.remember_task( - subscription_server.on_message(self.connection_context, content) - ) + subscription_server.on_message(self.connection_context, content) @classmethod async def encode_json(cls, content): From 583f3f0bced9edc3257904fe66f8d2609171fa44 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Wed, 29 Jul 2020 16:54:07 +1200 Subject: [PATCH 55/72] Fix async unsubscribe --- graphql_ws/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphql_ws/base.py b/graphql_ws/base.py index db4f675..798d19d 100644 --- a/graphql_ws/base.py +++ b/graphql_ws/base.py @@ -178,4 +178,4 @@ def unsubscribe(self, connection_context, op_id): connection_context.get_operation(op_id).dispose() # Close operation connection_context.remove_operation(op_id) - self.on_operation_complete(connection_context, op_id) + return self.on_operation_complete(connection_context, op_id) From de8ced3ab190d89237e08402193ff3b9baee63e2 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Thu, 30 Jul 2020 10:30:56 +1200 Subject: [PATCH 56/72] Move unsubscribe logic to the connection context --- graphql_ws/base.py | 32 +++++++++++++++----------------- graphql_ws/base_async.py | 32 ++++++++++++++++++-------------- graphql_ws/base_sync.py | 11 +++-------- tests/test_base.py | 9 +++++++-- tests/test_graphql_ws.py | 7 ++++--- 5 files changed, 47 insertions(+), 44 deletions(-) diff --git a/graphql_ws/base.py b/graphql_ws/base.py index 798d19d..35ee2fe 100644 --- a/graphql_ws/base.py +++ b/graphql_ws/base.py @@ -35,9 +35,18 @@ def get_operation(self, op_id): def remove_operation(self, op_id): try: - del self.operations[op_id] + return self.operations.pop(op_id) except KeyError: - pass + return + + def unsubscribe(self, op_id): + async_iterator = self.remove_operation(op_id) + if hasattr(async_iterator, 'dispose'): + async_iterator.dispose() + + def unsubscribe_all(self): + for op_id in list(self.operations): + self.unsubscribe(op_id) def receive(self): raise NotImplementedError("receive method not implemented") @@ -76,12 +85,6 @@ def process_message(self, connection_context, parsed_message): elif op_type == GQL_START: assert isinstance(payload, dict), "The payload must be a dict" - - # If we already have a subscription with this id, unsubscribe from - # it first - if connection_context.has_operation(op_id): - self.unsubscribe(connection_context, op_id) - params = self.get_graphql_params(connection_context, payload) return self.on_start(connection_context, op_id, params) @@ -116,7 +119,10 @@ def on_open(self, connection_context): raise NotImplementedError("on_open method not implemented") def on_stop(self, connection_context, op_id): - raise NotImplementedError("on_stop method not implemented") + return connection_context.unsubscribe(op_id) + + def on_close(self, connection_context): + return connection_context.unsubscribe_all() def send_message(self, connection_context, op_id=None, op_type=None, payload=None): message = self.build_message(op_id, op_type, payload) @@ -171,11 +177,3 @@ def on_message(self, connection_context, message): return self.send_error(connection_context, None, e) return self.process_message(connection_context, parsed_message) - - def unsubscribe(self, connection_context, op_id): - if connection_context.has_operation(op_id): - # Close async iterator - connection_context.get_operation(op_id).dispose() - # Close operation - connection_context.remove_operation(op_id) - return self.on_operation_complete(connection_context, op_id) diff --git a/graphql_ws/base_async.py b/graphql_ws/base_async.py index d02cc29..6cedc67 100644 --- a/graphql_ws/base_async.py +++ b/graphql_ws/base_async.py @@ -88,6 +88,20 @@ def remember_task(self, task): task for task in self.pending_tasks if task.done() ) + async def unsubscribe(self, op_id): + super().unsubscribe(op_id) + + async def unsubscribe_all(self): + awaitables = [self.unsubscribe(op_id) for op_id in list(self.operations)] + for task in self.pending_tasks: + task.cancel() + awaitables.append(task) + if awaitables: + try: + await asyncio.gather(*awaitables) + except asyncio.CancelledError: + pass + class BaseAsyncSubscriptionServer(base.BaseSubscriptionServer, ABC): graphql_executor = AsyncioExecutor @@ -125,6 +139,10 @@ async def on_connection_init(self, connection_context, op_id, payload): await connection_context.close(1011) async def on_start(self, connection_context, op_id, params): + # Attempt to unsubscribe first in case we already have a subscription + # with this id. + await connection_context.unsubscribe(op_id) + execution_result = self.execute(params) if is_awaitable(execution_result): @@ -153,20 +171,6 @@ async def on_start(self, connection_context, op_id, params): await self.send_message(connection_context, op_id, GQL_COMPLETE) await self.on_operation_complete(connection_context, op_id) - async def on_close(self, connection_context): - awaitables = tuple( - self.unsubscribe(connection_context, op_id) - for op_id in connection_context.operations - ) + tuple(task.cancel() for task in connection_context.pending_tasks) - if awaitables: - try: - await asyncio.gather(*awaitables, loop=self.loop) - except asyncio.CancelledError: - pass - - async def on_stop(self, connection_context, op_id): - await self.unsubscribe(connection_context, op_id) - async def on_operation_complete(self, connection_context, op_id): pass diff --git a/graphql_ws/base_sync.py b/graphql_ws/base_sync.py index a6d2efb..06db900 100644 --- a/graphql_ws/base_sync.py +++ b/graphql_ws/base_sync.py @@ -20,11 +20,6 @@ def on_open(self, connection_context): def on_connect(self, connection_context, payload): pass - def on_close(self, connection_context): - remove_operations = list(connection_context.operations) - for op_id in remove_operations: - self.unsubscribe(connection_context, op_id) - def on_connection_init(self, connection_context, op_id, payload): try: self.on_connect(connection_context, payload) @@ -34,10 +29,10 @@ def on_connection_init(self, connection_context, op_id, payload): self.send_error(connection_context, op_id, e, GQL_CONNECTION_ERROR) connection_context.close(1011) - def on_stop(self, connection_context, op_id): - self.unsubscribe(connection_context, op_id) - def on_start(self, connection_context, op_id, params): + # Attempt to unsubscribe first in case we already have a subscription + # with this id. + connection_context.unsubscribe(op_id) try: execution_result = self.execute(params) assert isinstance( diff --git a/tests/test_base.py b/tests/test_base.py index 80de021..5b40ac5 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -16,8 +16,13 @@ def test_not_implemented(): server.on_connection_init(connection_context=None, op_id=1, payload={}) with pytest.raises(NotImplementedError): server.on_open(connection_context=None) - with pytest.raises(NotImplementedError): - server.on_stop(connection_context=None, op_id=1) + + +def test_on_stop(): + server = base.BaseSubscriptionServer(schema=None) + context = mock.Mock() + server.on_stop(connection_context=context, op_id=1) + context.unsubscribe.assert_called_with(1) def test_terminate(): diff --git a/tests/test_graphql_ws.py b/tests/test_graphql_ws.py index 4a7b845..e29e2a2 100644 --- a/tests/test_graphql_ws.py +++ b/tests/test_graphql_ws.py @@ -94,12 +94,12 @@ def test_start_existing_op(self, ss, cc): ss.get_graphql_params.return_value = {"params": True} cc.has_operation = mock.Mock() cc.has_operation.return_value = True - ss.unsubscribe = mock.Mock() + cc.unsubscribe = mock.Mock() ss.on_start = mock.Mock() ss.process_message( cc, {"id": "1", "type": constants.GQL_START, "payload": {"a": "b"}} ) - assert ss.unsubscribe.called + assert cc.unsubscribe.called ss.on_start.assert_called_with(cc, "1", {"params": True}) def test_start_bad_graphql_params(self, ss, cc): @@ -162,7 +162,8 @@ def test_build_message_partial(ss): assert ss.build_message(id=None, op_type=None, payload="PAYLOAD") == { "payload": "PAYLOAD" } - assert ss.build_message(id=None, op_type=None, payload=None) == {} + with pytest.raises(AssertionError): + ss.build_message(id=None, op_type=None, payload=None) def test_send_execution_result(ss): From 9bec86e8a6d016d27ade83d7601ae75722005829 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Thu, 30 Jul 2020 10:31:51 +1200 Subject: [PATCH 57/72] Remove a redundant async method --- graphql_ws/base_async.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/graphql_ws/base_async.py b/graphql_ws/base_async.py index 6cedc67..0d57c42 100644 --- a/graphql_ws/base_async.py +++ b/graphql_ws/base_async.py @@ -121,9 +121,6 @@ def process_message(self, connection_context, parsed_message): connection_context.remember_task(task) return task - async def send_message(self, *args, **kwargs): - await super().send_message(*args, **kwargs) - async def on_open(self, connection_context): pass From a1d2ebc203f15e98bad234137c324b3b0f5d646c Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Thu, 30 Jul 2020 14:21:08 +1200 Subject: [PATCH 58/72] Only send messages for operations that exist --- graphql_ws/base.py | 7 ++++--- graphql_ws/base_async.py | 10 +++++++++- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/graphql_ws/base.py b/graphql_ws/base.py index 35ee2fe..1ed2da1 100644 --- a/graphql_ws/base.py +++ b/graphql_ws/base.py @@ -125,9 +125,9 @@ def on_close(self, connection_context): return connection_context.unsubscribe_all() def send_message(self, connection_context, op_id=None, op_type=None, payload=None): - message = self.build_message(op_id, op_type, payload) - assert message, "You need to send at least one thing" - return connection_context.send(message) + if op_id is None or connection_context.has_operation(op_id): + message = self.build_message(op_id, op_type, payload) + return connection_context.send(message) def build_message(self, id, op_type, payload): message = {} @@ -137,6 +137,7 @@ def build_message(self, id, op_type, payload): message["type"] = op_type if payload is not None: message["payload"] = payload + assert message, "You need to send at least one thing" return message def send_execution_result(self, connection_context, op_id, execution_result): diff --git a/graphql_ws/base_async.py b/graphql_ws/base_async.py index 0d57c42..735818d 100644 --- a/graphql_ws/base_async.py +++ b/graphql_ws/base_async.py @@ -142,6 +142,7 @@ async def on_start(self, connection_context, op_id, params): execution_result = self.execute(params) + connection_context.register_operation(op_id, execution_result) if is_awaitable(execution_result): execution_result = await execution_result @@ -157,7 +158,6 @@ async def on_start(self, connection_context, op_id, params): ) except Exception as e: await self.send_error(connection_context, op_id, e) - connection_context.remove_operation(op_id) else: try: await self.send_execution_result( @@ -166,8 +166,16 @@ async def on_start(self, connection_context, op_id, params): except Exception as e: await self.send_error(connection_context, op_id, e) await self.send_message(connection_context, op_id, GQL_COMPLETE) + connection_context.remove_operation(op_id) await self.on_operation_complete(connection_context, op_id) + async def send_message( + self, connection_context, op_id=None, op_type=None, payload=None + ): + if op_id is None or connection_context.has_operation(op_id): + message = self.build_message(op_id, op_type, payload) + return await connection_context.send(message) + async def on_operation_complete(self, connection_context, op_id): pass From 94d874027edceb8ae56b80907db07d25db705cca Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Thu, 30 Jul 2020 16:52:27 +1200 Subject: [PATCH 59/72] Iterators are considered awaitable with the new method, so check only not aiter --- graphql_ws/base_async.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/graphql_ws/base_async.py b/graphql_ws/base_async.py index 735818d..7f7e74f 100644 --- a/graphql_ws/base_async.py +++ b/graphql_ws/base_async.py @@ -143,9 +143,6 @@ async def on_start(self, connection_context, op_id, params): execution_result = self.execute(params) connection_context.register_operation(op_id, execution_result) - if is_awaitable(execution_result): - execution_result = await execution_result - if hasattr(execution_result, "__aiter__"): iterator = await execution_result.__aiter__() connection_context.register_operation(op_id, iterator) @@ -160,6 +157,8 @@ async def on_start(self, connection_context, op_id, params): await self.send_error(connection_context, op_id, e) else: try: + if is_awaitable(execution_result): + execution_result = await execution_result await self.send_execution_result( connection_context, op_id, execution_result ) From 218c7fc5e26ed671f1f56f9aed0548d9026ae637 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Thu, 30 Jul 2020 17:35:56 +1200 Subject: [PATCH 60/72] Add request context directly to the payload rather than a request_context key --- graphql_ws/base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/graphql_ws/base.py b/graphql_ws/base.py index 1ed2da1..4df2fab 100644 --- a/graphql_ws/base.py +++ b/graphql_ws/base.py @@ -105,8 +105,7 @@ def on_connection_terminate(self, connection_context, op_id): return connection_context.close(1011) def get_graphql_params(self, connection_context, payload): - context = payload.get("context") or {} - context.setdefault("request_context", connection_context.request_context) + context = payload.get("context", connection_context.request_context) return { "request_string": payload.get("query"), "variable_values": payload.get("variables"), From ae0b0c7c9124a550c50db23781b0dc590beaec74 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Tue, 24 Nov 2020 16:54:44 +1300 Subject: [PATCH 61/72] Correctly unsubscribe after on_start operation is complete --- graphql_ws/base_async.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphql_ws/base_async.py b/graphql_ws/base_async.py index 7f7e74f..6954341 100644 --- a/graphql_ws/base_async.py +++ b/graphql_ws/base_async.py @@ -165,7 +165,7 @@ async def on_start(self, connection_context, op_id, params): except Exception as e: await self.send_error(connection_context, op_id, e) await self.send_message(connection_context, op_id, GQL_COMPLETE) - connection_context.remove_operation(op_id) + await connection_context.unsubscribe(op_id) await self.on_operation_complete(connection_context, op_id) async def send_message( From a964800472035943f1f965b1ca341d1cda147879 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Wed, 25 Nov 2020 00:23:24 +1300 Subject: [PATCH 62/72] Fix tests --- tests/test_graphql_ws.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/test_graphql_ws.py b/tests/test_graphql_ws.py index e29e2a2..3b85c49 100644 --- a/tests/test_graphql_ws.py +++ b/tests/test_graphql_ws.py @@ -1,4 +1,5 @@ from collections import OrderedDict + try: from unittest import mock except ImportError: @@ -95,12 +96,12 @@ def test_start_existing_op(self, ss, cc): cc.has_operation = mock.Mock() cc.has_operation.return_value = True cc.unsubscribe = mock.Mock() - ss.on_start = mock.Mock() + ss.execute = mock.Mock() + ss.send_message = mock.Mock() ss.process_message( cc, {"id": "1", "type": constants.GQL_START, "payload": {"a": "b"}} ) assert cc.unsubscribe.called - ss.on_start.assert_called_with(cc, "1", {"params": True}) def test_start_bad_graphql_params(self, ss, cc): ss.get_graphql_params = mock.Mock() @@ -110,9 +111,7 @@ def test_start_bad_graphql_params(self, ss, cc): ss.send_error = mock.Mock() ss.unsubscribe = mock.Mock() ss.on_start = mock.Mock() - ss.process_message( - cc, {"id": "1", "type": None, "payload": {"a": "b"}} - ) + ss.process_message(cc, {"id": "1", "type": None, "payload": {"a": "b"}}) assert ss.send_error.called assert ss.send_error.call_args[0][:2] == (cc, "1") assert isinstance(ss.send_error.call_args[0][2], Exception) @@ -144,7 +143,7 @@ def test_get_graphql_params(ss, cc): "request_string": "req", "variable_values": "vars", "operation_name": "query", - "context_value": {'request_context': None}, + "context_value": {}, } From cdbdda1744f949a656333dcbaa5fb8d9ed3442a6 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Mon, 29 Mar 2021 17:25:59 +1300 Subject: [PATCH 63/72] Async unsubscription needs to wait around for the future to cancel --- graphql_ws/base.py | 1 + graphql_ws/base_async.py | 9 +++++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/graphql_ws/base.py b/graphql_ws/base.py index 4df2fab..31ad657 100644 --- a/graphql_ws/base.py +++ b/graphql_ws/base.py @@ -43,6 +43,7 @@ def unsubscribe(self, op_id): async_iterator = self.remove_operation(op_id) if hasattr(async_iterator, 'dispose'): async_iterator.dispose() + return async_iterator def unsubscribe_all(self): for op_id in list(self.operations): diff --git a/graphql_ws/base_async.py b/graphql_ws/base_async.py index 6954341..0c62481 100644 --- a/graphql_ws/base_async.py +++ b/graphql_ws/base_async.py @@ -2,7 +2,7 @@ import inspect from abc import ABC, abstractmethod from types import CoroutineType, GeneratorType -from typing import Any, Union, List, Dict +from typing import Any, Dict, List, Union from weakref import WeakSet from graphql.execution.executors.asyncio import AsyncioExecutor @@ -89,7 +89,12 @@ def remember_task(self, task): ) async def unsubscribe(self, op_id): - super().unsubscribe(op_id) + async_iterator = super().unsubscribe(op_id) + if ( + getattr(async_iterator, "future", None) + and async_iterator.future.cancel() + ): + await async_iterator.future async def unsubscribe_all(self): awaitables = [self.unsubscribe(op_id) for op_id in list(self.operations)] From 45546366581b6c31b33b917cd5d805a758fe54a4 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Tue, 30 Mar 2021 09:12:12 +1300 Subject: [PATCH 64/72] Allow collection of tests even if aiohttp isn't installed --- tests/test_aiohttp.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index 88a48d1..40c43fd 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -1,15 +1,22 @@ +try: + from aiohttp import WSMsgType + from graphql_ws.aiohttp import AiohttpConnectionContext, AiohttpSubscriptionServer +except ImportError: # pragma: no cover + WSMsgType = None + from unittest import mock import pytest -from aiohttp import WSMsgType -from graphql_ws.aiohttp import AiohttpConnectionContext, AiohttpSubscriptionServer from graphql_ws.base import ConnectionClosedException +if_aiohttp_installed = pytest.mark.skipif( + WSMsgType is None, reason="aiohttp is not installed" +) + class AsyncMock(mock.Mock): def __call__(self, *args, **kwargs): - async def coro(): return super(AsyncMock, self).__call__(*args, **kwargs) @@ -24,6 +31,7 @@ def mock_ws(): return ws +@if_aiohttp_installed @pytest.mark.asyncio class TestConnectionContext: async def test_receive_good_data(self, mock_ws): @@ -69,5 +77,6 @@ async def test_close(self, mock_ws): mock_ws.close.assert_called_with(code=123) +@if_aiohttp_installed def test_subscription_server_smoke(): AiohttpSubscriptionServer(schema=None) From 56f46a1f1c58050af9844c3db865e03c01229745 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Tue, 30 Mar 2021 09:13:39 +1300 Subject: [PATCH 65/72] Make the python 2 async observer send graphql error for exceptions explicitly returned --- graphql_ws/base_sync.py | 6 +++++- tests/test_base.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/graphql_ws/base_sync.py b/graphql_ws/base_sync.py index 06db900..f6b6c68 100644 --- a/graphql_ws/base_sync.py +++ b/graphql_ws/base_sync.py @@ -65,7 +65,11 @@ def __init__( self.send_message = send_message def on_next(self, value): - self.send_execution_result(self.connection_context, self.op_id, value) + if isinstance(value, Exception): + send_method = self.send_error + else: + send_method = self.send_execution_result + send_method(self.connection_context, self.op_id, value) def on_completed(self): self.send_message(self.connection_context, self.op_id, GQL_COMPLETE) diff --git a/tests/test_base.py b/tests/test_base.py index 5b40ac5..1ce6300 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -8,6 +8,7 @@ import pytest from graphql_ws import base +from graphql_ws.base_sync import SubscriptionObserver def test_not_implemented(): @@ -77,3 +78,35 @@ def test_context_operations(): assert not context.has_operation(1) # Removing a non-existant operation fails silently. context.remove_operation(999) + + +def test_observer_data(): + ws = mock.Mock() + context = base.BaseConnectionContext(ws) + send_result, send_error, send_message = mock.Mock(), mock.Mock(), mock.Mock() + observer = SubscriptionObserver( + connection_context=context, + op_id=1, + send_execution_result=send_result, + send_error=send_error, + send_message=send_message, + ) + observer.on_next('data') + assert send_result.called + assert not send_error.called + + +def test_observer_exception(): + ws = mock.Mock() + context = base.BaseConnectionContext(ws) + send_result, send_error, send_message = mock.Mock(), mock.Mock(), mock.Mock() + observer = SubscriptionObserver( + connection_context=context, + op_id=1, + send_execution_result=send_result, + send_error=send_error, + send_message=send_message, + ) + observer.on_next(TypeError('some bad message')) + assert send_error.called + assert not send_result.called From 5abd858a813ecb09834b92c22184e3e466aa988d Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Tue, 30 Mar 2021 09:15:27 +1300 Subject: [PATCH 66/72] asyncio.wait receiving coroutines is deprecated, create tasks explicitly --- graphql_ws/base_async.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/graphql_ws/base_async.py b/graphql_ws/base_async.py index 0c62481..bc98dc5 100644 --- a/graphql_ws/base_async.py +++ b/graphql_ws/base_async.py @@ -54,7 +54,10 @@ async def resolve( else: items = None if items is not None: - children = [resolve(child, _container=data, _key=key) for key, child in items] + children = [ + asyncio.create_task(resolve(child, _container=data, _key=key)) + for key, child in items + ] if children: await asyncio.wait(children) @@ -90,10 +93,7 @@ def remember_task(self, task): async def unsubscribe(self, op_id): async_iterator = super().unsubscribe(op_id) - if ( - getattr(async_iterator, "future", None) - and async_iterator.future.cancel() - ): + if getattr(async_iterator, "future", None) and async_iterator.future.cancel(): await async_iterator.future async def unsubscribe_all(self): From d2d55a12382e0fc3c938c77fc9fcbd8e72c8ddb7 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Tue, 30 Mar 2021 09:19:52 +1300 Subject: [PATCH 67/72] Tidy up a test warning --- tests/test_base_async.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_base_async.py b/tests/test_base_async.py index d341c18..d1a952b 100644 --- a/tests/test_base_async.py +++ b/tests/test_base_async.py @@ -33,6 +33,7 @@ async def test_terminate(server: TestServer): async def test_send_error(server: TestServer): context = AsyncMock() + context.has_operation = mock.Mock() await server.send_error(connection_context=context, op_id=1, error="test error") context.send.assert_called_with( {"id": 1, "type": "error", "payload": {"message": "test error"}} From f7cb773fdb03b47c19172aa2b5f38ac62a4e5b76 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Tue, 30 Mar 2021 09:20:45 +1300 Subject: [PATCH 68/72] Rename TestServer to avoid it being collected by pytest --- tests/test_base_async.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_base_async.py b/tests/test_base_async.py index d1a952b..d62eda5 100644 --- a/tests/test_base_async.py +++ b/tests/test_base_async.py @@ -15,23 +15,23 @@ async def __call__(self, *args, **kwargs): return super().__call__(*args, **kwargs) -class TestServer(base_async.BaseAsyncSubscriptionServer): +class TstServer(base_async.BaseAsyncSubscriptionServer): def handle(self, *args, **kwargs): - pass + pass # pragma: no cover @pytest.fixture def server(): - return TestServer(schema=None) + return TstServer(schema=None) -async def test_terminate(server: TestServer): +async def test_terminate(server: TstServer): context = AsyncMock() await server.on_connection_terminate(connection_context=context, op_id=1) context.close.assert_called_with(1011) -async def test_send_error(server: TestServer): +async def test_send_error(server: TstServer): context = AsyncMock() context.has_operation = mock.Mock() await server.send_error(connection_context=context, op_id=1, error="test error") From 80890c32124037b533e4a043a899dd60fbae419d Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Tue, 30 Mar 2021 09:21:46 +1300 Subject: [PATCH 69/72] Update test matrix --- setup.cfg | 7 ++----- tests/conftest.py | 2 -- tox.ini | 10 +++++----- 3 files changed, 7 insertions(+), 12 deletions(-) diff --git a/setup.cfg b/setup.cfg index 1e7ea2a..1e85964 100644 --- a/setup.cfg +++ b/setup.cfg @@ -15,15 +15,12 @@ classifiers = License :: OSI Approved :: MIT License Natural Language :: English Programming Language :: Python :: 2 - Programming Language :: Python :: 2.6 Programming Language :: Python :: 2.7 Programming Language :: Python :: 3 - Programming Language :: Python :: 3.3 - Programming Language :: Python :: 3.4 - Programming Language :: Python :: 3.5 Programming Language :: Python :: 3.6 Programming Language :: Python :: 3.7 Programming Language :: Python :: 3.8 + Programming Language :: Python :: 3.9 [options] zip_safe = False @@ -94,4 +91,4 @@ omit = [coverage:report] exclude_lines = pragma: no cover - @abstract \ No newline at end of file + @abstract diff --git a/tests/conftest.py b/tests/conftest.py index fa905b4..595968a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,5 @@ if sys.version_info > (3,): collect_ignore = ["test_django_channels.py"] - if sys.version_info < (3, 6): - collect_ignore.append("test_gevent.py") else: collect_ignore = ["test_aiohttp.py", "test_base_async.py"] diff --git a/tox.ini b/tox.ini index 42d13b4..62e2f8b 100644 --- a/tox.ini +++ b/tox.ini @@ -1,15 +1,15 @@ [tox] -envlist = +envlist = coverage_setup - py27, py35, py36, py37, py38, flake8 + py27, py36, py37, py38, py39, flake8 coverage_report [travis] python = - 3.8: py38, flake8 + 3.9: py39, flake8 + 3.8: py38 3.7: py37 3.6: py36 - 3.5: py35 2.7: py27 [testenv] @@ -33,4 +33,4 @@ commands = coverage html coverage xml coverage report --include="tests/*" --fail-under=100 -m - coverage report --omit="tests/*" # --fail-under=90 -m \ No newline at end of file + coverage report --omit="tests/*" # --fail-under=90 -m From 9ced6094a3b2af37695711d2f25316b4e1926bb1 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Tue, 30 Mar 2021 09:35:30 +1300 Subject: [PATCH 70/72] Update travis python versions --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index a3ef963..5104cdc 100644 --- a/.travis.yml +++ b/.travis.yml @@ -11,9 +11,9 @@ deploy: install: pip install -U tox-travis language: python python: +- 3.9 - 3.8 - 3.7 - 3.6 -- 3.5 - 2.7 script: tox From 3adfaa9ce13052c4c96ed5d689ec8e82dec76cb1 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Tue, 30 Mar 2021 10:56:14 +1300 Subject: [PATCH 71/72] Try using a newer travis dist to fix cryptography building issues --- .travis.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.travis.yml b/.travis.yml index 5104cdc..67a356c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -10,6 +10,7 @@ deploy: tags: true install: pip install -U tox-travis language: python +dist: focal python: - 3.9 - 3.8 From 703e4074573b2dca068f2eb36a25a06808ec8698 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Tue, 30 Mar 2021 11:18:45 +1300 Subject: [PATCH 72/72] Use python 3.6 friendly asyncio method --- graphql_ws/base_async.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphql_ws/base_async.py b/graphql_ws/base_async.py index bc98dc5..a21ca5e 100644 --- a/graphql_ws/base_async.py +++ b/graphql_ws/base_async.py @@ -55,7 +55,7 @@ async def resolve( items = None if items is not None: children = [ - asyncio.create_task(resolve(child, _container=data, _key=key)) + asyncio.ensure_future(resolve(child, _container=data, _key=key)) for key, child in items ] if children: