diff --git a/graphql_ws/base.py b/graphql_ws/base.py index 9556192..4010847 100644 --- a/graphql_ws/base.py +++ b/graphql_ws/base.py @@ -59,22 +59,24 @@ def __init__(self, schema, keep_alive=True): self.keep_alive = keep_alive def get_graphql_params(self, connection_context, payload): - return { + params = { 'request_string': payload.get('query'), 'variable_values': payload.get('variables'), 'operation_name': payload.get('operationName'), 'context_value': payload.get('context'), } - - def build_message(self, id, op_type, payload): - message = {} - if id is not None: - message['id'] = id - if op_type is not None: - message['type'] = op_type - if payload is not None: - message['payload'] = payload - return message + if connection_context.request_context: + context = ( + params['context_value'].copy() + if params['context_value'] + else {} + ) + if isinstance(context, dict): + context.setdefault( + 'request', connection_context.request_context + ) + params['context_value'] = context + return params def process_message(self, connection_context, parsed_message): op_id = parsed_message.get('id')