在PyTorch中输入LSTM网络?

PyTorch的这个版本似乎为循环神经网络提供了可变长度输入的PackedSequence 。 但是,我发现正确使用它有点困难。

使用pad_packed_sequence恢复由pack_padded_sequence馈送的RNN层的输出,我们得到了T x B x N张量outputs ,其中T是最大时间步长, B是批量大小, N是隐藏大小。 我发现对于批次中的短序列,后续输出将全部为零。

这是我的问题。

  • 对于需要所有序列的最后输出的单个输出任务,简单outputs[-1]将给出错误的结果,因为该张量对于短序列包含大量的零。 人们需要按序列长度构建索引以获取所有序列的最后一个输出。 有更简单的方法来做到这一点吗?
  • 对于多输出任务(例如seq2seq),通常会添加一个线性层N x O并将批处理输出T x B x O重塑为TB x O并计算真实目标TB的交叉熵损失(通常是语言中的整数模型)。 在这种情况下,批量输出中的这些零是重要的吗?

  • 问题1 - 最后时间步

    这是我用来获取最后时间步的输出的代码。 我不知道是否有更简单的解决方案。 如果是这样,我想知道它。 我遵循这个讨论,并抓住了我的last_timestep方法的相关代码片段。 这是我的前锋。

    class BaselineRNN(nn.Module):
        def __init__(self, **kwargs):
            ...
    
        def last_timestep(self, unpacked, lengths):
            # Index of the last output for each sequence.
            idx = (lengths - 1).view(-1, 1).expand(unpacked.size(0),
                                                   unpacked.size(2)).unsqueeze(1)
            return unpacked.gather(1, idx).squeeze()
    
        def forward(self, x, lengths):
            embs = self.embedding(x)
    
            # pack the batch
            packed = pack_padded_sequence(embs, list(lengths.data),
                                          batch_first=True)
    
            out_packed, (h, c) = self.rnn(packed)
    
            out_unpacked, _ = pad_packed_sequence(out_packed, batch_first=True)
    
            # get the outputs from the last *non-masked* timestep for each sentence
            last_outputs = self.last_timestep(out_unpacked, lengths)
    
            # project to the classes using a linear layer
            logits = self.linear(last_outputs)
    
            return logits
    

    问题2 - 掩蔽交叉熵损失

    是的,默认情况下,零填充时间步(目标)很重要。 但是,掩盖它们非常容易。 你有两个选择,取决于你使用的PyTorch版本。

  • PyTorch 0.2.0:现在pytorch支持直接在CrossEntropyLoss中使用ignore_index参数进行掩码。 例如,在语言建模或seq2seq中,我添加零填充,我掩饰零填充单词(目标)就像这样:

    loss_function = nn.CrossEntropyLoss(ignore_index = 0)

  • PyTorch 0.1.12及更早版本:在较旧版本的PyTorch中,不支持掩膜,因此您必须实施自己的解决方法。 我使用的解决方案是jihunchoi的masked_cross_entropy.py。 您可能也对此讨论感兴趣。

  • 链接地址: http://www.djcxy.com/p/40599.html

    上一篇: input LSTM network in PyTorch?

    下一篇: Three.js: Transition 2 Textures with Zoom and Blend Effects