|
2 | 2 | import functools
|
3 | 3 | import logging
|
4 | 4 | import sys
|
| 5 | +from rx import Observable |
5 | 6 |
|
6 | 7 | from six import string_types
|
7 | 8 | from promise import Promise, promise_for_dict, is_thenable
|
|
15 | 16 | GraphQLSchema, GraphQLUnionType)
|
16 | 17 | from .base import (ExecutionContext, ExecutionResult, ResolveInfo,
|
17 | 18 | collect_fields, default_resolve_fn, get_field_def,
|
18 |
| - get_operation_root_type) |
| 19 | + get_operation_root_type, SubscriberExecutionContext) |
19 | 20 | from .executors.sync import SyncExecutor
|
20 | 21 | from .middleware import MiddlewareManager
|
21 | 22 |
|
@@ -61,6 +62,9 @@ def on_rejected(error):
|
61 | 62 | return None
|
62 | 63 |
|
63 | 64 | def on_resolve(data):
|
| 65 | + if isinstance(data, Observable): |
| 66 | + return data |
| 67 | + |
64 | 68 | if not context.errors:
|
65 | 69 | return ExecutionResult(data=data)
|
66 | 70 | return ExecutionResult(data=data, errors=context.errors)
|
@@ -88,6 +92,9 @@ def execute_operation(exe_context, operation, root_value):
|
88 | 92 | if operation.operation == 'mutation':
|
89 | 93 | return execute_fields_serially(exe_context, type, root_value, fields)
|
90 | 94 |
|
| 95 | + if operation.operation == 'subscription': |
| 96 | + return subscribe_fields(exe_context, type, root_value, fields) |
| 97 | + |
91 | 98 | return execute_fields(exe_context, type, root_value, fields)
|
92 | 99 |
|
93 | 100 |
|
@@ -140,6 +147,39 @@ def execute_fields(exe_context, parent_type, source_value, fields):
|
140 | 147 | return promise_for_dict(final_results)
|
141 | 148 |
|
142 | 149 |
|
| 150 | +def subscribe_fields(exe_context, parent_type, source_value, fields): |
| 151 | + exe_context = SubscriberExecutionContext(exe_context) |
| 152 | + |
| 153 | + def on_error(error): |
| 154 | + exe_context.report_error(error) |
| 155 | + |
| 156 | + def map_result(data): |
| 157 | + if exe_context.errors: |
| 158 | + result = ExecutionResult(data=data, errors=exe_context.errors) |
| 159 | + else: |
| 160 | + result = ExecutionResult(data=data) |
| 161 | + exe_context.reset() |
| 162 | + return result |
| 163 | + |
| 164 | + observables = [] |
| 165 | + |
| 166 | + # assert len(fields) == 1, "Can only subscribe one element at a time." |
| 167 | + |
| 168 | + for response_name, field_asts in fields.items(): |
| 169 | + |
| 170 | + result = subscribe_field(exe_context, parent_type, |
| 171 | + source_value, field_asts) |
| 172 | + if result is Undefined: |
| 173 | + continue |
| 174 | + |
| 175 | + # Map observable results |
| 176 | + observable = result.map(lambda data: map_result({response_name: data})) |
| 177 | + return observable |
| 178 | + observables.append(observable) |
| 179 | + |
| 180 | + return Observable.merge(observables) |
| 181 | + |
| 182 | + |
143 | 183 | def resolve_field(exe_context, parent_type, source, field_asts):
|
144 | 184 | field_ast = field_asts[0]
|
145 | 185 | field_name = field_ast.name.value
|
@@ -191,6 +231,64 @@ def resolve_field(exe_context, parent_type, source, field_asts):
|
191 | 231 | )
|
192 | 232 |
|
193 | 233 |
|
| 234 | +def subscribe_field(exe_context, parent_type, source, field_asts): |
| 235 | + field_ast = field_asts[0] |
| 236 | + field_name = field_ast.name.value |
| 237 | + |
| 238 | + field_def = get_field_def(exe_context.schema, parent_type, field_name) |
| 239 | + if not field_def: |
| 240 | + return Undefined |
| 241 | + |
| 242 | + return_type = field_def.type |
| 243 | + resolve_fn = field_def.resolver or default_resolve_fn |
| 244 | + |
| 245 | + # We wrap the resolve_fn from the middleware |
| 246 | + resolve_fn_middleware = exe_context.get_field_resolver(resolve_fn) |
| 247 | + |
| 248 | + # Build a dict of arguments from the field.arguments AST, using the variables scope to |
| 249 | + # fulfill any variable references. |
| 250 | + args = exe_context.get_argument_values(field_def, field_ast) |
| 251 | + |
| 252 | + # The resolve function's optional third argument is a context value that |
| 253 | + # is provided to every resolve function within an execution. It is commonly |
| 254 | + # used to represent an authenticated user, or request-specific caches. |
| 255 | + context = exe_context.context_value |
| 256 | + |
| 257 | + # The resolve function's optional third argument is a collection of |
| 258 | + # information about the current execution state. |
| 259 | + info = ResolveInfo( |
| 260 | + field_name, |
| 261 | + field_asts, |
| 262 | + return_type, |
| 263 | + parent_type, |
| 264 | + schema=exe_context.schema, |
| 265 | + fragments=exe_context.fragments, |
| 266 | + root_value=exe_context.root_value, |
| 267 | + operation=exe_context.operation, |
| 268 | + variable_values=exe_context.variable_values, |
| 269 | + context=context |
| 270 | + ) |
| 271 | + |
| 272 | + executor = exe_context.executor |
| 273 | + result = resolve_or_error(resolve_fn_middleware, |
| 274 | + source, info, args, executor) |
| 275 | + |
| 276 | + if isinstance(result, Exception): |
| 277 | + raise result |
| 278 | + |
| 279 | + if not isinstance(result, Observable): |
| 280 | + raise GraphQLError( |
| 281 | + 'Subscription must return Async Iterable or Observable. Received: {}'.format(repr(result))) |
| 282 | + |
| 283 | + return result.map(functools.partial( |
| 284 | + complete_value_catching_error, |
| 285 | + exe_context, |
| 286 | + return_type, |
| 287 | + field_asts, |
| 288 | + info, |
| 289 | + )) |
| 290 | + |
| 291 | + |
194 | 292 | def resolve_or_error(resolve_fn, source, info, args, executor):
|
195 | 293 | try:
|
196 | 294 | return executor.execute(resolve_fn, source, info, **args)
|
|
0 commit comments