diff --git a/.travis.yml b/.travis.yml index e7fb1a63..e8b55782 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,10 +2,7 @@ language: python sudo: false python: - 2.7 -- 3.4 -- 3.5 -- 3.6 -- "pypy-5.3.1" +# - "pypy-5.3.1" before_install: - | if [ "$TRAVIS_PYTHON_VERSION" = "pypy" ]; then @@ -22,7 +19,9 @@ before_install: fi install: - pip install -e .[test] +- pip install flake8 script: +- flake8 - py.test --cov=graphql graphql tests after_success: - coveralls @@ -33,10 +32,13 @@ matrix: - pip install pytest-asyncio script: - py.test --cov=graphql graphql tests tests_py35 - - python: '2.7' - install: pip install flake8 + - python: '3.6' + after_install: + - pip install pytest-asyncio script: - - flake8 + - py.test --cov=graphql graphql tests tests_py35 + - python: '2.7' + deploy: provider: pypi user: syrusakbary diff --git a/graphql/__init__.py b/graphql/__init__.py index 6d082ac1..7e7aae8e 100644 --- a/graphql/__init__.py +++ b/graphql/__init__.py @@ -120,6 +120,7 @@ # Execute GraphQL queries. from .execution import ( # no import order execute, + subscribe, ResolveInfo, MiddlewareManager, middlewares @@ -254,6 +255,7 @@ 'print_ast', 'visit', 'execute', + 'subscribe', 'ResolveInfo', 'MiddlewareManager', 'middlewares', diff --git a/graphql/execution/__init__.py b/graphql/execution/__init__.py index 99fde15d..a2de1f33 100644 --- a/graphql/execution/__init__.py +++ b/graphql/execution/__init__.py @@ -18,13 +18,14 @@ 2) fragment "spreads" e.g. "...c" 3) inline fragment "spreads" e.g. "...on Type { a }" """ -from .executor import execute +from .executor import execute, subscribe from .base import ExecutionResult, ResolveInfo from .middleware import middlewares, MiddlewareManager __all__ = [ 'execute', + 'subscribe', 'ExecutionResult', 'ResolveInfo', 'MiddlewareManager', diff --git a/graphql/execution/base.py b/graphql/execution/base.py index 055f438b..f088a16b 100644 --- a/graphql/execution/base.py +++ b/graphql/execution/base.py @@ -19,9 +19,9 @@ class ExecutionContext(object): and the fragments defined in the query document""" __slots__ = 'schema', 'fragments', 'root_value', 'operation', 'variable_values', 'errors', 'context_value', \ - 'argument_values_cache', 'executor', 'middleware', '_subfields_cache' + 'argument_values_cache', 'executor', 'middleware', 'allow_subscriptions', '_subfields_cache' - def __init__(self, schema, document_ast, root_value, context_value, variable_values, operation_name, executor, middleware): + def __init__(self, schema, document_ast, root_value, context_value, variable_values, operation_name, executor, middleware, allow_subscriptions): """Constructs a ExecutionContext object from the arguments passed to execute, which we will pass throughout the other execution methods.""" @@ -32,7 +32,8 @@ def __init__(self, schema, document_ast, root_value, context_value, variable_val for definition in document_ast.definitions: if isinstance(definition, ast.OperationDefinition): if not operation_name and operation: - raise GraphQLError('Must provide operation name if query contains multiple operations.') + raise GraphQLError( + 'Must provide operation name if query contains multiple operations.') if not operation_name or definition.name and definition.name.value == operation_name: operation = definition @@ -42,18 +43,21 @@ def __init__(self, schema, document_ast, root_value, context_value, variable_val else: raise GraphQLError( - u'GraphQL cannot execute a request containing a {}.'.format(definition.__class__.__name__), + u'GraphQL cannot execute a request containing a {}.'.format( + definition.__class__.__name__), definition ) if not operation: if operation_name: - raise GraphQLError(u'Unknown operation named "{}".'.format(operation_name)) + raise GraphQLError( + u'Unknown operation named "{}".'.format(operation_name)) else: raise GraphQLError('Must provide an operation.') - variable_values = get_variable_values(schema, operation.variable_definitions or [], variable_values) + variable_values = get_variable_values( + schema, operation.variable_definitions or [], variable_values) self.schema = schema self.fragments = fragments @@ -65,6 +69,7 @@ def __init__(self, schema, document_ast, root_value, context_value, variable_val self.argument_values_cache = {} self.executor = executor self.middleware = middleware + self.allow_subscriptions = allow_subscriptions self._subfields_cache = {} def get_field_resolver(self, field_resolver): @@ -82,7 +87,8 @@ def get_argument_values(self, field_def, field_ast): return result def report_error(self, error, traceback=None): - sys.excepthook(type(error), error, getattr(error, 'stack', None) or traceback) + sys.excepthook(type(error), error, getattr( + error, 'stack', None) or traceback) self.errors.append(error) def get_sub_fields(self, return_type, field_asts): @@ -101,6 +107,20 @@ def get_sub_fields(self, return_type, field_asts): return self._subfields_cache[k] +class SubscriberExecutionContext(object): + __slots__ = 'exe_context', 'errors' + + def __init__(self, exe_context): + self.exe_context = exe_context + self.errors = [] + + def reset(self): + self.errors = [] + + def __getattr__(self, name): + return getattr(self.exe_context, name) + + class ExecutionResult(object): """The result of execution. `data` is the result of executing the query, `errors` is null if no errors occurred, and is a @@ -186,7 +206,8 @@ def collect_fields(ctx, runtime_type, selection_set, fields, prev_fragment_names ctx, selection, runtime_type): continue - collect_fields(ctx, runtime_type, selection.selection_set, fields, prev_fragment_names) + collect_fields(ctx, runtime_type, + selection.selection_set, fields, prev_fragment_names) elif isinstance(selection, ast.FragmentSpread): frag_name = selection.name.value @@ -202,7 +223,8 @@ def collect_fields(ctx, runtime_type, selection_set, fields, prev_fragment_names does_fragment_condition_match(ctx, fragment, runtime_type): continue - collect_fields(ctx, runtime_type, fragment.selection_set, fields, prev_fragment_names) + collect_fields(ctx, runtime_type, + fragment.selection_set, fields, prev_fragment_names) return fields diff --git a/graphql/execution/executor.py b/graphql/execution/executor.py index e8b00c1a..7ae0577a 100644 --- a/graphql/execution/executor.py +++ b/graphql/execution/executor.py @@ -2,6 +2,7 @@ import functools import logging import sys +from rx import Observable from six import string_types from promise import Promise, promise_for_dict, is_thenable @@ -15,16 +16,21 @@ GraphQLSchema, GraphQLUnionType) from .base import (ExecutionContext, ExecutionResult, ResolveInfo, collect_fields, default_resolve_fn, get_field_def, - get_operation_root_type) + get_operation_root_type, SubscriberExecutionContext) from .executors.sync import SyncExecutor from .middleware import MiddlewareManager logger = logging.getLogger(__name__) +def subscribe(*args, **kwargs): + allow_subscriptions = kwargs.pop('allow_subscriptions', True) + return execute(*args, allow_subscriptions=allow_subscriptions, **kwargs) + + def execute(schema, document_ast, root_value=None, context_value=None, variable_values=None, operation_name=None, executor=None, - return_promise=False, middleware=None): + return_promise=False, middleware=None, allow_subscriptions=False): assert schema, 'Must provide schema' assert isinstance(schema, GraphQLSchema), ( 'Schema must be an instance of GraphQLSchema. Also ensure that there are ' + @@ -50,7 +56,8 @@ def execute(schema, document_ast, root_value=None, context_value=None, variable_values, operation_name, executor, - middleware + middleware, + allow_subscriptions ) def executor(v): @@ -61,6 +68,9 @@ def on_rejected(error): return None def on_resolve(data): + if isinstance(data, Observable): + return data + if not context.errors: return ExecutionResult(data=data) return ExecutionResult(data=data, errors=context.errors) @@ -88,6 +98,15 @@ def execute_operation(exe_context, operation, root_value): if operation.operation == 'mutation': return execute_fields_serially(exe_context, type, root_value, fields) + if operation.operation == 'subscription': + if not exe_context.allow_subscriptions: + raise Exception( + "Subscriptions are not allowed. " + "You will need to either use the subscribe function " + "or pass allow_subscriptions=True" + ) + return subscribe_fields(exe_context, type, root_value, fields) + return execute_fields(exe_context, type, root_value, fields) @@ -140,6 +159,44 @@ def execute_fields(exe_context, parent_type, source_value, fields): return promise_for_dict(final_results) +def subscribe_fields(exe_context, parent_type, source_value, fields): + exe_context = SubscriberExecutionContext(exe_context) + + def on_error(error): + exe_context.report_error(error) + + def map_result(data): + if exe_context.errors: + result = ExecutionResult(data=data, errors=exe_context.errors) + else: + result = ExecutionResult(data=data) + exe_context.reset() + return result + + observables = [] + + # assert len(fields) == 1, "Can only subscribe one element at a time." + + for response_name, field_asts in fields.items(): + + result = subscribe_field(exe_context, parent_type, + source_value, field_asts) + if result is Undefined: + continue + + def catch_error(error): + exe_context.errors.append(error) + return Observable.just(None) + + # Map observable results + observable = result.catch_exception(catch_error).map( + lambda data: map_result({response_name: data})) + return observable + observables.append(observable) + + return Observable.merge(observables) + + def resolve_field(exe_context, parent_type, source, field_asts): field_ast = field_asts[0] field_name = field_ast.name.value @@ -191,6 +248,64 @@ def resolve_field(exe_context, parent_type, source, field_asts): ) +def subscribe_field(exe_context, parent_type, source, field_asts): + field_ast = field_asts[0] + field_name = field_ast.name.value + + field_def = get_field_def(exe_context.schema, parent_type, field_name) + if not field_def: + return Undefined + + return_type = field_def.type + resolve_fn = field_def.resolver or default_resolve_fn + + # We wrap the resolve_fn from the middleware + resolve_fn_middleware = exe_context.get_field_resolver(resolve_fn) + + # Build a dict of arguments from the field.arguments AST, using the variables scope to + # fulfill any variable references. + args = exe_context.get_argument_values(field_def, field_ast) + + # The resolve function's optional third argument is a context value that + # is provided to every resolve function within an execution. It is commonly + # used to represent an authenticated user, or request-specific caches. + context = exe_context.context_value + + # The resolve function's optional third argument is a collection of + # information about the current execution state. + info = ResolveInfo( + field_name, + field_asts, + return_type, + parent_type, + schema=exe_context.schema, + fragments=exe_context.fragments, + root_value=exe_context.root_value, + operation=exe_context.operation, + variable_values=exe_context.variable_values, + context=context + ) + + executor = exe_context.executor + result = resolve_or_error(resolve_fn_middleware, + source, info, args, executor) + + if isinstance(result, Exception): + raise result + + if not isinstance(result, Observable): + raise GraphQLError( + 'Subscription must return Async Iterable or Observable. Received: {}'.format(repr(result))) + + return result.map(functools.partial( + complete_value_catching_error, + exe_context, + return_type, + field_asts, + info, + )) + + def resolve_or_error(resolve_fn, source, info, args, executor): try: return executor.execute(resolve_fn, source, info, **args) diff --git a/graphql/execution/executors/asyncio.py b/graphql/execution/executors/asyncio.py index 0aec27c2..f20ddbf7 100644 --- a/graphql/execution/executors/asyncio.py +++ b/graphql/execution/executors/asyncio.py @@ -25,7 +25,15 @@ def ensure_future(coro_or_future, loop=None): del task._source_traceback[-1] return task else: - raise TypeError('A Future, a coroutine or an awaitable is required') + raise TypeError( + 'A Future, a coroutine or an awaitable is required') + +try: + from .asyncio_utils import asyncgen_to_observable, isasyncgen +except Exception: + def isasyncgen(obj): False + + def asyncgen_to_observable(asyncgen): pass class AsyncioExecutor(object): @@ -50,4 +58,6 @@ def execute(self, fn, *args, **kwargs): future = ensure_future(result, loop=self.loop) self.futures.append(future) return Promise.resolve(future) + elif isasyncgen(result): + return asyncgen_to_observable(result) return result diff --git a/graphql/execution/executors/asyncio_utils.py b/graphql/execution/executors/asyncio_utils.py new file mode 100644 index 00000000..836d90d7 --- /dev/null +++ b/graphql/execution/executors/asyncio_utils.py @@ -0,0 +1,135 @@ +from inspect import isasyncgen +from asyncio import ensure_future +from rx import Observable, AnonymousObserver +from rx.core import ObservableBase, Disposable, ObserverBase + +from rx.concurrency import current_thread_scheduler + +from rx.core import Observer, Observable, Disposable +from rx.core.anonymousobserver import AnonymousObserver +from rx.core.autodetachobserver import AutoDetachObserver + + +# class AsyncgenDisposable(Disposable): +# """Represents a Disposable that disposes the asyncgen automatically.""" + +# def __init__(self, asyncgen): +# """Initializes a new instance of the AsyncgenDisposable class.""" + +# self.asyncgen = asyncgen +# self.is_disposed = False + +# super(AsyncgenDisposable, self).__init__() + +# def dispose(self): +# """Sets the status to disposed""" +# self.asyncgen.aclose() +# self.is_disposed = True + + +class AsyncgenObserver(AutoDetachObserver): + def __init__(self, asyncgen, *args, **kwargs): + self._asyncgen = asyncgen + self.is_disposed = False + super(AsyncgenObserver, self).__init__(*args, **kwargs) + + async def dispose_asyncgen(self): + if self.is_disposed: + return + + try: + # await self._asyncgen.aclose() + await self._asyncgen.athrow(StopAsyncIteration) + self.is_disposed = True + except: + pass + + def dispose(self): + if self.is_disposed: + return + disposed = super(AsyncgenObserver, self).dispose() + # print("DISPOSE observer!", disposed) + ensure_future(self.dispose_asyncgen()) + + +class AsyncgenObservable(ObservableBase): + """Class to create an Observable instance from a delegate-based + implementation of the Subscribe method.""" + + def __init__(self, subscribe, asyncgen): + """Creates an observable sequence object from the specified + subscription function. + + Keyword arguments: + :param types.FunctionType subscribe: Subscribe method implementation. + """ + + self._subscribe = subscribe + self._asyncgen = asyncgen + super(AsyncgenObservable, self).__init__() + + def _subscribe_core(self, observer): + # print("GET SUBSCRIBER", observer) + return self._subscribe(observer) + # print("SUBSCRIBER RESULT", subscriber) + # return subscriber + + def subscribe(self, on_next=None, on_error=None, on_completed=None, observer=None): + + if isinstance(on_next, Observer): + observer = on_next + elif hasattr(on_next, "on_next") and callable(on_next.on_next): + observer = on_next + elif not observer: + observer = AnonymousObserver(on_next, on_error, on_completed) + + auto_detach_observer = AsyncgenObserver(self._asyncgen, observer) + + def fix_subscriber(subscriber): + """Fixes subscriber to make sure it returns a Disposable instead + of None or a dispose function""" + + if not hasattr(subscriber, "dispose"): + subscriber = Disposable.create(subscriber) + + return subscriber + + def set_disposable(scheduler=None, value=None): + try: + subscriber = self._subscribe_core(auto_detach_observer) + except Exception as ex: + if not auto_detach_observer.fail(ex): + raise + else: + auto_detach_observer.disposable = fix_subscriber(subscriber) + + # Subscribe needs to set up the trampoline before for subscribing. + # Actually, the first call to Subscribe creates the trampoline so + # that it may assign its disposable before any observer executes + # OnNext over the CurrentThreadScheduler. This enables single- + # threaded cancellation + # https://social.msdn.microsoft.com/Forums/en-US/eb82f593-9684-4e27- + # 97b9-8b8886da5c33/whats-the-rationale-behind-how-currentthreadsche + # dulerschedulerequired-behaves?forum=rx + if current_thread_scheduler.schedule_required(): + current_thread_scheduler.schedule(set_disposable) + else: + set_disposable() + + # Hide the identity of the auto detach observer + return Disposable.create(auto_detach_observer.dispose) + + +def asyncgen_to_observable(asyncgen): + def emit(observer): + ensure_future(iterate_asyncgen(asyncgen, observer)) + return AsyncgenObservable(emit, asyncgen) + + +async def iterate_asyncgen(asyncgen, observer): + try: + async for item in asyncgen: + observer.on_next(item) + observer.on_completed() + except Exception as e: + observer.on_error(e) diff --git a/graphql/execution/tests/test_executor.py b/graphql/execution/tests/test_executor.py index 1ac90db5..e36c596f 100644 --- a/graphql/execution/tests/test_executor.py +++ b/graphql/execution/tests/test_executor.py @@ -112,7 +112,8 @@ def deeper(self): schema = GraphQLSchema(query=DataType) - result = execute(schema, ast, Data(), operation_name='Example', variable_values={'size': 100}) + result = execute(schema, ast, Data(), + operation_name='Example', variable_values={'size': 100}) assert not result.errors assert result.data == expected @@ -178,7 +179,8 @@ def resolver(root_value, *_): 'a': GraphQLField(GraphQLString, resolver=resolver) }) - result = execute(GraphQLSchema(Type), ast, Data(), operation_name='Example') + result = execute(GraphQLSchema(Type), ast, + Data(), operation_name='Example') assert not result.errors assert resolver.got_here @@ -209,7 +211,8 @@ def resolver(source, info, numArg, stringArg): resolver=resolver), }) - result = execute(GraphQLSchema(Type), doc_ast, None, operation_name='Example') + result = execute(GraphQLSchema(Type), doc_ast, + None, operation_name='Example') assert not result.errors assert resolver.got_here @@ -282,7 +285,8 @@ class Data(object): Type = GraphQLObjectType('Type', { 'a': GraphQLField(GraphQLString) }) - result = execute(GraphQLSchema(Type), ast, Data(), operation_name='OtherExample') + result = execute(GraphQLSchema(Type), ast, Data(), + operation_name='OtherExample') assert not result.errors assert result.data == {'second': 'b'} @@ -313,7 +317,8 @@ class Data(object): 'a': GraphQLField(GraphQLString) }) with raises(GraphQLError) as excinfo: - execute(GraphQLSchema(Type), ast, Data(), operation_name="UnknownExample") + execute(GraphQLSchema(Type), ast, Data(), + operation_name="UnknownExample") assert 'Unknown operation named "UnknownExample".' == str(excinfo.value) @@ -329,7 +334,8 @@ class Data(object): }) with raises(GraphQLError) as excinfo: execute(GraphQLSchema(Type), ast, Data()) - assert 'Must provide operation name if query contains multiple operations.' == str(excinfo.value) + assert 'Must provide operation name if query contains multiple operations.' == str( + excinfo.value) def test_uses_the_query_schema_for_queries(): @@ -374,6 +380,7 @@ class Data(object): def test_uses_the_subscription_schema_for_subscriptions(): + from rx import Observable doc = 'query Q { a } subscription S { a }' class Data(object): @@ -385,9 +392,14 @@ class Data(object): 'a': GraphQLField(GraphQLString) }) S = GraphQLObjectType('S', { - 'a': GraphQLField(GraphQLString) + 'a': GraphQLField(GraphQLString, resolver=lambda root, info: Observable.from_(['b'])) }) - result = execute(GraphQLSchema(Q, subscription=S), ast, Data(), operation_name='S') + result = execute(GraphQLSchema(Q, subscription=S), + ast, Data(), operation_name='S', allow_subscriptions=True) + assert isinstance(result, Observable) + l = [] + result.subscribe(l.append) + result = l[0] assert not result.errors assert result.data == {'a': 'b'} @@ -437,7 +449,8 @@ def test_does_not_include_arguments_that_were_not_set(): { 'field': GraphQLField( GraphQLString, - resolver=lambda source, info, **args: args and json.dumps(args, sort_keys=True, separators=(',', ':')), + resolver=lambda source, info, **args: args and json.dumps( + args, sort_keys=True, separators=(',', ':')), args={ 'a': GraphQLArgument(GraphQLBoolean), 'b': GraphQLArgument(GraphQLBoolean), @@ -501,7 +514,8 @@ def __init__(self, value): ] } - assert 'Expected value of type "SpecialType" but got: NotSpecial.' in [str(e) for e in result.errors] + assert 'Expected value of type "SpecialType" but got: NotSpecial.' in [ + str(e) for e in result.errors] def test_fails_to_execute_a_query_containing_a_type_definition(): @@ -547,7 +561,8 @@ def resolver(*_): ) execute(schema, query) - logger.exception.assert_called_with("An error occurred while resolving field Query.foo") + logger.exception.assert_called_with( + "An error occurred while resolving field Query.foo") def test_middleware(): @@ -576,7 +591,8 @@ def reversed_middleware(next, *args, **kwargs): return p.then(lambda x: x[::-1]) middlewares = MiddlewareManager(reversed_middleware) - result = execute(GraphQLSchema(Type), doc_ast, Data(), middleware=middlewares) + result = execute(GraphQLSchema(Type), doc_ast, + Data(), middleware=middlewares) assert result.data == {'ok': 'ko', 'not_ok': 'ko_ton'} @@ -607,7 +623,8 @@ def resolve(self, next, *args, **kwargs): return p.then(lambda x: x[::-1]) middlewares = MiddlewareManager(MyMiddleware()) - result = execute(GraphQLSchema(Type), doc_ast, Data(), middleware=middlewares) + result = execute(GraphQLSchema(Type), doc_ast, + Data(), middleware=middlewares) assert result.data == {'ok': 'ko', 'not_ok': 'ko_ton'} @@ -640,9 +657,14 @@ class MyEmptyMiddleware(object): def resolve(self, next, *args, **kwargs): return next(*args, **kwargs) - middlewares_with_promise = MiddlewareManager(MyPromiseMiddleware(), wrap_in_promise=False) - middlewares_without_promise = MiddlewareManager(MyEmptyMiddleware(), wrap_in_promise=False) - - result1 = execute(GraphQLSchema(Type), doc_ast, Data(), middleware=middlewares_with_promise) - result2 = execute(GraphQLSchema(Type), doc_ast, Data(), middleware=middlewares_without_promise) - assert result1.data == result2.data and result1.data == {'ok': 'ok', 'not_ok': 'not_ok'} + middlewares_with_promise = MiddlewareManager( + MyPromiseMiddleware(), wrap_in_promise=False) + middlewares_without_promise = MiddlewareManager( + MyEmptyMiddleware(), wrap_in_promise=False) + + result1 = execute(GraphQLSchema(Type), doc_ast, Data(), + middleware=middlewares_with_promise) + result2 = execute(GraphQLSchema(Type), doc_ast, Data(), + middleware=middlewares_without_promise) + assert result1.data == result2.data and result1.data == { + 'ok': 'ok', 'not_ok': 'not_ok'} diff --git a/graphql/execution/tests/test_subscribe.py b/graphql/execution/tests/test_subscribe.py new file mode 100644 index 00000000..2acb56a9 --- /dev/null +++ b/graphql/execution/tests/test_subscribe.py @@ -0,0 +1,395 @@ +from collections import OrderedDict, namedtuple +from rx import Observable, Observer +from rx.subjects import Subject +from graphql import parse, GraphQLObjectType, GraphQLString, GraphQLBoolean, GraphQLInt, GraphQLField, GraphQLList, GraphQLSchema, graphql, subscribe + +Email = namedtuple('Email', 'from_,subject,message,unread') + +EmailType = GraphQLObjectType( + name='Email', + fields=OrderedDict([ + ('from', GraphQLField(GraphQLString, resolver=lambda x, info: x.from_)), + ('subject', GraphQLField(GraphQLString)), + ('message', GraphQLField(GraphQLString)), + ('unread', GraphQLField(GraphQLBoolean)), + ]) +) + +InboxType = GraphQLObjectType( + name='Inbox', + fields=OrderedDict([ + ('total', GraphQLField(GraphQLInt, + resolver=lambda inbox, context: len(inbox.emails))), + ('unread', GraphQLField(GraphQLInt, + resolver=lambda inbox, context: len([e for e in inbox.emails if e.unread]))), + ('emails', GraphQLField(GraphQLList(EmailType))), + ]) +) + +QueryType = GraphQLObjectType( + name='Query', + fields=OrderedDict([ + ('inbox', GraphQLField(InboxType)), + ]) +) + +EmailEventType = GraphQLObjectType( + name='EmailEvent', + fields=OrderedDict([ + ('email', GraphQLField(EmailType, + resolver=lambda root, info: root[0])), + ('inbox', GraphQLField(InboxType, + resolver=lambda root, info: root[1])), + ]) +) + + +def get_unbound_function(func): + if not getattr(func, '__self__', True): + return func.__func__ + return func + + +def email_schema_with_resolvers(resolve_fn=None): + def default_resolver(root, info): + func = getattr(root, 'importantEmail', None) + if func: + func = get_unbound_function(func) + return func() + return Observable.empty() + + return GraphQLSchema( + query=QueryType, + subscription=GraphQLObjectType( + name='Subscription', + fields=OrderedDict([ + ('importantEmail', GraphQLField( + EmailEventType, + resolver=resolve_fn or default_resolver, + )) + ]) + ) + ) + + +email_schema = email_schema_with_resolvers() + + +class MyObserver(Observer): + def on_next(self, value): + self.has_on_next = value + + def on_error(self, err): + self.has_on_error = err + + def on_completed(self): + self.has_on_completed = True + + +def create_subscription(stream, schema=email_schema, ast=None, vars=None): + class Root(object): + class inbox(object): + emails = [ + Email( + from_='joe@graphql.org', + subject='Hello', + message='Hello World', + unread=False, + ) + ] + + def importantEmail(): + return stream + + def send_important_email(new_email): + Root.inbox.emails.append(new_email) + stream.on_next((new_email, Root.inbox)) + # stream.on_completed() + + default_ast = parse(''' + subscription { + importantEmail { + email { + from + subject + } + inbox { + unread + total + } + } + } + ''') + + return send_important_email, graphql( + schema, + ast or default_ast, + Root, + None, + vars, + allow_subscriptions=True, + ) + + +def test_accepts_an_object_with_named_properties_as_arguments(): + document = parse(''' + subscription { + importantEmail + } + ''') + result = subscribe( + email_schema, + document, + root_value=None + ) + assert isinstance(result, Observable) + + +def test_accepts_multiple_subscription_fields_defined_in_schema(): + SubscriptionTypeMultiple = GraphQLObjectType( + name='Subscription', + fields=OrderedDict([ + ('importantEmail', GraphQLField(EmailEventType)), + ('nonImportantEmail', GraphQLField(EmailEventType)), + ]) + ) + test_schema = GraphQLSchema( + query=QueryType, + subscription=SubscriptionTypeMultiple + ) + + stream = Subject() + send_important_email, subscription = create_subscription( + stream, test_schema) + + email = Email( + from_='yuzhi@graphql.org', + subject='Alright', + message='Tests are good', + unread=True, + ) + l = [] + stream.subscribe(l.append) + send_important_email(email) + assert l[0][0] == email + + +def test_accepts_type_definition_with_sync_subscribe_function(): + SubscriptionType = GraphQLObjectType( + name='Subscription', + fields=OrderedDict([ + ('importantEmail', GraphQLField( + EmailEventType, resolver=lambda *_: Observable.from_([None]))), + ]) + ) + test_schema = GraphQLSchema( + query=QueryType, + subscription=SubscriptionType + ) + + stream = Subject() + send_important_email, subscription = create_subscription( + stream, test_schema) + + email = Email( + from_='yuzhi@graphql.org', + subject='Alright', + message='Tests are good', + unread=True, + ) + l = [] + subscription.subscribe(l.append) + send_important_email(email) + + assert l # [0].data == {'importantEmail': None} + + +def test_throws_an_error_if_subscribe_does_not_return_an_iterator(): + SubscriptionType = GraphQLObjectType( + name='Subscription', + fields=OrderedDict([ + ('importantEmail', GraphQLField( + EmailEventType, resolver=lambda *_: None)), + ]) + ) + test_schema = GraphQLSchema( + query=QueryType, + subscription=SubscriptionType + ) + + stream = Subject() + _, subscription = create_subscription( + stream, test_schema) + + assert str( + subscription.errors[0]) == 'Subscription must return Async Iterable or Observable. Received: None' + + +def test_returns_an_error_if_subscribe_function_returns_error(): + exc = Exception("Throw!") + + def thrower(root, info): + raise exc + + erroring_email_schema = email_schema_with_resolvers(thrower) + result = subscribe(erroring_email_schema, parse(''' + subscription { + importantEmail + } + ''')) + + assert result.errors == [exc] + + +# Subscription Publish Phase +def test_produces_a_payload_for_multiple_subscribe_in_same_subscription(): + stream = Subject() + send_important_email, subscription1 = create_subscription(stream) + subscription2 = create_subscription(stream)[1] + + payload1 = [] + payload2 = [] + + subscription1.subscribe(payload1.append) + subscription2.subscribe(payload2.append) + + email = Email( + from_='yuzhi@graphql.org', + subject='Alright', + message='Tests are good', + unread=True, + ) + + send_important_email(email) + expected_payload = { + 'importantEmail': { + 'email': { + 'from': 'yuzhi@graphql.org', + 'subject': 'Alright', + }, + 'inbox': { + 'unread': 1, + 'total': 2, + }, + } + } + + assert payload1[0].data == expected_payload + assert payload2[0].data == expected_payload + + +# Subscription Publish Phase +def test_produces_a_payload_per_subscription_event(): + stream = Subject() + send_important_email, subscription = create_subscription(stream) + + payload = [] + + subscription.subscribe(payload.append) + send_important_email(Email( + from_='yuzhi@graphql.org', + subject='Alright', + message='Tests are good', + unread=True, + )) + expected_payload = { + 'importantEmail': { + 'email': { + 'from': 'yuzhi@graphql.org', + 'subject': 'Alright', + }, + 'inbox': { + 'unread': 1, + 'total': 2, + }, + } + } + + assert len(payload) == 1 + assert payload[0].data == expected_payload + + send_important_email(Email( + from_='hyo@graphql.org', + subject='Tools', + message='I <3 making things', + unread=True, + )) + expected_payload = { + 'importantEmail': { + 'email': { + 'from': 'hyo@graphql.org', + 'subject': 'Tools', + }, + 'inbox': { + 'unread': 2, + 'total': 3, + }, + } + } + + assert len(payload) == 2 + assert payload[-1].data == expected_payload + + # The client decides to disconnect + stream.on_completed() + + send_important_email(Email( + from_='adam@graphql.org', + subject='Important', + message='Read me please', + unread=True, + )) + + assert len(payload) == 2 + + +def test_event_order_is_correct_for_multiple_publishes(): + stream = Subject() + send_important_email, subscription = create_subscription(stream) + + payload = [] + + subscription.subscribe(payload.append) + send_important_email(Email( + from_='yuzhi@graphql.org', + subject='Message', + message='Tests are good', + unread=True, + )) + send_important_email(Email( + from_='yuzhi@graphql.org', + subject='Message 2', + message='Tests are good 2', + unread=True, + )) + + expected_payload1 = { + 'importantEmail': { + 'email': { + 'from': 'yuzhi@graphql.org', + 'subject': 'Message', + }, + 'inbox': { + 'unread': 1, + 'total': 2, + }, + } + } + + expected_payload2 = { + 'importantEmail': { + 'email': { + 'from': 'yuzhi@graphql.org', + 'subject': 'Message 2', + }, + 'inbox': { + 'unread': 2, + 'total': 3, + }, + } + } + + assert len(payload) == 2 + print(payload) + assert payload[0].data == expected_payload1 + assert payload[1].data == expected_payload2 diff --git a/graphql/graphql.py b/graphql/graphql.py index dc053f01..401a75a2 100644 --- a/graphql/graphql.py +++ b/graphql/graphql.py @@ -40,7 +40,7 @@ def graphql(*args, **kwargs): def execute_graphql(schema, request_string='', root_value=None, context_value=None, variable_values=None, operation_name=None, executor=None, - return_promise=False, middleware=None): + return_promise=False, middleware=None, allow_subscriptions=False): try: if isinstance(request_string, Document): ast = request_string @@ -62,7 +62,8 @@ def execute_graphql(schema, request_string='', root_value=None, context_value=No variable_values=variable_values or {}, executor=executor, middleware=middleware, - return_promise=return_promise + return_promise=return_promise, + allow_subscriptions=allow_subscriptions, ) except Exception as e: return ExecutionResult( diff --git a/graphql/pyutils/version.py b/graphql/pyutils/version.py index f7339119..a8dd7354 100644 --- a/graphql/pyutils/version.py +++ b/graphql/pyutils/version.py @@ -73,6 +73,6 @@ def get_git_changeset(): ) timestamp = git_log.communicate()[0] timestamp = datetime.datetime.utcfromtimestamp(int(timestamp)) - except: + except Exception: return None return timestamp.strftime('%Y%m%d%H%M%S') diff --git a/graphql/type/tests/test_enum_type.py b/graphql/type/tests/test_enum_type.py index 2381a8c9..c68ee131 100644 --- a/graphql/type/tests/test_enum_type.py +++ b/graphql/type/tests/test_enum_type.py @@ -1,5 +1,6 @@ from collections import OrderedDict +from rx import Observable from pytest import raises from graphql import graphql @@ -35,7 +36,8 @@ def get_first(args, *keys): 'fromInt': GraphQLArgument(GraphQLInt), 'fromString': GraphQLArgument(GraphQLString) }, - resolver=lambda value, info, **args: get_first(args, 'fromInt', 'fromString', 'fromEnum') + resolver=lambda value, info, **args: get_first( + args, 'fromInt', 'fromString', 'fromEnum') ), 'colorInt': GraphQLField( type=GraphQLInt, @@ -43,7 +45,8 @@ def get_first(args, *keys): 'fromEnum': GraphQLArgument(ColorType), 'fromInt': GraphQLArgument(GraphQLInt), }, - resolver=lambda value, info, **args: get_first(args, 'fromInt', 'fromEnum') + resolver=lambda value, info, **args: get_first( + args, 'fromInt', 'fromEnum') ) } ) @@ -69,12 +72,14 @@ def get_first(args, *keys): args={ 'color': GraphQLArgument(ColorType) }, - resolver=lambda value, info, **args: args.get('color') + resolver=lambda value, info, **args: Observable.from_( + [args.get('color')]) ) } ) -Schema = GraphQLSchema(query=QueryType, mutation=MutationType, subscription=SubscriptionType) +Schema = GraphQLSchema( + query=QueryType, mutation=MutationType, subscription=SubscriptionType) def test_accepts_enum_literals_as_input(): @@ -130,13 +135,15 @@ def test_does_not_accept_enum_literal_in_place_of_int(): def test_accepts_json_string_as_enum_variable(): - result = graphql(Schema, 'query test($color: Color!) { colorEnum(fromEnum: $color) }', variable_values={'color': 'BLUE'}) + result = graphql(Schema, 'query test($color: Color!) { colorEnum(fromEnum: $color) }', variable_values={ + 'color': 'BLUE'}) assert not result.errors assert result.data == {'colorEnum': 'BLUE'} def test_accepts_enum_literals_as_input_arguments_to_mutations(): - result = graphql(Schema, 'mutation x($color: Color!) { favoriteEnum(color: $color) }', variable_values={'color': 'GREEN'}) + result = graphql(Schema, 'mutation x($color: Color!) { favoriteEnum(color: $color) }', variable_values={ + 'color': 'GREEN'}) assert not result.errors assert result.data == {'favoriteEnum': 'GREEN'} @@ -144,32 +151,40 @@ def test_accepts_enum_literals_as_input_arguments_to_mutations(): def test_accepts_enum_literals_as_input_arguments_to_subscriptions(): result = graphql( Schema, 'subscription x($color: Color!) { subscribeToEnum(color: $color) }', variable_values={ - 'color': 'GREEN'}) + 'color': 'GREEN'}, allow_subscriptions=True) + assert isinstance(result, Observable) + l = [] + result.subscribe(l.append) + result = l[0] assert not result.errors assert result.data == {'subscribeToEnum': 'GREEN'} def test_does_not_accept_internal_value_as_enum_variable(): - result = graphql(Schema, 'query test($color: Color!) { colorEnum(fromEnum: $color) }', variable_values={'color': 2}) + result = graphql( + Schema, 'query test($color: Color!) { colorEnum(fromEnum: $color) }', variable_values={'color': 2}) assert not result.data assert result.errors[0].message == 'Variable "$color" got invalid value 2.\n' \ 'Expected type "Color", found 2.' def test_does_not_accept_string_variables_as_enum_input(): - result = graphql(Schema, 'query test($color: String!) { colorEnum(fromEnum: $color) }', variable_values={'color': 'BLUE'}) + result = graphql(Schema, 'query test($color: String!) { colorEnum(fromEnum: $color) }', variable_values={ + 'color': 'BLUE'}) assert not result.data assert result.errors[0].message == 'Variable "color" of type "String!" used in position expecting type "Color".' def test_does_not_accept_internal_value_as_enum_input(): - result = graphql(Schema, 'query test($color: Int!) { colorEnum(fromEnum: $color) }', variable_values={'color': 2}) + result = graphql( + Schema, 'query test($color: Int!) { colorEnum(fromEnum: $color) }', variable_values={'color': 2}) assert not result.data assert result.errors[0].message == 'Variable "color" of type "Int!" used in position expecting type "Color".' def test_enum_value_may_have_an_internal_value_of_0(): - result = graphql(Schema, '{ colorEnum(fromEnum: RED) colorInt(fromEnum: RED) }') + result = graphql( + Schema, '{ colorEnum(fromEnum: RED) colorInt(fromEnum: RED) }') assert not result.errors assert result.data == {'colorEnum': 'RED', 'colorInt': 0} diff --git a/setup.cfg b/setup.cfg index 94f77b15..724139ce 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [flake8] -exclude = tests,scripts,setup.py,docs +exclude = tests,scripts,setup.py,docs,graphql/execution/executors/asyncio_utils.py max-line-length = 160 [bdist_wheel] diff --git a/setup.py b/setup.py index 72e6178c..8d69808b 100644 --- a/setup.py +++ b/setup.py @@ -23,7 +23,8 @@ install_requires = [ 'six>=1.10.0', - 'promise>=2.1.dev' + 'promise>=2.1', + 'rx>=1.6.0', ] tests_requires = [ @@ -37,6 +38,7 @@ 'pytest-mock==1.2', ] + class PyTest(TestCommand): def finalize_options(self): TestCommand.finalize_options(self) @@ -44,7 +46,7 @@ def finalize_options(self): self.test_suite = True def run_tests(self): - #import here, cause outside the eggs aren't loaded + # import here, cause outside the eggs aren't loaded import pytest errno = pytest.main(self.test_args) sys.exit(errno) @@ -75,10 +77,11 @@ def run_tests(self): 'Topic :: Internet :: WWW/HTTP', ], keywords='api graphql protocol rest', - packages=find_packages(exclude=['tests', 'tests_py35', 'tests.*', 'tests_py35.*']), + packages=find_packages( + exclude=['tests', 'tests_py35', 'tests.*', 'tests_py35.*']), install_requires=install_requires, tests_require=tests_requires, - cmdclass = {'test': PyTest}, + cmdclass={'test': PyTest}, extras_require={ 'gevent': [ 'gevent==1.1rc1'