Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 62 additions & 63 deletions src/BlockFermionicUniTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,21 +164,13 @@ namespace cytnx {

// if exists:
if (std::all_of(tot_qns.begin(), tot_qns.end(), [](const int &i) { return i == 0; })) {
// get size & init block!
// init block!
for (cytnx_int32 i = 0; i < Loc.size(); i++) {
size[i] = this->_bonds[i]._impl->_degs[Loc[i]];
}
if (!no_alloc) {
// cytnx_uint64 blockNelem = 1;
for (cytnx_int32 i = 0; i < Loc.size(); i++) {
size[i] = this->_bonds[i]._impl->_degs[Loc[i]];
// blockNelem *= size[i];
}
this->_blocks.push_back(zeros(size, dtype, device));
// blocklens.push_back(blockNelem);
// blocksizes.push_back(size);
// totblocksize += blockNelem;
} else {
for (cytnx_int32 i = 0; i < Loc.size(); i++) {
size[i] = this->_bonds[i]._impl->_degs[Loc[i]];
}
this->_blocks.push_back(Tensor(size, dtype, device, false));
}
// push its loc
Expand Down Expand Up @@ -1366,6 +1358,7 @@ namespace cytnx {
// output instance;
BlockFermionicUniTensor *tmp = new BlockFermionicUniTensor();
BlockFermionicUniTensor *Rtn = (BlockFermionicUniTensor *)rhs.get();
const unsigned int common_dtype = Type.type_promote(this->dtype(), rhs->dtype());
std::vector<string> out_labels;
std::vector<Bond> out_bonds;
cytnx_int64 out_rowrank;
Expand All @@ -1381,7 +1374,7 @@ namespace cytnx {
vec_concatenate_(out_labels, this->_labels, rhs->_labels);

// cout << out_bonds;
tmp->Init(out_bonds, out_labels, out_rowrank, this->dtype(), this->device(), false);
tmp->Init(out_bonds, out_labels, out_rowrank, common_dtype, this->device(), false);

// tmp->_name = this->_name + "+" + rhs->_name;

Expand Down Expand Up @@ -1551,6 +1544,7 @@ namespace cytnx {
} else {
BlockFermionicUniTensor *tmp = new BlockFermionicUniTensor();
BlockFermionicUniTensor *Rtn = (BlockFermionicUniTensor *)rhs.get();
const unsigned int common_dtype = Type.type_promote(this->dtype(), rhs->dtype());
std::vector<string> out_labels;
std::vector<Bond> out_bonds;
cytnx_int64 out_rowrank;
Expand All @@ -1569,20 +1563,21 @@ namespace cytnx {
for (cytnx_uint64 i = 0; i < comm_idx2.size(); i++)
if (comm_idx2[i] < rhs->_rowrank) out_rowrank--;

#ifdef UNI_MKL
// Initialize!!
if (true or
(this->dtype() != Type.Double and this->dtype() != Type.ComplexDouble) and
(this->dtype() != Type.Float and this->dtype() != Type.ComplexFloat) or
this->is_diag() or Rtn->is_diag()) {
tmp->Init(out_bonds, out_labels, out_rowrank, this->dtype(), this->device(), false,
false);
} else {
tmp->Init(out_bonds, out_labels, out_rowrank, this->dtype(), this->device(), false, true);
}
#else
tmp->Init(out_bonds, out_labels, out_rowrank, this->dtype(), this->device(), false, false);
#endif
// #ifdef UNI_MKL
// // Initialize!!
// if (true or
// (this->dtype() != Type.Double and this->dtype() != Type.ComplexDouble) and
// (this->dtype() != Type.Float and this->dtype() != Type.ComplexFloat) or
// this->is_diag() or Rtn->is_diag()) {
// tmp->Init(out_bonds, out_labels, out_rowrank, common_dtype, this->device(),
// false, false);
// } else {
// tmp->Init(out_bonds, out_labels, out_rowrank, common_dtype, this->device(),
// false, true);
// }
// #else
Comment on lines +1566 to +1578
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reason to leave the commented code? We can revert it from git history if we need it again in the future. If there is a reason to do this, I suggest leave a comment for the reason.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not delete these parts because I did not write them, and I am not sure what is happening here. If no one needs this currently, we can certainly delete it.

tmp->Init(out_bonds, out_labels, out_rowrank, common_dtype, this->device(), false, false);
// #endif

// now, build the itoi table:
std::vector<std::vector<cytnx_uint64>> itoiL_common(this->_blocks.size()),
Expand Down Expand Up @@ -1675,6 +1670,7 @@ namespace cytnx {
std::vector<Scalar> betas(Rtn->_blocks.size(), 0.0);

BlockFermionicUniTensor *tmp_Rtn = Rtn;
bool tmp_rtn_is_casted = false;

// check if all sub-tensor are same dtype and device
if (User_debug) {
Expand Down Expand Up @@ -1709,13 +1705,14 @@ namespace cytnx {
}
#ifdef UNI_MKL
// If the dtype of this and Rtn are different, we need to cast to the common dtype
if (this->dtype() != Rtn->dtype()) {
if (Rtn->dtype() != common_dtype) {
BlockFermionicUniTensor *tmpp = Rtn->clone_meta(true, true);
tmpp->_blocks.resize(Rtn->_blocks.size());
for (cytnx_int64 blk = 0; blk < Rtn->_blocks.size(); blk++) {
tmpp->_blocks[blk] = Rtn->_blocks[blk].astype(this->dtype());
tmpp->_blocks[blk] = Rtn->_blocks[blk].astype(common_dtype);
}
tmp_Rtn = tmpp;
tmp_rtn_is_casted = true;
}
// First select left block to do gemm
for (cytnx_int64 a = 0; a < this->_blocks.size(); a++) {
Expand Down Expand Up @@ -1757,44 +1754,46 @@ namespace cytnx {
reshaped[targ_b] = true;
betas[binx] = 0.0;
}
// prepare to call gemm_batch
if (false and (tmp->dtype() <= 4 and this->dtype() <= 4 and tmp_Rtn->dtype() <= 4) and
(tmp->dtype() != Type.Void and this->dtype() != Type.Void and
tmp_Rtn->dtype() != Type.Void)) {
ms[binx] = this->_blocks[a].shape()[0];
ns[binx] = tmp_Rtn->_blocks[b].shape()[1];
ks[binx] = comm_dim;
LMems[binx] = this->_blocks[a].storage()._impl->data();
RMems[binx] = tmp_Rtn->_blocks[b].storage()._impl->data();
CMems[binx] = tmp->_blocks[targ_b].storage()._impl->data();
// // prepare to call gemm_batch
// if (false and (tmp->dtype() <= 4 and this->dtype() <= 4 and tmp_Rtn->dtype() <= 4)
// and
// (tmp->dtype() != Type.Void and this->dtype() != Type.Void and
// tmp_Rtn->dtype() != Type.Void)) {
// ms[binx] = this->_blocks[a].shape()[0];
// ns[binx] = tmp_Rtn->_blocks[b].shape()[1];
// ks[binx] = comm_dim;
// LMems[binx] = this->_blocks[a].storage()._impl->data();
// RMems[binx] = tmp_Rtn->_blocks[b].storage()._impl->data();
// CMems[binx] = tmp->_blocks[targ_b].storage()._impl->data();
// } else {
// fermionic signs included here
if (signfliplhs[a] != signfliprhs[b]) {
tmp->_blocks[targ_b] -= linalg::Matmul(this->_blocks[a], tmp_Rtn->_blocks[b])
.reshape(tmp->_blocks[targ_b].shape());
} else {
// fermionic signs included here
if (signfliplhs[a] != signfliprhs[b]) {
tmp->_blocks[targ_b] -= linalg::Matmul(this->_blocks[a], tmp_Rtn->_blocks[b])
.reshape(tmp->_blocks[targ_b].shape());
} else {
tmp->_blocks[targ_b] += linalg::Matmul(this->_blocks[a], tmp_Rtn->_blocks[b])
.reshape(tmp->_blocks[targ_b].shape());
}
tmp->_blocks[targ_b] += linalg::Matmul(this->_blocks[a], tmp_Rtn->_blocks[b])
.reshape(tmp->_blocks[targ_b].shape());
}
// }
}
// mkl_set_interface_layer(MKL_INTERFACE_ILP64);

blas_int group_count = itoiR_idx.size();
if (false and (tmp->dtype() <= 4 and this->dtype() <= 4 and tmp_Rtn->dtype() <= 4) and
(tmp->dtype() != Type.Void and this->dtype() != Type.Void and
tmp_Rtn->dtype() != Type.Void)) {
group_size.resize(group_count, 1);
// TODOfermions: alphas need to include sign factors!
cytnx_error_msg(true,
"[ERROR] Fermionic sign flips not implemented yet in Gemm_Batch "
"contracition. One needs to change the signs of the alphas.%s",
"\n")
linalg::__Gemm_Batch(transs, transs, ms, ns, ks, alphas,
(const void **)LMems.data(), (const void **)RMems.data(),
betas, (void **)CMems.data(), group_count, group_size,
this->dtype(), tmp->device());
}
// if (false and (tmp->dtype() <= 4 and this->dtype() <= 4 and tmp_Rtn->dtype() <= 4)
// and
// (tmp->dtype() != Type.Void and this->dtype() != Type.Void and
// tmp_Rtn->dtype() != Type.Void)) {
// blas_int group_count = itoiR_idx.size();
// group_size.resize(group_count, 1);
// // TODOfermions: alphas need to include sign factors!
// cytnx_error_msg(true,
// "[ERROR] Fermionic sign flips not implemented yet in Gemm_Batch "
// "contracition. One needs to change the signs of the alphas.%s",
// "\n")
// linalg::__Gemm_Batch(transs, transs, ms, ns, ks, alphas,
// (const void **)LMems.data(), (const void **)RMems.data(),
// betas, (void **)CMems.data(), group_count, group_size,
// common_dtype, tmp->device());
// }
// restore the shape&permutation of this->_blocks[a]
for (cytnx_uint64 binx = 0; binx < itoiR_idx.size(); binx++) {
cytnx_uint64 b = itoiR_idx[binx];
Expand All @@ -1816,7 +1815,7 @@ namespace cytnx {
}

// if Rtn dtype is casted, delete the tmp_Rtn
if (this->dtype() != Rtn->dtype()) {
if (tmp_rtn_is_casted) {
delete tmp_Rtn;
}
}
Expand Down
106 changes: 53 additions & 53 deletions src/BlockUniTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,21 +153,13 @@ namespace cytnx {

// if exists:
if (std::all_of(tot_qns.begin(), tot_qns.end(), [](const int &i) { return i == 0; })) {
// get size & init block!
// init block!
for (cytnx_int32 i = 0; i < Loc.size(); i++) {
size[i] = this->_bonds[i]._impl->_degs[Loc[i]];
}
if (!no_alloc) {
// cytnx_uint64 blockNelem = 1;
for (cytnx_int32 i = 0; i < Loc.size(); i++) {
size[i] = this->_bonds[i]._impl->_degs[Loc[i]];
// blockNelem *= size[i];
}
this->_blocks.push_back(zeros(size, dtype, device));
// blocklens.push_back(blockNelem);
// blocksizes.push_back(size);
// totblocksize += blockNelem;
} else {
for (cytnx_int32 i = 0; i < Loc.size(); i++) {
size[i] = this->_bonds[i]._impl->_degs[Loc[i]];
}
this->_blocks.push_back(Tensor(size, dtype, device, false));
}
// push its loc
Expand Down Expand Up @@ -745,6 +737,7 @@ namespace cytnx {
// output instance;
BlockUniTensor *tmp = new BlockUniTensor();
BlockUniTensor *Rtn = (BlockUniTensor *)rhs.get();
const unsigned int common_dtype = Type.type_promote(this->dtype(), rhs->dtype());
std::vector<string> out_labels;
std::vector<Bond> out_bonds;
cytnx_int64 out_rowrank;
Expand All @@ -760,7 +753,7 @@ namespace cytnx {
vec_concatenate_(out_labels, this->_labels, rhs->_labels);

// cout << out_bonds;
tmp->Init(out_bonds, out_labels, out_rowrank, this->dtype(), this->device(), false);
tmp->Init(out_bonds, out_labels, out_rowrank, common_dtype, this->device(), false);

// tmp->_name = this->_name + "+" + rhs->_name;

Expand Down Expand Up @@ -906,6 +899,7 @@ namespace cytnx {
} else {
BlockUniTensor *tmp = new BlockUniTensor();
BlockUniTensor *Rtn = (BlockUniTensor *)rhs.get();
const unsigned int common_dtype = Type.type_promote(this->dtype(), rhs->dtype());
std::vector<string> out_labels;
std::vector<Bond> out_bonds;
cytnx_int64 out_rowrank;
Expand All @@ -924,20 +918,21 @@ namespace cytnx {
for (cytnx_uint64 i = 0; i < comm_idx2.size(); i++)
if (comm_idx2[i] < rhs->_rowrank) out_rowrank--;

#ifdef UNI_MKL
// Initialize!!
if (true or
(this->dtype() != Type.Double and this->dtype() != Type.ComplexDouble) and
(this->dtype() != Type.Float and this->dtype() != Type.ComplexFloat) or
this->is_diag() or Rtn->is_diag()) {
tmp->Init(out_bonds, out_labels, out_rowrank, this->dtype(), this->device(), false,
false);
} else {
tmp->Init(out_bonds, out_labels, out_rowrank, this->dtype(), this->device(), false, true);
}
#else
tmp->Init(out_bonds, out_labels, out_rowrank, this->dtype(), this->device(), false, false);
#endif
// #ifdef UNI_MKL
// // Initialize!!
// if (true or
// (this->dtype() != Type.Double and this->dtype() != Type.ComplexDouble) and
// (this->dtype() != Type.Float and this->dtype() != Type.ComplexFloat) or
// this->is_diag() or Rtn->is_diag()) {
// tmp->Init(out_bonds, out_labels, out_rowrank, common_dtype, this->device(),
// false, false);
// } else {
// tmp->Init(out_bonds, out_labels, out_rowrank, common_dtype, this->device(),
// false, true);
// }
// #else
tmp->Init(out_bonds, out_labels, out_rowrank, common_dtype, this->device(), false, false);
// #endif

// now, build the itoi table:
std::vector<std::vector<cytnx_uint64>> itoiL_common(this->_blocks.size()),
Expand Down Expand Up @@ -1012,6 +1007,7 @@ namespace cytnx {
std::vector<Scalar> betas(Rtn->_blocks.size(), 0.0);

BlockUniTensor *tmp_Rtn = Rtn;
bool tmp_rtn_is_casted = false;

// check if all sub-tensor are same dtype and device
if (User_debug) {
Expand Down Expand Up @@ -1046,13 +1042,14 @@ namespace cytnx {
}
#ifdef UNI_MKL
// If the dtype of this and Rtn are different, we need to cast to the common dtype
if (this->dtype() != Rtn->dtype()) {
if (Rtn->dtype() != common_dtype) {
BlockUniTensor *tmpp = Rtn->clone_meta(true, true);
tmpp->_blocks.resize(Rtn->_blocks.size());
for (cytnx_int64 blk = 0; blk < Rtn->_blocks.size(); blk++) {
tmpp->_blocks[blk] = Rtn->_blocks[blk].astype(this->dtype());
tmpp->_blocks[blk] = Rtn->_blocks[blk].astype(common_dtype);
}
tmp_Rtn = tmpp;
tmp_rtn_is_casted = true;
}
// First select left block to do gemm
for (cytnx_int64 a = 0; a < this->_blocks.size(); a++) {
Expand Down Expand Up @@ -1094,32 +1091,35 @@ namespace cytnx {
reshaped[targ_b] = true;
betas[binx] = 0.0;
}
// prepare to call gemm_batch
if (false and (tmp->dtype() <= 4 and this->dtype() <= 4 and tmp_Rtn->dtype() <= 4) and
(tmp->dtype() != Type.Void and this->dtype() != Type.Void and
tmp_Rtn->dtype() != Type.Void)) {
ms[binx] = this->_blocks[a].shape()[0];
ns[binx] = tmp_Rtn->_blocks[b].shape()[1];
ks[binx] = comm_dim;
LMems[binx] = this->_blocks[a].storage()._impl->data();
RMems[binx] = tmp_Rtn->_blocks[b].storage()._impl->data();
CMems[binx] = tmp->_blocks[targ_b].storage()._impl->data();
} else {
tmp->_blocks[targ_b] += linalg::Matmul(this->_blocks[a], tmp_Rtn->_blocks[b])
.reshape(tmp->_blocks[targ_b].shape());
}
// // prepare to call gemm_batch
// if (false and (tmp->dtype() <= 4 and this->dtype() <= 4 and tmp_Rtn->dtype() <= 4)
// and
// (tmp->dtype() != Type.Void and this->dtype() != Type.Void and
// tmp_Rtn->dtype() != Type.Void)) {
// ms[binx] = this->_blocks[a].shape()[0];
// ns[binx] = tmp_Rtn->_blocks[b].shape()[1];
// ks[binx] = comm_dim;
// LMems[binx] = this->_blocks[a].storage()._impl->data();
// RMems[binx] = tmp_Rtn->_blocks[b].storage()._impl->data();
// CMems[binx] = tmp->_blocks[targ_b].storage()._impl->data();
// } else {
tmp->_blocks[targ_b] += linalg::Matmul(this->_blocks[a], tmp_Rtn->_blocks[b])
.reshape(tmp->_blocks[targ_b].shape());
// }
}
// mkl_set_interface_layer(MKL_INTERFACE_ILP64);

blas_int group_count = itoiR_idx.size();
if (false and (tmp->dtype() <= 4 and this->dtype() <= 4 and tmp_Rtn->dtype() <= 4) and
(tmp->dtype() != Type.Void and this->dtype() != Type.Void and
tmp_Rtn->dtype() != Type.Void)) {
group_size.resize(group_count, 1);
linalg::__Gemm_Batch(transs, transs, ms, ns, ks, alphas, (const void **)LMems.data(),
(const void **)RMems.data(), betas, (void **)CMems.data(),
group_count, group_size, this->dtype(), tmp->device());
}
// if (false and (tmp->dtype() <= 4 and this->dtype() <= 4 and tmp_Rtn->dtype() <= 4)
// and
// (tmp->dtype() != Type.Void and this->dtype() != Type.Void and
// tmp_Rtn->dtype() != Type.Void)) {
// blas_int group_count = itoiR_idx.size();
// group_size.resize(group_count, 1);
// linalg::__Gemm_Batch(transs, transs, ms, ns, ks, alphas, (const void
// **)LMems.data(),
// (const void **)RMems.data(), betas, (void **)CMems.data(),
// group_count, group_size, common_dtype, tmp->device());
// }
// restore the shape&permutation of this->_blocks[a]
for (cytnx_uint64 binx = 0; binx < itoiR_idx.size(); binx++) {
cytnx_uint64 b = itoiR_idx[binx];
Expand All @@ -1141,7 +1141,7 @@ namespace cytnx {
}

// if Rtn dtype is casted, delete the tmp_Rtn
if (this->dtype() != Rtn->dtype()) {
if (tmp_rtn_is_casted) {
delete tmp_Rtn;
}
}
Expand Down
Loading
Loading