diff --git a/python/drisk_api/graph_client.py b/python/drisk_api/graph_client.py index ee80a45..16cf2c0 100644 --- a/python/drisk_api/graph_client.py +++ b/python/drisk_api/graph_client.py @@ -3,6 +3,7 @@ from uuid import UUID, uuid4 import requests +import warnings from .drisk_api import PyGraphDiff @@ -197,6 +198,44 @@ def get_nodes(self, node_ids: List[UUID]) -> Dict[UUID, "Node"]: UUID(id): Node(self, UUID(id), **data["properties"]) for id, data in node_data.items() } + + def get_all_node_ids(self) -> List[UUID]: + """ + Retrieve all node IDs in the graph. + + Returns + ------- + List[UUID]: List of every node ID in the graph. + + """ + url = f"{self.url}/{self.graph_id}/nodes" + r = requests.get(url, headers={"Authorization": self.auth_token}) + if r.status_code >= 300: + raise EdgeException(r.status_code, r.text) + return [UUID(node["id"]) for node in r.json()] + + def get_node_id_by_label(self, label: str) -> UUID: + """ + Retrieve a node by its label. + + Args: + label (str): the label of the node to be retrieved. + + Returns + ------- + UUID: the node ID corresponding to the node label given. Raises warning if multiple nodes with that label are found. + + """ + all_node_ids = self.get_all_node_ids() + instances = 0 + for node_id in all_node_ids: + node = self.get_node(node_id) + if node and node.label() == label: + instances += 1 + if instances > 1: + warnings.warn("Multiple instances of nodes with the label specified.") + return node + return None def get_successors( self,