-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcode_executor.py
More file actions
180 lines (150 loc) · 5.94 KB
/
code_executor.py
File metadata and controls
180 lines (150 loc) · 5.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import sys
import io
import os
import glob
import traceback
from contextlib import redirect_stdout, redirect_stderr
from typing import Dict, Any, Optional, List
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
class CodeExecutor:
def __init__(self, output_dir: str = "./"):
self.output_dir = output_dir
def execute_visualization_code(
self,
code: str,
df: pd.DataFrame,
output_filename: str = "output.png",
base_filename: Optional[str] = None
) -> Dict[str, Any]:
"""执行可视化代码,支持多图输出"""
result = {
'success': False,
'output': '',
'error': '',
'output_file': None,
'output_files': []
}
exec_globals = {
'pd': pd,
'df': df,
'np': __import__('numpy'),
'plt': plt,
'matplotlib': matplotlib,
'sns': __import__('seaborn'),
'__builtins__': __builtins__
}
# 处理文件名
if base_filename and not base_filename.endswith('.png'):
code = code.replace('output.png', f'{base_filename}.png')
elif output_filename != 'output.png':
code = code.replace('output.png', output_filename)
if 'savefig' not in code and 'save(' not in code:
code += f"\nplt.savefig('{output_filename}', dpi=300, bbox_inches='tight')"
stdout_capture = io.StringIO()
stderr_capture = io.StringIO()
# 记录执行前的图片文件
if base_filename:
output_pattern = base_filename if os.path.isdir(base_filename) else os.path.dirname(base_filename) or '.'
before_files = set(glob.glob(os.path.join(output_pattern, '*.png')))
else:
before_files = set()
try:
plt.clf()
plt.close('all')
with redirect_stdout(stdout_capture), redirect_stderr(stderr_capture):
exec(code, exec_globals)
result['success'] = True
result['output'] = stdout_capture.getvalue()
# 检测生成的文件
if base_filename:
after_files = set(glob.glob(os.path.join(output_pattern, '*.png')))
new_files = list(after_files - before_files)
if new_files:
result['output_files'] = sorted(new_files)
result['output_file'] = new_files[0] if len(new_files) == 1 else None
else:
result['output_file'] = output_filename
else:
result['output_file'] = output_filename
plt.close('all')
except Exception as e:
result['success'] = False
result['error'] = f"{type(e).__name__}: {str(e)}\n\n{traceback.format_exc()}"
result['output'] = stdout_capture.getvalue()
plt.close('all')
return result
def execute_combined_visualization(
self,
code: str,
data_dict: Dict[str, pd.DataFrame],
output_dir: str,
base_filename: str = "combined"
) -> Dict[str, Any]:
"""执行多数据集综合可视化代码"""
result = {
'success': False,
'output': '',
'error': '',
'output_file': None,
'output_files': []
}
exec_globals = {
'pd': pd,
'data_dict': data_dict,
'np': __import__('numpy'),
'plt': plt,
'matplotlib': matplotlib,
'sns': __import__('seaborn'),
'__builtins__': __builtins__
}
# 记录执行前的文件
before_files = set(glob.glob(os.path.join(output_dir, '*.png')))
stdout_capture = io.StringIO()
stderr_capture = io.StringIO()
try:
plt.clf()
plt.close('all')
# 切换到输出目录以便正确保存文件
original_dir = os.getcwd()
os.chdir(output_dir)
with redirect_stdout(stdout_capture), redirect_stderr(stderr_capture):
exec(code, exec_globals)
os.chdir(original_dir)
result['success'] = True
result['output'] = stdout_capture.getvalue()
# 检测生成的文件
after_files = set(glob.glob(os.path.join(output_dir, '*.png')))
new_files = sorted(list(after_files - before_files))
if new_files:
result['output_files'] = new_files
result['output_file'] = new_files[0] if len(new_files) == 1 else None
else:
# 如果没有检测到新文件,可能使用了默认名称
default_output = os.path.join(output_dir, 'output.png')
if os.path.exists(default_output):
result['output_file'] = default_output
plt.close('all')
except Exception as e:
result['success'] = False
result['error'] = f"{type(e).__name__}: {str(e)}\n\n{traceback.format_exc()}"
result['output'] = stdout_capture.getvalue()
plt.close('all')
try:
os.chdir(original_dir)
except:
pass
return result
def validate_code(self, code: str) -> Dict[str, Any]:
"""验证代码语法"""
result = {'valid': False, 'error': ''}
try:
compile(code, '<string>', 'exec')
result['valid'] = True
except SyntaxError as e:
result['error'] = f"语法错误: {str(e)}"
except Exception as e:
result['error'] = f"编译错误: {str(e)}"
return result