|
5 | 5 | #include "deepx/tensorfunc/reduce.hpp" |
6 | 6 | #include "deepx/tensorfunc/broadcast.hpp" |
7 | 7 | #include "deepx/tensorfunc/compare.hpp" |
| 8 | +#include "stdutil/num.hpp" |
8 | 9 |
|
9 | 10 | namespace deepx::op |
10 | 11 | { |
@@ -82,14 +83,24 @@ namespace deepx::op |
82 | 83 |
|
83 | 84 | void forward(mem::Mem &mem) override{ |
84 | 85 | auto A=mem.gettensor<T>(this->args[0]); |
85 | | - auto b=mem.getarg<T>(this->args[1]); |
| 86 | + T b; |
| 87 | + if (!is_float(this->args[1])){ |
| 88 | + b=mem.getarg<T>(this->args[1]); |
| 89 | + }else{ |
| 90 | + b=T(atof(this->args[1].c_str())); |
| 91 | + } |
86 | 92 | auto output=mem.gettensor<T>(this->returns[0]); |
87 | 93 | deepx::tensorfunc::max(*A, b, *output); |
88 | 94 | } |
89 | 95 |
|
90 | 96 | void backward(mem::Mem &mem) override{ |
91 | 97 | auto A=mem.gettensor<T>(this->args[0]); |
92 | | - auto b=mem.getarg<T>(this->args[1]); |
| 98 | + T b; |
| 99 | + if (!is_float(this->args[1])){ |
| 100 | + b=mem.getarg<T>(this->args[1]); |
| 101 | + }else{ |
| 102 | + b=T(atof(this->args[1].c_str())); |
| 103 | + } |
93 | 104 | auto A_grad=mem.gettensor<T>(this->args_grad [0]); |
94 | 105 | auto output_grad=mem.gettensor<T>(this->returns_grad[0]); |
95 | 106 | deepx::tensorfunc::max_grad(*A, b, *A_grad, *output_grad); |
@@ -139,14 +150,24 @@ namespace deepx::op |
139 | 150 | } |
140 | 151 | void forward(mem::Mem &mem) override{ |
141 | 152 | auto A=mem.gettensor<T>(this->args[0]); |
142 | | - auto b=mem.getarg<T>(this->args[1]); |
| 153 | + T b; |
| 154 | + if (!is_float(this->args[1])){ |
| 155 | + b=mem.getarg<T>(this->args[1]); |
| 156 | + }else{ |
| 157 | + b=T(atof(this->args[1].c_str())); |
| 158 | + } |
143 | 159 | auto output=mem.gettensor<T>(this->returns[0]); |
144 | 160 | deepx::tensorfunc::min(*A, b, *output); |
145 | 161 | } |
146 | 162 |
|
147 | 163 | void backward(mem::Mem &mem) override{ |
148 | 164 | auto A=mem.gettensor<T>(this->args[0]); |
149 | | - auto b=mem.getarg<T>(this->args[1]); |
| 165 | + T b; |
| 166 | + if (!is_float(this->args[1])){ |
| 167 | + b=mem.getarg<T>(this->args[1]); |
| 168 | + }else{ |
| 169 | + b=T(atof(this->args[1].c_str())); |
| 170 | + } |
150 | 171 | auto A_grad=mem.gettensor<T>(this->args_grad[0]); |
151 | 172 | auto output_grad=mem.gettensor<T>(this->returns_grad[0]); |
152 | 173 | deepx::tensorfunc::min_grad(*A, b, *A_grad, *output_grad); |
|
0 commit comments