From e9dbc8250b3f297b7ef9faf71687cb06253fc271 Mon Sep 17 00:00:00 2001 From: Ying-Jer Kao Date: Mon, 30 Mar 2026 18:29:09 +0800 Subject: [PATCH 1/2] Fix mixed real/complex block contract dtype promotion (#758) --- src/BlockFermionicUniTensor.cpp | 21 ++++++++++++--------- src/BlockUniTensor.cpp | 21 ++++++++++++--------- tests/BlockUniTensor_test.cpp | 25 +++++++++++++++++++++++++ 3 files changed, 49 insertions(+), 18 deletions(-) diff --git a/src/BlockFermionicUniTensor.cpp b/src/BlockFermionicUniTensor.cpp index 095c1bea0..209af67d5 100644 --- a/src/BlockFermionicUniTensor.cpp +++ b/src/BlockFermionicUniTensor.cpp @@ -1366,6 +1366,7 @@ namespace cytnx { // output instance; BlockFermionicUniTensor *tmp = new BlockFermionicUniTensor(); BlockFermionicUniTensor *Rtn = (BlockFermionicUniTensor *)rhs.get(); + const unsigned int common_dtype = Type.type_promote(this->dtype(), rhs->dtype()); std::vector out_labels; std::vector out_bonds; cytnx_int64 out_rowrank; @@ -1381,7 +1382,7 @@ namespace cytnx { vec_concatenate_(out_labels, this->_labels, rhs->_labels); // cout << out_bonds; - tmp->Init(out_bonds, out_labels, out_rowrank, this->dtype(), this->device(), false); + tmp->Init(out_bonds, out_labels, out_rowrank, common_dtype, this->device(), false); // tmp->_name = this->_name + "+" + rhs->_name; @@ -1551,6 +1552,7 @@ namespace cytnx { } else { BlockFermionicUniTensor *tmp = new BlockFermionicUniTensor(); BlockFermionicUniTensor *Rtn = (BlockFermionicUniTensor *)rhs.get(); + const unsigned int common_dtype = Type.type_promote(this->dtype(), rhs->dtype()); std::vector out_labels; std::vector out_bonds; cytnx_int64 out_rowrank; @@ -1575,13 +1577,12 @@ namespace cytnx { (this->dtype() != Type.Double and this->dtype() != Type.ComplexDouble) and (this->dtype() != Type.Float and this->dtype() != Type.ComplexFloat) or this->is_diag() or Rtn->is_diag()) { - tmp->Init(out_bonds, out_labels, out_rowrank, this->dtype(), this->device(), false, - false); + tmp->Init(out_bonds, out_labels, out_rowrank, common_dtype, this->device(), false, false); } else { - tmp->Init(out_bonds, out_labels, out_rowrank, this->dtype(), this->device(), false, true); + tmp->Init(out_bonds, out_labels, out_rowrank, common_dtype, this->device(), false, true); } #else - tmp->Init(out_bonds, out_labels, out_rowrank, this->dtype(), this->device(), false, false); + tmp->Init(out_bonds, out_labels, out_rowrank, common_dtype, this->device(), false, false); #endif // now, build the itoi table: @@ -1675,6 +1676,7 @@ namespace cytnx { std::vector betas(Rtn->_blocks.size(), 0.0); BlockFermionicUniTensor *tmp_Rtn = Rtn; + bool tmp_rtn_is_casted = false; // check if all sub-tensor are same dtype and device if (User_debug) { @@ -1709,13 +1711,14 @@ namespace cytnx { } #ifdef UNI_MKL // If the dtype of this and Rtn are different, we need to cast to the common dtype - if (this->dtype() != Rtn->dtype()) { + if (Rtn->dtype() != common_dtype) { BlockFermionicUniTensor *tmpp = Rtn->clone_meta(true, true); tmpp->_blocks.resize(Rtn->_blocks.size()); for (cytnx_int64 blk = 0; blk < Rtn->_blocks.size(); blk++) { - tmpp->_blocks[blk] = Rtn->_blocks[blk].astype(this->dtype()); + tmpp->_blocks[blk] = Rtn->_blocks[blk].astype(common_dtype); } tmp_Rtn = tmpp; + tmp_rtn_is_casted = true; } // First select left block to do gemm for (cytnx_int64 a = 0; a < this->_blocks.size(); a++) { @@ -1793,7 +1796,7 @@ namespace cytnx { linalg::__Gemm_Batch(transs, transs, ms, ns, ks, alphas, (const void **)LMems.data(), (const void **)RMems.data(), betas, (void **)CMems.data(), group_count, group_size, - this->dtype(), tmp->device()); + common_dtype, tmp->device()); } // restore the shape&permutation of this->_blocks[a] for (cytnx_uint64 binx = 0; binx < itoiR_idx.size(); binx++) { @@ -1816,7 +1819,7 @@ namespace cytnx { } // if Rtn dtype is casted, delete the tmp_Rtn - if (this->dtype() != Rtn->dtype()) { + if (tmp_rtn_is_casted) { delete tmp_Rtn; } } diff --git a/src/BlockUniTensor.cpp b/src/BlockUniTensor.cpp index 9c1bdf75e..36d95465b 100644 --- a/src/BlockUniTensor.cpp +++ b/src/BlockUniTensor.cpp @@ -745,6 +745,7 @@ namespace cytnx { // output instance; BlockUniTensor *tmp = new BlockUniTensor(); BlockUniTensor *Rtn = (BlockUniTensor *)rhs.get(); + const unsigned int common_dtype = Type.type_promote(this->dtype(), rhs->dtype()); std::vector out_labels; std::vector out_bonds; cytnx_int64 out_rowrank; @@ -760,7 +761,7 @@ namespace cytnx { vec_concatenate_(out_labels, this->_labels, rhs->_labels); // cout << out_bonds; - tmp->Init(out_bonds, out_labels, out_rowrank, this->dtype(), this->device(), false); + tmp->Init(out_bonds, out_labels, out_rowrank, common_dtype, this->device(), false); // tmp->_name = this->_name + "+" + rhs->_name; @@ -906,6 +907,7 @@ namespace cytnx { } else { BlockUniTensor *tmp = new BlockUniTensor(); BlockUniTensor *Rtn = (BlockUniTensor *)rhs.get(); + const unsigned int common_dtype = Type.type_promote(this->dtype(), rhs->dtype()); std::vector out_labels; std::vector out_bonds; cytnx_int64 out_rowrank; @@ -930,13 +932,12 @@ namespace cytnx { (this->dtype() != Type.Double and this->dtype() != Type.ComplexDouble) and (this->dtype() != Type.Float and this->dtype() != Type.ComplexFloat) or this->is_diag() or Rtn->is_diag()) { - tmp->Init(out_bonds, out_labels, out_rowrank, this->dtype(), this->device(), false, - false); + tmp->Init(out_bonds, out_labels, out_rowrank, common_dtype, this->device(), false, false); } else { - tmp->Init(out_bonds, out_labels, out_rowrank, this->dtype(), this->device(), false, true); + tmp->Init(out_bonds, out_labels, out_rowrank, common_dtype, this->device(), false, true); } #else - tmp->Init(out_bonds, out_labels, out_rowrank, this->dtype(), this->device(), false, false); + tmp->Init(out_bonds, out_labels, out_rowrank, common_dtype, this->device(), false, false); #endif // now, build the itoi table: @@ -1012,6 +1013,7 @@ namespace cytnx { std::vector betas(Rtn->_blocks.size(), 0.0); BlockUniTensor *tmp_Rtn = Rtn; + bool tmp_rtn_is_casted = false; // check if all sub-tensor are same dtype and device if (User_debug) { @@ -1046,13 +1048,14 @@ namespace cytnx { } #ifdef UNI_MKL // If the dtype of this and Rtn are different, we need to cast to the common dtype - if (this->dtype() != Rtn->dtype()) { + if (Rtn->dtype() != common_dtype) { BlockUniTensor *tmpp = Rtn->clone_meta(true, true); tmpp->_blocks.resize(Rtn->_blocks.size()); for (cytnx_int64 blk = 0; blk < Rtn->_blocks.size(); blk++) { - tmpp->_blocks[blk] = Rtn->_blocks[blk].astype(this->dtype()); + tmpp->_blocks[blk] = Rtn->_blocks[blk].astype(common_dtype); } tmp_Rtn = tmpp; + tmp_rtn_is_casted = true; } // First select left block to do gemm for (cytnx_int64 a = 0; a < this->_blocks.size(); a++) { @@ -1118,7 +1121,7 @@ namespace cytnx { group_size.resize(group_count, 1); linalg::__Gemm_Batch(transs, transs, ms, ns, ks, alphas, (const void **)LMems.data(), (const void **)RMems.data(), betas, (void **)CMems.data(), - group_count, group_size, this->dtype(), tmp->device()); + group_count, group_size, common_dtype, tmp->device()); } // restore the shape&permutation of this->_blocks[a] for (cytnx_uint64 binx = 0; binx < itoiR_idx.size(); binx++) { @@ -1141,7 +1144,7 @@ namespace cytnx { } // if Rtn dtype is casted, delete the tmp_Rtn - if (this->dtype() != Rtn->dtype()) { + if (tmp_rtn_is_casted) { delete tmp_Rtn; } } diff --git a/tests/BlockUniTensor_test.cpp b/tests/BlockUniTensor_test.cpp index 4e406f92d..258ef309a 100644 --- a/tests/BlockUniTensor_test.cpp +++ b/tests/BlockUniTensor_test.cpp @@ -568,6 +568,31 @@ TEST_F(BlockUniTensorTest, contract3) { } } +TEST_F(BlockUniTensorTest, contract_mixed_dtype_order_independent) { + // Reproduce issue #758: real(QN) x complex(QN) should not depend on argument order. + UniTensor left_real = UT_contract_L2.astype(Type.Double); + UniTensor right_complex = UT_contract_R2.astype(Type.ComplexDouble); + + left_real.set_labels({"a", "b"}); + right_complex.set_labels({"b", "c"}); + + UniTensor out_real_complex; + UniTensor out_complex_real; + EXPECT_NO_THROW(out_real_complex = left_real.contract(right_complex)); + EXPECT_NO_THROW(out_complex_real = right_complex.contract(left_real)); + + EXPECT_EQ(out_real_complex.dtype(), Type.ComplexDouble); + EXPECT_EQ(out_complex_real.dtype(), Type.ComplexDouble); + + // Cross-check against all-complex references for each contraction ordering. + UniTensor left_complex = left_real.astype(Type.ComplexDouble); + UniTensor right_complex_ref = right_complex.astype(Type.ComplexDouble); + UniTensor ref_real_complex = left_complex.contract(right_complex_ref); + UniTensor ref_complex_real = right_complex_ref.contract(left_complex); + EXPECT_TRUE(AreNearlyEqUniTensor(out_real_complex, ref_real_complex, 1e-10)); + EXPECT_TRUE(AreNearlyEqUniTensor(out_complex_real, ref_complex_real, 1e-10)); +} + TEST_F(BlockUniTensorTest, same_data) { UniTensor B = UT_pB_ans.permute({1, 0, 2}); UniTensor C = B.contiguous(); From 8a92a14df9a133aea4ab5634bccd376302b07259 Mon Sep 17 00:00:00 2001 From: Manuel Schneider Date: Wed, 1 Apr 2026 15:56:27 +0800 Subject: [PATCH 2/2] cleanup of currently unused code branches --- src/BlockFermionicUniTensor.cpp | 110 +++++++++++++++----------------- src/BlockUniTensor.cpp | 91 +++++++++++++------------- 2 files changed, 97 insertions(+), 104 deletions(-) diff --git a/src/BlockFermionicUniTensor.cpp b/src/BlockFermionicUniTensor.cpp index 209af67d5..af526f2f6 100644 --- a/src/BlockFermionicUniTensor.cpp +++ b/src/BlockFermionicUniTensor.cpp @@ -164,21 +164,13 @@ namespace cytnx { // if exists: if (std::all_of(tot_qns.begin(), tot_qns.end(), [](const int &i) { return i == 0; })) { - // get size & init block! + // init block! + for (cytnx_int32 i = 0; i < Loc.size(); i++) { + size[i] = this->_bonds[i]._impl->_degs[Loc[i]]; + } if (!no_alloc) { - // cytnx_uint64 blockNelem = 1; - for (cytnx_int32 i = 0; i < Loc.size(); i++) { - size[i] = this->_bonds[i]._impl->_degs[Loc[i]]; - // blockNelem *= size[i]; - } this->_blocks.push_back(zeros(size, dtype, device)); - // blocklens.push_back(blockNelem); - // blocksizes.push_back(size); - // totblocksize += blockNelem; } else { - for (cytnx_int32 i = 0; i < Loc.size(); i++) { - size[i] = this->_bonds[i]._impl->_degs[Loc[i]]; - } this->_blocks.push_back(Tensor(size, dtype, device, false)); } // push its loc @@ -1571,19 +1563,21 @@ namespace cytnx { for (cytnx_uint64 i = 0; i < comm_idx2.size(); i++) if (comm_idx2[i] < rhs->_rowrank) out_rowrank--; - #ifdef UNI_MKL - // Initialize!! - if (true or - (this->dtype() != Type.Double and this->dtype() != Type.ComplexDouble) and - (this->dtype() != Type.Float and this->dtype() != Type.ComplexFloat) or - this->is_diag() or Rtn->is_diag()) { - tmp->Init(out_bonds, out_labels, out_rowrank, common_dtype, this->device(), false, false); - } else { - tmp->Init(out_bonds, out_labels, out_rowrank, common_dtype, this->device(), false, true); - } - #else + // #ifdef UNI_MKL + // // Initialize!! + // if (true or + // (this->dtype() != Type.Double and this->dtype() != Type.ComplexDouble) and + // (this->dtype() != Type.Float and this->dtype() != Type.ComplexFloat) or + // this->is_diag() or Rtn->is_diag()) { + // tmp->Init(out_bonds, out_labels, out_rowrank, common_dtype, this->device(), + // false, false); + // } else { + // tmp->Init(out_bonds, out_labels, out_rowrank, common_dtype, this->device(), + // false, true); + // } + // #else tmp->Init(out_bonds, out_labels, out_rowrank, common_dtype, this->device(), false, false); - #endif + // #endif // now, build the itoi table: std::vector> itoiL_common(this->_blocks.size()), @@ -1760,44 +1754,46 @@ namespace cytnx { reshaped[targ_b] = true; betas[binx] = 0.0; } - // prepare to call gemm_batch - if (false and (tmp->dtype() <= 4 and this->dtype() <= 4 and tmp_Rtn->dtype() <= 4) and - (tmp->dtype() != Type.Void and this->dtype() != Type.Void and - tmp_Rtn->dtype() != Type.Void)) { - ms[binx] = this->_blocks[a].shape()[0]; - ns[binx] = tmp_Rtn->_blocks[b].shape()[1]; - ks[binx] = comm_dim; - LMems[binx] = this->_blocks[a].storage()._impl->data(); - RMems[binx] = tmp_Rtn->_blocks[b].storage()._impl->data(); - CMems[binx] = tmp->_blocks[targ_b].storage()._impl->data(); + // // prepare to call gemm_batch + // if (false and (tmp->dtype() <= 4 and this->dtype() <= 4 and tmp_Rtn->dtype() <= 4) + // and + // (tmp->dtype() != Type.Void and this->dtype() != Type.Void and + // tmp_Rtn->dtype() != Type.Void)) { + // ms[binx] = this->_blocks[a].shape()[0]; + // ns[binx] = tmp_Rtn->_blocks[b].shape()[1]; + // ks[binx] = comm_dim; + // LMems[binx] = this->_blocks[a].storage()._impl->data(); + // RMems[binx] = tmp_Rtn->_blocks[b].storage()._impl->data(); + // CMems[binx] = tmp->_blocks[targ_b].storage()._impl->data(); + // } else { + // fermionic signs included here + if (signfliplhs[a] != signfliprhs[b]) { + tmp->_blocks[targ_b] -= linalg::Matmul(this->_blocks[a], tmp_Rtn->_blocks[b]) + .reshape(tmp->_blocks[targ_b].shape()); } else { - // fermionic signs included here - if (signfliplhs[a] != signfliprhs[b]) { - tmp->_blocks[targ_b] -= linalg::Matmul(this->_blocks[a], tmp_Rtn->_blocks[b]) - .reshape(tmp->_blocks[targ_b].shape()); - } else { - tmp->_blocks[targ_b] += linalg::Matmul(this->_blocks[a], tmp_Rtn->_blocks[b]) - .reshape(tmp->_blocks[targ_b].shape()); - } + tmp->_blocks[targ_b] += linalg::Matmul(this->_blocks[a], tmp_Rtn->_blocks[b]) + .reshape(tmp->_blocks[targ_b].shape()); } + // } } // mkl_set_interface_layer(MKL_INTERFACE_ILP64); - blas_int group_count = itoiR_idx.size(); - if (false and (tmp->dtype() <= 4 and this->dtype() <= 4 and tmp_Rtn->dtype() <= 4) and - (tmp->dtype() != Type.Void and this->dtype() != Type.Void and - tmp_Rtn->dtype() != Type.Void)) { - group_size.resize(group_count, 1); - // TODOfermions: alphas need to include sign factors! - cytnx_error_msg(true, - "[ERROR] Fermionic sign flips not implemented yet in Gemm_Batch " - "contracition. One needs to change the signs of the alphas.%s", - "\n") - linalg::__Gemm_Batch(transs, transs, ms, ns, ks, alphas, - (const void **)LMems.data(), (const void **)RMems.data(), - betas, (void **)CMems.data(), group_count, group_size, - common_dtype, tmp->device()); - } + // if (false and (tmp->dtype() <= 4 and this->dtype() <= 4 and tmp_Rtn->dtype() <= 4) + // and + // (tmp->dtype() != Type.Void and this->dtype() != Type.Void and + // tmp_Rtn->dtype() != Type.Void)) { + // blas_int group_count = itoiR_idx.size(); + // group_size.resize(group_count, 1); + // // TODOfermions: alphas need to include sign factors! + // cytnx_error_msg(true, + // "[ERROR] Fermionic sign flips not implemented yet in Gemm_Batch " + // "contracition. One needs to change the signs of the alphas.%s", + // "\n") + // linalg::__Gemm_Batch(transs, transs, ms, ns, ks, alphas, + // (const void **)LMems.data(), (const void **)RMems.data(), + // betas, (void **)CMems.data(), group_count, group_size, + // common_dtype, tmp->device()); + // } // restore the shape&permutation of this->_blocks[a] for (cytnx_uint64 binx = 0; binx < itoiR_idx.size(); binx++) { cytnx_uint64 b = itoiR_idx[binx]; diff --git a/src/BlockUniTensor.cpp b/src/BlockUniTensor.cpp index 36d95465b..26e27f617 100644 --- a/src/BlockUniTensor.cpp +++ b/src/BlockUniTensor.cpp @@ -153,21 +153,13 @@ namespace cytnx { // if exists: if (std::all_of(tot_qns.begin(), tot_qns.end(), [](const int &i) { return i == 0; })) { - // get size & init block! + // init block! + for (cytnx_int32 i = 0; i < Loc.size(); i++) { + size[i] = this->_bonds[i]._impl->_degs[Loc[i]]; + } if (!no_alloc) { - // cytnx_uint64 blockNelem = 1; - for (cytnx_int32 i = 0; i < Loc.size(); i++) { - size[i] = this->_bonds[i]._impl->_degs[Loc[i]]; - // blockNelem *= size[i]; - } this->_blocks.push_back(zeros(size, dtype, device)); - // blocklens.push_back(blockNelem); - // blocksizes.push_back(size); - // totblocksize += blockNelem; } else { - for (cytnx_int32 i = 0; i < Loc.size(); i++) { - size[i] = this->_bonds[i]._impl->_degs[Loc[i]]; - } this->_blocks.push_back(Tensor(size, dtype, device, false)); } // push its loc @@ -926,19 +918,21 @@ namespace cytnx { for (cytnx_uint64 i = 0; i < comm_idx2.size(); i++) if (comm_idx2[i] < rhs->_rowrank) out_rowrank--; - #ifdef UNI_MKL - // Initialize!! - if (true or - (this->dtype() != Type.Double and this->dtype() != Type.ComplexDouble) and - (this->dtype() != Type.Float and this->dtype() != Type.ComplexFloat) or - this->is_diag() or Rtn->is_diag()) { - tmp->Init(out_bonds, out_labels, out_rowrank, common_dtype, this->device(), false, false); - } else { - tmp->Init(out_bonds, out_labels, out_rowrank, common_dtype, this->device(), false, true); - } - #else + // #ifdef UNI_MKL + // // Initialize!! + // if (true or + // (this->dtype() != Type.Double and this->dtype() != Type.ComplexDouble) and + // (this->dtype() != Type.Float and this->dtype() != Type.ComplexFloat) or + // this->is_diag() or Rtn->is_diag()) { + // tmp->Init(out_bonds, out_labels, out_rowrank, common_dtype, this->device(), + // false, false); + // } else { + // tmp->Init(out_bonds, out_labels, out_rowrank, common_dtype, this->device(), + // false, true); + // } + // #else tmp->Init(out_bonds, out_labels, out_rowrank, common_dtype, this->device(), false, false); - #endif + // #endif // now, build the itoi table: std::vector> itoiL_common(this->_blocks.size()), @@ -1097,32 +1091,35 @@ namespace cytnx { reshaped[targ_b] = true; betas[binx] = 0.0; } - // prepare to call gemm_batch - if (false and (tmp->dtype() <= 4 and this->dtype() <= 4 and tmp_Rtn->dtype() <= 4) and - (tmp->dtype() != Type.Void and this->dtype() != Type.Void and - tmp_Rtn->dtype() != Type.Void)) { - ms[binx] = this->_blocks[a].shape()[0]; - ns[binx] = tmp_Rtn->_blocks[b].shape()[1]; - ks[binx] = comm_dim; - LMems[binx] = this->_blocks[a].storage()._impl->data(); - RMems[binx] = tmp_Rtn->_blocks[b].storage()._impl->data(); - CMems[binx] = tmp->_blocks[targ_b].storage()._impl->data(); - } else { - tmp->_blocks[targ_b] += linalg::Matmul(this->_blocks[a], tmp_Rtn->_blocks[b]) - .reshape(tmp->_blocks[targ_b].shape()); - } + // // prepare to call gemm_batch + // if (false and (tmp->dtype() <= 4 and this->dtype() <= 4 and tmp_Rtn->dtype() <= 4) + // and + // (tmp->dtype() != Type.Void and this->dtype() != Type.Void and + // tmp_Rtn->dtype() != Type.Void)) { + // ms[binx] = this->_blocks[a].shape()[0]; + // ns[binx] = tmp_Rtn->_blocks[b].shape()[1]; + // ks[binx] = comm_dim; + // LMems[binx] = this->_blocks[a].storage()._impl->data(); + // RMems[binx] = tmp_Rtn->_blocks[b].storage()._impl->data(); + // CMems[binx] = tmp->_blocks[targ_b].storage()._impl->data(); + // } else { + tmp->_blocks[targ_b] += linalg::Matmul(this->_blocks[a], tmp_Rtn->_blocks[b]) + .reshape(tmp->_blocks[targ_b].shape()); + // } } // mkl_set_interface_layer(MKL_INTERFACE_ILP64); - blas_int group_count = itoiR_idx.size(); - if (false and (tmp->dtype() <= 4 and this->dtype() <= 4 and tmp_Rtn->dtype() <= 4) and - (tmp->dtype() != Type.Void and this->dtype() != Type.Void and - tmp_Rtn->dtype() != Type.Void)) { - group_size.resize(group_count, 1); - linalg::__Gemm_Batch(transs, transs, ms, ns, ks, alphas, (const void **)LMems.data(), - (const void **)RMems.data(), betas, (void **)CMems.data(), - group_count, group_size, common_dtype, tmp->device()); - } + // if (false and (tmp->dtype() <= 4 and this->dtype() <= 4 and tmp_Rtn->dtype() <= 4) + // and + // (tmp->dtype() != Type.Void and this->dtype() != Type.Void and + // tmp_Rtn->dtype() != Type.Void)) { + // blas_int group_count = itoiR_idx.size(); + // group_size.resize(group_count, 1); + // linalg::__Gemm_Batch(transs, transs, ms, ns, ks, alphas, (const void + // **)LMems.data(), + // (const void **)RMems.data(), betas, (void **)CMems.data(), + // group_count, group_size, common_dtype, tmp->device()); + // } // restore the shape&permutation of this->_blocks[a] for (cytnx_uint64 binx = 0; binx < itoiR_idx.size(); binx++) { cytnx_uint64 b = itoiR_idx[binx];