-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathrun_example.py
More file actions
executable file
·82 lines (62 loc) · 2.89 KB
/
run_example.py
File metadata and controls
executable file
·82 lines (62 loc) · 2.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
import torch
from lib.eval.eval_synth import EvalSynth
from lib.synthesizer.synthesizer_typed import TopLevelSynthesizerTyped
from lib.utils.csv_utils import read_csv_to_dict, List, Dict
from lib.eval.benchmark import Benchmark
from lib.utils.benchmark_utils import create_benchmark
from lib.utils.data_utils import get_data
from lib.program import Program
from lib.neural_parser.parser import QueryNeuralParser
from lib.synthesizer.program_synthesizer_typed import ProgramSynthesizerTyped
from parse_args import args
from lib.grammar.typed_cfg import TypedCFG
"""
This is the user query handler script for the user study, should connect to the flask in some way
"""
class UserQueryHandler:
def __init__(self):
os.chdir("datavis")
if torch.cuda.is_available() and args.gpu is not None:
args.device = torch.device(('cuda:' + args.gpu))
else:
args.device = 'cpu'
grammar = TypedCFG()
self.synth = TopLevelSynthesizerTyped(QueryNeuralParser(args), ProgramSynthesizerTyped(grammar), timeout=300)
data_dir = "eval/data"
self.benchmark = Benchmark(dataname="cars", bname="user", nl="", benchmark_set="chi21")
data, data_constraint, falx_data = get_data(data_dir, self.benchmark, mode='synth', generate_synthesis_constraint=True)
self.benchmark.data = data
self.counter = 0
os.chdir("..")
def run_user_query_benchmark(self, query, dataset_name):
os.chdir("datavis")
benchmark_set = "chi21"
eval_engine = EvalSynth()
benchmarks = read_csv_to_dict(eval_engine.get_benchmark_path(benchmark_set))
for b in benchmarks:
if b["query"] == query:
benchmark: Benchmark = create_benchmark(b, benchmark_set)
print(query)
res = eval_engine.eval(benchmark)
synthesized_prog_vlspec: List[Dict] = [prog.to_vega_lite() for prog in res.output]
os.chdir("..")
return synthesized_prog_vlspec
def run_user_query(self, query, k):
os.chdir("datavis")
self.benchmark.nl = query
self.benchmark.bname = "user-" + str(self.counter)
self.counter += 1
res: List[Program] = self.synth.synthesize(self.benchmark, k=k)
synthesized_prog_vlspec: List[Dict] = [prog.to_vega_lite() for prog in res]
os.chdir("..")
return synthesized_prog_vlspec
if __name__ == '__main__':
handler = UserQueryHandler()
res = handler.run_user_query('how does MPG compare to displacement, broken out by region?', 10)
print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~1st~~~~~~~~~~~~~~~~~~~~~~~")
print(res)
res = handler.run_user_query('show the relationship between acceleration and cylinders', 10)
print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~2nd~~~~~~~~~~~~~~~~~~~~~~~")
print(res)
# res = run_user_query_benchmark('how does MPG compare to displacement, broken out by region?', "Cars")