@@ -810,78 +810,51 @@ namespace deepx::tensorfunc
810810 }
811811 }
812812
813+ // C=A^value
814+ // highway 不支持POW
813815 template <typename T>
814- void pow (const Tensor<T> &input, const T value, Tensor<T> &output)
816+ void pow_scalar (const Tensor<T> &input, const T value, Tensor<T> &output)
815817 {
816818 if (input.shape == output.shape )
817819 {
818- output.shape .rangeParallel (output.shape .dim - 1 , [&input, &output, &value](int i)
820+ output.shape .rangeParallel (output.shape .dim , [&input, &output, &value](int i)
819821 {
820- int shape_last=output.shape [-1 ];
821- const ScalableTag<T> tag;
822- const size_t lanes = Lanes (tag);
823- size_t j=0 ;
824-
825- // 1. 处理前置未对齐部分
826- while (j < shape_last && !IsAligned (tag,input.data + i + j)) {
827- output.data [i+j] = std::pow (input.data [i+j], value);
828- ++j;
829- }
830-
831- // 2. 处理中间对齐部分
832- size_t aligned_end=shape_last-(shape_last%lanes);
833- for (; j+lanes<=aligned_end; j += lanes )
834- {
835- auto vec = Load (tag, input.data + i + j);
836- auto scalar = Set (tag, value);
837- auto vec_result = Pow (vec, scalar);
838- Store (vec_result, tag, output.data + i + j);
839- }
822+ output.data [i] = std::pow (input.data [i], value);
823+ });
824+ }
825+ else
826+ {
827+ throw std::invalid_argument (" shape mismatch" );
828+ }
829+ }
840830
841- // 3. 处理尾部剩余元素
842- for (;j<shape_last;j++)
843- {
844- output.data [i+j] = std::pow (input.data [i+j], value);
845- } });
831+ // C=A^B
832+ template <typename T>
833+ void pow (const Tensor<T> &A, Tensor<T> &B,Tensor<T> &C)
834+ {
835+ if (A.shape == B.shape && A.shape == C.shape )
836+ {
837+ C.shape .rangeParallel (C.shape .dim , [&A, &B, &C](int i){
838+ C.data [i] = std::pow (A.data [i], B.data [i]);
839+ });
846840 }
847841 else
848842 {
849843 throw std::invalid_argument (" shape mismatch" );
850844 }
851845 }
852846
847+ // hwy库没有log函数,所以只能用std::log
848+
853849 template <typename T>
854850 void log (const Tensor<T> &input, Tensor<T> &output)
855851 {
856852 if (input.shape == output.shape )
857853 {
858- output.shape .rangeParallel (output.shape .dim - 1 , [&input, &output](int i)
859- {
860- int shape_last=output.shape [-1 ];
861- const ScalableTag<T> tag;
862- const size_t lanes = Lanes (tag);
863- size_t j=0 ;
864-
865- // 1. 处理前置未对齐部分
866- while (j < shape_last && !IsAligned (tag,input.data + i + j)) {
867- output.data [i+j] = std::log (input.data [i+j]);
868- ++j;
869- }
870-
871- // 2. 处理中间对齐部分
872- size_t aligned_end=shape_last-(shape_last%lanes);
873- for (; j+lanes<=aligned_end; j += lanes )
874- {
875- auto vec = Load (tag, input.data + i + j);
876- auto vec_result = Log (vec);
877- Store (vec_result, tag, output.data + i + j);
878- }
879-
880- // 3. 处理尾部剩余元素
881- for (;j<shape_last;j++)
882- {
883- output.data [i+j] = std::log (input.data [i+j]);
884- }
854+ output.shape .rangeParallel (output.shape .dim , [&input, &output](int i){
855+
856+ output.data [i] = std::log (input.data [i]);
857+
885858 });
886859 }
887860 else
@@ -898,7 +871,6 @@ namespace deepx::tensorfunc
898871 {
899872 output.shape .rangeParallel (output.shape .dim , [&input, &output](int i)
900873 {
901-
902874 output.data [i] = std::exp (input.data [i]);
903875
904876 });
0 commit comments