From 1ede487db86c1fa0c32c81636bb514c234905d17 Mon Sep 17 00:00:00 2001 From: Ying-Jer Kao Date: Sat, 4 Apr 2026 00:09:08 +0800 Subject: [PATCH] fix: return discarded block singular values --- src/linalg/Gesvd_truncate.cpp | 31 ++++++++++++--- src/linalg/Svd_truncate.cpp | 31 ++++++++++++--- tests/linalg_test/linalg_test.cpp | 64 +++++++++++++++++++++++++++++++ 3 files changed, 114 insertions(+), 12 deletions(-) diff --git a/src/linalg/Gesvd_truncate.cpp b/src/linalg/Gesvd_truncate.cpp index 4b0c29995..1407831ba 100644 --- a/src/linalg/Gesvd_truncate.cpp +++ b/src/linalg/Gesvd_truncate.cpp @@ -19,6 +19,27 @@ namespace cytnx { namespace linalg { + namespace { + UniTensor BuildBlockDiscardedSingularValues(const Tensor &Sall, const cytnx_uint64 smidx, + const unsigned int return_err) { + Tensor terr({1}, Sall.dtype()); + terr.storage().at(0) = 0; + if (smidx == 0) { + return UniTensor(terr); + } + if (return_err == 1) { + terr.storage().at(0) = Sall.storage()(smidx - 1); + return UniTensor(terr); + } + + terr = Tensor({smidx}, Sall.dtype()); + for (cytnx_uint64 i = 0; i < smidx; i++) { + terr.storage().at(i) = Sall.storage()(smidx - 1 - i); + } + return UniTensor(terr); + } + } // namespace + std::vector Gesvd_truncate(const Tensor &Tin, const cytnx_uint64 &keepdim, const double &err, const bool &is_U, const bool &is_vT, const unsigned int &return_err, const cytnx_uint64 &mindim) { @@ -344,10 +365,9 @@ namespace cytnx { // handle return_err! if (return_err == 1) { - outCyT.push_back(UniTensor(Tensor({1}, Smin.dtype()))); - outCyT.back().get_block_().storage().at(0) = Smin; + outCyT.push_back(BuildBlockDiscardedSingularValues(Sall, smidx, return_err)); } else if (return_err) { - outCyT.push_back(UniTensor(Sall.get({Accessor::tilend(smidx)}))); + outCyT.push_back(BuildBlockDiscardedSingularValues(Sall, smidx, return_err)); } } // _gesvd_truncate_Block_UT @@ -638,10 +658,9 @@ namespace cytnx { } // handle return_err! if (return_err == 1) { - outCyT.push_back(UniTensor(Tensor({1}, Smin.dtype()))); - outCyT.back().get_block_().storage().at(0) = Smin; + outCyT.push_back(BuildBlockDiscardedSingularValues(Sall, smidx, return_err)); } else if (return_err) { - outCyT.push_back(UniTensor(Sall.get({Accessor::tilend(smidx)}))); + outCyT.push_back(BuildBlockDiscardedSingularValues(Sall, smidx, return_err)); } } else { if (return_err >= 1) { diff --git a/src/linalg/Svd_truncate.cpp b/src/linalg/Svd_truncate.cpp index 19147ae2c..a1b0f9908 100644 --- a/src/linalg/Svd_truncate.cpp +++ b/src/linalg/Svd_truncate.cpp @@ -12,6 +12,27 @@ #include "backend/linalg_internal_interface.hpp" namespace cytnx { namespace linalg { + namespace { + UniTensor BuildBlockDiscardedSingularValues(const Tensor &Sall, const cytnx_uint64 smidx, + const unsigned int return_err) { + Tensor terr({1}, Sall.dtype()); + terr.storage().at(0) = 0; + if (smidx == 0) { + return UniTensor(terr); + } + if (return_err == 1) { + terr.storage().at(0) = Sall.storage()(smidx - 1); + return UniTensor(terr); + } + + terr = Tensor({smidx}, Sall.dtype()); + for (cytnx_uint64 i = 0; i < smidx; i++) { + terr.storage().at(i) = Sall.storage()(smidx - 1 - i); + } + return UniTensor(terr); + } + } // namespace + std::vector Svd_truncate(const Tensor &Tin, const cytnx_uint64 &keepdim, const double &err, const bool &is_UvT, const unsigned int &return_err, const cytnx_uint64 &mindim) { @@ -331,10 +352,9 @@ namespace cytnx { // handle return_err! if (return_err == 1) { - outCyT.push_back(UniTensor(Tensor({1}, Smin.dtype()))); - outCyT.back().get_block_().storage().at(0) = Smin; + outCyT.push_back(BuildBlockDiscardedSingularValues(Sall, smidx, return_err)); } else if (return_err) { - outCyT.push_back(UniTensor(Sall.get({Accessor::tilend(smidx)}))); + outCyT.push_back(BuildBlockDiscardedSingularValues(Sall, smidx, return_err)); } } // _svd_truncate_Block_UTs @@ -464,10 +484,9 @@ namespace cytnx { } // handle return_err! if (return_err == 1) { - outCyT.push_back(UniTensor(Tensor({1}, Smin.dtype()))); - outCyT.back().get_block_().storage().at(0) = Smin; + outCyT.push_back(BuildBlockDiscardedSingularValues(Sall, smidx, return_err)); } else if (return_err) { - outCyT.push_back(UniTensor(Sall.get({Accessor::tilend(smidx)}))); + outCyT.push_back(BuildBlockDiscardedSingularValues(Sall, smidx, return_err)); } } else { if (return_err >= 1) { diff --git a/tests/linalg_test/linalg_test.cpp b/tests/linalg_test/linalg_test.cpp index e4deeaf49..4bf5d6f4d 100644 --- a/tests/linalg_test/linalg_test.cpp +++ b/tests/linalg_test/linalg_test.cpp @@ -1,5 +1,15 @@ #include "linalg_test.h" +namespace { + Tensor SortedBlockSingularValues(const UniTensor &S) { + Tensor all_svals = S.get_block_(0); + for (cytnx_int64 i = 1; i < S.Nblocks(); i++) { + all_svals = algo::Concatenate(all_svals, S.get_block_(i)); + } + return algo::Sort(all_svals); + } +} // namespace + TEST_F(linalg_Test, BkUt_Svd_truncate1) { std::vector res = linalg::Svd_truncate(svd_T, 200, 0, true); std::vector vnm_S; @@ -23,6 +33,60 @@ TEST_F(linalg_Test, BkUt_Svd_truncate2) { auto con_T2 = Contract(Contract(res[1], res[0]), res[2]); } +TEST_F(linalg_Test, BkUt_Svd_truncate_return_err_returns_discarded_values) { + std::vector full = linalg::Svd_truncate(svd_T, 999, 0, true, 0); + std::vector trunc = linalg::Svd_truncate(svd_T, 5, 0, true, 999); + Tensor all_svals = SortedBlockSingularValues(full[0]); + + ASSERT_EQ(full.size(), 3); + ASSERT_EQ(trunc.size(), 4); + ASSERT_EQ(trunc[0].shape()[0], 5); + ASSERT_EQ(all_svals.shape()[0], 400); + ASSERT_EQ(trunc[3].shape()[0], all_svals.shape()[0] - trunc[0].shape()[0]); + + for (cytnx_uint64 i = 0; i < trunc[3].shape()[0]; i++) { + EXPECT_EQ(all_svals.at({trunc[3].shape()[0] - 1 - i}), trunc[3].at({i})); + } +} + +TEST_F(linalg_Test, BkUt_Gesvd_truncate_return_err_returns_discarded_values) { + std::vector full = linalg::Gesvd_truncate(svd_T, 999, 0, true, true, 0); + std::vector trunc = linalg::Gesvd_truncate(svd_T, 5, 0, true, true, 999); + Tensor all_svals = SortedBlockSingularValues(full[0]); + + ASSERT_EQ(full.size(), 3); + ASSERT_EQ(trunc.size(), 4); + ASSERT_EQ(trunc[0].shape()[0], 5); + ASSERT_EQ(all_svals.shape()[0], 400); + ASSERT_EQ(trunc[3].shape()[0], all_svals.shape()[0] - trunc[0].shape()[0]); + + for (cytnx_uint64 i = 0; i < trunc[3].shape()[0]; i++) { + EXPECT_EQ(all_svals.at({trunc[3].shape()[0] - 1 - i}), trunc[3].at({i})); + } +} + +TEST_F(linalg_Test, BkUt_Svd_truncate_return_err_one_returns_first_discarded_value) { + std::vector full = linalg::Svd_truncate(svd_T, 999, 0, true, 0); + std::vector trunc = linalg::Svd_truncate(svd_T, 5, 0, true, 1); + Tensor all_svals = SortedBlockSingularValues(full[0]); + + ASSERT_EQ(full.size(), 3); + ASSERT_EQ(trunc.size(), 4); + ASSERT_EQ(trunc[3].shape()[0], 1); + EXPECT_EQ(all_svals.at({all_svals.shape()[0] - trunc[0].shape()[0] - 1}), trunc[3].at({0})); +} + +TEST_F(linalg_Test, BkUt_Gesvd_truncate_return_err_one_returns_first_discarded_value) { + std::vector full = linalg::Gesvd_truncate(svd_T, 999, 0, true, true, 0); + std::vector trunc = linalg::Gesvd_truncate(svd_T, 5, 0, true, true, 1); + Tensor all_svals = SortedBlockSingularValues(full[0]); + + ASSERT_EQ(full.size(), 3); + ASSERT_EQ(trunc.size(), 4); + ASSERT_EQ(trunc[3].shape()[0], 1); + EXPECT_EQ(all_svals.at({all_svals.shape()[0] - trunc[0].shape()[0] - 1}), trunc[3].at({0})); +} + // TEST_F(linalg_Test, BkUt_Svd_truncate3) { // Bond I = Bond(BD_IN, {Qs(-5), Qs(-3), Qs(-1), Qs(1), Qs(3), Qs(5)}, {1, 4, 10, 9, 5, 1}); // Bond J = Bond(BD_OUT, {Qs(1), Qs(-1)}, {1, 1});