-
Notifications
You must be signed in to change notification settings - Fork 215
Expand file tree
/
Copy pathgenerate_knowledge_graph.py
More file actions
127 lines (101 loc) · 3.82 KB
/
generate_knowledge_graph.py
File metadata and controls
127 lines (101 loc) · 3.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from langchain_experimental.graph_transformers import LLMGraphTransformer
from langchain_core.documents import Document
from langchain_openai import ChatOpenAI
from pyvis.network import Network
from dotenv import load_dotenv
import os
import asyncio
# Load the .env file
load_dotenv()
# Get API key from environment variable
api_key = os.getenv("OPENAI_API_KEY")
llm = ChatOpenAI(temperature=0, model_name="gpt-4o")
graph_transformer = LLMGraphTransformer(llm=llm)
# Extract graph data from input text
async def extract_graph_data(text):
"""
Asynchronously extracts graph data from input text using a graph transformer.
Args:
text (str): Input text to be processed into graph format.
Returns:
list: A list of GraphDocument objects containing nodes and relationships.
"""
documents = [Document(page_content=text)]
graph_documents = await graph_transformer.aconvert_to_graph_documents(documents)
return graph_documents
def visualize_graph(graph_documents):
"""
Visualizes a knowledge graph using PyVis based on the extracted graph documents.
Args:
graph_documents (list): A list of GraphDocument objects with nodes and relationships.
Returns:
pyvis.network.Network: The visualized network graph object.
"""
# Create network
net = Network(height="1200px", width="100%", directed=True,
notebook=False, bgcolor="#222222", font_color="white", filter_menu=True, cdn_resources='remote')
nodes = graph_documents[0].nodes
relationships = graph_documents[0].relationships
# Build lookup for valid nodes
node_dict = {node.id: node for node in nodes}
# Filter out invalid edges and collect valid node IDs
valid_edges = []
valid_node_ids = set()
for rel in relationships:
if rel.source.id in node_dict and rel.target.id in node_dict:
valid_edges.append(rel)
valid_node_ids.update([rel.source.id, rel.target.id])
# Track which nodes are part of any relationship
connected_node_ids = set()
for rel in relationships:
connected_node_ids.add(rel.source.id)
connected_node_ids.add(rel.target.id)
# Add valid nodes to the graph
for node_id in valid_node_ids:
node = node_dict[node_id]
try:
net.add_node(node.id, label=node.id, title=node.type, group=node.type)
except:
continue # Skip node if error occurs
# Add valid edges to the graph
for rel in valid_edges:
try:
net.add_edge(rel.source.id, rel.target.id, label=rel.type.lower())
except:
continue # Skip edge if error occurs
# Configure graph layout and physics
net.set_options("""
{
"physics": {
"forceAtlas2Based": {
"gravitationalConstant": -100,
"centralGravity": 0.01,
"springLength": 200,
"springConstant": 0.08
},
"minVelocity": 0.75,
"solver": "forceAtlas2Based"
}
}
""")
output_file = "knowledge_graph.html"
try:
net.save_graph(output_file)
print(f"Graph saved to {os.path.abspath(output_file)}")
return net
except Exception as e:
print(f"Error saving graph: {e}")
return None
def generate_knowledge_graph(text):
"""
Generates and visualizes a knowledge graph from input text.
This function runs the graph extraction asynchronously and then visualizes
the resulting graph using PyVis.
Args:
text (str): Input text to convert into a knowledge graph.
Returns:
pyvis.network.Network: The visualized network graph object.
"""
graph_documents = asyncio.run(extract_graph_data(text))
net = visualize_graph(graph_documents)
return net