Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 12 additions & 13 deletions src/backend/linalg_internal_cpu/Trace_internal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace cytnx {
const std::vector<cytnx_uint64> &remain_rank_id,
const std::vector<cytnx_int64> &shape, const cytnx_uint64 &ax1,
const cytnx_uint64 &ax2) {
cytnx::UniTensor I_UT = cytnx::UniTensor(eye(Ndiag, Tn.dtype()), false, -1);
cytnx::UniTensor I_UT = cytnx::UniTensor::eye(Ndiag, {}, true, Tn.dtype(), Tn.device());

UniTensor UTn = UniTensor(Tn, false, 2);
I_UT.set_labels({UTn._impl->_labels[ax1], UTn._impl->_labels[ax2]});
Expand All @@ -50,9 +50,8 @@ namespace cytnx {
// }
}

// TODO: remove Nomp parameter
void Trace_internal_cd(const bool &is_2d, Tensor &out, const Tensor &Tn,
const cytnx_uint64 &Ndiag, const int &Nomp, const cytnx_uint64 &Nelem,
const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem,
const std::vector<cytnx_uint64> &accu,
const std::vector<cytnx_uint64> &remain_rank_id,
const std::vector<cytnx_int64> &shape, const cytnx_uint64 &ax1,
Expand All @@ -65,7 +64,7 @@ namespace cytnx {
}

void Trace_internal_cf(const bool &is_2d, Tensor &out, const Tensor &Tn,
const cytnx_uint64 &Ndiag, const int &Nomp, const cytnx_uint64 &Nelem,
const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem,
const std::vector<cytnx_uint64> &accu,
const std::vector<cytnx_uint64> &remain_rank_id,
const std::vector<cytnx_int64> &shape, const cytnx_uint64 &ax1,
Expand All @@ -78,7 +77,7 @@ namespace cytnx {
}

void Trace_internal_d(const bool &is_2d, Tensor &out, const Tensor &Tn,
const cytnx_uint64 &Ndiag, const int &Nomp, const cytnx_uint64 &Nelem,
const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem,
const std::vector<cytnx_uint64> &accu,
const std::vector<cytnx_uint64> &remain_rank_id,
const std::vector<cytnx_int64> &shape, const cytnx_uint64 &ax1,
Expand All @@ -91,7 +90,7 @@ namespace cytnx {
}

void Trace_internal_f(const bool &is_2d, Tensor &out, const Tensor &Tn,
const cytnx_uint64 &Ndiag, const int &Nomp, const cytnx_uint64 &Nelem,
const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem,
const std::vector<cytnx_uint64> &accu,
const std::vector<cytnx_uint64> &remain_rank_id,
const std::vector<cytnx_int64> &shape, const cytnx_uint64 &ax1,
Expand All @@ -104,7 +103,7 @@ namespace cytnx {
}

void Trace_internal_u64(const bool &is_2d, Tensor &out, const Tensor &Tn,
const cytnx_uint64 &Ndiag, const int &Nomp, const cytnx_uint64 &Nelem,
const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem,
const std::vector<cytnx_uint64> &accu,
const std::vector<cytnx_uint64> &remain_rank_id,
const std::vector<cytnx_int64> &shape, const cytnx_uint64 &ax1,
Expand All @@ -117,7 +116,7 @@ namespace cytnx {
}

void Trace_internal_i64(const bool &is_2d, Tensor &out, const Tensor &tn,
const cytnx_uint64 &ndiag, const int &nomp, const cytnx_uint64 &nelem,
const cytnx_uint64 &ndiag, const cytnx_uint64 &nelem,
const std::vector<cytnx_uint64> &accu,
const std::vector<cytnx_uint64> &remain_rank_id,
const std::vector<cytnx_int64> &shape, const cytnx_uint64 &ax1,
Expand All @@ -130,7 +129,7 @@ namespace cytnx {
}

void Trace_internal_u32(const bool &is_2d, Tensor &out, const Tensor &tn,
const cytnx_uint64 &ndiag, const int &nomp, const cytnx_uint64 &nelem,
const cytnx_uint64 &ndiag, const cytnx_uint64 &nelem,
const std::vector<cytnx_uint64> &accu,
const std::vector<cytnx_uint64> &remain_rank_id,
const std::vector<cytnx_int64> &shape, const cytnx_uint64 &ax1,
Expand All @@ -143,7 +142,7 @@ namespace cytnx {
}

void Trace_internal_i32(const bool &is_2d, Tensor &out, const Tensor &tn,
const cytnx_uint64 &ndiag, const int &nomp, const cytnx_uint64 &nelem,
const cytnx_uint64 &ndiag, const cytnx_uint64 &nelem,
const std::vector<cytnx_uint64> &accu,
const std::vector<cytnx_uint64> &remain_rank_id,
const std::vector<cytnx_int64> &shape, const cytnx_uint64 &ax1,
Expand All @@ -156,7 +155,7 @@ namespace cytnx {
}

void Trace_internal_u16(const bool &is_2d, Tensor &out, const Tensor &tn,
const cytnx_uint64 &ndiag, const int &nomp, const cytnx_uint64 &nelem,
const cytnx_uint64 &ndiag, const cytnx_uint64 &nelem,
const std::vector<cytnx_uint64> &accu,
const std::vector<cytnx_uint64> &remain_rank_id,
const std::vector<cytnx_int64> &shape, const cytnx_uint64 &ax1,
Expand All @@ -169,7 +168,7 @@ namespace cytnx {
}

void Trace_internal_i16(const bool &is_2d, Tensor &out, const Tensor &tn,
const cytnx_uint64 &ndiag, const int &nomp, const cytnx_uint64 &nelem,
const cytnx_uint64 &ndiag, const cytnx_uint64 &nelem,
const std::vector<cytnx_uint64> &accu,
const std::vector<cytnx_uint64> &remain_rank_id,
const std::vector<cytnx_int64> &shape, const cytnx_uint64 &ax1,
Expand All @@ -182,7 +181,7 @@ namespace cytnx {
}

void Trace_internal_b(const bool &is_2d, Tensor &out, const Tensor &tn,
const cytnx_uint64 &ndiag, const int &nomp, const cytnx_uint64 &nelem,
const cytnx_uint64 &ndiag, const cytnx_uint64 &nelem,
const std::vector<cytnx_uint64> &accu,
const std::vector<cytnx_uint64> &remain_rank_id,
const std::vector<cytnx_int64> &shape, const cytnx_uint64 &ax1,
Expand Down
22 changes: 11 additions & 11 deletions src/backend/linalg_internal_cpu/Trace_internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,85 +13,85 @@ namespace cytnx {
namespace linalg_internal {

void Trace_internal_cd(const bool &is_2d, Tensor &out, const Tensor &Tn,
const cytnx_uint64 &Ndiag, const int &Nomp, const cytnx_uint64 &Nelem,
const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem,
const std::vector<cytnx_uint64> &accu,
const std::vector<cytnx_uint64> &remain_rank_id,
const std::vector<cytnx_int64> &shape, const cytnx_uint64 &ax1,
const cytnx_uint64 &ax2);

void Trace_internal_cf(const bool &is_2d, Tensor &out, const Tensor &Tn,
const cytnx_uint64 &Ndiag, const int &Nomp, const cytnx_uint64 &Nelem,
const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem,
const std::vector<cytnx_uint64> &accu,
const std::vector<cytnx_uint64> &remain_rank_id,
const std::vector<cytnx_int64> &shape, const cytnx_uint64 &ax1,
const cytnx_uint64 &ax2);

void Trace_internal_d(const bool &is_2d, Tensor &out, const Tensor &Tn,
const cytnx_uint64 &Ndiag, const int &Nomp,
const cytnx_uint64 &Ndiag,

const cytnx_uint64 &Nelem, const std::vector<cytnx_uint64> &accu,
const std::vector<cytnx_uint64> &remain_rank_id,
const std::vector<cytnx_int64> &shape, const cytnx_uint64 &ax1,
const cytnx_uint64 &ax2);

void Trace_internal_f(const bool &is_2d, Tensor &out, const Tensor &Tn,
const cytnx_uint64 &Ndiag, const int &Nomp,
const cytnx_uint64 &Ndiag,

const cytnx_uint64 &Nelem, const std::vector<cytnx_uint64> &accu,
const std::vector<cytnx_uint64> &remain_rank_id,
const std::vector<cytnx_int64> &shape, const cytnx_uint64 &ax1,
const cytnx_uint64 &ax2);

void Trace_internal_u64(const bool &is_2d, Tensor &out, const Tensor &Tn,
const cytnx_uint64 &Ndiag, const int &Nomp,
const cytnx_uint64 &Ndiag,

const cytnx_uint64 &Nelem, const std::vector<cytnx_uint64> &accu,
const std::vector<cytnx_uint64> &remain_rank_id,
const std::vector<cytnx_int64> &shape, const cytnx_uint64 &ax1,
const cytnx_uint64 &ax2);

void Trace_internal_i64(const bool &is_2d, Tensor &out, const Tensor &Tn,
const cytnx_uint64 &Ndiag, const int &Nomp,
const cytnx_uint64 &Ndiag,

const cytnx_uint64 &Nelem, const std::vector<cytnx_uint64> &accu,
const std::vector<cytnx_uint64> &remain_rank_id,
const std::vector<cytnx_int64> &shape, const cytnx_uint64 &ax1,
const cytnx_uint64 &ax2);

void Trace_internal_u32(const bool &is_2d, Tensor &out, const Tensor &Tn,
const cytnx_uint64 &Ndiag, const int &Nomp,
const cytnx_uint64 &Ndiag,

const cytnx_uint64 &Nelem, const std::vector<cytnx_uint64> &accu,
const std::vector<cytnx_uint64> &remain_rank_id,
const std::vector<cytnx_int64> &shape, const cytnx_uint64 &ax1,
const cytnx_uint64 &ax2);

void Trace_internal_i32(const bool &is_2d, Tensor &out, const Tensor &Tn,
const cytnx_uint64 &Ndiag, const int &Nomp,
const cytnx_uint64 &Ndiag,

const cytnx_uint64 &Nelem, const std::vector<cytnx_uint64> &accu,
const std::vector<cytnx_uint64> &remain_rank_id,
const std::vector<cytnx_int64> &shape, const cytnx_uint64 &ax1,
const cytnx_uint64 &ax2);

void Trace_internal_u16(const bool &is_2d, Tensor &out, const Tensor &Tn,
const cytnx_uint64 &Ndiag, const int &Nomp,
const cytnx_uint64 &Ndiag,

const cytnx_uint64 &Nelem, const std::vector<cytnx_uint64> &accu,
const std::vector<cytnx_uint64> &remain_rank_id,
const std::vector<cytnx_int64> &shape, const cytnx_uint64 &ax1,
const cytnx_uint64 &ax2);

void Trace_internal_i16(const bool &is_2d, Tensor &out, const Tensor &Tn,
const cytnx_uint64 &Ndiag, const int &Nomp,
const cytnx_uint64 &Ndiag,

const cytnx_uint64 &Nelem, const std::vector<cytnx_uint64> &accu,
const std::vector<cytnx_uint64> &remain_rank_id,
const std::vector<cytnx_int64> &shape, const cytnx_uint64 &ax1,
const cytnx_uint64 &ax2);

void Trace_internal_b(const bool &is_2d, Tensor &out, const Tensor &Tn,
const cytnx_uint64 &Ndiag, const int &Nomp,
const cytnx_uint64 &Ndiag,

const cytnx_uint64 &Nelem, const std::vector<cytnx_uint64> &accu,
const std::vector<cytnx_uint64> &remain_rank_id,
Expand Down
1 change: 1 addition & 0 deletions src/backend/linalg_internal_gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ target_sources_local(cytnx
cuMod_internal.cu
cuDet_internal.cu
cuSum_internal.cu
cuTrace_internal.cu
cuMaxMin_internal.cu
cuTensordot_internal.cu
cuQuantumGeSvd_internal.cu
Expand Down
Loading
Loading