Skip to content

Latest commit

 

History

History
58 lines (52 loc) · 2.61 KB

深度循环神经网络.md

File metadata and controls

58 lines (52 loc) · 2.61 KB

深度循环神经网络是[[循环神经网络(RNN)]]的变体,具有更多的隐状态层。 [Pasted image 20230626173211.png] 相比于单层的RNN,深度RNN的最终的隐藏状态由上一个时间步的隐藏状态与这一个时间步的上一层隐藏状态生成。 在pytorch中,使用深度RNN非常简单只需要:

rnn_layer = nn.RNN(num_inputs, num_hiddens, num_layers=3)

指定num_layers大于等于1即可使用深度RNN。 使用这个深度RNN层的网络的困惑度图像如下 [Pasted image 20230626174318.png] 可以发现相较于普通的RNN[[循环神经网络(RNN)#^5d833f|RNN的困惑度]],深度RNN可以更快的降低困惑度。

其原理可以看作是在对输入进行一次RNN计算后,使用得到的隐状态作为新的输入多次计算新的隐状态。根据这个原理的代码实现如下:

class RNNLayer(nn.Module):  
  
    def __init__(self, vocab_size, hidden_size, num_layers=1):  
        super().__init__()  
        self.bidirectional = None  
        self.num_layers = num_layers  
        self.vocab_size = vocab_size  
        self.hidden_size = hidden_size  
  
        self.net = nn.Linear(self.vocab_size + self.hidden_size, self.hidden_size)  
        if self.num_layers != 1:  
            self.deep = nn.Linear(self.hidden_size * 2, self.hidden_size)  
  
    def x2h_layer(self, inputs, state):  
        states = []  
        for X in inputs:  
            data = torch.unsqueeze(X, dim=0)  
            data = torch.cat((data, state), dim=2)  
            state = torch.tanh(self.net(data))  
            states.append(state)  
        states = torch.cat(states, dim=0)  
        return states, state  
  
    def forward(self, inputs, H):  
        all_states, last_state = self.x2h_layer(inputs, torch.unsqueeze(H[0], dim=0))  
        # 如果只有1层,直接输出  
        if self.num_layers == 1:  
            return all_states, last_state  
  
        # 如果是深度RNN则将计算得到的隐状态作为输入多次计算  
        last_states = [last_state]  
        for state in H[1:]:  
            states = []  
            state = torch.unsqueeze(state, dim=0)  
            for X in all_states:  
                data = torch.unsqueeze(X, dim=0)  
                data = torch.cat((data, state), dim=2)  
                state = torch.tanh(self.deep(data))  
                states.append(state)  
            last_states.append(states[-1])  
            all_states = torch.cat(states, dim=0)  
        last_states = torch.cat(last_states, dim=0)  
        return all_states, last_states