1- #include " op.hpp"
21
2+ #include < iomanip>
3+ #include < sstream>
4+ #include < ctime>
5+
6+ #include " op.hpp"
37namespace deepx ::op
48{
5- // 与deepx/front/py/deepx/nn/deepxir.py对应
9+ // 与deepx/front/py/deepx/nn/deepxir.py对应
610
7- // 新格式示例:mul@float32 a(a_grad) b(b_grad) -> a(a_grad) //id=1 create_time=1714512000 send_time=1714512000 recv_time=1714512000
8- void Op::load (const string &input) {
11+ // 新格式示例:mul@float32 a(a_grad) b(b_grad) -> a(a_grad) //id=1 create_time=1714512000 send_time=1714512000 recv_time=1714512000
12+ void Op::load (const string &input)
13+ {
914 // 分割元数据部分
1015 size_t meta_pos = input.find (" //" );
1116 string body = input.substr (0 , meta_pos);
1217 string meta = (meta_pos != string::npos) ? input.substr (meta_pos + 2 ) : " " ;
1318
1419 // 解析操作主体
1520 size_t arrow_pos = body.find (" ->" );
16- if (arrow_pos == string::npos) {
21+ if (arrow_pos == string::npos)
22+ {
1723 arrow_pos = body.find (" <-" );
18- if (arrow_pos != string::npos) {
19- grad = true ; // 反向传播标记
24+ if (arrow_pos != string::npos)
25+ {
26+ grad = true ; // 反向传播标记
2027 }
2128 }
22-
23- if (arrow_pos == string::npos) {
29+
30+ if (arrow_pos == string::npos)
31+ {
2432 throw runtime_error (" Invalid IR format: missing arrow" );
2533 }
2634
@@ -29,23 +37,32 @@ namespace deepx::op
2937
3038 // 解析操作名和数据类型
3139 size_t at_pos = head.find (' @' );
32- if (at_pos != string::npos) {
40+ if (at_pos != string::npos)
41+ {
3342 name = head.substr (0 , at_pos);
3443 size_t space_pos = head.find (' ' , at_pos);
35- if (space_pos != string::npos) {
44+ if (space_pos != string::npos)
45+ {
3646 dtype = head.substr (at_pos + 1 , space_pos - at_pos - 1 );
3747 head = head.substr (space_pos + 1 );
38- } else {
48+ }
49+ else
50+ {
3951 dtype = head.substr (at_pos + 1 );
4052 head.clear ();
4153 }
42- } else {
54+ }
55+ else
56+ {
4357 size_t space_pos = head.find (' ' );
44- if (space_pos != string::npos) {
58+ if (space_pos != string::npos)
59+ {
4560 name = head.substr (0 , space_pos);
4661 head = head.substr (space_pos + 1 );
4762 dtype = " any" ;
48- } else {
63+ }
64+ else
65+ {
4966 name = head;
5067 head.clear ();
5168 dtype = " any" ;
@@ -55,53 +72,68 @@ namespace deepx::op
5572 // 解析输入参数
5673 stringstream head_ss (head);
5774 string token;
58- while (head_ss >> token) {
59- size_t bracket = token.find (' (' );
60- if (bracket != string::npos && token.back () == ' )' ) {
75+ while (head_ss >> token)
76+ {
77+ size_t bracket = token.find (' (' );
78+ if (bracket != string::npos && token.back () == ' )' )
79+ {
6180 args.push_back (token.substr (0 , bracket));
6281 args_grad.push_back (token.substr (bracket + 1 , token.size () - bracket - 2 ));
63- } else {
82+ }
83+ else
84+ {
6485 args.push_back (token);
65- args_grad.emplace_back (" " ); // 保持梯度与参数数量一致
86+ args_grad.emplace_back (" " ); // 保持梯度与参数数量一致
6687 }
6788 }
6889
6990 // 解析输出参数
7091 stringstream tail_ss (tail);
71- while (tail_ss >> token) {
92+ while (tail_ss >> token)
93+ {
7294 size_t bracket = token.find (' (' );
73- if (bracket != string::npos && token.back () == ' )' ) {
95+ if (bracket != string::npos && token.back () == ' )' )
96+ {
7497 returns.push_back (token.substr (0 , bracket));
7598 returns_grad.push_back (token.substr (bracket + 1 , token.size () - bracket - 2 ));
76- } else {
99+ }
100+ else
101+ {
77102 returns.push_back (token);
78- returns_grad.emplace_back (" " ); // 保持梯度与参数数量一致
103+ returns_grad.emplace_back (" " ); // 保持梯度与参数数量一致
79104 }
80105 }
81106
82107 // 解析元数据
83- if (!meta.empty ()) {
108+ if (!meta.empty ())
109+ {
84110 stringstream meta_ss (meta);
85111 string key, value;
86- while (meta_ss >> key) {
112+ while (meta_ss >> key)
113+ {
87114 size_t eq_pos = key.find (' =' );
88- if (eq_pos != string::npos) {
115+ if (eq_pos != string::npos)
116+ {
89117 value = key.substr (eq_pos + 1 );
90118 key = key.substr (0 , eq_pos);
91-
92- if (key == " id" ) {
119+
120+ if (key == " id" )
121+ {
93122 id = stoi (value);
94- } else if (key == " created_at" ) {
123+ }
124+ else if (key == " created_at" )
125+ {
95126 created_at = system_clock::from_time_t (stod (value));
96- } else if (key == " sent_at" ) {
127+ }
128+ else if (key == " sent_at" )
129+ {
97130 sent_at = system_clock::from_time_t (stod (value));
98131 }
99132 }
100133 }
101134 }
102135 }
103136
104-
105137 void Op::init (const string &opname,
106138 const string &dtype,
107139 const vector<string> &args,
@@ -148,4 +180,56 @@ namespace deepx::op
148180 }
149181 }
150182 }
183+ static std::string format_time (const system_clock::time_point &tp)
184+ {
185+ using namespace std ::chrono;
186+ auto ms = duration_cast<microseconds>(tp.time_since_epoch ());
187+ auto sec = duration_cast<seconds>(ms);
188+ ms -= sec;
189+
190+ std::time_t t = sec.count ();
191+ std::tm tm;
192+ localtime_r (&t, &tm); // 线程安全版本
193+
194+ std::ostringstream oss;
195+ oss << std::put_time (&tm, " %Y-%m-%d %H:%M:%S" )
196+ << ' .' << std::setfill (' 0' ) << std::setw (6 ) << ms.count ();
197+ return oss.str ();
198+ }
199+ std::string Op::to_string (bool show_extra) const
200+ {
201+ std::stringstream ss;
202+ ss << name << " @" << dtype;
203+ for (size_t i = 0 ; i < args.size (); ++i)
204+ {
205+ if (grad)
206+ {
207+ ss << " " << args[i] << " (:+)" << args_grad[i];
208+ }
209+ else
210+ {
211+ ss << " " << args[i];
212+ }
213+ }
214+ ss << " ->" ;
215+ for (size_t i = 0 ; i < returns.size (); ++i)
216+ {
217+ if (grad)
218+ {
219+ ss << " " << returns[i] << " (:+)" << returns_grad[i];
220+ }
221+ else
222+ {
223+ ss << " " << returns[i];
224+ }
225+ }
226+ if (show_extra)
227+ {
228+ ss << " //id=" << id
229+ << " created_at=" << format_time (created_at)
230+ << " sent_at=" << format_time (sent_at)
231+ << " recv_at=" << format_time (recv_at);
232+ }
233+ return ss.str ();
234+ }
151235}
0 commit comments