Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft svd #17

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/cytnx_torch/bond.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ def redirect(self) -> "AbstractBond":
out.bond_type = BondType(-self.bond_type.value)
return out

@property
def directional(self) -> bool:
return self.bond_type != BondType.NONE


@dataclass
class Bond(AbstractBond):
Expand Down
3 changes: 2 additions & 1 deletion src/cytnx_torch/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from .bond import Bond, SymBond

if TYPE_CHECKING:
from .unitensor.entry import RegularUniTensor, BlockUniTensor, AbstractUniTensor
from .unitensor.entry import RegularUniTensor, BlockUniTensor
from .unitensor.base import AbstractUniTensor


@dataclass
Expand Down
40 changes: 36 additions & 4 deletions src/cytnx_torch/linalg/svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from torch import linalg as torch_linalg # noqa F401


def _svd_regular_tn(A: RegularUniTensor):
def _svd_regular_tn(
A: RegularUniTensor, truncate_dim: int = None, truncate_tol: float = None
):
"""
svd(A, **kwargs):
Singular Value Decomposition of a matrix A.
Expand All @@ -19,11 +21,41 @@ def _svd_regular_tn(A: RegularUniTensor):
V (cytnx.UniTensor): the right singular vectors
"""

if A.is_diag:
raise ValueError("TODO SVD is not supported for diagonal tensors for now")

mat_A, cL, cR = A.as_matrix(left_bond_label="_tmp_L_", right_bond_label="_tmp_R_")

# get the data:
# u,s,v = torch_linalg.svd(mat_A)
u_dat, s_dat, v_dat = torch_linalg.svd(mat_A.data)

# create new bonds:
# new_bond_dim = len(s)
# internal_bond = Bond(new_bond_dim,bond_type=BondType.OUT)
new_bond_dim = len(s_dat)
internal_bond = Bond(
new_bond_dim,
bond_type=BondType.OUT if A.is_directional_bonds else BondType.NONE,
)

# construct U tensor:
u_labels = ["_tmp_L_", "_aux_L_"]
u_bonds = [mat_A.bonds[0], internal_bond]
u = RegularUniTensor(labels=u_labels, bonds=u_bonds, is_diag=False, data=u_dat)

# construct s tensor:
s_labels = ["_aux_L_", "_aux_R_"]
s_bonds = [internal_bond.redirect(), internal_bond]
s = RegularUniTensor(labels=s_labels, bonds=s_bonds, is_diag=True, data=s_dat)

# construct v tensor:
v_labels = ["_aux_R_", "_tmp_R_"]
v_bonds = [internal_bond, mat_A.bonds[1]]
v = RegularUniTensor(labels=v_labels, bonds=v_bonds, is_diag=False, data=v_dat)

if truncate_dim or truncate_tol:
# deal truncate
# TODO
pass

new_mat = u.contract(s).contract(v)

return cL.contract(new_mat).contract(cR)
4 changes: 4 additions & 0 deletions src/cytnx_torch/unitensor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ def device(self) -> torch.device:
def dtype(self) -> torch.dtype:
raise NotImplementedError("not implement for abstract type trait.")

@property
def is_directional_bonds(self) -> bool:
return all([b.directional for b in self.bonds])

@abstractmethod
def _repr_body_diagram(self) -> str:
raise NotImplementedError("not implement for abstract type trait.")
Expand Down
4 changes: 4 additions & 0 deletions src/cytnx_torch/unitensor/block_unitensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ def _get_symmetries(self) -> Tuple[Symmetry]:
else:
return tuple()

@property
def is_directional_bonds(self) -> bool:
return True

def _generate_meta(self) -> np.ndarray[int]:
qnindices = [np.arange(len(bd._qnums)) for bd in self.bonds]
qn_indices_map = np.meshgrid(*qnindices)
Expand Down