diff --git a/include/Accessor.hpp b/include/Accessor.hpp index f9b4d8ed3..52b1dd47a 100644 --- a/include/Accessor.hpp +++ b/include/Accessor.hpp @@ -84,6 +84,9 @@ namespace cytnx { Accessor(const Accessor &rhs); // copy assignment: Accessor &operator=(const Accessor &rhs); + + // check equality + bool operator==(const Accessor &rhs) const; ///@endcond int type() const { return this->_type; } diff --git a/include/Tensor.hpp b/include/Tensor.hpp index 5a4abf5f4..e7645c419 100644 --- a/include/Tensor.hpp +++ b/include/Tensor.hpp @@ -1008,10 +1008,13 @@ namespace cytnx { /** @brief get elements using Accessor (C++ API) / slices (python API) @param[in] accessors the Accessor (C++ API) / slices (python API) to get the elements. + @param[out] removed the indices that were removed from the original shape of the Tensor are + pushed to the end of this vector. Usually, an empty vector should be passed. @return [Tensor] @see \link cytnx::Accessor Accessor\endlink for cordinate with Accessor in C++ API. @note - 1. the return will be a new Tensor instance, which not share memory with the current Tensor. + The return will be a new Tensor instance, which does not share memory with the current + Tensor. ## Equivalently: One can also using more intruisive way to get the slice using [] operator. @@ -1026,9 +1029,16 @@ namespace cytnx { #### output> \verbinclude example/Tensor/get.py.out */ + Tensor get(const std::vector &accessors, + std::vector &removed) const { + Tensor out; + out._impl = this->_impl->get(accessors, removed); + return out; + } Tensor get(const std::vector &accessors) const { Tensor out; - out._impl = this->_impl->get(accessors); + std::vector removed; + out._impl = this->_impl->get(accessors, removed); return out; } diff --git a/include/UniTensor.hpp b/include/UniTensor.hpp index d241fdb51..b0910bec7 100644 --- a/include/UniTensor.hpp +++ b/include/UniTensor.hpp @@ -684,12 +684,12 @@ namespace cytnx { if (this->is_diag()) { cytnx_error_msg( in.shape() != this->_block.shape(), - "[ERROR][DenseUniTensor] put_block, the input tensor shape does not match.%s", "\n"); + "[ERROR][DenseUniTensor][put_block] the input tensor shape does not match.%s", "\n"); this->_block = in.clone(); } else { cytnx_error_msg( in.shape() != this->shape(), - "[ERROR][DenseUniTensor] put_block, the input tensor shape does not match.%s", "\n"); + "[ERROR][DenseUniTensor][put_block] the input tensor shape does not match.%s", "\n"); this->_block = in.clone(); } } @@ -711,12 +711,12 @@ namespace cytnx { if (this->is_diag()) { cytnx_error_msg( in.shape() != this->_block.shape(), - "[ERROR][DenseUniTensor] put_block, the input tensor shape does not match.%s", "\n"); + "[ERROR][DenseUniTensor][put_block] the input tensor shape does not match.%s", "\n"); this->_block = in; } else { cytnx_error_msg( in.shape() != this->shape(), - "[ERROR][DenseUniTensor] put_block, the input tensor shape does not match.%s", "\n"); + "[ERROR][DenseUniTensor][put_block] the input tensor shape does not match.%s", "\n"); this->_block = in; } } @@ -731,16 +731,9 @@ namespace cytnx { true, "[ERROR][DenseUniTensor] try to put_block using qnum on a non-symmetry UniTensor%s", "\n"); } - // this will only work on non-symm tensor (DenseUniTensor) - boost::intrusive_ptr get(const std::vector &accessors) { - boost::intrusive_ptr out(new DenseUniTensor()); - out->Init_by_Tensor(this->_block.get(accessors), false, 0); // wrapping around. - return out; - } - // this will only work on non-symm tensor (DenseUniTensor) - void set(const std::vector &accessors, const Tensor &rhs) { - this->_block.set(accessors, rhs); - } + // these two methods only work on non-symm tensor (DenseUniTensor) + boost::intrusive_ptr get(const std::vector &accessors); + void set(const std::vector &accessors, const Tensor &rhs); void reshape_(const std::vector &new_shape, const cytnx_uint64 &rowrank = 0); boost::intrusive_ptr reshape(const std::vector &new_shape, @@ -1700,7 +1693,7 @@ namespace cytnx { true, "[ERROR] cannot perform elementwise arithmetic '+' between Scalar and BlockUniTensor.\n %s " "\n", - "This operation would destroy the block structure. [Suggest] Avoid or use get/set_block(s) " + "This operation would destroy the block structure. [Suggest] Avoid or use get/put_block(s) " "to do operation on blocks."); } @@ -1713,7 +1706,7 @@ namespace cytnx { true, "[ERROR] cannot perform elementwise arithmetic '-' between Scalar and BlockUniTensor.\n %s " "\n", - "This operation would destroy the block structure. [Suggest] Avoid or use get/set_block(s) " + "This operation would destroy the block structure. [Suggest] Avoid or use get/put_block(s) " "to do operation on blocks."); } void lSub_(const Scalar &lhs) { @@ -1721,7 +1714,7 @@ namespace cytnx { true, "[ERROR] cannot perform elementwise arithmetic '-' between Scalar and BlockUniTensor.\n %s " "\n", - "This operation would destroy the block structure. [Suggest] Avoid or use get/set_block(s) " + "This operation would destroy the block structure. [Suggest] Avoid or use get/put_block(s) " "to do operation on blocks."); } @@ -1733,7 +1726,7 @@ namespace cytnx { "[ERROR] cannot perform elementwise arithmetic '/' between Scalar and BlockUniTensor.\n %s " "\n", "This operation would cause division by zero on non-block elements. [Suggest] Avoid or use " - "get/set_block(s) to do operation on blocks."); + "get/put_block(s) to do operation on blocks."); } void from_(const boost::intrusive_ptr &rhs, const bool &force, const cytnx_double &tol); @@ -2490,7 +2483,7 @@ namespace cytnx { "BlockFermionicUniTensor.\n %s " "\n", "This operation would destroy the block structure. [Suggest] Avoid or use " - "get/set_block(s) to do operation on blocks."); + "get/put_block(s) to do operation on blocks."); } void Mul_(const boost::intrusive_ptr &rhs); @@ -2503,7 +2496,7 @@ namespace cytnx { "BlockFermionicUniTensor.\n %s " "\n", "This operation would destroy the block structure. [Suggest] Avoid or use " - "get/set_block(s) to do operation on blocks."); + "get/put_block(s) to do operation on blocks."); } void lSub_(const Scalar &lhs) { cytnx_error_msg(true, @@ -2511,7 +2504,7 @@ namespace cytnx { "BlockFermionicUniTensor.\n %s " "\n", "This operation would destroy the block structure. [Suggest] Avoid or use " - "get/set_block(s) to do operation on blocks."); + "get/put_block(s) to do operation on blocks."); } void Div_(const boost::intrusive_ptr &rhs); @@ -2522,7 +2515,7 @@ namespace cytnx { "BlockFermionicUniTensor.\n %s " "\n", "This operation would cause division by zero on non-block elements. " - "[Suggest] Avoid or use get/set_block(s) to do operation on blocks."); + "[Suggest] Avoid or use get/put_block(s) to do operation on blocks."); } void from_(const boost::intrusive_ptr &rhs, const bool &force); @@ -4280,11 +4273,69 @@ namespace cytnx { in.permute_(new_order); return *this; } + + /** + @brief get elements using Accessor (C++ API) / slices (python API) + @param[in] accessors the Accessor (C++ API) / slices (python API) to get the elements. + @return [UniTensor] + @see Tensor::get, UniTensor::operator[] + @note + 1. The return will be a new UniTensor instance, which does not share memory with the current + UniTensor. + + 2. Equivalently, one can also use the [] operator to access elements. + + 3. For diagonal UniTensors, the accessor list can have either one element (to address the + diagonal elements), or two elements (in this case, the output will be a non-diagonal + UniTensor). + */ UniTensor get(const std::vector &accessors) const { UniTensor out; out._impl = this->_impl->get(accessors); return out; } + + /** + @brief get elements using Accessor (C++ API) / slices (python API) + @see get() + */ + UniTensor operator[](const std::vector &accessors) const { + UniTensor out; + out._impl = this->_impl->get(accessors); + return out; + } + UniTensor operator[](const std::initializer_list &accessors) const { + std::vector acc_in = accessors; + return this->get(acc_in); + } + UniTensor operator[](const std::vector &accessors) const { + std::vector acc_in; + for (cytnx_int64 i = 0; i < accessors.size(); i++) { + acc_in.push_back(cytnx::Accessor(accessors[i])); + } + return this->get(acc_in); + } + UniTensor operator[](const std::initializer_list &accessors) const { + std::vector acc_in = accessors; + return (*this)[acc_in]; + } + + /** + @brief set elements using Accessor (C++ API) / slices (python API) + @param[in] accessors the Accessor (C++ API) / slices (python API) to set the elements. + @param[in] rhs the tensor containing the values to set. + @return [UniTensor] + @see Tensor::set, UniTensor::operator[], UniTensor::get + @note + 1. The return will be a new UniTensor instance, which does not share memory with the current + UniTensor. + + 2. Equivalently, one can also use the [] operator to access elements. + + 3. For diagonal UniTensors, the accessor list can have either one element (to address the + diagonal elements; rhs must be one-dimensional), or two elements (in this case, the + output will be a non-diagonal UniTensor; rhs must be two-dimensional). + */ UniTensor &set(const std::vector &accessors, const Tensor &rhs) { this->_impl->set(accessors, rhs); return *this; diff --git a/include/backend/Tensor_impl.hpp b/include/backend/Tensor_impl.hpp index 66e151d21..a497e9751 100644 --- a/include/backend/Tensor_impl.hpp +++ b/include/backend/Tensor_impl.hpp @@ -187,7 +187,12 @@ namespace cytnx { return this->_storage.at(RealRank); } - boost::intrusive_ptr get(const std::vector &accessors); + boost::intrusive_ptr get(const std::vector &accessors, + std::vector &removed); + boost::intrusive_ptr get(const std::vector &accessors) { + std::vector removed; + return this->get(accessors, removed); + } [[deprecated("Use Tensor_impl::get instead")]] boost::intrusive_ptr get_deprecated( const std::vector &accessors); void set(const std::vector &accessors, diff --git a/pybind/unitensor_py.cpp b/pybind/unitensor_py.cpp index 67e49e623..adb669feb 100644 --- a/pybind/unitensor_py.cpp +++ b/pybind/unitensor_py.cpp @@ -316,8 +316,26 @@ void unitensor_binding(py::module &m) { std::vector accessors; if (self.is_diag()){ if (py::isinstance(locators)) { - cytnx_error_msg(true, - "[ERROR] cannot get element using [tuple] on is_diag=True UniTensor since the block is rank-1, consider [int] or [int:int] instead.%s", "\n"); + py::tuple Args = locators.cast(); + cytnx_error_msg(Args.size() > 2, + "[ERROR][slicing] A diagonal UniTensor can only be accessed with one- or two dimensional slicing.%s", "\n"); + cytnx_uint64 cnt = 0; + // mixing of slice and ints + for (cytnx_uint32 axis = 0; axis < Args.size(); axis++) { + cnt++; + // check type: + if (py::isinstance(Args[axis])) { + py::slice sls = Args[axis].cast(); + if (!sls.compute((ssize_t)self.shape()[axis], &start, &stop, &step, &slicelength)) + throw py::error_already_set(); + accessors.push_back(cytnx::Accessor::range(cytnx_int64(start), cytnx_int64(stop), + cytnx_int64(step))); + } else { + accessors.push_back(cytnx::Accessor(Args[axis].cast())); + } + } + // cytnx_error_msg(true, + // "[ERROR] cannot get element using [tuple] on is_diag=True UniTensor since the block is rank-1, consider [int] or [int:int] instead.%s", "\n"); } else if (py::isinstance(locators)) { py::slice sls = locators.cast(); if (!sls.compute((ssize_t)self.shape()[0], &start, &stop, &step, &slicelength)) @@ -385,8 +403,26 @@ void unitensor_binding(py::module &m) { std::vector accessors; if (self.is_diag()){ if (py::isinstance(locators)) { - cytnx_error_msg(true, - "[ERROR] cannot get element using [tuple] on is_diag=True UniTensor since the block is rank-1, consider [int] or [int:int] instead.%s", "\n"); + py::tuple Args = locators.cast(); + cytnx_error_msg(Args.size() > 2, + "[ERROR][slicing] A diagonal UniTensor can only be accessed with one- or two dimensional slicing.%s", "\n"); + cytnx_uint64 cnt = 0; + // mixing of slice and ints + for (cytnx_uint32 axis = 0; axis < Args.size(); axis++) { + cnt++; + // check type: + if (py::isinstance(Args[axis])) { + py::slice sls = Args[axis].cast(); + if (!sls.compute((ssize_t)self.shape()[axis], &start, &stop, &step, &slicelength)) + throw py::error_already_set(); + accessors.push_back(cytnx::Accessor::range(cytnx_int64(start), cytnx_int64(stop), + cytnx_int64(step))); + } else { + accessors.push_back(cytnx::Accessor(Args[axis].cast())); + } + } + // cytnx_error_msg(true, + // "[ERROR] cannot get element using [tuple] on is_diag=True UniTensor since the block is rank-1, consider [int] or [int:int] instead.%s", "\n"); } else if (py::isinstance(locators)) { py::slice sls = locators.cast(); if (!sls.compute((ssize_t)self.shape()[0], &start, &stop, &step, &slicelength)) @@ -453,8 +489,26 @@ void unitensor_binding(py::module &m) { std::vector accessors; if (self.is_diag()){ if (py::isinstance(locators)) { - cytnx_error_msg(true, - "[ERROR] cannot get element using [tuple] on is_diag=True UniTensor since the block is rank-1, consider [int] or [int:int] instead.%s", "\n"); + py::tuple Args = locators.cast(); + cytnx_error_msg(Args.size() > 2, + "[ERROR][slicing] A diagonal UniTensor can only be accessed with one- or two dimensional slicing.%s", "\n"); + cytnx_uint64 cnt = 0; + // mixing of slice and ints + for (cytnx_uint32 axis = 0; axis < Args.size(); axis++) { + cnt++; + // check type: + if (py::isinstance(Args[axis])) { + py::slice sls = Args[axis].cast(); + if (!sls.compute((ssize_t)self.shape()[axis], &start, &stop, &step, &slicelength)) + throw py::error_already_set(); + accessors.push_back(cytnx::Accessor::range(cytnx_int64(start), cytnx_int64(stop), + cytnx_int64(step))); + } else { + accessors.push_back(cytnx::Accessor(Args[axis].cast())); + } + } + // cytnx_error_msg(true, + // "[ERROR] cannot get element using [tuple] on is_diag=True UniTensor since the block is rank-1, consider [int] or [int:int] instead.%s", "\n"); } else if (py::isinstance(locators)) { py::slice sls = locators.cast(); if (!sls.compute((ssize_t)self.shape()[0], &start, &stop, &step, &slicelength)) diff --git a/src/Accessor.cpp b/src/Accessor.cpp index 08df57830..85b2908c5 100644 --- a/src/Accessor.cpp +++ b/src/Accessor.cpp @@ -114,6 +114,14 @@ namespace cytnx { return *this; } + // check equality + bool Accessor::operator==(const Accessor &rhs) const { + bool out = (this->_type == rhs._type) && (this->_min == rhs._min) && (this->_max == rhs._max) && + (this->loc == rhs.loc) && (this->_step == rhs._step) && + (this->idx_list == rhs.idx_list); + return out; + } + // get the real len from dim // if _type is all, pos will be null, and len == dim // if _type is range, pos will be the locator, and len == len(pos) diff --git a/src/DenseUniTensor.cpp b/src/DenseUniTensor.cpp index 3686ecd26..d6a5ca99c 100644 --- a/src/DenseUniTensor.cpp +++ b/src/DenseUniTensor.cpp @@ -640,6 +640,72 @@ namespace cytnx { free(rlbl); free(buffer); } + + boost::intrusive_ptr DenseUniTensor::get(const std::vector &accessors) { + if (accessors.empty()) return this->clone_meta(); + DenseUniTensor *out = this->clone_meta(); + std::vector removed; // bonds to be removed + if (this->_is_diag) { + if (accessors.size() == 1) { + out->_block = this->_block.get(accessors, removed); + if (removed.empty()) { // change dimension of bonds + for (cytnx_int64 idx = out->_bonds.size() - 1; idx >= 0; idx--) { + out->_bonds[idx]._impl->_dim = out->_block.shape()[0]; + } + } else { // erase all bonds + out->_is_braket_form = false; + out->_is_diag = false; + out->_rowrank = 0; + out->_labels = std::vector(); + out->_bonds = std::vector(); + } + } else { // convert to non-diagonal UniTensor + cytnx_error_msg(accessors.size() > 2, + "[ERROR][DenseUniTensor][get] For diagonal UniTensors, only one or two " + "accessor elements are allowed.%s", + "\n"); + out->_block = this->_block; + out->to_dense_(); + return out->get(accessors); + } + } else { // non-diagonal + out->_block = this->_block.get(accessors, removed); + for (cytnx_int64 idx = removed.size() - 1; idx >= 0; idx--) { + out->_labels.erase(out->_labels.begin() + removed[idx]); + out->_bonds.erase(out->_bonds.begin() + removed[idx]); + if (removed[idx] < this->_rowrank) out->_rowrank--; + } + // adapt dimensions on bonds + auto dims = out->_block.shape(); + for (cytnx_int64 idx = 0; idx < out->_bonds.size(); idx++) { + out->_bonds[idx]._impl->_dim = dims[idx]; + } + // update_braket + if (out->is_tag() && !out->_is_braket_form) { + out->_is_braket_form = out->_update_braket(); + } + } + return boost::intrusive_ptr(out); + } + + void DenseUniTensor::set(const std::vector &accessors, const Tensor &rhs) { + if (accessors.empty()) return; + if (this->_is_diag) { + if (accessors.size() == 1) { + this->_block.set(accessors, rhs); + } else { // convert to non-diagonal UniTensor + cytnx_error_msg(accessors.size() > 2, + "[ERROR][DenseUniTensor][get] For diagonal UniTensors, only one or two " + "accessor elements are allowed.%s", + "\n"); + this->to_dense_(); + this->_block.set(accessors, rhs); + } + } else { // non-diagonal + this->_block.set(accessors, rhs); + } + } + void DenseUniTensor::reshape_(const std::vector &new_shape, const cytnx_uint64 &rowrank) { cytnx_error_msg(this->is_tag(), diff --git a/src/backend/Tensor_impl.cpp b/src/backend/Tensor_impl.cpp index 74e92b887..651945591 100644 --- a/src/backend/Tensor_impl.cpp +++ b/src/backend/Tensor_impl.cpp @@ -158,8 +158,8 @@ namespace cytnx { // shadow new: // - boost::intrusive_ptr Tensor_impl::get( - const std::vector &accessors) { + boost::intrusive_ptr Tensor_impl::get(const std::vector &accessors, + std::vector &removed) { cytnx_error_msg(accessors.size() > this->_shape.size(), "%s", "The input indexes rank is out of range! (>Tensor's rank)."); @@ -234,10 +234,10 @@ namespace cytnx { // permute back: std::vector new_mapper(this->_mapper.begin(), this->_mapper.end()); std::vector new_shape; - std::vector remove_id; + // std::vector removed; for (unsigned int i = 0; i < out->_shape.size(); i++) { if (out->shape()[i] == 1 && (acc[i].type() == Accessor::Singl)) - remove_id.push_back(this->_mapper[this->_invmapper[i]]); + removed.push_back(this->_mapper[this->_invmapper[i]]); else new_shape.push_back(out->shape()[i]); } @@ -247,8 +247,8 @@ namespace cytnx { // cout << "inv_mapper" << endl; // cout << this->_invmapper << endl; - // cout << "remove_id" << endl; - // cout << remove_id << endl; + // cout << "removed" << endl; + // cout << removed << endl; // cout << "out shape raw" << endl; // cout << out->shape() << endl; @@ -262,10 +262,10 @@ namespace cytnx { std::vector perm; for (unsigned int i = 0; i < new_mapper.size(); i++) { perm.push_back(new_mapper[i]); - for (unsigned int j = 0; j < remove_id.size(); j++) { - if (new_mapper[i] > remove_id[j]) + for (unsigned int j = 0; j < removed.size(); j++) { + if (new_mapper[i] > removed[j]) perm.back() -= 1; - else if (new_mapper[i] == remove_id[j]) { + else if (new_mapper[i] == removed[j]) { perm.pop_back(); break; } @@ -371,10 +371,10 @@ namespace cytnx { // permute input to currect pos std::vector new_mapper(this->_mapper.begin(), this->_mapper.end()); std::vector new_shape; - std::vector remove_id; + std::vector removed; for (unsigned int i = 0; i < get_shape.size(); i++) { if (acc[i].type() == Accessor::Singl) - remove_id.push_back(this->_mapper[this->_invmapper[i]]); + removed.push_back(this->_mapper[this->_invmapper[i]]); else new_shape.push_back(get_shape[i]); } @@ -386,10 +386,10 @@ namespace cytnx { for (unsigned int i = 0; i < new_mapper.size(); i++) { perm.push_back(new_mapper[i]); - for (unsigned int j = 0; j < remove_id.size(); j++) { - if (new_mapper[i] > remove_id[j]) + for (unsigned int j = 0; j < removed.size(); j++) { + if (new_mapper[i] > removed[j]) perm.back() -= 1; - else if (new_mapper[i] == remove_id[j]) { + else if (new_mapper[i] == removed[j]) { perm.pop_back(); break; } diff --git a/tests/Accessor_test.cpp b/tests/Accessor_test.cpp index 4da2fba72..b17c5e037 100644 --- a/tests/Accessor_test.cpp +++ b/tests/Accessor_test.cpp @@ -144,3 +144,19 @@ TEST_F(AccessorTest, AllGenerator) { cytnx::Accessor acc = cytnx::Accessor::all(); ASSERT_EQ(acc.type(), cytnx::Accessor::All); } + +TEST_F(AccessorTest, Equality) { + EXPECT_TRUE(single == cytnx::Accessor(5)); + EXPECT_TRUE(all == cytnx::Accessor::all()); + EXPECT_TRUE(range == cytnx::Accessor::range(1, 4, 2)); + EXPECT_TRUE(tilend == cytnx::Accessor::tilend(2, 1)); + EXPECT_TRUE(step == cytnx::Accessor::step(3)); + EXPECT_TRUE(list == cytnx::Accessor({0, 2, 3})); + EXPECT_FALSE(single == cytnx::Accessor(4)); + EXPECT_FALSE(range == cytnx::Accessor::range(1, 6, 2)); + EXPECT_FALSE(tilend == cytnx::Accessor::tilend(2, 2)); + EXPECT_FALSE(step == cytnx::Accessor::step(4)); + EXPECT_FALSE(list == cytnx::Accessor({0, 1, 3})); + EXPECT_FALSE(range == all); + EXPECT_FALSE(tilend == single); +} diff --git a/tests/DenseUniTensor_test.cpp b/tests/DenseUniTensor_test.cpp index 230ffc8e5..85b1a7fa3 100644 --- a/tests/DenseUniTensor_test.cpp +++ b/tests/DenseUniTensor_test.cpp @@ -453,10 +453,174 @@ TEST_F(DenseUniTensorTest, get_index_not_exist) { } /*=====test info===== -describe:test get_index, but input is uninitialized Unitensor +describe:test get_index, but input is uninitialized UniTensor ====================*/ TEST_F(DenseUniTensorTest, get_index_uninit) { EXPECT_EQ(ut_uninit.get_index(""), -1); } +/*=====test info===== +describe:test get on diagonal UniTensor with one accessor keeps diagonal form +====================*/ +TEST_F(DenseUniTensorTest, get_diag_single_accessor_keeps_diag) { + auto ut = ut_complex_diag.clone(); + UniTensor out = ut.get({Accessor::range(1, 4, 2)}); + + EXPECT_TRUE(out.is_diag()); + EXPECT_EQ(out.name(), ut.name()); + EXPECT_EQ(out.labels(), ut.labels()); + EXPECT_EQ(out.rowrank(), 1); + EXPECT_EQ(out.shape(), std::vector({2, 2})); + EXPECT_EQ(out.at({0}), cytnx_complex128(1.0, 0.0)); + EXPECT_EQ(out.at({1}), cytnx_complex128(3.0, 0.0)); + + // source should remain unchanged + EXPECT_TRUE(AreEqUniTensorMeta(ut_complex_diag, ut)); + EXPECT_TRUE(AreEqUniTensor(ut_complex_diag, ut)); +} + +/*=====test info===== +describe:test set on diagonal UniTensor with one accessor keeps diagonal form +====================*/ +TEST_F(DenseUniTensorTest, set_diag_single_accessor_keeps_diag) { + auto ut = ut_complex_diag.clone(); + auto rhs = Tensor({2}, Type.ComplexDouble); + rhs.at({0}) = cytnx_complex128(10.0, 1.0); + rhs.at({1}) = cytnx_complex128(20.0, 2.0); + + ut.set({Accessor::range(1, 4, 2)}, rhs); + + EXPECT_TRUE(ut.is_diag()); + EXPECT_EQ(ut.name(), ut_complex_diag.name()); + EXPECT_EQ(ut.labels(), ut_complex_diag.labels()); + EXPECT_EQ(ut.shape(), std::vector({4, 4})); + EXPECT_EQ(ut.rowrank(), 1); + EXPECT_EQ(ut.at({0}), cytnx_complex128(0.0, 0.0)); + EXPECT_EQ(ut.at({1}), cytnx_complex128(10.0, 1.0)); + EXPECT_EQ(ut.at({2}), cytnx_complex128(2.0, 0.0)); + EXPECT_EQ(ut.at({3}), cytnx_complex128(20.0, 2.0)); +} + +/*=====test info===== +describe:test slicing diagonal UniTensor with one accessor keeps metadata +====================*/ +TEST_F(DenseUniTensorTest, slice_diag_single_accessor) { + auto ut = ut_complex_diag.clone(); + UniTensor out = ut[{Accessor::range(0, 3, 2)}]; + + EXPECT_TRUE(out.is_diag()); + EXPECT_EQ(out.name(), ut.name()); + EXPECT_EQ(out.labels(), ut.labels()); + EXPECT_EQ(out.rowrank(), 1); + EXPECT_EQ(out.shape(), std::vector({2, 2})); + EXPECT_EQ(out.at({0}), cytnx_complex128(0.0, 0.0)); + EXPECT_EQ(out.at({1}), cytnx_complex128(2.0, 0.0)); + + // source should remain unchanged + EXPECT_TRUE(AreEqUniTensorMeta(ut_complex_diag, ut)); + EXPECT_TRUE(AreEqUniTensor(ut_complex_diag, ut)); +} + +/*=====test info===== +describe:test set on diagonal UniTensor with two accessors converts to dense +====================*/ +TEST_F(DenseUniTensorTest, set_diag_two_accessors_convert_dense) { + auto ut = ut_complex_diag.clone(); + auto rhs = Tensor({4}, Type.ComplexDouble); + rhs.at({0}) = cytnx_complex128(4.0, 1.0); + rhs.at({1}) = cytnx_complex128(5.0, 2.0); + rhs.at({2}) = cytnx_complex128(6.0, 3.0); + rhs.at({3}) = cytnx_complex128(7.0, 4.0); + + ut.set({Accessor(1), Accessor::all()}, rhs); + + EXPECT_FALSE(ut.is_diag()); + EXPECT_EQ(ut.name(), ut_complex_diag.name()); + EXPECT_EQ(ut.labels(), ut_complex_diag.labels()); + EXPECT_EQ(ut.shape(), std::vector({4, 4})); + EXPECT_EQ(ut.rowrank(), 1); + EXPECT_EQ(ut.at({0, 0}), cytnx_complex128(0.0, 0.0)); + EXPECT_EQ(ut.at({1, 0}), cytnx_complex128(4.0, 1.0)); + EXPECT_EQ(ut.at({1, 1}), cytnx_complex128(5.0, 2.0)); + EXPECT_EQ(ut.at({1, 2}), cytnx_complex128(6.0, 3.0)); + EXPECT_EQ(ut.at({1, 3}), cytnx_complex128(7.0, 4.0)); + EXPECT_EQ(ut.at({2, 2}), cytnx_complex128(2.0, 0.0)); + EXPECT_EQ(ut.at({3, 3}), cytnx_complex128(3.0, 0.0)); +} + +/*=====test info===== +describe:test slicing diagonal UniTensor with two different accessors converts to dense and keeps +metadata +====================*/ +TEST_F(DenseUniTensorTest, slice_diag_two_accessors_convert_dense) { + auto ut = ut_complex_diag.clone(); + UniTensor out = ut[{Accessor(1), Accessor::all()}]; + + EXPECT_FALSE(out.is_diag()); + EXPECT_EQ(out.name(), ut.name()); + std::vector newlabels = {"col"}; + EXPECT_EQ(out.labels(), newlabels); + EXPECT_EQ(out.rowrank(), 0); + EXPECT_EQ(out.shape(), std::vector({4})); + EXPECT_EQ(out.at({0}), cytnx_complex128(0.0, 0.0)); + EXPECT_EQ(out.at({1}), cytnx_complex128(1.0, 0.0)); + EXPECT_EQ(out.at({2}), cytnx_complex128(0.0, 0.0)); + EXPECT_EQ(out.at({3}), cytnx_complex128(0.0, 0.0)); + + // source should remain unchanged + EXPECT_TRUE(AreEqUniTensorMeta(ut_complex_diag, ut)); + EXPECT_TRUE(AreEqUniTensor(ut_complex_diag, ut)); +} + +/*=====test info===== +describe:test get on diagonal UniTensor with two accessors converts to dense +====================*/ +TEST_F(DenseUniTensorTest, get_diag_two_accessors_convert_dense) { + auto ut_complex_diag_snapshot = ut_complex_diag.clone(); + UniTensor out = ut_complex_diag.get({Accessor::all(), Accessor::range(1, 4, 2)}); + + EXPECT_FALSE(out.is_diag()); + EXPECT_EQ(out.name(), ut_complex_diag.name()); + std::vector newlabels = {"row", "col"}; + EXPECT_EQ(out.labels(), newlabels); + EXPECT_EQ(out.rowrank(), 1); + EXPECT_EQ(out.shape(), std::vector({4, 2})); + EXPECT_EQ(out.at({0, 0}), cytnx_complex128(0.0, 0.0)); + EXPECT_EQ(out.at({1, 0}), cytnx_complex128(1.0, 0.0)); + EXPECT_EQ(out.at({2, 0}), cytnx_complex128(0.0, 0.0)); + EXPECT_EQ(out.at({3, 0}), cytnx_complex128(0.0, 0.0)); + EXPECT_EQ(out.at({0, 1}), cytnx_complex128(0.0, 0.0)); + EXPECT_EQ(out.at({1, 1}), cytnx_complex128(0.0, 0.0)); + EXPECT_EQ(out.at({2, 1}), cytnx_complex128(0.0, 0.0)); + EXPECT_EQ(out.at({3, 1}), cytnx_complex128(3.0, 0.0)); + + // source should remain unchanged + EXPECT_TRUE(AreEqUniTensorMeta(ut_complex_diag, ut_complex_diag_snapshot)); + EXPECT_TRUE(AreEqUniTensor(ut_complex_diag, ut_complex_diag_snapshot)); +} + +/*=====test info===== +describe:test slicing non-diagonal UniTensor keeps name labels and rowrank metadata +====================*/ +TEST_F(DenseUniTensorTest, slice_nondiag_keeps_metadata) { + auto ut = utar345.clone(); + UniTensor out = ut[{Accessor::all(), Accessor::range(1, 4, 2), Accessor(3)}]; + + EXPECT_FALSE(out.is_diag()); + EXPECT_EQ(out.name(), ut.name()); + std::vector newlabels = {"a", "b"}; + EXPECT_EQ(out.labels(), newlabels); + EXPECT_EQ(out.shape(), std::vector({3, 2})); + EXPECT_EQ(out.at({0, 0}), cytnx_complex128(8.0, 0.0)); + EXPECT_EQ(out.at({0, 1}), cytnx_complex128(18.0, 0.0)); + EXPECT_EQ(out.at({1, 0}), cytnx_complex128(28.0, 0.0)); + EXPECT_EQ(out.at({1, 1}), cytnx_complex128(38.0, 0.0)); + EXPECT_EQ(out.at({2, 0}), cytnx_complex128(48.0, 0.0)); + EXPECT_EQ(out.at({2, 1}), cytnx_complex128(58.0, 0.0)); + + // source should remain unchanged + EXPECT_TRUE(AreEqUniTensorMeta(utar345, ut)); + EXPECT_TRUE(AreEqUniTensor(utar345, ut)); +} + /*=====test info===== describe:test bonds ====================*/ @@ -480,7 +644,7 @@ TEST_F(DenseUniTensorTest, shape) { } TEST_F(DenseUniTensorTest, shape_diag) { - EXPECT_EQ(std::vector({2, 2}), ut_complex_diag.shape()); + EXPECT_EQ(std::vector({4, 4}), ut_complex_diag.shape()); EXPECT_TRUE(ut_complex_diag.is_diag()); } @@ -2135,7 +2299,7 @@ TEST_F(DenseUniTensorTest, Add_diag_diag) { } /*=====test info===== -describe:test adding two UniTensor with different shape but not contain 1-element Unitensor. +describe:test adding two UniTensor with different shape but not contain 1-element UniTensor. ====================*/ TEST_F(DenseUniTensorTest, Add_UT_UT_rank_error) { auto ut1 = UniTensor({Bond(1), Bond(2)}); @@ -2353,7 +2517,7 @@ TEST_F(DenseUniTensorTest, Add__self) { } /*=====test info===== -describe:test adding two UniTensor with different shape but not contain 1-element Unitensor. +describe:test adding two UniTensor with different shape but not contain 1-element UniTensor. ====================*/ TEST_F(DenseUniTensorTest, Add__UT_UT_rank_error) { auto ut1 = UniTensor({Bond(1), Bond(2)}); @@ -2605,7 +2769,7 @@ TEST_F(DenseUniTensorTest, Sub_diag_diag) { } /*=====test info===== -describe:test subing two UniTensor with different shape but not contain 1-element Unitensor. +describe:test subing two UniTensor with different shape but not contain 1-element UniTensor. ====================*/ TEST_F(DenseUniTensorTest, Sub_UT_UT_rank_error) { auto ut1 = UniTensor({Bond(1), Bond(2)}); @@ -2822,7 +2986,7 @@ TEST_F(DenseUniTensorTest, Sub__self) { } /*=====test info===== -describe:test subing two UniTensor with different shape but not contain 1-element Unitensor. +describe:test subing two UniTensor with different shape but not contain 1-element UniTensor. ====================*/ TEST_F(DenseUniTensorTest, Sub__UT_UT_rank_error) { auto ut1 = UniTensor({Bond(1), Bond(2)}); @@ -3072,7 +3236,7 @@ TEST_F(DenseUniTensorTest, Mul_diag_diag) { } /*=====test info===== -describe:test muling two UniTensor with different shape but not contain 1-element Unitensor. +describe:test muling two UniTensor with different shape but not contain 1-element UniTensor. ====================*/ TEST_F(DenseUniTensorTest, Mul_UT_UT_rank_error) { auto ut1 = UniTensor({Bond(1), Bond(2)}); @@ -3291,7 +3455,7 @@ TEST_F(DenseUniTensorTest, Mul__self) { } /*=====test info===== -describe:test muling two UniTensor with different shape but not contain 1-element Unitensor. +describe:test muling two UniTensor with different shape but not contain 1-element UniTensor. ====================*/ TEST_F(DenseUniTensorTest, Mul__UT_UT_rank_error) { auto ut1 = UniTensor({Bond(1), Bond(2)}); @@ -3543,7 +3707,7 @@ TEST_F(DenseUniTensorTest, Div_diag_diag) { } /*=====test info===== -describe:test diving two UniTensor with different shape but not contain 1-element Unitensor. +describe:test diving two UniTensor with different shape but not contain 1-element UniTensor. ====================*/ TEST_F(DenseUniTensorTest, Div_UT_UT_rank_error) { auto ut1 = UniTensor({Bond(1), Bond(2)}); @@ -3756,7 +3920,7 @@ TEST_F(DenseUniTensorTest, Div__self) { } /*=====test info===== -describe:test diving two UniTensor with different shape but not contain 1-element Unitensor. +describe:test diving two UniTensor with different shape but not contain 1-element UniTensor. ====================*/ TEST_F(DenseUniTensorTest, Div__UT_UT_rank_error) { auto ut1 = UniTensor({Bond(1), Bond(2)}); diff --git a/tests/DenseUniTensor_test.h b/tests/DenseUniTensor_test.h index 1a3970e5f..73285c57f 100644 --- a/tests/DenseUniTensor_test.h +++ b/tests/DenseUniTensor_test.h @@ -27,6 +27,7 @@ class DenseUniTensorTest : public ::testing::Test { UniTensor ut_complex_diag; Bond phy = Bond(2, BD_IN); Bond aux = Bond(1, BD_IN); + Bond bond4 = Bond(4, BD_IN); DenseUniTensor dut; Tensor tzero345 = zeros({3, 4, 5}).astype(Type.ComplexDouble); Tensor tar345 = arange({3 * 4 * 5}).reshape({3, 4, 5}).astype(Type.ComplexDouble); @@ -56,6 +57,7 @@ class DenseUniTensorTest : public ::testing::Test { utzero345 = UniTensor(zeros(3 * 4 * 5)).reshape({3, 4, 5}).astype(Type.ComplexDouble); utone345 = UniTensor(ones(3 * 4 * 5)).reshape({3, 4, 5}).astype(Type.ComplexDouble); utar345 = UniTensor(arange(3 * 4 * 5)).reshape({3, 4, 5}).astype(Type.ComplexDouble); + utar345.set_labels({"a", "b", "c"}).set_name("utar345").set_rowrank_(2); utzero3456 = UniTensor(zeros(3 * 4 * 5 * 6)).reshape({3, 4, 5, 6}).astype(Type.ComplexDouble); utone3456 = UniTensor(ones(3 * 4 * 5 * 6)).reshape({3, 4, 5, 6}).astype(Type.ComplexDouble); utar3456 = UniTensor(arange(3 * 4 * 5 * 6)).reshape({3, 4, 5, 6}).astype(Type.ComplexDouble); @@ -66,7 +68,9 @@ class DenseUniTensorTest : public ::testing::Test { for (size_t i = 0; i < 3 * 4 * 5 * 6; i++) utarcomplex3456.at({i}) = cytnx_complex128(i, i); utarcomplex3456 = utarcomplex3456.reshape({3, 4, 5, 6}).astype(Type.ComplexDouble); ut_complex_diag = - UniTensor({phy, phy.redirect()}, {"1", "2"}, 1, Type.ComplexDouble, Device.cpu, true); + UniTensor({bond4, bond4.redirect()}, {"row", "col"}, 1, Type.ComplexDouble, Device.cpu, true); + ut_complex_diag.put_block(arange(4).astype(Type.ComplexDouble)); + ut_complex_diag.set_name("ut_complex_diag"); ut1 = ut1.Load(data_dir + "denseutensor1.cytnx").astype(Type.ComplexDouble); ut2 = ut2.Load(data_dir + "denseutensor2.cytnx").astype(Type.ComplexDouble); diff --git a/tests/test_tools.cpp b/tests/test_tools.cpp index 74bedb06a..d8d7e6e92 100644 --- a/tests/test_tools.cpp +++ b/tests/test_tools.cpp @@ -522,6 +522,20 @@ namespace cytnx { return AreNearlyEqUniTensor(Ut1, Ut2, 0); } + bool AreEqUniTensorMeta(const UniTensor& Ut1, const UniTensor& Ut2) { + if (Ut1.uten_type() != Ut2.uten_type()) return false; + if (Ut1.name() != Ut2.name()) return false; + if (Ut1.labels() != Ut2.labels()) return false; + if (Ut1.rowrank() != Ut2.rowrank()) return false; + if (Ut1.rank() != Ut2.rank()) return false; + if (Ut1.shape() != Ut2.shape()) return false; + if (Ut1.is_diag() != Ut2.is_diag()) return false; + if (Ut1.is_tag() != Ut2.is_tag()) return false; + if (Ut1.is_braket_form() != Ut2.is_braket_form()) return false; + if (Ut1.bonds() != Ut2.bonds()) return false; + return true; + } + // UniTensor void InitUniTensorUniform(UniTensor& UT, unsigned int rand_seed) { auto dtype = UT.dtype(); diff --git a/tests/test_tools.h b/tests/test_tools.h index 89903a1ab..0528f4ab4 100644 --- a/tests/test_tools.h +++ b/tests/test_tools.h @@ -53,6 +53,7 @@ namespace cytnx { bool AreNearlyEqUniTensor(const UniTensor& Ut1, const UniTensor& Ut2, const cytnx_double tol = 0); bool AreEqUniTensor(const UniTensor& Ut1, const UniTensor& Ut2); + bool AreEqUniTensorMeta(const UniTensor& Ut1, const UniTensor& Ut2); bool AreElemSame(const Tensor& T1, const std::vector& idices1, const Tensor& T2, const std::vector& idices2);