Skip to content

Commit

Permalink
-Svd_truncate calls Svd, and then uses the dtype of U for the truncat…
Browse files Browse the repository at this point in the history
…ion instead of that of S (always real) or Tin (can be Int, Bool, etc)

-some renaming of variable truc to trunc
  • Loading branch information
manuschneider committed Oct 2, 2024
1 parent f9781a9 commit 3588989
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 70 deletions.
136 changes: 68 additions & 68 deletions src/backend/linalg_internal_cpu/memcpyTruncation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,62 +14,62 @@ namespace cytnx {
const cytnx_uint64 &keepdim, const double &err, const bool &is_U,
const bool &is_vT, const unsigned int &return_err,
const unsigned int &mindim) {
// determine the truc_dim
// determine the trunc_dim
cytnx_uint64 Kdim = keepdim;
cytnx_uint64 nums = S.storage().size();
if (nums < keepdim) {
Kdim = nums;
}
cytnx_uint64 truc_dim = Kdim;
cytnx_uint64 trunc_dim = Kdim;
for (cytnx_int64 i = Kdim - 1; i >= 0; i--) {
if (((cytnx_double *)S._impl->storage()._impl->Mem)[i] < err and truc_dim - 1 >= mindim) {
truc_dim--;
if (((cytnx_double *)S._impl->storage()._impl->Mem)[i] < err and trunc_dim - 1 >= mindim) {
trunc_dim--;
} else {
break;
}
}
if (truc_dim == 0) {
truc_dim = 1;
if (trunc_dim == 0) {
trunc_dim = 1;
}
if (truc_dim != nums) {
if (trunc_dim != nums) {
// perform the manual truncation

Tensor newS = Tensor({truc_dim}, S.dtype(), S.device());
Tensor newS = Tensor({trunc_dim}, S.dtype(), S.device());
memcpy((cytnx_double *)newS._impl->storage()._impl->Mem,
(cytnx_double *)S._impl->storage()._impl->Mem, truc_dim * sizeof(cytnx_double));
(cytnx_double *)S._impl->storage()._impl->Mem, trunc_dim * sizeof(cytnx_double));
if (is_U) {
Tensor newU = Tensor({U.shape()[0], truc_dim}, U.dtype(), U.device());
Tensor newU = Tensor({U.shape()[0], trunc_dim}, U.dtype(), U.device());

int src = 0;
int dest = 0;
// copy with strides.
for (int i = 0; i < U.shape()[0]; i++) {
memcpy((cytnx_complex128 *)newU._impl->storage()._impl->Mem + src,
(cytnx_complex128 *)U._impl->storage()._impl->Mem + dest,
truc_dim * sizeof(cytnx_complex128));
src += truc_dim;
trunc_dim * sizeof(cytnx_complex128));
src += trunc_dim;
dest += U.shape()[1];
}
U = newU;
}
if (is_vT) {
Tensor newvT = Tensor({truc_dim, vT.shape()[1]}, vT.dtype(), vT.device());
Tensor newvT = Tensor({trunc_dim, vT.shape()[1]}, vT.dtype(), vT.device());
// simply copy a new one dropping the tail.
memcpy((cytnx_complex128 *)newvT._impl->storage()._impl->Mem,
(cytnx_complex128 *)vT._impl->storage()._impl->Mem,
vT.shape()[1] * truc_dim * sizeof(cytnx_complex128));
vT.shape()[1] * trunc_dim * sizeof(cytnx_complex128));
vT = newvT;
}
if (return_err == 1) {
Tensor newterr = Tensor({1}, S.dtype(), S.device());
((cytnx_double *)newterr._impl->storage()._impl->Mem)[0] =
((cytnx_double *)S._impl->storage()._impl->Mem)[truc_dim];
((cytnx_double *)S._impl->storage()._impl->Mem)[trunc_dim];
terr = newterr;
} else if (return_err) {
cytnx_uint64 discared_dim = S.shape()[0] - truc_dim;
cytnx_uint64 discared_dim = S.shape()[0] - trunc_dim;
Tensor newterr = Tensor({discared_dim}, S.dtype(), S.device());
memcpy((cytnx_double *)newterr._impl->storage()._impl->Mem,
(cytnx_double *)S._impl->storage()._impl->Mem + truc_dim,
(cytnx_double *)S._impl->storage()._impl->Mem + trunc_dim,
discared_dim * sizeof(cytnx_double));
terr = newterr;
}
Expand All @@ -81,62 +81,62 @@ namespace cytnx {
const cytnx_uint64 &keepdim, const double &err, const bool &is_U,
const bool &is_vT, const unsigned int &return_err,
const unsigned int &mindim) {
// determine the truc_dim
// determine the trunc_dim
cytnx_uint64 Kdim = keepdim;
cytnx_uint64 nums = S.storage().size();
if (nums < keepdim) {
Kdim = nums;
}
cytnx_uint64 truc_dim = Kdim;
cytnx_uint64 trunc_dim = Kdim;
for (cytnx_int64 i = Kdim - 1; i >= 0; i--) {
if (((cytnx_double *)S._impl->storage()._impl->Mem)[i] < err and truc_dim - 1 >= mindim) {
truc_dim--;
if (((cytnx_double *)S._impl->storage()._impl->Mem)[i] < err and trunc_dim - 1 >= mindim) {
trunc_dim--;
} else {
break;
}
}
if (truc_dim == 0) {
truc_dim = 1;
if (trunc_dim == 0) {
trunc_dim = 1;
}
if (truc_dim != nums) {
if (trunc_dim != nums) {
// perform the manual truncation

Tensor newS = Tensor({truc_dim}, S.dtype(), S.device());
Tensor newS = Tensor({trunc_dim}, S.dtype(), S.device());
memcpy((cytnx_double *)newS._impl->storage()._impl->Mem,
(cytnx_double *)S._impl->storage()._impl->Mem, truc_dim * sizeof(cytnx_double));
(cytnx_double *)S._impl->storage()._impl->Mem, trunc_dim * sizeof(cytnx_double));
if (is_U) {
Tensor newU = Tensor({U.shape()[0], truc_dim}, U.dtype(), U.device());
Tensor newU = Tensor({U.shape()[0], trunc_dim}, U.dtype(), U.device());

int src = 0;
int dest = 0;
// copy with strides.
for (int i = 0; i < U.shape()[0]; i++) {
memcpy((cytnx_complex64 *)newU._impl->storage()._impl->Mem + src,
(cytnx_complex64 *)U._impl->storage()._impl->Mem + dest,
truc_dim * sizeof(cytnx_complex64));
src += truc_dim;
trunc_dim * sizeof(cytnx_complex64));
src += trunc_dim;
dest += U.shape()[1];
}
U = newU;
}
if (is_vT) {
Tensor newvT = Tensor({truc_dim, vT.shape()[1]}, vT.dtype(), vT.device());
Tensor newvT = Tensor({trunc_dim, vT.shape()[1]}, vT.dtype(), vT.device());
// simply copy a new one dropping the tail.
memcpy((cytnx_complex64 *)newvT._impl->storage()._impl->Mem,
(cytnx_complex64 *)vT._impl->storage()._impl->Mem,
vT.shape()[1] * truc_dim * sizeof(cytnx_complex64));
vT.shape()[1] * trunc_dim * sizeof(cytnx_complex64));
vT = newvT;
}
if (return_err == 1) {
Tensor newterr = Tensor({1}, S.dtype(), S.device());
((cytnx_double *)newterr._impl->storage()._impl->Mem)[0] =
((cytnx_double *)S._impl->storage()._impl->Mem)[truc_dim];
((cytnx_double *)S._impl->storage()._impl->Mem)[trunc_dim];
terr = newterr;
} else if (return_err) {
cytnx_uint64 discared_dim = S.shape()[0] - truc_dim;
cytnx_uint64 discared_dim = S.shape()[0] - trunc_dim;
Tensor newterr = Tensor({discared_dim}, S.dtype(), S.device());
memcpy((cytnx_double *)newterr._impl->storage()._impl->Mem,
(cytnx_double *)S._impl->storage()._impl->Mem + truc_dim,
(cytnx_double *)S._impl->storage()._impl->Mem + trunc_dim,
discared_dim * sizeof(cytnx_double));
terr = newterr;
}
Expand All @@ -148,62 +148,62 @@ namespace cytnx {
const cytnx_uint64 &keepdim, const double &err, const bool &is_U,
const bool &is_vT, const unsigned int &return_err,
const unsigned int &mindim) {
// determine the truc_dim
// determine the trunc_dim
cytnx_uint64 Kdim = keepdim;
cytnx_uint64 nums = S.storage().size();
if (nums < keepdim) {
Kdim = nums;
}
cytnx_uint64 truc_dim = Kdim;
cytnx_uint64 trunc_dim = Kdim;
for (cytnx_int64 i = Kdim - 1; i >= 0; i--) {
if (((cytnx_double *)S._impl->storage()._impl->Mem)[i] < err and truc_dim - 1 >= mindim) {
truc_dim--;
if (((cytnx_double *)S._impl->storage()._impl->Mem)[i] < err and trunc_dim - 1 >= mindim) {
trunc_dim--;
} else {
break;
}
}
if (truc_dim == 0) {
truc_dim = 1;
if (trunc_dim == 0) {
trunc_dim = 1;
}
if (truc_dim != nums) {
if (trunc_dim != nums) {
// perform the manual truncation

Tensor newS = Tensor({truc_dim}, S.dtype(), S.device());
Tensor newS = Tensor({trunc_dim}, S.dtype(), S.device());
memcpy((cytnx_double *)newS._impl->storage()._impl->Mem,
(cytnx_double *)S._impl->storage()._impl->Mem, truc_dim * sizeof(cytnx_double));
(cytnx_double *)S._impl->storage()._impl->Mem, trunc_dim * sizeof(cytnx_double));
if (is_U) {
Tensor newU = Tensor({U.shape()[0], truc_dim}, U.dtype(), U.device());
Tensor newU = Tensor({U.shape()[0], trunc_dim}, U.dtype(), U.device());

int src = 0;
int dest = 0;
// copy with strides.
for (int i = 0; i < U.shape()[0]; i++) {
memcpy((cytnx_double *)newU._impl->storage()._impl->Mem + src,
(cytnx_double *)U._impl->storage()._impl->Mem + dest,
truc_dim * sizeof(cytnx_double));
src += truc_dim;
trunc_dim * sizeof(cytnx_double));
src += trunc_dim;
dest += U.shape()[1];
}
U = newU;
}
if (is_vT) {
Tensor newvT = Tensor({truc_dim, vT.shape()[1]}, vT.dtype(), vT.device());
Tensor newvT = Tensor({trunc_dim, vT.shape()[1]}, vT.dtype(), vT.device());
// simply copy a new one dropping the tail.
memcpy((cytnx_double *)newvT._impl->storage()._impl->Mem,
(cytnx_double *)vT._impl->storage()._impl->Mem,
vT.shape()[1] * truc_dim * sizeof(cytnx_double));
vT.shape()[1] * trunc_dim * sizeof(cytnx_double));
vT = newvT;
}
if (return_err == 1) {
Tensor newterr = Tensor({1}, S.dtype(), S.device());
((cytnx_double *)newterr._impl->storage()._impl->Mem)[0] =
((cytnx_double *)S._impl->storage()._impl->Mem)[truc_dim];
((cytnx_double *)S._impl->storage()._impl->Mem)[trunc_dim];
terr = newterr;
} else if (return_err) {
cytnx_uint64 discared_dim = S.shape()[0] - truc_dim;
cytnx_uint64 discared_dim = S.shape()[0] - trunc_dim;
Tensor newterr = Tensor({discared_dim}, S.dtype(), S.device());
memcpy((cytnx_double *)newterr._impl->storage()._impl->Mem,
(cytnx_double *)S._impl->storage()._impl->Mem + truc_dim,
(cytnx_double *)S._impl->storage()._impl->Mem + trunc_dim,
discared_dim * sizeof(cytnx_double));
terr = newterr;
}
Expand All @@ -215,62 +215,62 @@ namespace cytnx {
const cytnx_uint64 &keepdim, const double &err, const bool &is_U,
const bool &is_vT, const unsigned int &return_err,
const unsigned int &mindim) {
// determine the truc_dim
// determine the trunc_dim
cytnx_uint64 Kdim = keepdim;
cytnx_uint64 nums = S.storage().size();
if (nums < keepdim) {
Kdim = nums;
}
cytnx_uint64 truc_dim = Kdim;
cytnx_uint64 trunc_dim = Kdim;
for (cytnx_int64 i = Kdim - 1; i >= 0; i--) {
if (((cytnx_double *)S._impl->storage()._impl->Mem)[i] < err and truc_dim - 1 >= mindim) {
truc_dim--;
if (((cytnx_double *)S._impl->storage()._impl->Mem)[i] < err and trunc_dim - 1 >= mindim) {
trunc_dim--;
} else {
break;
}
}
if (truc_dim == 0) {
truc_dim = 1;
if (trunc_dim == 0) {
trunc_dim = 1;
}
if (truc_dim != nums) {
if (trunc_dim != nums) {
// perform the manual truncation

Tensor newS = Tensor({truc_dim}, S.dtype(), S.device());
Tensor newS = Tensor({trunc_dim}, S.dtype(), S.device());
memcpy((cytnx_double *)newS._impl->storage()._impl->Mem,
(cytnx_double *)S._impl->storage()._impl->Mem, truc_dim * sizeof(cytnx_double));
(cytnx_double *)S._impl->storage()._impl->Mem, trunc_dim * sizeof(cytnx_double));
if (is_U) {
Tensor newU = Tensor({U.shape()[0], truc_dim}, U.dtype(), U.device());
Tensor newU = Tensor({U.shape()[0], trunc_dim}, U.dtype(), U.device());

int src = 0;
int dest = 0;
// copy with strides.
for (int i = 0; i < U.shape()[0]; i++) {
memcpy((cytnx_float *)newU._impl->storage()._impl->Mem + src,
(cytnx_float *)U._impl->storage()._impl->Mem + dest,
truc_dim * sizeof(cytnx_float));
src += truc_dim;
trunc_dim * sizeof(cytnx_float));
src += trunc_dim;
dest += U.shape()[1];
}
U = newU;
}
if (is_vT) {
Tensor newvT = Tensor({truc_dim, vT.shape()[1]}, vT.dtype(), vT.device());
Tensor newvT = Tensor({trunc_dim, vT.shape()[1]}, vT.dtype(), vT.device());
// simply copy a new one dropping the tail.
memcpy((cytnx_float *)newvT._impl->storage()._impl->Mem,
(cytnx_float *)vT._impl->storage()._impl->Mem,
vT.shape()[1] * truc_dim * sizeof(cytnx_float));
vT.shape()[1] * trunc_dim * sizeof(cytnx_float));
vT = newvT;
}
if (return_err == 1) {
Tensor newterr = Tensor({1}, S.dtype(), S.device());
((cytnx_double *)newterr._impl->storage()._impl->Mem)[0] =
((cytnx_double *)S._impl->storage()._impl->Mem)[truc_dim];
((cytnx_double *)S._impl->storage()._impl->Mem)[trunc_dim];
terr = newterr;
} else if (return_err) {
cytnx_uint64 discared_dim = S.shape()[0] - truc_dim;
cytnx_uint64 discared_dim = S.shape()[0] - trunc_dim;
Tensor newterr = Tensor({discared_dim}, S.dtype(), S.device());
memcpy((cytnx_double *)newterr._impl->storage()._impl->Mem,
(cytnx_double *)S._impl->storage()._impl->Mem + truc_dim,
(cytnx_double *)S._impl->storage()._impl->Mem + trunc_dim,
discared_dim * sizeof(cytnx_double));
terr = newterr;
}
Expand Down
5 changes: 3 additions & 2 deletions src/linalg/Svd_truncate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ namespace cytnx {

Tensor terr({1}, Tin.dtype(), Tin.device());

cytnx::linalg_internal::lii.memcpyTruncation_ii[tmps[0].dtype()](
// dtype should be that of U (or Vt) here, since S is real and Tin could be Int, Bool etc.
cytnx::linalg_internal::lii.memcpyTruncation_ii[tmps[1].dtype()](
tmps[1], tmps[2], tmps[0], terr, keepdim, err, is_UvT, is_UvT, return_err, mindim);

std::vector<Tensor> outT;
Expand All @@ -39,7 +40,7 @@ namespace cytnx {
std::vector<Tensor> tmps = Svd(Tin, is_UvT);
Tensor terr({1}, Tin.dtype(), Tin.device());

cytnx::linalg_internal::lii.cudaMemcpyTruncation_ii[tmps[0].dtype()](
cytnx::linalg_internal::lii.cudaMemcpyTruncation_ii[tmps[1].dtype()](
tmps[1], tmps[2], tmps[0], terr, keepdim, err, is_UvT, is_UvT, return_err, mindim);

std::vector<Tensor> outT;
Expand Down

0 comments on commit 3588989

Please sign in to comment.