diff --git a/src/cytnx_torch/bond.py b/src/cytnx_torch/bond.py index df219f3..a4bb90c 100644 --- a/src/cytnx_torch/bond.py +++ b/src/cytnx_torch/bond.py @@ -74,6 +74,10 @@ def redirect(self) -> "AbstractBond": out.bond_type = BondType(-self.bond_type.value) return out + @property + def directional(self) -> bool: + return self.bond_type != BondType.NONE + @dataclass class Bond(AbstractBond): diff --git a/src/cytnx_torch/converter.py b/src/cytnx_torch/converter.py index 1f80afc..986f4af 100644 --- a/src/cytnx_torch/converter.py +++ b/src/cytnx_torch/converter.py @@ -5,7 +5,8 @@ from .bond import Bond, SymBond if TYPE_CHECKING: - from .unitensor.entry import RegularUniTensor, BlockUniTensor, AbstractUniTensor + from .unitensor.entry import RegularUniTensor, BlockUniTensor + from .unitensor.base import AbstractUniTensor @dataclass diff --git a/src/cytnx_torch/linalg/svd.py b/src/cytnx_torch/linalg/svd.py index 4cb8eb8..ba2b945 100644 --- a/src/cytnx_torch/linalg/svd.py +++ b/src/cytnx_torch/linalg/svd.py @@ -3,7 +3,9 @@ from torch import linalg as torch_linalg # noqa F401 -def _svd_regular_tn(A: RegularUniTensor): +def _svd_regular_tn( + A: RegularUniTensor, truncate_dim: int = None, truncate_tol: float = None +): """ svd(A, **kwargs): Singular Value Decomposition of a matrix A. @@ -19,11 +21,41 @@ def _svd_regular_tn(A: RegularUniTensor): V (cytnx.UniTensor): the right singular vectors """ + if A.is_diag: + raise ValueError("TODO SVD is not supported for diagonal tensors for now") + mat_A, cL, cR = A.as_matrix(left_bond_label="_tmp_L_", right_bond_label="_tmp_R_") # get the data: - # u,s,v = torch_linalg.svd(mat_A) + u_dat, s_dat, v_dat = torch_linalg.svd(mat_A.data) # create new bonds: - # new_bond_dim = len(s) - # internal_bond = Bond(new_bond_dim,bond_type=BondType.OUT) + new_bond_dim = len(s_dat) + internal_bond = Bond( + new_bond_dim, + bond_type=BondType.OUT if A.is_directional_bonds else BondType.NONE, + ) + + # construct U tensor: + u_labels = ["_tmp_L_", "_aux_L_"] + u_bonds = [mat_A.bonds[0], internal_bond] + u = RegularUniTensor(labels=u_labels, bonds=u_bonds, is_diag=False, data=u_dat) + + # construct s tensor: + s_labels = ["_aux_L_", "_aux_R_"] + s_bonds = [internal_bond.redirect(), internal_bond] + s = RegularUniTensor(labels=s_labels, bonds=s_bonds, is_diag=True, data=s_dat) + + # construct v tensor: + v_labels = ["_aux_R_", "_tmp_R_"] + v_bonds = [internal_bond, mat_A.bonds[1]] + v = RegularUniTensor(labels=v_labels, bonds=v_bonds, is_diag=False, data=v_dat) + + if truncate_dim or truncate_tol: + # deal truncate + # TODO + pass + + new_mat = u.contract(s).contract(v) + + return cL.contract(new_mat).contract(cR) diff --git a/src/cytnx_torch/unitensor/base.py b/src/cytnx_torch/unitensor/base.py index 90fa9a9..62dbcd6 100644 --- a/src/cytnx_torch/unitensor/base.py +++ b/src/cytnx_torch/unitensor/base.py @@ -144,6 +144,10 @@ def device(self) -> torch.device: def dtype(self) -> torch.dtype: raise NotImplementedError("not implement for abstract type trait.") + @property + def is_directional_bonds(self) -> bool: + return all([b.directional for b in self.bonds]) + @abstractmethod def _repr_body_diagram(self) -> str: raise NotImplementedError("not implement for abstract type trait.") diff --git a/src/cytnx_torch/unitensor/block_unitensor.py b/src/cytnx_torch/unitensor/block_unitensor.py index 64fe491..697d760 100644 --- a/src/cytnx_torch/unitensor/block_unitensor.py +++ b/src/cytnx_torch/unitensor/block_unitensor.py @@ -96,6 +96,10 @@ def _get_symmetries(self) -> Tuple[Symmetry]: else: return tuple() + @property + def is_directional_bonds(self) -> bool: + return True + def _generate_meta(self) -> np.ndarray[int]: qnindices = [np.arange(len(bd._qnums)) for bd in self.bonds] qn_indices_map = np.meshgrid(*qnindices)