Skip to content

Commit

Permalink
add descriptions of functions
Browse files Browse the repository at this point in the history
  • Loading branch information
bojunliu0818 committed Jun 21, 2024
1 parent 89f4130 commit a6b4c0f
Showing 1 changed file with 43 additions and 0 deletions.
43 changes: 43 additions & 0 deletions tsdart/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@ def __init__(self, epsilon=1e-6, mode='regularize', symmetrized=False):
self._symmetrized = symmetrized

def forward(self, data):
""" Compute VAMP2 loss.
Parameters
----------
data : tuple
Softmax probabilities of batch of transition pairs.
Returns
-------
VAMP2 loss
"""

assert len(data) == 2

koopman = estimate_koopman_matrix(data[0], data[1], epsilon=self._epsilon, mode=self._mode, symmetrized=self._symmetrized)
Expand Down Expand Up @@ -86,6 +98,21 @@ def __init__(self, feat_dim, n_states, device, proto_update_factor=0.5, scaling_
self.scaling_temperature = scaling_temperature

def forward(self, features, labels):
""" Compute dispersion loss.
Parameters
----------
features : torch.Tensor
Hyperspherical embeddings of a batch of data.
labels : torch.Tensor
Metastable states of a batch of data.
Returns
-------
loss : torch.Tensor
Dispersion loss
"""

prototypes = self.prototypes.to(device=self.device)
for i in range(len(labels)):
Expand Down Expand Up @@ -143,6 +170,22 @@ def __init__(self, n_states, device, scaling_temperature=0.1):
self.scaling_temperature = scaling_temperature

def forward(self, features, labels):
""" Compute dispersion loss.
Parameters
----------
features : torch.Tensor
Hyperspherical embeddings of a batch of data.
labels : torch.Tensor
Metastable states of a batch of data.
Returns
-------
prototypes : torch.Tensor
State center vectors of shape [n_states, feat_dim].
"""

with torch.no_grad():
proxy_labels = torch.arange(0, self.n_states).to(device=self.device)
labels = labels.contiguous().view(-1, 1)
Expand Down

0 comments on commit a6b4c0f

Please sign in to comment.