diff --git a/src/backend/linalg_internal_cpu/Trace_internal.cpp b/src/backend/linalg_internal_cpu/Trace_internal.cpp index 9a047f492..c18423196 100644 --- a/src/backend/linalg_internal_cpu/Trace_internal.cpp +++ b/src/backend/linalg_internal_cpu/Trace_internal.cpp @@ -26,7 +26,7 @@ namespace cytnx { const std::vector &remain_rank_id, const std::vector &shape, const cytnx_uint64 &ax1, const cytnx_uint64 &ax2) { - cytnx::UniTensor I_UT = cytnx::UniTensor(eye(Ndiag, Tn.dtype()), false, -1); + cytnx::UniTensor I_UT = cytnx::UniTensor::eye(Ndiag, {}, true, Tn.dtype(), Tn.device()); UniTensor UTn = UniTensor(Tn, false, 2); I_UT.set_labels({UTn._impl->_labels[ax1], UTn._impl->_labels[ax2]}); @@ -50,9 +50,8 @@ namespace cytnx { // } } - // TODO: remove Nomp parameter void Trace_internal_cd(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const int &Nomp, const cytnx_uint64 &Nelem, + const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, const std::vector &accu, const std::vector &remain_rank_id, const std::vector &shape, const cytnx_uint64 &ax1, @@ -65,7 +64,7 @@ namespace cytnx { } void Trace_internal_cf(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const int &Nomp, const cytnx_uint64 &Nelem, + const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, const std::vector &accu, const std::vector &remain_rank_id, const std::vector &shape, const cytnx_uint64 &ax1, @@ -78,7 +77,7 @@ namespace cytnx { } void Trace_internal_d(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const int &Nomp, const cytnx_uint64 &Nelem, + const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, const std::vector &accu, const std::vector &remain_rank_id, const std::vector &shape, const cytnx_uint64 &ax1, @@ -91,7 +90,7 @@ namespace cytnx { } void Trace_internal_f(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const int &Nomp, const cytnx_uint64 &Nelem, + const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, const std::vector &accu, const std::vector &remain_rank_id, const std::vector &shape, const cytnx_uint64 &ax1, @@ -104,7 +103,7 @@ namespace cytnx { } void Trace_internal_u64(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const int &Nomp, const cytnx_uint64 &Nelem, + const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, const std::vector &accu, const std::vector &remain_rank_id, const std::vector &shape, const cytnx_uint64 &ax1, @@ -117,7 +116,7 @@ namespace cytnx { } void Trace_internal_i64(const bool &is_2d, Tensor &out, const Tensor &tn, - const cytnx_uint64 &ndiag, const int &nomp, const cytnx_uint64 &nelem, + const cytnx_uint64 &ndiag, const cytnx_uint64 &nelem, const std::vector &accu, const std::vector &remain_rank_id, const std::vector &shape, const cytnx_uint64 &ax1, @@ -130,7 +129,7 @@ namespace cytnx { } void Trace_internal_u32(const bool &is_2d, Tensor &out, const Tensor &tn, - const cytnx_uint64 &ndiag, const int &nomp, const cytnx_uint64 &nelem, + const cytnx_uint64 &ndiag, const cytnx_uint64 &nelem, const std::vector &accu, const std::vector &remain_rank_id, const std::vector &shape, const cytnx_uint64 &ax1, @@ -143,7 +142,7 @@ namespace cytnx { } void Trace_internal_i32(const bool &is_2d, Tensor &out, const Tensor &tn, - const cytnx_uint64 &ndiag, const int &nomp, const cytnx_uint64 &nelem, + const cytnx_uint64 &ndiag, const cytnx_uint64 &nelem, const std::vector &accu, const std::vector &remain_rank_id, const std::vector &shape, const cytnx_uint64 &ax1, @@ -156,7 +155,7 @@ namespace cytnx { } void Trace_internal_u16(const bool &is_2d, Tensor &out, const Tensor &tn, - const cytnx_uint64 &ndiag, const int &nomp, const cytnx_uint64 &nelem, + const cytnx_uint64 &ndiag, const cytnx_uint64 &nelem, const std::vector &accu, const std::vector &remain_rank_id, const std::vector &shape, const cytnx_uint64 &ax1, @@ -169,7 +168,7 @@ namespace cytnx { } void Trace_internal_i16(const bool &is_2d, Tensor &out, const Tensor &tn, - const cytnx_uint64 &ndiag, const int &nomp, const cytnx_uint64 &nelem, + const cytnx_uint64 &ndiag, const cytnx_uint64 &nelem, const std::vector &accu, const std::vector &remain_rank_id, const std::vector &shape, const cytnx_uint64 &ax1, @@ -182,7 +181,7 @@ namespace cytnx { } void Trace_internal_b(const bool &is_2d, Tensor &out, const Tensor &tn, - const cytnx_uint64 &ndiag, const int &nomp, const cytnx_uint64 &nelem, + const cytnx_uint64 &ndiag, const cytnx_uint64 &nelem, const std::vector &accu, const std::vector &remain_rank_id, const std::vector &shape, const cytnx_uint64 &ax1, diff --git a/src/backend/linalg_internal_cpu/Trace_internal.hpp b/src/backend/linalg_internal_cpu/Trace_internal.hpp index 64375a423..43a7150bf 100644 --- a/src/backend/linalg_internal_cpu/Trace_internal.hpp +++ b/src/backend/linalg_internal_cpu/Trace_internal.hpp @@ -13,21 +13,21 @@ namespace cytnx { namespace linalg_internal { void Trace_internal_cd(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const int &Nomp, const cytnx_uint64 &Nelem, + const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, const std::vector &accu, const std::vector &remain_rank_id, const std::vector &shape, const cytnx_uint64 &ax1, const cytnx_uint64 &ax2); void Trace_internal_cf(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const int &Nomp, const cytnx_uint64 &Nelem, + const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, const std::vector &accu, const std::vector &remain_rank_id, const std::vector &shape, const cytnx_uint64 &ax1, const cytnx_uint64 &ax2); void Trace_internal_d(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const int &Nomp, + const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, const std::vector &accu, const std::vector &remain_rank_id, @@ -35,7 +35,7 @@ namespace cytnx { const cytnx_uint64 &ax2); void Trace_internal_f(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const int &Nomp, + const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, const std::vector &accu, const std::vector &remain_rank_id, @@ -43,7 +43,7 @@ namespace cytnx { const cytnx_uint64 &ax2); void Trace_internal_u64(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const int &Nomp, + const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, const std::vector &accu, const std::vector &remain_rank_id, @@ -51,7 +51,7 @@ namespace cytnx { const cytnx_uint64 &ax2); void Trace_internal_i64(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const int &Nomp, + const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, const std::vector &accu, const std::vector &remain_rank_id, @@ -59,7 +59,7 @@ namespace cytnx { const cytnx_uint64 &ax2); void Trace_internal_u32(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const int &Nomp, + const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, const std::vector &accu, const std::vector &remain_rank_id, @@ -67,7 +67,7 @@ namespace cytnx { const cytnx_uint64 &ax2); void Trace_internal_i32(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const int &Nomp, + const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, const std::vector &accu, const std::vector &remain_rank_id, @@ -75,7 +75,7 @@ namespace cytnx { const cytnx_uint64 &ax2); void Trace_internal_u16(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const int &Nomp, + const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, const std::vector &accu, const std::vector &remain_rank_id, @@ -83,7 +83,7 @@ namespace cytnx { const cytnx_uint64 &ax2); void Trace_internal_i16(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const int &Nomp, + const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, const std::vector &accu, const std::vector &remain_rank_id, @@ -91,7 +91,7 @@ namespace cytnx { const cytnx_uint64 &ax2); void Trace_internal_b(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const int &Nomp, + const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, const std::vector &accu, const std::vector &remain_rank_id, diff --git a/src/backend/linalg_internal_gpu/CMakeLists.txt b/src/backend/linalg_internal_gpu/CMakeLists.txt index c3ca5f0ca..4ef039925 100644 --- a/src/backend/linalg_internal_gpu/CMakeLists.txt +++ b/src/backend/linalg_internal_gpu/CMakeLists.txt @@ -62,6 +62,7 @@ target_sources_local(cytnx cuMod_internal.cu cuDet_internal.cu cuSum_internal.cu + cuTrace_internal.cu cuMaxMin_internal.cu cuTensordot_internal.cu cuQuantumGeSvd_internal.cu diff --git a/src/backend/linalg_internal_gpu/cuTrace_internal.cu b/src/backend/linalg_internal_gpu/cuTrace_internal.cu index 878d840f4..bbdcc9752 100644 --- a/src/backend/linalg_internal_gpu/cuTrace_internal.cu +++ b/src/backend/linalg_internal_gpu/cuTrace_internal.cu @@ -14,12 +14,12 @@ namespace cytnx { template void _trace_2d_gpu(Tensor &out, const Tensor &Tn, const cytnx_uint64 &Ndiag) { - T a = 0; - T *rawdata = Tn.storage().data(); - cytnx_uint64 Ldim = Tn.shape()[1]; - // reduce! - for (cytnx_uint64 i = 0; i < Ndiag; i++) a += rawdata[i * Ldim + i]; - out.storage().at(0) = a; + cytnx::UniTensor I_UT = cytnx::UniTensor::eye(Ndiag, {}, true, Tn.dtype(), Tn.device()); + // similar to _trace_nd_gpu + UniTensor UTn = UniTensor(Tn, false, 2); + I_UT.set_labels({UTn._impl->_labels[0], UTn._impl->_labels[1]}); + + out = Contract(I_UT, UTn).get_block_(); } template @@ -28,149 +28,149 @@ namespace cytnx { const std::vector &remain_rank_id, const std::vector &shape, const cytnx_uint64 &ax1, const cytnx_uint64 &ax2) { - cytnx::UniTensor I_UT = cytnx::UniTensor(zeros(Ndiag, Tn.dtype(), Tn.device()), true, -1); + // currently identical to CPU version + cytnx::UniTensor I_UT = cytnx::UniTensor::eye(Ndiag, {}, true, Tn.dtype(), Tn.device()); - I_UT.set_labels({"0", "1"}); UniTensor UTn = UniTensor(Tn, false, 2); - UTn.set_labels( - vec_cast(vec_range(100, 100 + UTn.labels().size()))); - UTn._impl->_labels[ax1] = "0"; - UTn._impl->_labels[ax2] = "1"; + I_UT.set_labels({UTn._impl->_labels[ax1], UTn._impl->_labels[ax2]}); + out = Contract(I_UT, UTn).get_block_(); } void cuTrace_internal_cd(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const int &Nomp, const cytnx_uint64 &Nelem, + const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, const std::vector &accu, const std::vector &remain_rank_id, const std::vector &shape, const cytnx_uint64 &ax1, const cytnx_uint64 &ax2) { if (is_2d) { - _trace_2d(out, Tn, Ndiag); + _trace_2d_gpu(out, Tn, Ndiag); } else { - _trace_nd(out, Tn, Ndiag, Nelem, accu, remain_rank_id, shape, ax1, ax2); + _trace_nd_gpu(out, Tn, Ndiag, Nelem, accu, remain_rank_id, shape, ax1, + ax2); } } void cuTrace_internal_cf(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const int &Nomp, const cytnx_uint64 &Nelem, + const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, const std::vector &accu, const std::vector &remain_rank_id, const std::vector &shape, const cytnx_uint64 &ax1, const cytnx_uint64 &ax2) { if (is_2d) { - _trace_2d(out, Tn, Ndiag); + _trace_2d_gpu(out, Tn, Ndiag); } else { - _trace_nd(out, Tn, Ndiag, Nelem, accu, remain_rank_id, shape, ax1, ax2); + _trace_nd_gpu(out, Tn, Ndiag, Nelem, accu, remain_rank_id, shape, ax1, + ax2); } } void cuTrace_internal_d(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const int &Nomp, const cytnx_uint64 &Nelem, + const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, const std::vector &accu, const std::vector &remain_rank_id, const std::vector &shape, const cytnx_uint64 &ax1, const cytnx_uint64 &ax2) { if (is_2d) { - _trace_2d(out, Tn, Ndiag); + _trace_2d_gpu(out, Tn, Ndiag); } else { - _trace_nd(out, Tn, Ndiag, Nelem, accu, remain_rank_id, shape, ax1, ax2); + _trace_nd_gpu(out, Tn, Ndiag, Nelem, accu, remain_rank_id, shape, ax1, ax2); } } void cuTrace_internal_f(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const int &Nomp, const cytnx_uint64 &Nelem, + const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, const std::vector &accu, const std::vector &remain_rank_id, const std::vector &shape, const cytnx_uint64 &ax1, const cytnx_uint64 &ax2) { if (is_2d) { - _trace_2d(out, Tn, Ndiag); + _trace_2d_gpu(out, Tn, Ndiag); } else { - _trace_nd(out, Tn, Ndiag, Nelem, accu, remain_rank_id, shape, ax1, ax2); + _trace_nd_gpu(out, Tn, Ndiag, Nelem, accu, remain_rank_id, shape, ax1, ax2); } } void cuTrace_internal_u64(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const int &Nomp, const cytnx_uint64 &Nelem, + const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, const std::vector &accu, const std::vector &remain_rank_id, const std::vector &shape, const cytnx_uint64 &ax1, const cytnx_uint64 &ax2) { if (is_2d) { - _trace_2d(out, Tn, Ndiag); + _trace_2d_gpu(out, Tn, Ndiag); } else { - _trace_nd(out, Tn, Ndiag, Nelem, accu, remain_rank_id, shape, ax1, ax2); + _trace_nd_gpu(out, Tn, Ndiag, Nelem, accu, remain_rank_id, shape, ax1, ax2); } } void cuTrace_internal_i64(const bool &is_2d, Tensor &out, const Tensor &tn, - const cytnx_uint64 &ndiag, const int &nomp, const cytnx_uint64 &nelem, + const cytnx_uint64 &ndiag, const cytnx_uint64 &nelem, const std::vector &accu, const std::vector &remain_rank_id, const std::vector &shape, const cytnx_uint64 &ax1, const cytnx_uint64 &ax2) { if (is_2d) { - _trace_2d(out, tn, ndiag); + _trace_2d_gpu(out, tn, ndiag); } else { - _trace_nd(out, tn, ndiag, nelem, accu, remain_rank_id, shape, ax1, ax2); + _trace_nd_gpu(out, tn, ndiag, nelem, accu, remain_rank_id, shape, ax1, ax2); } } void cuTrace_internal_u32(const bool &is_2d, Tensor &out, const Tensor &tn, - const cytnx_uint64 &ndiag, const int &nomp, const cytnx_uint64 &nelem, + const cytnx_uint64 &ndiag, const cytnx_uint64 &nelem, const std::vector &accu, const std::vector &remain_rank_id, const std::vector &shape, const cytnx_uint64 &ax1, const cytnx_uint64 &ax2) { if (is_2d) { - _trace_2d(out, tn, ndiag); + _trace_2d_gpu(out, tn, ndiag); } else { - _trace_nd(out, tn, ndiag, nelem, accu, remain_rank_id, shape, ax1, ax2); + _trace_nd_gpu(out, tn, ndiag, nelem, accu, remain_rank_id, shape, ax1, ax2); } } void cuTrace_internal_i32(const bool &is_2d, Tensor &out, const Tensor &tn, - const cytnx_uint64 &ndiag, const int &nomp, const cytnx_uint64 &nelem, + const cytnx_uint64 &ndiag, const cytnx_uint64 &nelem, const std::vector &accu, const std::vector &remain_rank_id, const std::vector &shape, const cytnx_uint64 &ax1, const cytnx_uint64 &ax2) { if (is_2d) { - _trace_2d(out, tn, ndiag); + _trace_2d_gpu(out, tn, ndiag); } else { - _trace_nd(out, tn, ndiag, nelem, accu, remain_rank_id, shape, ax1, ax2); + _trace_nd_gpu(out, tn, ndiag, nelem, accu, remain_rank_id, shape, ax1, ax2); } } void cuTrace_internal_u16(const bool &is_2d, Tensor &out, const Tensor &tn, - const cytnx_uint64 &ndiag, const int &nomp, const cytnx_uint64 &nelem, + const cytnx_uint64 &ndiag, const cytnx_uint64 &nelem, const std::vector &accu, const std::vector &remain_rank_id, const std::vector &shape, const cytnx_uint64 &ax1, const cytnx_uint64 &ax2) { if (is_2d) { - _trace_2d(out, tn, ndiag); + _trace_2d_gpu(out, tn, ndiag); } else { - _trace_nd(out, tn, ndiag, nelem, accu, remain_rank_id, shape, ax1, ax2); + _trace_nd_gpu(out, tn, ndiag, nelem, accu, remain_rank_id, shape, ax1, ax2); } } void cuTrace_internal_i16(const bool &is_2d, Tensor &out, const Tensor &tn, - const cytnx_uint64 &ndiag, const int &nomp, const cytnx_uint64 &nelem, + const cytnx_uint64 &ndiag, const cytnx_uint64 &nelem, const std::vector &accu, const std::vector &remain_rank_id, const std::vector &shape, const cytnx_uint64 &ax1, const cytnx_uint64 &ax2) { if (is_2d) { - _trace_2d(out, tn, ndiag); + _trace_2d_gpu(out, tn, ndiag); } else { - _trace_nd(out, tn, ndiag, nelem, accu, remain_rank_id, shape, ax1, ax2); + _trace_nd_gpu(out, tn, ndiag, nelem, accu, remain_rank_id, shape, ax1, ax2); } } void cuTrace_internal_b(const bool &is_2d, Tensor &out, const Tensor &tn, - const cytnx_uint64 &ndiag, const int &nomp, const cytnx_uint64 &nelem, + const cytnx_uint64 &ndiag, const cytnx_uint64 &nelem, const std::vector &accu, const std::vector &remain_rank_id, const std::vector &shape, const cytnx_uint64 &ax1, diff --git a/src/backend/linalg_internal_gpu/cuTrace_internal.hpp b/src/backend/linalg_internal_gpu/cuTrace_internal.hpp index 300aee615..743dbef87 100644 --- a/src/backend/linalg_internal_gpu/cuTrace_internal.hpp +++ b/src/backend/linalg_internal_gpu/cuTrace_internal.hpp @@ -13,21 +13,21 @@ namespace cytnx { namespace linalg_internal { void cuTrace_internal_cd(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const int &Nomp, const cytnx_uint64 &Nelem, + const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, const std::vector &accu, const std::vector &remain_rank_id, const std::vector &shape, const cytnx_uint64 &ax1, const cytnx_uint64 &ax2); void cuTrace_internal_cf(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const int &Nomp, const cytnx_uint64 &Nelem, + const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, const std::vector &accu, const std::vector &remain_rank_id, const std::vector &shape, const cytnx_uint64 &ax1, const cytnx_uint64 &ax2); void cuTrace_internal_d(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const int &Nomp, + const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, const std::vector &accu, const std::vector &remain_rank_id, @@ -35,7 +35,7 @@ namespace cytnx { const cytnx_uint64 &ax2); void cuTrace_internal_f(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const int &Nomp, + const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, const std::vector &accu, const std::vector &remain_rank_id, @@ -43,7 +43,7 @@ namespace cytnx { const cytnx_uint64 &ax2); void cuTrace_internal_u64(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const int &Nomp, + const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, const std::vector &accu, const std::vector &remain_rank_id, @@ -51,7 +51,7 @@ namespace cytnx { const cytnx_uint64 &ax2); void cuTrace_internal_i64(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const int &Nomp, + const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, const std::vector &accu, const std::vector &remain_rank_id, @@ -59,7 +59,7 @@ namespace cytnx { const cytnx_uint64 &ax2); void cuTrace_internal_u32(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const int &Nomp, + const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, const std::vector &accu, const std::vector &remain_rank_id, @@ -67,7 +67,7 @@ namespace cytnx { const cytnx_uint64 &ax2); void cuTrace_internal_i32(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const int &Nomp, + const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, const std::vector &accu, const std::vector &remain_rank_id, @@ -75,7 +75,7 @@ namespace cytnx { const cytnx_uint64 &ax2); void cuTrace_internal_u16(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const int &Nomp, + const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, const std::vector &accu, const std::vector &remain_rank_id, @@ -83,7 +83,7 @@ namespace cytnx { const cytnx_uint64 &ax2); void cuTrace_internal_i16(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const int &Nomp, + const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, const std::vector &accu, const std::vector &remain_rank_id, @@ -91,7 +91,7 @@ namespace cytnx { const cytnx_uint64 &ax2); void cuTrace_internal_b(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const int &Nomp, + const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, const std::vector &accu, const std::vector &remain_rank_id, diff --git a/src/backend/linalg_internal_interface.cpp b/src/backend/linalg_internal_interface.cpp index a2080c87a..e69472b0e 100644 --- a/src/backend/linalg_internal_interface.cpp +++ b/src/backend/linalg_internal_interface.cpp @@ -988,6 +988,20 @@ namespace cytnx { cuVd_ii[Type.Uint16] = cuVectordot_internal_u16; cuVd_ii[Type.Bool] = cuVectordot_internal_b; + //===================== + cuTrace_ii = vector(N_Type); + + cuTrace_ii[Type.ComplexDouble] = cuTrace_internal_cd; + cuTrace_ii[Type.ComplexFloat] = cuTrace_internal_cf; + cuTrace_ii[Type.Double] = cuTrace_internal_d; + cuTrace_ii[Type.Float] = cuTrace_internal_f; + cuTrace_ii[Type.Uint64] = cuTrace_internal_u64; + cuTrace_ii[Type.Int64] = cuTrace_internal_i64; + cuTrace_ii[Type.Uint32] = cuTrace_internal_u32; + cuTrace_ii[Type.Int32] = cuTrace_internal_i32; + cuTrace_ii[Type.Uint16] = cuTrace_internal_u16; + cuTrace_ii[Type.Int16] = cuTrace_internal_i16; + cuTrace_ii[Type.Bool] = cuTrace_internal_b; //================ cuOuter_ii = vector>(N_Type, vector(N_Type, NULL)); diff --git a/src/backend/linalg_internal_interface.hpp b/src/backend/linalg_internal_interface.hpp index 5835a7ecb..fe9b8bc82 100644 --- a/src/backend/linalg_internal_interface.hpp +++ b/src/backend/linalg_internal_interface.hpp @@ -61,6 +61,7 @@ #include "linalg_internal_gpu/cuOuter_internal.hpp" #include "linalg_internal_gpu/cuPow_internal.hpp" #include "linalg_internal_gpu/cuSum_internal.hpp" + #include "linalg_internal_gpu/cuTrace_internal.hpp" #include "linalg_internal_gpu/cuSvd_internal.hpp" #include "linalg_internal_gpu/cuVectordot_internal.hpp" #include "linalg_internal_gpu/cudaMemcpyTruncation.hpp" @@ -174,8 +175,7 @@ namespace cytnx { const boost::intrusive_ptr &, const cytnx_uint64 &); typedef void (*Tracefunc_oii)(const bool &, Tensor &, const Tensor &, const cytnx_uint64 &, - const int &, const cytnx_uint64 &, - const std::vector &, + const cytnx_uint64 &, const std::vector &, const std::vector &, const std::vector &, const cytnx_uint64 &, const cytnx_uint64 &); @@ -270,6 +270,7 @@ namespace cytnx { std::vector cuDet_ii; std::vector cuMM_ii; std::vector cuSum_ii; + std::vector cuTrace_ii; std::vector cuTensordot_ii; std::vector cudaMemcpyTruncation_ii; diff --git a/src/linalg/Trace.cpp b/src/linalg/Trace.cpp index 4395df73b..e646aa73b 100644 --- a/src/linalg/Trace.cpp +++ b/src/linalg/Trace.cpp @@ -67,11 +67,18 @@ namespace cytnx { if (shape.size() == 0) { // 2d if (Tn.device() == Device.cpu) - linalg_internal::lii.Trace_ii[Tn.dtype()](true, out, Tn, Ndiag, Device.Ncpus, 0, {}, {}, - {}, 0, + linalg_internal::lii.Trace_ii[Tn.dtype()](true, out, Tn, Ndiag, 0, {}, {}, {}, 0, 0); // only the first 4 args will be used. else { - cytnx_error_msg(true, "[ERROR][Trace] GPU is under developing.%s", "\n"); + #ifdef UNI_GPU + checkCudaErrors(cudaSetDevice(Tn.device())); + linalg_internal::lii.cuTrace_ii[Tn.dtype()](true, out, Tn, Ndiag, 0, {}, {}, {}, 0, + 0); // only the first 4 args will be used. + #else + cytnx_error_msg(true, "[Trace] fatal error,%s", + "try to call the gpu section without CUDA support.\n"); + return out; + #endif } } else { // nd @@ -85,10 +92,18 @@ namespace cytnx { } // std::cout << "entry Trace" << std::endl; if (Tn.device() == Device.cpu) - linalg_internal::lii.Trace_ii[Tn.dtype()](false, out, Tn, Ndiag, Nelem, Device.Ncpus, - accu, remain_rank_id, shape, ax1, ax2); + linalg_internal::lii.Trace_ii[Tn.dtype()](false, out, Tn, Ndiag, Nelem, accu, + remain_rank_id, shape, ax1, ax2); else { - cytnx_error_msg(true, "[ERROR][Trace] GPU is under developing.%s", "\n"); + #ifdef UNI_GPU + checkCudaErrors(cudaSetDevice(Tn.device())); + linalg_internal::lii.cuTrace_ii[Tn.dtype()](false, out, Tn, Ndiag, Nelem, accu, + remain_rank_id, shape, ax1, ax2); + #else + cytnx_error_msg(true, "[Trace] fatal error,%s", + "try to call the gpu section without CUDA support.\n"); + return out; + #endif } out.reshape_(shape); } diff --git a/tests/gpu/BlockUniTensor_test.cpp b/tests/gpu/BlockUniTensor_test.cpp index 2d42a592a..8879c9c09 100644 --- a/tests/gpu/BlockUniTensor_test.cpp +++ b/tests/gpu/BlockUniTensor_test.cpp @@ -1,7 +1,6 @@ #include "BlockUniTensor_test.h" TEST_F(BlockUniTensorTest, gpu_Trace) { - GTEST_SKIP() << "Calculating the trace of a matrix is not implemented on the GPU."; // std::cout<