Skip to content

Commit

Permalink
Merge pull request #478 from Cytnx-dev/haoti/Fix_underline
Browse files Browse the repository at this point in the history
Modify reshape_ and permute_ to return itself.
  • Loading branch information
hunghaoti authored Oct 8, 2024
2 parents 4861097 + b1ea554 commit c266ec2
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 14 deletions.
18 changes: 12 additions & 6 deletions include/Tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -620,13 +620,13 @@ namespace cytnx {
*/
const bool &is_contiguous() const { return this->_impl->is_contiguous(); }

Tensor permute_(const std::vector<cytnx_uint64> &rnks) {
Tensor &permute_(const std::vector<cytnx_uint64> &rnks) {
this->_impl->permute_(rnks);
return *this;
}
/// @cond
template <class... Ts>
Tensor permute_(const cytnx_uint64 &e1, const Ts &...elems) {
Tensor &permute_(const cytnx_uint64 &e1, const Ts &...elems) {
std::vector<cytnx_uint64> argv = dynamic_arg_uint64_resolver(e1, elems...);
this->_impl->permute_(argv);
return *this;
Expand Down Expand Up @@ -725,21 +725,27 @@ namespace cytnx {
#### output>
\verbinclude example/Tensor/reshape_.py.out
*/
void reshape_(const std::vector<cytnx_int64> &new_shape) { this->_impl->reshape_(new_shape); }
Tensor &reshape_(const std::vector<cytnx_int64> &new_shape) {
this->_impl->reshape_(new_shape);
return *this;
}
/// @cond
void reshape_(const std::vector<cytnx_uint64> &new_shape) {
Tensor &reshape_(const std::vector<cytnx_uint64> &new_shape) {
std::vector<cytnx_int64> shape(new_shape.begin(), new_shape.end());
this->_impl->reshape_(shape);
return *this;
}
void reshape_(const std::initializer_list<cytnx_int64> &new_shape) {
Tensor &reshape_(const std::initializer_list<cytnx_int64> &new_shape) {
std::vector<cytnx_int64> shape = new_shape;
this->_impl->reshape_(shape);
return *this;
}
template <class... Ts>
void reshape_(const cytnx_int64 &e1, const Ts... elems) {
Tensor &reshape_(const cytnx_int64 &e1, const Ts... elems) {
std::vector<cytnx_int64> shape = dynamic_arg_int64_resolver(e1, elems...);
// std::cout << shape << std::endl;
this->_impl->reshape_(shape);
return *this;
}
/// @endcond

Expand Down
10 changes: 7 additions & 3 deletions include/UniTensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2718,8 +2718,9 @@ namespace cytnx {
@param[in] rowrank the row rank after the permutation
@warning \p by_label will be deprecated!
*/
void permute_(const std::vector<cytnx_int64> &mapper, const cytnx_int64 &rowrank = -1) {
UniTensor &permute_(const std::vector<cytnx_int64> &mapper, const cytnx_int64 &rowrank = -1) {
this->_impl->permute_(mapper, rowrank);
return *this;
}

/**
Expand All @@ -2728,8 +2729,9 @@ namespace cytnx {
@param[in] rowrank the row rank after the permutation
@see permute(const std::vector<std::string> &mapper, const cytnx_int64 &rowrank = -1)
*/
void permute_(const std::vector<std::string> &mapper, const cytnx_int64 &rowrank = -1) {
UniTensor &permute_(const std::vector<std::string> &mapper, const cytnx_int64 &rowrank = -1) {
this->_impl->permute_(mapper, rowrank);
return *this;
}

// void permute_( const std::initializer_list<char*> &mapper, const cytnx_int64 &rowrank= -1){
Expand Down Expand Up @@ -3337,8 +3339,10 @@ namespace cytnx {
cannot be UTenType::Block.
@see reshape(const std::vector<cytnx_int64> &new_shape, const cytnx_uint64 &rowrank)
*/
void reshape_(const std::vector<cytnx_int64> &new_shape, const cytnx_uint64 &rowrank = 0) {
UniTensor &reshape_(const std::vector<cytnx_int64> &new_shape,
const cytnx_uint64 &rowrank = 0) {
this->_impl->reshape_(new_shape, rowrank);
return *this;
}

/**
Expand Down
4 changes: 2 additions & 2 deletions pybind/tensor_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ void tensor_binding(py::module &m) {
[](cytnx::Tensor &self, py::args args) {
std::vector<cytnx::cytnx_uint64> c_args = args.cast<std::vector<cytnx::cytnx_uint64>>();
// std::cout << c_args.size() << std::endl;
self.permute_(c_args);
return &self.permute_(c_args);
})
.def("permute",
[](cytnx::Tensor &self, py::args args) -> cytnx::Tensor {
Expand All @@ -200,7 +200,7 @@ void tensor_binding(py::module &m) {
.def("reshape_",
[](cytnx::Tensor &self, py::args args) {
std::vector<cytnx::cytnx_int64> c_args = args.cast<std::vector<cytnx::cytnx_int64>>();
self.reshape_(c_args);
return &self.reshape_(c_args);
})
.def("reshape",
[](cytnx::Tensor &self, py::args args) -> cytnx::Tensor {
Expand Down
6 changes: 3 additions & 3 deletions pybind/unitensor_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ void unitensor_binding(py::module &m) {
if (kwargs.contains("rowrank")) rowrank = kwargs["rowrank"].cast<cytnx::cytnx_int64>();
}

self.reshape_(c_args, rowrank);
return &self.reshape_(c_args, rowrank);
})
.def("elem_exists", &UniTensor::elem_exists)
.def("item",
Expand Down Expand Up @@ -592,12 +592,12 @@ void unitensor_binding(py::module &m) {
// [Deprecated by_label!]
.def("permute_", [](UniTensor &self, const std::vector<cytnx_int64> &mapper, const cytnx_int64 &rowrank){

self.permute_(mapper,rowrank);
return &self.permute_(mapper,rowrank);

},py::arg("mapper"), py::arg("rowrank")=(cytnx_int64)(-1))

.def("permute_", [](UniTensor &self, const std::vector<std::string> &mapper, const cytnx_int64 &rowrank){
self.permute_(mapper,rowrank);
return &self.permute_(mapper,rowrank);
},py::arg("mapper"), py::arg("rowrank")=(cytnx_int64)(-1))

.def("permute", [](UniTensor &self, const std::vector<cytnx_int64> &mapper, const cytnx_int64 &rowrank){
Expand Down

0 comments on commit c266ec2

Please sign in to comment.