Skip to content

Commit

Permalink
add arg tol for dn to bk conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
j9263178 committed Dec 3, 2023
1 parent c93b318 commit eccb0be
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 13 deletions.
4 changes: 2 additions & 2 deletions cytnx/UniTensor_conti.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,6 @@ def at(self, labels:List[str], locator:List[int]):


@add_method(UniTensor)
def convert_from(self, Tin, force=False):
self.cfrom(Tin,force);
def convert_from(self, Tin, force=False, tol = 1e-14):
self.cfrom(Tin,force,tol);
return self
10 changes: 7 additions & 3 deletions include/UniTensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,8 @@ namespace cytnx {
virtual const cytnx_int16 &at_for_sparse(const std::vector<cytnx_uint64> &locator,
const cytnx_int16 &aux) const;

virtual void from_(const boost::intrusive_ptr<UniTensor_base> &rhs, const bool &force,
const cytnx_double &tol);
virtual void from_(const boost::intrusive_ptr<UniTensor_base> &rhs, const bool &force);

virtual void group_basis_();
Expand Down Expand Up @@ -1700,7 +1702,8 @@ namespace cytnx {
"This operation will destroy block structure. [Suggest] using get/set_block(s) to do "
"operation on the block(s).");
}
void from_(const boost::intrusive_ptr<UniTensor_base> &rhs, const bool &force);
void from_(const boost::intrusive_ptr<UniTensor_base> &rhs, const bool &force,
const cytnx_double &tol);

void group_basis_();

Expand Down Expand Up @@ -4078,8 +4081,9 @@ namespace cytnx {
void _Save(std::fstream &f) const;
/// @endcond

UniTensor &convert_from(const UniTensor &rhs, const bool &force = false) {
this->_impl->from_(rhs._impl, force);
UniTensor &convert_from(const UniTensor &rhs, const bool &force = false,
const cytnx_double &tol = 1e-14) {
this->_impl->from_(rhs._impl, force, tol);
return *this;
}

Expand Down
6 changes: 3 additions & 3 deletions pybind/unitensor_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1365,10 +1365,10 @@ void unitensor_binding(py::module &m) {
},
py::arg("low"), py::arg("high"), py::arg("seed")= -1)

.def("cfrom", [](UniTensor &self, const UniTensor &in, const bool &force){
self.convert_from(in,force);
.def("cfrom", [](UniTensor &self, const UniTensor &in, const bool &force, const cytnx_double &tol){
self.convert_from(in,force,tol);
},
py::arg("Tin"), py::arg("force") = false)
py::arg("Tin"), py::arg("force") = false, py::arg("tol") = 1e-14)
.def("get_qindices", [](UniTensor &self, const cytnx_uint64 &bidx){return self.get_qindices(bidx);});
; // end of object line

Expand Down
10 changes: 6 additions & 4 deletions src/BlockUniTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1992,7 +1992,8 @@ namespace cytnx {
this->combineBonds(idx_mapper, force);
}

void _BK_from_DN(BlockUniTensor *ths, DenseUniTensor *rhs, const bool &force) {
void _BK_from_DN(BlockUniTensor *ths, DenseUniTensor *rhs, const bool &force,
const cytnx_double &tol) {
if (!force) {
// more checking:
if (int(rhs->bond_(0).type()) != bondType::BD_NONE) {
Expand Down Expand Up @@ -2020,7 +2021,7 @@ namespace cytnx {
elem = rhs->_block.at(cart);
} else {
if (!force)
if (abs(Scalar(rhs->_block.at(cart))) > 1e-14) {
if (abs(Scalar(rhs->_block.at(cart))) > tol) {
cytnx_error_msg(true,
"[ERROR] force = false, trying to convert DenseUT to BlockUT that "
"violate the symmetry structure.%s",
Expand All @@ -2034,12 +2035,13 @@ namespace cytnx {
cytnx_error_msg(true, "[ERROR] BlockUT-> BlockUT not implemented.%s", "\n");
}

void BlockUniTensor::from_(const boost::intrusive_ptr<UniTensor_base> &rhs, const bool &force) {
void BlockUniTensor::from_(const boost::intrusive_ptr<UniTensor_base> &rhs, const bool &force,
const cytnx_double &tol) {
// checking shape:
cytnx_error_msg(this->shape() != rhs->shape(), "[ERROR][from_] shape does not match.%s", "\n");

if (rhs->uten_type() == UTenType.Dense) {
_BK_from_DN(this, (DenseUniTensor *)(rhs.get()), force);
_BK_from_DN(this, (DenseUniTensor *)(rhs.get()), force, tol);
} else if (rhs->uten_type() == UTenType.Block) {
_BK_from_BK(this, (BlockUniTensor *)(rhs.get()), force);
} else {
Expand Down
3 changes: 2 additions & 1 deletion src/UniTensor_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,8 @@ namespace cytnx {
"\n");
}

void UniTensor_base::from_(const boost::intrusive_ptr<UniTensor_base> &rhs, const bool &force) {
void UniTensor_base::from_(const boost::intrusive_ptr<UniTensor_base> &rhs, const bool &force,
const cytnx_double &tol) {
cytnx_error_msg(true, "[ERROR] fatal internal, cannot call on a un-initialize UniTensor_base%s",
"\n");
}
Expand Down

0 comments on commit eccb0be

Please sign in to comment.