Skip to content

Commit 4803d6c

Browse files
committed
Added subscription support
1 parent f8a6a8a commit 4803d6c

File tree

4 files changed

+400
-12
lines changed

4 files changed

+400
-12
lines changed

graphql/execution/base.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ def __init__(self, schema, document_ast, root_value, context_value, variable_val
3232
for definition in document_ast.definitions:
3333
if isinstance(definition, ast.OperationDefinition):
3434
if not operation_name and operation:
35-
raise GraphQLError('Must provide operation name if query contains multiple operations.')
35+
raise GraphQLError(
36+
'Must provide operation name if query contains multiple operations.')
3637

3738
if not operation_name or definition.name and definition.name.value == operation_name:
3839
operation = definition
@@ -42,18 +43,21 @@ def __init__(self, schema, document_ast, root_value, context_value, variable_val
4243

4344
else:
4445
raise GraphQLError(
45-
u'GraphQL cannot execute a request containing a {}.'.format(definition.__class__.__name__),
46+
u'GraphQL cannot execute a request containing a {}.'.format(
47+
definition.__class__.__name__),
4648
definition
4749
)
4850

4951
if not operation:
5052
if operation_name:
51-
raise GraphQLError(u'Unknown operation named "{}".'.format(operation_name))
53+
raise GraphQLError(
54+
u'Unknown operation named "{}".'.format(operation_name))
5255

5356
else:
5457
raise GraphQLError('Must provide an operation.')
5558

56-
variable_values = get_variable_values(schema, operation.variable_definitions or [], variable_values)
59+
variable_values = get_variable_values(
60+
schema, operation.variable_definitions or [], variable_values)
5761

5862
self.schema = schema
5963
self.fragments = fragments
@@ -82,7 +86,8 @@ def get_argument_values(self, field_def, field_ast):
8286
return result
8387

8488
def report_error(self, error, traceback=None):
85-
sys.excepthook(type(error), error, getattr(error, 'stack', None) or traceback)
89+
sys.excepthook(type(error), error, getattr(
90+
error, 'stack', None) or traceback)
8691
self.errors.append(error)
8792

8893
def get_sub_fields(self, return_type, field_asts):
@@ -101,6 +106,20 @@ def get_sub_fields(self, return_type, field_asts):
101106
return self._subfields_cache[k]
102107

103108

109+
class SubscriberExecutionContext(object):
110+
__slots__ = 'exe_context', 'errors'
111+
112+
def __init__(self, exe_context):
113+
self.exe_context = exe_context
114+
self.errors = []
115+
116+
def reset(self):
117+
self.errors = []
118+
119+
def __getattr__(self, name):
120+
return getattr(self.exe_context, name)
121+
122+
104123
class ExecutionResult(object):
105124
"""The result of execution. `data` is the result of executing the
106125
query, `errors` is null if no errors occurred, and is a
@@ -186,7 +205,8 @@ def collect_fields(ctx, runtime_type, selection_set, fields, prev_fragment_names
186205
ctx, selection, runtime_type):
187206
continue
188207

189-
collect_fields(ctx, runtime_type, selection.selection_set, fields, prev_fragment_names)
208+
collect_fields(ctx, runtime_type,
209+
selection.selection_set, fields, prev_fragment_names)
190210

191211
elif isinstance(selection, ast.FragmentSpread):
192212
frag_name = selection.name.value
@@ -202,7 +222,8 @@ def collect_fields(ctx, runtime_type, selection_set, fields, prev_fragment_names
202222
does_fragment_condition_match(ctx, fragment, runtime_type):
203223
continue
204224

205-
collect_fields(ctx, runtime_type, fragment.selection_set, fields, prev_fragment_names)
225+
collect_fields(ctx, runtime_type,
226+
fragment.selection_set, fields, prev_fragment_names)
206227

207228
return fields
208229

graphql/execution/executor.py

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import functools
33
import logging
44
import sys
5+
from rx import Observable
56

67
from six import string_types
78
from promise import Promise, promise_for_dict, is_thenable
@@ -15,7 +16,7 @@
1516
GraphQLSchema, GraphQLUnionType)
1617
from .base import (ExecutionContext, ExecutionResult, ResolveInfo,
1718
collect_fields, default_resolve_fn, get_field_def,
18-
get_operation_root_type)
19+
get_operation_root_type, SubscriberExecutionContext)
1920
from .executors.sync import SyncExecutor
2021
from .middleware import MiddlewareManager
2122

@@ -61,6 +62,9 @@ def on_rejected(error):
6162
return None
6263

6364
def on_resolve(data):
65+
if isinstance(data, Observable):
66+
return data
67+
6468
if not context.errors:
6569
return ExecutionResult(data=data)
6670
return ExecutionResult(data=data, errors=context.errors)
@@ -88,6 +92,9 @@ def execute_operation(exe_context, operation, root_value):
8892
if operation.operation == 'mutation':
8993
return execute_fields_serially(exe_context, type, root_value, fields)
9094

95+
if operation.operation == 'subscription':
96+
return subscribe_fields(exe_context, type, root_value, fields)
97+
9198
return execute_fields(exe_context, type, root_value, fields)
9299

93100

@@ -140,6 +147,39 @@ def execute_fields(exe_context, parent_type, source_value, fields):
140147
return promise_for_dict(final_results)
141148

142149

150+
def subscribe_fields(exe_context, parent_type, source_value, fields):
151+
exe_context = SubscriberExecutionContext(exe_context)
152+
153+
def on_error(error):
154+
exe_context.report_error(error)
155+
156+
def map_result(data):
157+
if exe_context.errors:
158+
result = ExecutionResult(data=data, errors=exe_context.errors)
159+
else:
160+
result = ExecutionResult(data=data)
161+
exe_context.reset()
162+
return result
163+
164+
observables = []
165+
166+
# assert len(fields) == 1, "Can only subscribe one element at a time."
167+
168+
for response_name, field_asts in fields.items():
169+
170+
result = subscribe_field(exe_context, parent_type,
171+
source_value, field_asts)
172+
if result is Undefined:
173+
continue
174+
175+
# Map observable results
176+
observable = result.map(lambda data: map_result({response_name: data}))
177+
return observable
178+
observables.append(observable)
179+
180+
return Observable.merge(observables)
181+
182+
143183
def resolve_field(exe_context, parent_type, source, field_asts):
144184
field_ast = field_asts[0]
145185
field_name = field_ast.name.value
@@ -191,6 +231,64 @@ def resolve_field(exe_context, parent_type, source, field_asts):
191231
)
192232

193233

234+
def subscribe_field(exe_context, parent_type, source, field_asts):
235+
field_ast = field_asts[0]
236+
field_name = field_ast.name.value
237+
238+
field_def = get_field_def(exe_context.schema, parent_type, field_name)
239+
if not field_def:
240+
return Undefined
241+
242+
return_type = field_def.type
243+
resolve_fn = field_def.resolver or default_resolve_fn
244+
245+
# We wrap the resolve_fn from the middleware
246+
resolve_fn_middleware = exe_context.get_field_resolver(resolve_fn)
247+
248+
# Build a dict of arguments from the field.arguments AST, using the variables scope to
249+
# fulfill any variable references.
250+
args = exe_context.get_argument_values(field_def, field_ast)
251+
252+
# The resolve function's optional third argument is a context value that
253+
# is provided to every resolve function within an execution. It is commonly
254+
# used to represent an authenticated user, or request-specific caches.
255+
context = exe_context.context_value
256+
257+
# The resolve function's optional third argument is a collection of
258+
# information about the current execution state.
259+
info = ResolveInfo(
260+
field_name,
261+
field_asts,
262+
return_type,
263+
parent_type,
264+
schema=exe_context.schema,
265+
fragments=exe_context.fragments,
266+
root_value=exe_context.root_value,
267+
operation=exe_context.operation,
268+
variable_values=exe_context.variable_values,
269+
context=context
270+
)
271+
272+
executor = exe_context.executor
273+
result = resolve_or_error(resolve_fn_middleware,
274+
source, info, args, executor)
275+
276+
if isinstance(result, Exception):
277+
raise result
278+
279+
if not isinstance(result, Observable):
280+
raise GraphQLError(
281+
'Subscription must return Async Iterable or Observable. Received: {}'.format(repr(result)))
282+
283+
return result.map(functools.partial(
284+
complete_value_catching_error,
285+
exe_context,
286+
return_type,
287+
field_asts,
288+
info,
289+
))
290+
291+
194292
def resolve_or_error(resolve_fn, source, info, args, executor):
195293
try:
196294
return executor.execute(resolve_fn, source, info, **args)

0 commit comments

Comments
 (0)