diff --git a/GRUD.py b/GRUD.py index 263ecc8..31c5f80 100644 --- a/GRUD.py +++ b/GRUD.py @@ -123,7 +123,7 @@ def step(self, x, h, mask, delta, x_mean): combined = torch.cat((x, h, mask), 1) z = F.sigmoid(self.zl(combined)) r = F.sigmoid(self.rl(combined)) - combined_r = torch.cat((x, r * Hidden_State, mask), 1) + combined_r = torch.cat((x, r * Hidden_State, mask), 1) h_tilde = F.tanh(self.hl(combined_r)) h = (1 - z) * h + z * h_tilde