Skip to content

Commit

Permalink
remove time_factor in odetpp torch version
Browse files Browse the repository at this point in the history
  • Loading branch information
alilevy committed Mar 31, 2024
1 parent bb57f5e commit 59ef88f
Showing 1 changed file with 0 additions and 2 deletions.
2 changes: 0 additions & 2 deletions easy_tpp/model/torch_model/torch_ode_tpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,6 @@ def __init__(self, model_config):
hidden_size=[self.hidden_size])

self.ode_num_sample_per_step = model_config.model_specs['ode_num_sample_per_step']
self.time_factor = model_config.model_specs['time_factor']

self.solver = rk4_step_method

Expand Down Expand Up @@ -193,7 +192,6 @@ def forward(self, time_delta_seqs, type_seqs, **kwargs):
last_state = torch.zeros_like(type_seq_emb[:, 0, :], device=self.device)
for type_emb, dt in zip(torch.unbind(type_seq_emb, dim=-2),
torch.unbind(time_delta_seqs_, dim=-2)):
dt = dt / self.time_factor
last_state = self.layer_neural_ode(last_state + type_emb, dt)
total_state_at_event_minus.append(last_state)
total_state_at_event_plus.append(last_state + type_emb)
Expand Down

0 comments on commit 59ef88f

Please sign in to comment.