diff --git a/model/decoder.py b/model/decoder.py index 7307435..99df84f 100644 --- a/model/decoder.py +++ b/model/decoder.py @@ -168,8 +168,8 @@ def __init__(self, def forward_single_frame(self, x, h): r, z = self.ih(torch.cat([x, h], dim=1)).split(self.channels, dim=1) c = self.hh(torch.cat([x, r * h], dim=1)) - h = (1 - z) * h + z * c - return h, h + h = (1 - z) * c + z * h + return c, h def forward_time_series(self, x, h): o = []