This repository was archived by the owner on Dec 17, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathcli.py
More file actions
145 lines (121 loc) · 4.55 KB
/
cli.py
File metadata and controls
145 lines (121 loc) · 4.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import os
import csv
import sys
import json
import logging
from glob import glob
from logging.config import dictConfig
import click
import neo4j
import requests
dictConfig({
'version': 1,
'formatters': {
'simple': {
'format': '%(asctime)s - %(filename)s:%(lineno)s: %(message)s',
}
},
'handlers': {
'default': {
'level': 'INFO',
'class': 'logging.StreamHandler',
'formatter': 'simple',
"stream": "ext://sys.stdout",
},
},
'loggers': {
'': {
'handlers': ['default'],
'level': 'INFO',
'propagate': True
}
}
})
logger = logging.getLogger('cli')
def create_entity_index(neo4j_client, entity_type, property_name):
with neo4j_client.session() as session:
session.run(f"CREATE INDEX ON :{entity_type}({property_name})")
logger.info(
"created index on property '%s' of entity type `%s`",
entity_type, property_name
)
@click.group(context_settings=dict(help_option_names=['-h', '--help']))
def main():
pass
@main.command("import-to-neo4j")
@click.option("--url", default="bolt://localhost:7687/")
@click.option("--auth", default="neo4j:myneo4j")
@click.option("-d", "--data-dir", required=True)
@click.option("-b", "--batch-size", type=int, default=1000)
@click.option("--dropall", is_flag=True)
def import_to_neo4j(url, auth, data_dir, batch_size, dropall):
"""导入数据到 Neo4j"""
def convert_csv_row(csv_row):
row = {}
for header, value in csv_row.items():
key, *remain = header.split(':')
if key:
row[key] = value
return row
user, password = auth.split(':')
client = neo4j.GraphDatabase.driver(url, auth=(user, password))
if dropall:
with client.session() as session:
session.run('MATCH (n) DETACH DELETE n')
logger.info("Dropped all data in Neo4j server")
metadata = None
metadata_file = os.path.join(data_dir, "metadata.json")
if not os.path.exists(metadata_file):
logger.error("Cannot found 'metadata.json' in directory '%s'", data_dir)
sys.exit(1)
with open(metadata_file) as f:
metadata = json.load(f)
query_tmpl = 'UNWIND {values} as data create (:%s {%s})'
for entity_type, entity_file in metadata["entity-data"].items():
create_entity_index(client, entity_type, "id")
query = ''
with open(os.path.join(data_dir, entity_file)) as f:
entities, reader = [], csv.DictReader(f)
for row in reader:
entities.append(convert_csv_row(row))
if not query:
query = query_tmpl % (
entity_type,
','.join([f'{prop}:data.{prop}' for prop in entities[-1]])
)
if len(entities) == batch_size:
with client.session() as session:
session.run(query, {'values': entities})
logger.info("wrote %d entities in Neo4j server", batch_size)
entities = []
if entities:
with client.session() as session:
session.run(query, {'values': entities})
logger.info("wrote %d entities in Neo4j server", len(entities))
query_tmpl = (
'UNWIND {values} as data '
'MATCH (a:%s {id:data.start_id}) '
'MATCH (b:%s {id:data.end_id}) '
'CREATE (a)-[:`%s`]->(b)'
)
for relation_type, relation_file in metadata.get("relation-data", {}).items():
start_type, relation, end_type = relation_type.split('|')
query = query_tmpl % (start_type, end_type, relation)
with open(os.path.join(data_dir, relation_file)) as f:
relations, reader = [], csv.DictReader(f)
for row in reader:
relations.append({
"start_id": row[":START_ID"],
"end_id": row[":END_ID"],
})
if len(relations) == batch_size:
with client.session() as session:
session.run(query, {'values': relations})
logger.info("wrote %d relations in Neo4j server", batch_size)
relations = []
if relations:
with client.session() as session:
session.run(query, {'values': relations})
logger.info("wrote %d relations in Neo4j server", len(relations))
if __name__ == '__main__':
main()