Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/Accessor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ namespace cytnx {
Accessor(const Accessor &rhs);
// copy assignment:
Accessor &operator=(const Accessor &rhs);

// check equality
bool operator==(const Accessor &rhs) const;
///@endcond

int type() const { return this->_type; }
Expand Down
14 changes: 12 additions & 2 deletions include/Tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1008,10 +1008,13 @@ namespace cytnx {
/**
@brief get elements using Accessor (C++ API) / slices (python API)
@param[in] accessors the Accessor (C++ API) / slices (python API) to get the elements.
@param[out] removed the indices that were removed from the original shape of the Tensor are
pushed to the end of this vector. Usually, an empty vector should be passed.
@return [Tensor]
@see \link cytnx::Accessor Accessor\endlink for cordinate with Accessor in C++ API.
@note
1. the return will be a new Tensor instance, which not share memory with the current Tensor.
The return will be a new Tensor instance, which does not share memory with the current
Tensor.

## Equivalently:
One can also using more intruisive way to get the slice using [] operator.
Expand All @@ -1026,9 +1029,16 @@ namespace cytnx {
#### output>
\verbinclude example/Tensor/get.py.out
*/
Tensor get(const std::vector<cytnx::Accessor> &accessors,
std::vector<cytnx_int64> &removed) const {
Tensor out;
out._impl = this->_impl->get(accessors, removed);
return out;
}
Tensor get(const std::vector<cytnx::Accessor> &accessors) const {
Tensor out;
out._impl = this->_impl->get(accessors);
std::vector<cytnx_int64> removed;
out._impl = this->_impl->get(accessors, removed);
return out;
}

Expand Down
95 changes: 73 additions & 22 deletions include/UniTensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -684,12 +684,12 @@ namespace cytnx {
if (this->is_diag()) {
cytnx_error_msg(
in.shape() != this->_block.shape(),
"[ERROR][DenseUniTensor] put_block, the input tensor shape does not match.%s", "\n");
"[ERROR][DenseUniTensor][put_block] the input tensor shape does not match.%s", "\n");
this->_block = in.clone();
} else {
cytnx_error_msg(
in.shape() != this->shape(),
"[ERROR][DenseUniTensor] put_block, the input tensor shape does not match.%s", "\n");
"[ERROR][DenseUniTensor][put_block] the input tensor shape does not match.%s", "\n");
this->_block = in.clone();
}
}
Expand All @@ -711,12 +711,12 @@ namespace cytnx {
if (this->is_diag()) {
cytnx_error_msg(
in.shape() != this->_block.shape(),
"[ERROR][DenseUniTensor] put_block, the input tensor shape does not match.%s", "\n");
"[ERROR][DenseUniTensor][put_block] the input tensor shape does not match.%s", "\n");
this->_block = in;
} else {
cytnx_error_msg(
in.shape() != this->shape(),
"[ERROR][DenseUniTensor] put_block, the input tensor shape does not match.%s", "\n");
"[ERROR][DenseUniTensor][put_block] the input tensor shape does not match.%s", "\n");
this->_block = in;
}
}
Expand All @@ -731,16 +731,9 @@ namespace cytnx {
true, "[ERROR][DenseUniTensor] try to put_block using qnum on a non-symmetry UniTensor%s",
"\n");
}
// this will only work on non-symm tensor (DenseUniTensor)
boost::intrusive_ptr<UniTensor_base> get(const std::vector<Accessor> &accessors) {
boost::intrusive_ptr<UniTensor_base> out(new DenseUniTensor());
out->Init_by_Tensor(this->_block.get(accessors), false, 0); // wrapping around.
return out;
}
// this will only work on non-symm tensor (DenseUniTensor)
void set(const std::vector<Accessor> &accessors, const Tensor &rhs) {
this->_block.set(accessors, rhs);
}
// these two methods only work on non-symm tensor (DenseUniTensor)
boost::intrusive_ptr<UniTensor_base> get(const std::vector<Accessor> &accessors);
void set(const std::vector<Accessor> &accessors, const Tensor &rhs);

void reshape_(const std::vector<cytnx_int64> &new_shape, const cytnx_uint64 &rowrank = 0);
boost::intrusive_ptr<UniTensor_base> reshape(const std::vector<cytnx_int64> &new_shape,
Expand Down Expand Up @@ -1700,7 +1693,7 @@ namespace cytnx {
true,
"[ERROR] cannot perform elementwise arithmetic '+' between Scalar and BlockUniTensor.\n %s "
"\n",
"This operation would destroy the block structure. [Suggest] Avoid or use get/set_block(s) "
"This operation would destroy the block structure. [Suggest] Avoid or use get/put_block(s) "
"to do operation on blocks.");
}

Expand All @@ -1713,15 +1706,15 @@ namespace cytnx {
true,
"[ERROR] cannot perform elementwise arithmetic '-' between Scalar and BlockUniTensor.\n %s "
"\n",
"This operation would destroy the block structure. [Suggest] Avoid or use get/set_block(s) "
"This operation would destroy the block structure. [Suggest] Avoid or use get/put_block(s) "
"to do operation on blocks.");
}
void lSub_(const Scalar &lhs) {
cytnx_error_msg(
true,
"[ERROR] cannot perform elementwise arithmetic '-' between Scalar and BlockUniTensor.\n %s "
"\n",
"This operation would destroy the block structure. [Suggest] Avoid or use get/set_block(s) "
"This operation would destroy the block structure. [Suggest] Avoid or use get/put_block(s) "
"to do operation on blocks.");
}

Expand All @@ -1733,7 +1726,7 @@ namespace cytnx {
"[ERROR] cannot perform elementwise arithmetic '/' between Scalar and BlockUniTensor.\n %s "
"\n",
"This operation would cause division by zero on non-block elements. [Suggest] Avoid or use "
"get/set_block(s) to do operation on blocks.");
"get/put_block(s) to do operation on blocks.");
}
void from_(const boost::intrusive_ptr<UniTensor_base> &rhs, const bool &force,
const cytnx_double &tol);
Expand Down Expand Up @@ -2490,7 +2483,7 @@ namespace cytnx {
"BlockFermionicUniTensor.\n %s "
"\n",
"This operation would destroy the block structure. [Suggest] Avoid or use "
"get/set_block(s) to do operation on blocks.");
"get/put_block(s) to do operation on blocks.");
}

void Mul_(const boost::intrusive_ptr<UniTensor_base> &rhs);
Expand All @@ -2503,15 +2496,15 @@ namespace cytnx {
"BlockFermionicUniTensor.\n %s "
"\n",
"This operation would destroy the block structure. [Suggest] Avoid or use "
"get/set_block(s) to do operation on blocks.");
"get/put_block(s) to do operation on blocks.");
}
void lSub_(const Scalar &lhs) {
cytnx_error_msg(true,
"[ERROR] cannot perform elementwise arithmetic '-' between Scalar and "
"BlockFermionicUniTensor.\n %s "
"\n",
"This operation would destroy the block structure. [Suggest] Avoid or use "
"get/set_block(s) to do operation on blocks.");
"get/put_block(s) to do operation on blocks.");
}

void Div_(const boost::intrusive_ptr<UniTensor_base> &rhs);
Expand All @@ -2522,7 +2515,7 @@ namespace cytnx {
"BlockFermionicUniTensor.\n %s "
"\n",
"This operation would cause division by zero on non-block elements. "
"[Suggest] Avoid or use get/set_block(s) to do operation on blocks.");
"[Suggest] Avoid or use get/put_block(s) to do operation on blocks.");
}
void from_(const boost::intrusive_ptr<UniTensor_base> &rhs, const bool &force);

Expand Down Expand Up @@ -4280,11 +4273,69 @@ namespace cytnx {
in.permute_(new_order);
return *this;
}

/**
@brief get elements using Accessor (C++ API) / slices (python API)
@param[in] accessors the Accessor (C++ API) / slices (python API) to get the elements.
@return [UniTensor]
@see Tensor::get, UniTensor::operator[]
@note
1. The return will be a new UniTensor instance, which does not share memory with the current
UniTensor.

2. Equivalently, one can also use the [] operator to access elements.

3. For diagonal UniTensors, the accessor list can have either one element (to address the
diagonal elements), or two elements (in this case, the output will be a non-diagonal
UniTensor).
*/
UniTensor get(const std::vector<Accessor> &accessors) const {
UniTensor out;
out._impl = this->_impl->get(accessors);
return out;
}

/**
@brief get elements using Accessor (C++ API) / slices (python API)
@see get()
*/
UniTensor operator[](const std::vector<cytnx::Accessor> &accessors) const {
UniTensor out;
out._impl = this->_impl->get(accessors);
return out;
}
UniTensor operator[](const std::initializer_list<cytnx::Accessor> &accessors) const {
std::vector<cytnx::Accessor> acc_in = accessors;
return this->get(acc_in);
}
UniTensor operator[](const std::vector<cytnx_int64> &accessors) const {
std::vector<cytnx::Accessor> acc_in;
for (cytnx_int64 i = 0; i < accessors.size(); i++) {
acc_in.push_back(cytnx::Accessor(accessors[i]));
}
return this->get(acc_in);
}
UniTensor operator[](const std::initializer_list<cytnx_int64> &accessors) const {
std::vector<cytnx_int64> acc_in = accessors;
return (*this)[acc_in];
}

/**
@brief set elements using Accessor (C++ API) / slices (python API)
@param[in] accessors the Accessor (C++ API) / slices (python API) to set the elements.
@param[in] rhs the tensor containing the values to set.
@return [UniTensor]
@see Tensor::set, UniTensor::operator[], UniTensor::get
@note
1. The return will be a new UniTensor instance, which does not share memory with the current
UniTensor.

2. Equivalently, one can also use the [] operator to access elements.

3. For diagonal UniTensors, the accessor list can have either one element (to address the
diagonal elements; rhs must be one-dimensional), or two elements (in this case, the
output will be a non-diagonal UniTensor; rhs must be two-dimensional).
*/
UniTensor &set(const std::vector<Accessor> &accessors, const Tensor &rhs) {
this->_impl->set(accessors, rhs);
return *this;
Expand Down
7 changes: 6 additions & 1 deletion include/backend/Tensor_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,12 @@ namespace cytnx {
return this->_storage.at(RealRank);
}

boost::intrusive_ptr<Tensor_impl> get(const std::vector<cytnx::Accessor> &accessors);
boost::intrusive_ptr<Tensor_impl> get(const std::vector<cytnx::Accessor> &accessors,
std::vector<cytnx_int64> &removed);
boost::intrusive_ptr<Tensor_impl> get(const std::vector<cytnx::Accessor> &accessors) {
std::vector<cytnx_int64> removed;
return this->get(accessors, removed);
}
[[deprecated("Use Tensor_impl::get instead")]] boost::intrusive_ptr<Tensor_impl> get_deprecated(
const std::vector<cytnx::Accessor> &accessors);
void set(const std::vector<cytnx::Accessor> &accessors,
Expand Down
66 changes: 60 additions & 6 deletions pybind/unitensor_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,26 @@ void unitensor_binding(py::module &m) {
std::vector<cytnx::Accessor> accessors;
if (self.is_diag()){
if (py::isinstance<py::tuple>(locators)) {
cytnx_error_msg(true,
"[ERROR] cannot get element using [tuple] on is_diag=True UniTensor since the block is rank-1, consider [int] or [int:int] instead.%s", "\n");
py::tuple Args = locators.cast<py::tuple>();
cytnx_error_msg(Args.size() > 2,
"[ERROR][slicing] A diagonal UniTensor can only be accessed with one- or two dimensional slicing.%s", "\n");
cytnx_uint64 cnt = 0;
// mixing of slice and ints
for (cytnx_uint32 axis = 0; axis < Args.size(); axis++) {
cnt++;
// check type:
if (py::isinstance<py::slice>(Args[axis])) {
py::slice sls = Args[axis].cast<py::slice>();
if (!sls.compute((ssize_t)self.shape()[axis], &start, &stop, &step, &slicelength))
throw py::error_already_set();
accessors.push_back(cytnx::Accessor::range(cytnx_int64(start), cytnx_int64(stop),
cytnx_int64(step)));
} else {
accessors.push_back(cytnx::Accessor(Args[axis].cast<cytnx_int64>()));
}
}
// cytnx_error_msg(true,
// "[ERROR] cannot get element using [tuple] on is_diag=True UniTensor since the block is rank-1, consider [int] or [int:int] instead.%s", "\n");
} else if (py::isinstance<py::slice>(locators)) {
py::slice sls = locators.cast<py::slice>();
if (!sls.compute((ssize_t)self.shape()[0], &start, &stop, &step, &slicelength))
Expand Down Expand Up @@ -385,8 +403,26 @@ void unitensor_binding(py::module &m) {
std::vector<cytnx::Accessor> accessors;
if (self.is_diag()){
if (py::isinstance<py::tuple>(locators)) {
cytnx_error_msg(true,
"[ERROR] cannot get element using [tuple] on is_diag=True UniTensor since the block is rank-1, consider [int] or [int:int] instead.%s", "\n");
py::tuple Args = locators.cast<py::tuple>();
cytnx_error_msg(Args.size() > 2,
"[ERROR][slicing] A diagonal UniTensor can only be accessed with one- or two dimensional slicing.%s", "\n");
cytnx_uint64 cnt = 0;
// mixing of slice and ints
for (cytnx_uint32 axis = 0; axis < Args.size(); axis++) {
cnt++;
// check type:
if (py::isinstance<py::slice>(Args[axis])) {
py::slice sls = Args[axis].cast<py::slice>();
if (!sls.compute((ssize_t)self.shape()[axis], &start, &stop, &step, &slicelength))
throw py::error_already_set();
accessors.push_back(cytnx::Accessor::range(cytnx_int64(start), cytnx_int64(stop),
cytnx_int64(step)));
} else {
accessors.push_back(cytnx::Accessor(Args[axis].cast<cytnx_int64>()));
}
}
// cytnx_error_msg(true,
// "[ERROR] cannot get element using [tuple] on is_diag=True UniTensor since the block is rank-1, consider [int] or [int:int] instead.%s", "\n");
} else if (py::isinstance<py::slice>(locators)) {
py::slice sls = locators.cast<py::slice>();
if (!sls.compute((ssize_t)self.shape()[0], &start, &stop, &step, &slicelength))
Expand Down Expand Up @@ -453,8 +489,26 @@ void unitensor_binding(py::module &m) {
std::vector<cytnx::Accessor> accessors;
if (self.is_diag()){
if (py::isinstance<py::tuple>(locators)) {
cytnx_error_msg(true,
"[ERROR] cannot get element using [tuple] on is_diag=True UniTensor since the block is rank-1, consider [int] or [int:int] instead.%s", "\n");
py::tuple Args = locators.cast<py::tuple>();
cytnx_error_msg(Args.size() > 2,
"[ERROR][slicing] A diagonal UniTensor can only be accessed with one- or two dimensional slicing.%s", "\n");
cytnx_uint64 cnt = 0;
// mixing of slice and ints
for (cytnx_uint32 axis = 0; axis < Args.size(); axis++) {
cnt++;
// check type:
if (py::isinstance<py::slice>(Args[axis])) {
py::slice sls = Args[axis].cast<py::slice>();
if (!sls.compute((ssize_t)self.shape()[axis], &start, &stop, &step, &slicelength))
throw py::error_already_set();
accessors.push_back(cytnx::Accessor::range(cytnx_int64(start), cytnx_int64(stop),
cytnx_int64(step)));
} else {
accessors.push_back(cytnx::Accessor(Args[axis].cast<cytnx_int64>()));
}
}
// cytnx_error_msg(true,
// "[ERROR] cannot get element using [tuple] on is_diag=True UniTensor since the block is rank-1, consider [int] or [int:int] instead.%s", "\n");
} else if (py::isinstance<py::slice>(locators)) {
py::slice sls = locators.cast<py::slice>();
if (!sls.compute((ssize_t)self.shape()[0], &start, &stop, &step, &slicelength))
Expand Down
8 changes: 8 additions & 0 deletions src/Accessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,14 @@ namespace cytnx {
return *this;
}

// check equality
bool Accessor::operator==(const Accessor &rhs) const {
bool out = (this->_type == rhs._type) && (this->_min == rhs._min) && (this->_max == rhs._max) &&
(this->loc == rhs.loc) && (this->_step == rhs._step) &&
(this->idx_list == rhs.idx_list);
return out;
}

// get the real len from dim
// if _type is all, pos will be null, and len == dim
// if _type is range, pos will be the locator, and len == len(pos)
Expand Down
Loading
Loading