Skip to content

Commit e41874b

Browse files
committed
log,pow:修复
1 parent c03ed9b commit e41874b

21 files changed

Lines changed: 1700 additions & 205 deletions

File tree

doc/excuter/op-mem-ompsimd/list.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
| sum | float32, float64 | T2 = sum(T1, dims=[1,2]) | sum@float32 T1 1 2 -> T2 |
88
| matmul | float32, float64 | T3 = T1 @ T2 | matmul@float32 T1 T2 -> T3 |
99
| concat | float32, float64 | T3 = concat([T1, T2], axis=3) | concat@float32 T1 T2 3 -> T3 |
10+
| pow_scalar | float32, float64 | T2 = T1 ^ 2.0 | pow_scalar@float32 T1 2.0 -> T2 |
11+
| pow | float32, float64 | T3 = T1 ^ T2 | pow@float32 T1 T2 -> T3 |
1012
| max_scalar | float32, float64 | T2 = max(T1, 0.0) | max_scalar@float32 T1 0.0 -> T2 |
1113
| exp | float32, float64 | T2 = exp(T1) | exp@float32 T1 -> T2 |
1214
| min_scalar | float32, float64 | B= min(A, 1.0) | min_scalar@float32 A 1.0 -> B |

excuter/op-mem-ompsimd/src/client/main.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <mutex>
22
#include <thread>
3+
#include <cstdlib>
34

45
#include <deepx/tensorfunc/init.hpp>
56
#include "deepx/op/op.hpp"
@@ -10,6 +11,12 @@
1011
using namespace deepx::tensorfunc;
1112
using namespace deepx::mem;
1213

14+
// 从环境变量读取IR日志配置
15+
bool kIrLog = []() {
16+
const char* env = std::getenv("DEEPX_IR_LOG");
17+
return env != nullptr && (strcmp(env, "1") == 0 || strcasecmp(env, "true") == 0);
18+
}();
19+
1320
int main()
1421
{
1522
Mem mem;
@@ -42,17 +49,22 @@ int main()
4249
if (!tasks.empty()) {
4350
deepx::op::Op op = tasks.front();
4451
tasks.pop();
45-
cout << "~" << op.to_string()<< endl;
46-
std::string resp=to_string(op.id);
52+
53+
// 根据kIrLog标志决定是否打印op信息
54+
if (kIrLog) {
55+
cout << "~" << op.to_string() << endl;
56+
}
57+
58+
std::string resp = to_string(op.id);
4759
resp+="recv_at:";
4860
resp+=to_string(op.recv_at.time_since_epoch().count());
4961
if (opfactory.ops.find(op.name)==opfactory.ops.end()){
50-
cout<<"<op> "<<op.name<<" not found"<<endl;
62+
cerr<<"<op> "<<op.name<<" not found"<<endl;
5163
resp+="error op not found";
5264
}
5365
auto &type_map = opfactory.ops.find(op.name)->second;
5466
if (type_map.find(op.dtype)==type_map.end()){
55-
cout<<"<op>"<<op.name<<" "<<op.dtype<<" not found"<<endl;
67+
cerr<<"<op>"<<op.name<<" "<<op.dtype<<" not found"<<endl;
5668
resp+="error dtype not found";
5769
}
5870
auto src = type_map.find(op.dtype)->second;

excuter/op-mem-ompsimd/src/deepx/op/elementwise.hpp

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -452,12 +452,16 @@ namespace deepx::op
452452
class Pow : public Op
453453
{
454454
public:
455+
Pow(){
456+
this->init("pow",deepx::dtype<T>::name(), {}, {}, false, {}, {});
457+
}
455458
Pow(vector< string> args, vector< string> returns, bool require_grad = false, vector< string> args_grad = {}, vector< string> returns_grad = {}){
456459
this->init("pow",deepx::dtype<T>::name(), args, returns, require_grad, args_grad, returns_grad);
457460
}
458461
Pow(initializer_list< string> args, initializer_list< string> returns, bool require_grad = false, initializer_list< string> args_grad = {}, initializer_list< string> returns_grad = {}){
459462
this->init("pow",deepx::dtype<T>::name(), args, returns, require_grad, args_grad, returns_grad);
460463
}
464+
//已验证,2025-03-06,lipeng
461465
void forward(mem::Mem &mem) override
462466
{
463467
auto a = mem.gettensor<T>(this->args[0]).get();
@@ -515,27 +519,27 @@ namespace deepx::op
515519
}
516520
void forward(mem::Mem &mem) override
517521
{
518-
auto a = mem.gettensor<T>(this->args[0]).get();
519-
auto b = mem.getarg<T>(this->args[1]);
520-
auto c = mem.gettensor<T>(this->returns[0]);
521-
deepx::tensorfunc::pow(*a, b, *c);
522+
auto A = mem.gettensor<T>(this->args[0]).get();
523+
auto b = this->getarg<T>(1,mem);
524+
auto C = mem.gettensor<T>(this->returns[0]);
525+
deepx::tensorfunc::pow_scalar(*A, b, *C);
522526
}
523527
void backward(mem::Mem &mem) override
524528
{
525529
// 需要用到前向传播的输入、输出和标量指数
526-
auto a = mem.gettensor<T>(this->args[0]).get();
527-
auto b = mem.getarg<T>(this->args[1]); // 标量指数
528-
auto c = mem.gettensor<T>(this->returns[0]).get(); // c = a^b
529-
auto a_grad = mem.gettensor<T>(this->args_grad[0]).get();
530-
auto c_grad = mem.gettensor<T>(this->returns_grad[0]).get();
530+
auto A = mem.gettensor<T>(this->args[0]).get();
531+
auto b = this->getarg<T>(1,mem); // 标量指数
532+
auto C = mem.gettensor<T>(this->returns[0]).get(); // c = a^b
533+
auto A_grad = mem.gettensor<T>(this->args_grad[0]).get();
534+
auto C_grad = mem.gettensor<T>(this->returns_grad[0]).get();
531535

532536
// 标量幂运算的反向传播:
533537
// 对于 c = a^b,其中b是标量
534538
// ∂L/∂a = ∂L/∂c * ∂c/∂a = c_grad * b * a^(b-1)
535539
// = c_grad * b * (c/a) 【因为c=a^b,所以a^(b-1)=c/a】
536-
deepx::tensorfunc::div(*c, *a, *a_grad); // temp = c/a
537-
deepx::tensorfunc::mul(*a_grad, b, *a_grad); // temp = b * (c/a)
538-
deepx::tensorfunc::mul(*a_grad, *c_grad, *a_grad); // a_grad = c_grad * b * (c/a)
540+
deepx::tensorfunc::div(*C, *A, *A_grad); // temp = c/a
541+
deepx::tensorfunc::mul(*A_grad, b, *A_grad); // temp = b * (c/a)
542+
deepx::tensorfunc::mul(*A_grad, *C_grad, *A_grad); // a_grad = c_grad * b * (c/a)
539543
// 标量b不需要计算梯度
540544
}
541545
void setexample() override {

excuter/op-mem-ompsimd/src/deepx/op/opfactory.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,14 @@ namespace deepx::op
109109
opfactory.add_op(Exp<float>());
110110
opfactory.add_op(Exp<double>());
111111
}
112+
void register_pow(OpFactory &opfactory){
113+
opfactory.add_op(Pow<float>());
114+
opfactory.add_op(Pow<double>());
115+
}
116+
void register_pow_scalar(OpFactory &opfactory){
117+
opfactory.add_op(Pow_scalar<float>());
118+
opfactory.add_op(Pow_scalar<double>());
119+
}
112120
void register_elementwise_op(OpFactory &opfactory){
113121
register_add(opfactory);
114122
register_add_scalar(opfactory);
@@ -120,6 +128,8 @@ namespace deepx::op
120128
register_rdiv_scalar(opfactory);
121129
register_sqrt(opfactory);
122130
register_exp(opfactory);
131+
register_pow(opfactory);
132+
register_pow_scalar(opfactory);
123133
}
124134
//concat
125135

excuter/op-mem-ompsimd/src/deepx/tensorfunc/elementwise.hpp

Lines changed: 27 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -810,78 +810,51 @@ namespace deepx::tensorfunc
810810
}
811811
}
812812

813+
//C=A^value
814+
// highway 不支持POW
813815
template <typename T>
814-
void pow(const Tensor<T> &input, const T value, Tensor<T> &output)
816+
void pow_scalar(const Tensor<T> &input, const T value, Tensor<T> &output)
815817
{
816818
if (input.shape == output.shape)
817819
{
818-
output.shape.rangeParallel(output.shape.dim - 1, [&input, &output, &value](int i)
820+
output.shape.rangeParallel(output.shape.dim , [&input, &output, &value](int i)
819821
{
820-
int shape_last=output.shape[-1];
821-
const ScalableTag<T> tag;
822-
const size_t lanes = Lanes(tag);
823-
size_t j=0;
824-
825-
// 1. 处理前置未对齐部分
826-
while (j < shape_last && !IsAligned(tag,input.data + i + j)) {
827-
output.data[i+j] = std::pow(input.data[i+j], value);
828-
++j;
829-
}
830-
831-
// 2. 处理中间对齐部分
832-
size_t aligned_end=shape_last-(shape_last%lanes);
833-
for (; j+lanes<=aligned_end; j += lanes )
834-
{
835-
auto vec = Load(tag, input.data + i + j);
836-
auto scalar = Set(tag, value);
837-
auto vec_result = Pow(vec, scalar);
838-
Store(vec_result, tag, output.data + i + j);
839-
}
822+
output.data[i] = std::pow(input.data[i], value);
823+
});
824+
}
825+
else
826+
{
827+
throw std::invalid_argument("shape mismatch");
828+
}
829+
}
840830

841-
// 3. 处理尾部剩余元素
842-
for (;j<shape_last;j++)
843-
{
844-
output.data[i+j] = std::pow(input.data[i+j], value);
845-
} });
831+
//C=A^B
832+
template <typename T>
833+
void pow(const Tensor<T> &A, Tensor<T> &B,Tensor<T> &C)
834+
{
835+
if (A.shape == B.shape && A.shape == C.shape)
836+
{
837+
C.shape.rangeParallel(C.shape.dim , [&A, &B, &C](int i){
838+
C.data[i] = std::pow(A.data[i], B.data[i]);
839+
});
846840
}
847841
else
848842
{
849843
throw std::invalid_argument("shape mismatch");
850844
}
851845
}
852846

847+
//hwy库没有log函数,所以只能用std::log
848+
853849
template <typename T>
854850
void log(const Tensor<T> &input, Tensor<T> &output)
855851
{
856852
if (input.shape == output.shape)
857853
{
858-
output.shape.rangeParallel(output.shape.dim - 1, [&input, &output](int i)
859-
{
860-
int shape_last=output.shape[-1];
861-
const ScalableTag<T> tag;
862-
const size_t lanes = Lanes(tag);
863-
size_t j=0;
864-
865-
// 1. 处理前置未对齐部分
866-
while (j < shape_last && !IsAligned(tag,input.data + i + j)) {
867-
output.data[i+j] = std::log(input.data[i+j]);
868-
++j;
869-
}
870-
871-
// 2. 处理中间对齐部分
872-
size_t aligned_end=shape_last-(shape_last%lanes);
873-
for (; j+lanes<=aligned_end; j += lanes )
874-
{
875-
auto vec = Load(tag, input.data + i + j);
876-
auto vec_result = Log(vec);
877-
Store(vec_result, tag, output.data + i + j);
878-
}
879-
880-
// 3. 处理尾部剩余元素
881-
for (;j<shape_last;j++)
882-
{
883-
output.data[i+j] = std::log(input.data[i+j]);
884-
}
854+
output.shape.rangeParallel(output.shape.dim , [&input, &output](int i){
855+
856+
output.data[i] = std::log(input.data[i]);
857+
885858
});
886859
}
887860
else
@@ -898,7 +871,6 @@ namespace deepx::tensorfunc
898871
{
899872
output.shape.rangeParallel(output.shape.dim , [&input, &output](int i)
900873
{
901-
902874
output.data[i] = std::exp(input.data[i]);
903875

904876
});

front/py/deepx/nn/functional/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from .init import *
66
from .reduce import reduce_max,reduce_min,sum,prod,mean
77
from .transpose import transpose,reshape
8-
from .activite import relu,sigmoid,swish,swiglu
8+
from .activite import relu,sigmoid,swish
99

1010
__all__ = [
1111
"newtensor",
@@ -15,5 +15,5 @@
1515
"matmul",
1616
"max","min","sum","prod","mean",
1717
"transpose","reshape",
18-
"relu","sigmoid","swish","swiglu",
18+
"relu","sigmoid","swish",
1919
]

0 commit comments

Comments
 (0)