Skip to content

Commit

Permalink
Merge branch 'dev-master' into updateContracts
Browse files Browse the repository at this point in the history
merge
  • Loading branch information
j9263178 committed Sep 13, 2023
2 parents d03a2a5 + e44c67e commit 5f3cfea
Show file tree
Hide file tree
Showing 16 changed files with 1,057 additions and 206 deletions.
10 changes: 10 additions & 0 deletions codecov.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
coverage:
status:
project:
default:
target: auto
threshold: 100%
patch:
default:
target: auto
threshold: 100%
180 changes: 32 additions & 148 deletions src/linalg/Gesvd_truncate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,62 +19,28 @@ namespace cytnx {
std::vector<Tensor> Gesvd_truncate(const Tensor &Tin, const cytnx_uint64 &keepdim,
const double &err, const bool &is_U, const bool &is_vT,
const unsigned int &return_err) {
cytnx_error_msg(Tin.shape().size() != 2,
"[Gesvd_truncate] error, Gesvd_truncate can only operate on rank-2 Tensor.%s",
"\n");

if (Tin.device() == Device.cpu) {
std::vector<Tensor> tmps = Gesvd(Tin, is_U, is_vT);
Tensor terr({1}, Tin.dtype(), Tin.device());

cytnx_uint64 id = 0;
cytnx_uint64 Kdim = keepdim;
cytnx::linalg_internal::lii.memcpyTruncation_ii[Tin.dtype()](
tmps[1], tmps[2], tmps[0], terr, keepdim, err, is_U, is_vT, return_err);

Storage ts = tmps[0].storage();

if (ts.size() < keepdim) {
Kdim = ts.size();
}

cytnx_uint64 truc_dim = Kdim;
for (cytnx_int64 i = Kdim - 1; i >= 0; i--) {
if (ts.at(i) < err) {
truc_dim--;
} else {
break;
}
}

if (truc_dim == 0) {
truc_dim = 1;
}
/// std::cout << truc_dim << std::endl;
// cytnx_error_msg(tmps[0].shape()[0] < keepdim,"[ERROR] keepdim should be <= the valid # of
// singular value, %d!\n",tmps[0].shape()[0]);`
Tensor terr({1}, Type.Double);

if (truc_dim != ts.size()) {
if (return_err == 1)
terr = tmps[id](truc_dim);
else if (return_err)
terr = tmps[id].get({ac::tilend(truc_dim)});

tmps[id] = tmps[id].get({ac::range(0, truc_dim)});

if (is_U) {
id++;
tmps[id] = tmps[id].get({ac::all(), ac::range(0, truc_dim)});
}
if (is_vT) {
id++;
tmps[id] = tmps[id].get({ac::range(0, truc_dim), ac::all()});
}
}
if (return_err) tmps.push_back(terr);
std::vector<Tensor> outT;
outT.push_back(tmps[0]);
if (is_U) outT.push_back(tmps[1]);
if (is_vT) outT.push_back(tmps[2]);
if (return_err) outT.push_back(terr);

return tmps;
return outT;

} else {
#ifdef UNI_GPU
#ifdef UNI_CUQUANTUM
cytnx_error_msg(
Tin.shape().size() != 2,
"[Gesvd_truncate] error, Gesvd_truncate can only operate on rank-2 Tensor.%s", "\n");

Tensor in = Tin.contiguous();

Expand All @@ -88,10 +54,14 @@ namespace cytnx {
U.Init({in.shape()[0], n_singlu}, in.dtype(), in.device());
vT.Init({n_singlu, in.shape()[1]}, in.dtype(), in.device());
terr.Init({1}, in.dtype(), in.device());

cytnx::linalg_internal::lii.cuQuantumGeSvd_ii[in.dtype()](in, keepdim, err, return_err, U,
S, vT, terr);
std::vector<Tensor> outT;

cytnx::linalg_internal::lii.cudaMemcpyTruncation_ii[in.dtype()](
U, vT, S, terr, keepdim, err, is_U, is_vT, return_err);

std::vector<Tensor> outT;
outT.push_back(S);
if (is_U) outT.push_back(U);
if (is_vT) outT.push_back(vT);
Expand All @@ -100,9 +70,19 @@ namespace cytnx {
return outT;

#else
cytnx_error_msg(true, "[Gesvd_truncate] fatal error,%s",
"try to call the cuquantum section without cuQunatum support.\n");
return std::vector<Tensor>();
std::vector<Tensor> tmps = Gesvd(Tin, is_U, is_vT);
Tensor terr({1}, Tin.dtype(), Tin.device());

cytnx::linalg_internal::lii.cudaMemcpyTruncation_ii[Tin.dtype()](
tmps[1], tmps[2], tmps[0], terr, keepdim, err, is_U, is_vT, return_err);

std::vector<Tensor> outT;
outT.push_back(tmps[0]);
if (is_U) outT.push_back(tmps[1]);
if (is_vT) outT.push_back(tmps[2]);
if (return_err) outT.push_back(terr);

return outT;
#endif
#else
cytnx_error_msg(true, "[Gesvd_truncate] fatal error,%s",
Expand All @@ -119,83 +99,6 @@ namespace cytnx {
using namespace std;
typedef Accessor ac;

#ifdef UNI_GPU
#ifdef UNI_CUQUANTUM
void _cuquantum_gesvdj_truncate_Dense_UT(std::vector<UniTensor> &outCyT,
const cytnx::UniTensor &Tin,
const cytnx_uint64 &keepdim, const double &err,
const bool &is_U, const bool &is_vT,
const unsigned int &return_err) {
// Retrieve tensor from UniTensor
Tensor tmp;
if (Tin.is_contiguous())
tmp = Tin.get_block_();
else {
tmp = Tin.get_block();
tmp.contiguous_();
}

vector<cytnx_uint64> tmps = tmp.shape();
vector<cytnx_int64> oldshape(tmps.begin(), tmps.end());
tmps.clear();
vector<string> oldlabel = Tin.labels();

// collapse as Matrix:
cytnx_int64 rowdim = 1;
for (cytnx_uint64 i = 0; i < Tin.rowrank(); i++) rowdim *= Tin.shape()[i];

// pass to tensor API
vector<Tensor> outT = cytnx::linalg::Gesvd_truncate(tmp.reshape({rowdim, -1}), keepdim, err,
is_U, is_vT, return_err);

// set output
int t = 0;
outCyT.resize(outT.size());
cytnx::UniTensor &Cy_S = outCyT[t];
cytnx::Bond newBond(outT[0].shape()[0]);
Cy_S.Init({newBond, newBond}, {string("_aux_L"), string("_aux_R")}, 1, Type.Double,
Tin.device(),
true); // it is just reference so no hurt to alias ^^
Cy_S.put_block_(outT[t]);
t++;

if (is_U) {
cytnx::UniTensor &Cy_U = outCyT[t];
// shape
vector<cytnx_int64> shapeU = vec_clone(oldshape, Tin.rowrank());
shapeU.push_back(-1);

outT[t].reshape_(shapeU);

Cy_U.Init(outT[t], false, Tin.rowrank());
vector<string> labelU = vec_clone(oldlabel, Tin.rowrank());
labelU.push_back(Cy_S.labels()[0]);
Cy_U.set_labels(labelU);
t++; // U
}

if (is_vT) {
cytnx::UniTensor &Cy_vT = outCyT[t];
// shape
vector<cytnx_int64> shapevT(Tin.rank() - Tin.rowrank() + 1);
shapevT[0] = -1;
memcpy(&shapevT[1], &oldshape[Tin.rowrank()], sizeof(cytnx_int64) * (shapevT.size() - 1));

outT[t].reshape_(shapevT);

Cy_vT.Init(outT[t], false, 1);
vector<string> labelvT(shapevT.size());
labelvT[0] = Cy_S.labels()[1];
std::copy(oldlabel.begin() + Tin.rowrank(), oldlabel.end(), labelvT.begin() + 1);
Cy_vT.set_labels(labelvT);
t++; // vT
}

if (return_err) outCyT.back().Init(outT.back(), false, 0);
}
#endif
#endif

void _gesvd_truncate_Dense_UT(std::vector<UniTensor> &outCyT, const cytnx::UniTensor &Tin,
const cytnx_uint64 &keepdim, const double &err, const bool &is_U,
const bool &is_vT, const unsigned int &return_err) {
Expand Down Expand Up @@ -458,28 +361,9 @@ namespace cytnx {

std::vector<UniTensor> outCyT;
if (Tin.uten_type() == UTenType.Dense) {
if (Tin.device() == Device.cpu) {
_gesvd_truncate_Dense_UT(outCyT, Tin, keepdim, err, is_U, is_vT, return_err);
} else {
#ifdef UNI_GPU
#ifdef UNI_CUQUANTUM
_cuquantum_gesvdj_truncate_Dense_UT(outCyT, Tin, keepdim, err, is_U, is_vT, return_err);
#else
cytnx_error_msg(true, "[cuQuantumSvd] fatal error,%s",
"try to call the cuquantum section without cuQunatum support.\n");
return std::vector<cytnx::UniTensor>();
#endif

#else
cytnx_error_msg(true, "[cuQuantumSvd] fatal error,%s",
"try to call the gpu section without CUDA support.\n");
return std::vector<cytnx::UniTensor>();
#endif
}

_gesvd_truncate_Dense_UT(outCyT, Tin, keepdim, err, is_U, is_vT, return_err);
} else if (Tin.uten_type() == UTenType.Block) {
_gesvd_truncate_Block_UT(outCyT, Tin, keepdim, err, is_U, is_vT, return_err);

} else {
cytnx_error_msg(true, "[ERROR] only support gesvd for Dense and Block UniTensor.%s", "\n");
}
Expand Down
68 changes: 31 additions & 37 deletions src/linalg/Svd_truncate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "Tensor.hpp"
#include "UniTensor.hpp"
#include "algo.hpp"
#include "linalg_internal_interface.hpp"

namespace cytnx {
namespace linalg {
Expand All @@ -12,53 +13,46 @@ namespace cytnx {
const double &err, const bool &is_UvT,
const unsigned int &return_err) {
cytnx_error_msg(return_err < 0, "[ERROR] return_err can only be positive int%s", "\n");
std::vector<Tensor> tmps = Svd(Tin, is_UvT);
if (Tin.device() == Device.cpu) {
std::vector<Tensor> tmps = Svd(Tin, is_UvT);

cytnx_uint64 id = 0;
cytnx_uint64 Kdim = keepdim;
Tensor terr({1}, Tin.dtype(), Tin.device());

Storage ts = tmps[0].storage();
cytnx::linalg_internal::lii.memcpyTruncation_ii[Tin.dtype()](
tmps[1], tmps[2], tmps[0], terr, keepdim, err, is_UvT, is_UvT, return_err);

if (ts.size() < keepdim) {
Kdim = ts.size();
}

cytnx_uint64 truc_dim = Kdim;
for (cytnx_int64 i = Kdim - 1; i >= 0; i--) {
if (ts.at(i) < err) {
truc_dim--;
} else {
break;
std::vector<Tensor> outT;
outT.push_back(tmps[0]);
if (is_UvT) {
outT.push_back(tmps[1]);
outT.push_back(tmps[2]);
}
}
if (return_err) outT.push_back(terr);

if (truc_dim == 0) {
truc_dim = 1;
}
/// std::cout << truc_dim << std::endl;
// cytnx_error_msg(tmps[0].shape()[0] < keepdim,"[ERROR] keepdim should be <= the valid # of
// singular value, %d!\n",tmps[0].shape()[0]);
Tensor terr({1}, Type.Double);

if (truc_dim != ts.size()) {
if (return_err == 1)
terr = tmps[id](truc_dim);
else if (return_err)
terr = tmps[id].get({ac::tilend(truc_dim)});
return outT;
} else {
#ifdef UNI_GPU
std::vector<Tensor> tmps = Svd(Tin, is_UvT);
Tensor terr({1}, Tin.dtype(), Tin.device());

tmps[id] = tmps[id].get({ac::range(0, truc_dim)});
cytnx::linalg_internal::lii.cudaMemcpyTruncation_ii[Tin.dtype()](
tmps[1], tmps[2], tmps[0], terr, keepdim, err, is_UvT, is_UvT, return_err);

std::vector<Tensor> outT;
outT.push_back(tmps[0]);
if (is_UvT) {
id++;
tmps[id] = tmps[id].get({ac::all(), ac::range(0, truc_dim)});

id++;
tmps[id] = tmps[id].get({ac::range(0, truc_dim), ac::all()});
outT.push_back(tmps[1]);
outT.push_back(tmps[2]);
}
if (return_err) outT.push_back(terr);

return outT;
#else
cytnx_error_msg(true, "[Svd_truncate] fatal error,%s",
"try to call the gpu section without CUDA support.\n");
return std::vector<Tensor>();
#endif
}
if (return_err) tmps.push_back(terr);

return tmps;
}
} // namespace linalg
} // namespace cytnx
Expand Down
4 changes: 4 additions & 0 deletions src/linalg/linalg_internal_cpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ target_sources_local(cytnx
Axpy_internal.hpp
Ger_internal.hpp

memcpyTruncation.hpp

Add_internal.cpp
iAdd_internal.cpp
Arithmetic_internal.cpp
Expand Down Expand Up @@ -85,4 +87,6 @@ target_sources_local(cytnx

Axpy_internal.cpp
Ger_internal.cpp

memcpyTruncation.cpp
)
Loading

0 comments on commit 5f3cfea

Please sign in to comment.