From 33551eec0df9ea07d8145222788b2409869cf3cc Mon Sep 17 00:00:00 2001 From: Michael Chin Date: Mon, 2 Oct 2023 22:46:11 -0700 Subject: [PATCH 1/2] Add Gremlin proxy host fix and Neptune HTTP query support --- src/graph_notebook/magics/graph_magic.py | 35 ++++++++-- src/graph_notebook/neptune/client.py | 42 +++++++---- .../network/gremlin/GremlinNetwork.py | 70 ++++++++++++------- 3 files changed, 104 insertions(+), 43 deletions(-) diff --git a/src/graph_notebook/magics/graph_magic.py b/src/graph_notebook/magics/graph_magic.py index 1eb3d90f..e076992f 100644 --- a/src/graph_notebook/magics/graph_magic.py +++ b/src/graph_notebook/magics/graph_magic.py @@ -29,6 +29,7 @@ from SPARQLWrapper import SPARQLWrapper from botocore.session import get_session from gremlin_python.driver.protocol import GremlinServerError +from gremlin_python.structure.graph import Path from IPython.core.display import HTML, display_html, display from IPython.core.magic import (Magics, magics_class, cell_magic, line_magic, line_cell_magic, needs_local_scope) from ipywidgets.widgets.widget_description import DescriptionStyle @@ -808,15 +809,15 @@ def gremlin(self, line, cell, local_ns: dict = None): parser.add_argument('--explain-type', type=str.lower, default='', help='Explain mode to use when using the explain query mode.') parser.add_argument('-p', '--path-pattern', default='', help='path pattern') - parser.add_argument('-g', '--group-by', type=str, default='T.label', + parser.add_argument('-g', '--group-by', type=str, default='', help='Property used to group nodes (e.g. code, T.region) default is T.label') parser.add_argument('-gd', '--group-by-depth', action='store_true', default=False, help="Group nodes based on path hierarchy") parser.add_argument('-gr', '--group-by-raw', action='store_true', default=False, help="Group nodes by the raw result") - parser.add_argument('-d', '--display-property', type=str, default='T.label', + parser.add_argument('-d', '--display-property', type=str, default='', help='Property to display the value of on each node, default is T.label') - parser.add_argument('-de', '--edge-display-property', type=str, default='T.label', + parser.add_argument('-de', '--edge-display-property', type=str, default='', help='Property to display the value of on each edge, default is T.label') parser.add_argument('-t', '--tooltip-property', type=str, default='', help='Property to display the value of on each node tooltip. If not specified, tooltip ' @@ -936,8 +937,16 @@ def gremlin(self, line, cell, local_ns: dict = None): else: first_tab_html = pre_container_template.render(content='No profile found') else: + using_http = False query_start = time.time() * 1000 # time.time() returns time in seconds w/high precision; x1000 to get in ms - query_res = self.client.gremlin_query(cell, transport_args=transport_args) + if self.graph_notebook_config.proxy_host != '' and self.client.is_neptune_domain(): + using_http = True + query_res_http = self.client.gremlin_http_query(cell, headers={'Accept': 'application/vnd.gremlin-v1.0+json;types=false'}) + query_res_http.raise_for_status() + query_res_http_json = query_res_http.json() + query_res = query_res_http_json['result']['data'] + else: + query_res = self.client.gremlin_query(cell, transport_args=transport_args) query_time = time.time() * 1000 - query_start if not args.silent: gremlin_metadata = build_gremlin_metadata_from_query(query_type='query', results=query_res, @@ -951,7 +960,8 @@ def gremlin(self, line, cell, local_ns: dict = None): logger.debug(f'edge_display_property: {args.edge_display_property}') logger.debug(f'label_max_length: {args.label_max_length}') logger.debug(f'ignore_groups: {args.ignore_groups}') - gn = GremlinNetwork(group_by_property=args.group_by, display_property=args.display_property, + gn = GremlinNetwork(group_by_property=args.group_by, + display_property=args.display_property, group_by_raw=args.group_by_raw, group_by_depth=args.group_by_depth, edge_display_property=args.edge_display_property, @@ -959,10 +969,21 @@ def gremlin(self, line, cell, local_ns: dict = None): edge_tooltip_property=args.edge_tooltip_property, label_max_length=args.label_max_length, edge_label_max_length=args.edge_label_max_length, - ignore_groups=args.ignore_groups) + ignore_groups=args.ignore_groups, + using_http=using_http) + + if using_http and 'path()' in cell and query_res: + first_path = query_res[0] + if isinstance(first_path, dict) and first_path.keys() == {'labels', 'objects'}: + query_res_to_path_type = [] + for path in query_res: + new_path_list = path['objects'] + new_path = Path(labels=[], objects=new_path_list) + query_res_to_path_type.append(new_path) + query_res = query_res_to_path_type if args.path_pattern == '': - gn.add_results(query_res) + gn.add_results(query_res, is_http=using_http) else: pattern = parse_pattern_list_str(args.path_pattern) gn.add_results_with_pattern(query_res, pattern) diff --git a/src/graph_notebook/neptune/client.py b/src/graph_notebook/neptune/client.py index a25fe341..f50586db 100644 --- a/src/graph_notebook/neptune/client.py +++ b/src/graph_notebook/neptune/client.py @@ -17,6 +17,7 @@ from botocore.awsrequest import AWSRequest from gremlin_python.driver import client, serializer from gremlin_python.driver.protocol import GremlinServerError +from gremlin_python.driver.aiohttp.transport import AiohttpTransport from neo4j import GraphDatabase, DEFAULT_DATABASE from neo4j.exceptions import AuthError from base64 import b64encode @@ -24,7 +25,6 @@ from graph_notebook.neptune.bolt_auth_token import NeptuneBoltAuthToken - # This patch is no longer needed when graph_notebook is using the a Gremlin Python # client >= 3.5.0 as the HashableDict is now part of that client driver. # import graph_notebook.neptune.gremlin.graphsonV3d0_MapType_objectify_patch # noqa F401 @@ -45,7 +45,7 @@ # TODO: add doc links to each command FORMAT_CSV = 'csv' -FORMAT_OPENCYPHER='opencypher' +FORMAT_OPENCYPHER = 'opencypher' FORMAT_NTRIPLE = 'ntriples' FORMAT_NQUADS = 'nquads' FORMAT_RDFXML = 'rdfxml' @@ -191,11 +191,19 @@ def is_neptune_domain(self): return is_allowed_neptune_host(hostname=self.target_host, host_allowlist=self.neptune_hosts) def get_uri_with_port(self, use_websocket=False, use_proxy=False): - protocol = self._http_protocol if use_websocket is True: protocol = self._ws_protocol + else: + protocol = self._http_protocol - uri = f'{protocol}://{self.host}:{self.port}' + if use_proxy is True: + uri_host = self.proxy_host + uri_port = self.proxy_port + else: + uri_host = self.target_host + uri_port = self.target_port + + uri = f'{protocol}://{uri_host}:{uri_port}' return uri def sparql_query(self, query: str, headers=None, explain: str = '', path: str = '') -> requests.Response: @@ -267,11 +275,20 @@ def sparql_cancel(self, query_id: str, silent: bool = False): def get_gremlin_connection(self, transport_kwargs) -> client.Client: nest_asyncio.apply() - ws_url = f'{self.get_uri_with_port(use_websocket=True)}/gremlin' - request = self._prepare_request('GET', ws_url) + ws_url = f'{self.get_uri_with_port(use_websocket=True, use_proxy=False)}/gremlin' + if self.proxy_host != '': + proxy_http_url = f'{self.get_uri_with_port(use_websocket=False, use_proxy=True)}/gremlin' + transport_factory_args = lambda: AiohttpTransport(call_from_event_loop=True, proxy=proxy_http_url, + **transport_kwargs) + request = self._prepare_request('GET', proxy_http_url) + else: + transport_factory_args = lambda: AiohttpTransport(**transport_kwargs) + request = self._prepare_request('GET', ws_url) + traversal_source = 'g' if self.is_neptune_domain() else self.gremlin_traversal_source - return client.Client(ws_url, traversal_source, username=self.gremlin_username, - password=self.gremlin_password, message_serializer=self.gremlin_serializer, + return client.Client(ws_url, traversal_source, transport_factory=transport_factory_args, + username=self.gremlin_username, password=self.gremlin_password, + message_serializer=self.gremlin_serializer, headers=dict(request.headers), **transport_kwargs) def gremlin_query(self, query, transport_args=None, bindings=None): @@ -298,7 +315,8 @@ def gremlin_http_query(self, query, headers=None) -> requests.Response: if headers is None: headers = {} - uri = f'{self.get_uri_with_port()}/gremlin' + use_proxy = True if self.proxy_host != '' else False + uri = f'{self.get_uri_with_port(use_websocket=False, use_proxy=use_proxy)}/gremlin' data = {'gremlin': query} req = self._prepare_request('POST', uri, data=json.dumps(data), headers=headers) res = self._http_session.send(req, verify=self.ssl_verify) @@ -431,7 +449,7 @@ def stream(self, url, **kwargs) -> requests.Response: params = {} for k, v in kwargs.items(): params[k] = v - req = self._prepare_request('GET', url, params=params,data='') + req = self._prepare_request('GET', url, params=params, data='') res = self._http_session.send(req, verify=self.ssl_verify) return res.json() @@ -850,7 +868,7 @@ def with_sparql_path(self, path: str): def with_gremlin_traversal_source(self, traversal_source: str): self.args['gremlin_traversal_source'] = traversal_source return ClientBuilder(self.args) - + def with_gremlin_login(self, username: str, password: str): self.args['gremlin_username'] = username self.args['gremlin_password'] = password @@ -859,7 +877,7 @@ def with_gremlin_login(self, username: str, password: str): def with_gremlin_serializer(self, message_serializer: str): self.args['gremlin_serializer'] = message_serializer return ClientBuilder(self.args) - + def with_neo4j_login(self, username: str, password: str, auth: bool, database: str): self.args['neo4j_username'] = username self.args['neo4j_password'] = password diff --git a/src/graph_notebook/network/gremlin/GremlinNetwork.py b/src/graph_notebook/network/gremlin/GremlinNetwork.py index b1978002..73a8d553 100644 --- a/src/graph_notebook/network/gremlin/GremlinNetwork.py +++ b/src/graph_notebook/network/gremlin/GremlinNetwork.py @@ -84,6 +84,8 @@ def get_id(element): elif isinstance(element, dict): if T.id in element: element_id = element[T.id] + elif 'id' in element: + element_id = element['id'] else: element_id = generate_id_from_dict(element) else: @@ -104,13 +106,19 @@ class GremlinNetwork(EventfulNetwork): """ def __init__(self, graph: MultiDiGraph = None, callbacks=None, label_max_length=DEFAULT_LABEL_MAX_LENGTH, - edge_label_max_length=DEFAULT_LABEL_MAX_LENGTH, group_by_property=T_LABEL, display_property=T_LABEL, - edge_display_property=T_LABEL, tooltip_property=None, edge_tooltip_property=None, ignore_groups=False, - group_by_depth=False, group_by_raw=False): + edge_label_max_length=DEFAULT_LABEL_MAX_LENGTH, group_by_property=None, display_property=None, + edge_display_property=None, tooltip_property=None, edge_tooltip_property=None, ignore_groups=False, + group_by_depth=False, group_by_raw=False, using_http=False): if graph is None: graph = MultiDiGraph() if group_by_depth: group_by_property = DEPTH_GRP_KEY + if not group_by_property: + group_by_property = 'label' if using_http else T_LABEL + if not display_property: + display_property = 'label' if using_http else T_LABEL + if not edge_display_property: + edge_display_property = 'label' if using_http else T_LABEL super().__init__(graph, callbacks, label_max_length, edge_label_max_length, group_by_property, display_property, edge_display_property, tooltip_property, edge_tooltip_property, ignore_groups, group_by_raw) @@ -164,7 +172,7 @@ def get_explicit_edge_property_value(self, data, edge, custom_property): except KeyError: return else: - if custom_property == T_LABEL: + if custom_property in [T_LABEL, 'label']: property_value = edge.label else: try: @@ -289,18 +297,26 @@ def add_results_with_pattern(self, results, pattern_list: list): return - def add_results(self, results): + def add_results(self, results, is_http=False): """ receives path results and traverses them to add nodes and edges to the network graph. We will look at sets of three in a path to form a vertex -> edge -> vertex pattern. All other patters will be considered invalid at this time. :param results: the data to be processed. Must be of type :type Path + :param is_http: A flag indicating the type of token keys returned :return: """ if not isinstance(results, list): raise ValueError("results must be a list of paths") + if is_http: + gremlin_id = 'id' + gremlin_label = 'label' + else: + gremlin_id = T.id + gremlin_label = T.label + for path_index, path in enumerate(results): if isinstance(path, Path): if type(path[0]) is Edge or type(path[-1]) is Edge: @@ -309,10 +325,10 @@ def add_results(self, results): for i in range(len(path)): if isinstance(path[i], dict): is_elementmap = False - if T.id in path[i] and T.label in path[i]: + if gremlin_id in path[i] and gremlin_label in path[i]: for prop, value in path[i].items(): - # T.id and/or T.label could be renamed by a project() step - if isinstance(value, str) and prop not in [T.id, T.label]: + # ID and/or Label property keys could be renamed by a project() step + if isinstance(value, str) and prop not in [gremlin_id, gremlin_label]: is_elementmap = True break elif isinstance(value, dict): @@ -330,7 +346,7 @@ def add_results(self, results): self.insert_path_element(path, i) else: self.insert_path_element(path, i) - elif isinstance(path, dict) and T.id in path.keys() and T.label in path.keys(): + elif isinstance(path, dict) and gremlin_id in path.keys() and gremlin_label in path.keys(): self.insert_elementmap(path, index=path_index) else: raise ValueError("all entries in results must be paths or elementMaps") @@ -431,8 +447,9 @@ def add_vertex(self, v, path_index: int = -1): # Before looping though properties, we first search for T.label in vertex dict, then set title = T.label # Otherwise, we will hit KeyError if we don't iterate through T.label first to set the title # Since it is needed for checking for the vertex label's desired grouping behavior in group_by_property - if T.label in v.keys(): - title_plc = str(v[T.label]) + if T.label in v.keys() or 'label' in v.keys(): + label_key = T.label if T.label in v.keys() else 'label' + title_plc = str(v[label_key]) title, label = self.strip_and_truncate_label_and_title(title_plc, self.label_max_length) else: title_plc = '' @@ -442,7 +459,7 @@ def add_vertex(self, v, path_index: int = -1): group = str(v) group_is_set = True for k in v: - if str(k) == T_ID: + if str(k) in [T_ID, 'id']: node_id = str(v[k]) if isinstance(v[k], dict): @@ -593,14 +610,15 @@ def add_path_edge(self, edge, from_id='', to_id='', data=None): if self.edge_tooltip_property and self.edge_tooltip_property != self.edge_display_property: using_custom_tooltip = True tooltip_display_is_set = False - if T.label in edge.keys(): - edge_title_plc = str(edge[T.label]) + if T.label in edge.keys() or 'label' in edge.keys(): + label_key = T.label if T.label in edge.keys() else 'label' + edge_title_plc = str(edge[label_key]) edge_title, edge_label = self.strip_and_truncate_label_and_title(edge_title_plc, self.edge_label_max_length) else: edge_title_plc = '' for k in edge: - if str(k) == T_ID: + if str(k) in [T_ID, 'id']: edge_id = str(edge[k]) if isinstance(edge[k], dict): # Handle Direction properties, where the value is a map @@ -612,7 +630,7 @@ def add_path_edge(self, edge, from_id='', to_id='', data=None): else: properties[k] = edge[k] - if self.edge_display_property is not T_LABEL and not display_is_set: + if self.edge_display_property not in [T_LABEL, 'label'] and not display_is_set: label_property_raw_value = self.get_dict_element_property_value(edge, k, edge_title_plc, self.edge_display_property) if label_property_raw_value: @@ -688,9 +706,10 @@ def insert_path_element(self, path, i): from_id = get_id(path[i - 1]) self.add_vertex(path[i], path_index=i) - if type(path[i - 1]) is not Edge: - if type(path[i - 1]) is dict: - if Direction.IN not in path[i-1]: + last_path = path[i - 1] + if type(last_path) is not Edge: + if type(last_path) is dict: + if Direction.IN not in last_path and 'IN' not in last_path: self.add_blank_edge(from_id, get_id(path[i])) else: self.add_blank_edge(from_id, get_id(path[i])) @@ -705,15 +724,18 @@ def insert_elementmap(self, e_map, check_emap=False, path_element=None, index=No :param e_map: A dictionary containing the elementMap representation of a vertex or an edge """ # Handle directed edge elementMap - if Direction.IN in e_map.keys() and Direction.OUT in e_map.keys(): - from_id = get_id(e_map[Direction.OUT]) - to_id = get_id(e_map[Direction.IN]) + if (Direction.IN in e_map.keys() and Direction.OUT in e_map.keys()) \ + or ('IN' in e_map.keys() and 'OUT' in e_map.keys()): + out_prop = Direction.OUT if Direction.OUT in e_map.keys() else 'OUT' + in_prop = Direction.IN if Direction.IN in e_map.keys() else 'IN' + from_id = get_id(e_map[out_prop]) + to_id = get_id(e_map[in_prop]) # Ensure that the default nodes includes with edge elementMaps do not overwrite nodes # with the same ID that have been explicitly inserted. if not self.graph.has_node(from_id): - self.add_vertex(e_map[Direction.OUT], path_index=index-1) + self.add_vertex(e_map[out_prop], path_index=index-1) if not self.graph.has_node(to_id): - self.add_vertex(e_map[Direction.IN], path_index=index+1) + self.add_vertex(e_map[in_prop], path_index=index+1) self.add_path_edge(e_map, from_id, to_id) # Handle vertex elementMap else: From 58ece9582e9ac2d3577a4d8ea92d7ae1e02c13e8 Mon Sep 17 00:00:00 2001 From: Michael Chin Date: Wed, 4 Oct 2023 16:57:02 -0700 Subject: [PATCH 2/2] update changelog --- ChangeLog.md | 1 + 1 file changed, 1 insertion(+) diff --git a/ChangeLog.md b/ChangeLog.md index d1203102..a2e622b6 100644 --- a/ChangeLog.md +++ b/ChangeLog.md @@ -5,6 +5,7 @@ Starting with v1.31.6, this file will contain a record of major features and upd ## Upcoming - Added `--explain-type` option to `%%gremlin` ([Link to PR](https://github.com/aws/graph-notebook/pull/503)) - Added general documentation for `%%graph_notebook_config` options ([Link to PR](https://github.com/aws/graph-notebook/pull/504)) +- Added support for Gremlin proxy hosts and visualization of Neptune HTTP results ([Link to PR](https://github.com/aws/graph-notebook/pull/530)) - Modified Dockerfile to support Python 3.10 ([Link to PR](https://github.com/aws/graph-notebook/pull/519)) - Updated Docker documentation with platform-specific run commands ([Link to PR](https://github.com/aws/graph-notebook/pull/502)) - Fixed deprecation warnings in GitHub workflows ([Link to PR](https://github.com/aws/graph-notebook/pull/506))