[NLP] Description and implementation of LSTM neural network.

0. Statement 🏫

Today I intend to move from an intuitive understanding of LSTM to its implementation with PyTorch, and I believe readers can get substantial help through this blog.

1. What is the advantage of LSTM over RNN? 🤔

RNN can remember fewer words related to the context than LSTM, so RNN is often called a short-term neural network, while LSTM is a Long short-term neural network, where long means that it can remember more context than RNN which represents the short-term neural network. So you don’t have to be confused by the name(Long Short-term). 🤗

2. Differences in terminology between RNN and LSTM for picture representation. 🧐

The RNN and LSTM pictures are from here.
在这里插入图片描述
在这里插入图片描述

  1. The output in RNN is ot, while the output of LSTM is ht.
  2. The contextual information(Memory) in RNN is stored in ht(above), while the contextual information(Memory) in LSTM is stored in ct.

3. The composition and intuitive understanding of LSTM. 🤠

When I first saw the architecture diagram of LSTM, I noticed a schematic of sigmoid and multiplication together.请添加图片描述
The sigmoid takes values from 0 to 1, which means that certain numbers are multiplied by 0 or 1, which means that the significance of each of these structures is to decide whether to use the data from the previous time step.This structure is named “gate”.

3.1. LSTM: Forget gate.请添加图片描述

When the value of “ft” is 1, it means I want to use the data remembered in the previous time step, and when the value is 0, it means I want to forget it.

3.2. LSTM: Input gate and Cell State.请添加图片描述

When the value of “it” is 1, it means that I want to use the data entered at the current time (“C wave t”), which is calculated by “tanh”, “Wc” and “bc” based on the data entered at the current time “xt”.请添加图片描述
⚠️: In summary, the forgetting gate determines whether the information remembered at the previous time step is useful, and the inputting gate determines whether the information to be remembered at the current time step is important.

3.3. LSTM: Output.

请添加图片描述
In summary, the output “ht” of the LSTM is the element-wise product of the “tanh” operation of “ct” and the “output gate”.

3.4. Summary

请添加图片描述
The forgetting gate, input gate, and output gate require the sigmoid and the input x(t) at the current time and the context information h(t-1) from the previous time step. To input the c-wave into the cell requires tanh and x(t), h(t-1). The c(t) and h(t) to be passed to the next time step are easy to understand intuitively based on the above diagram. 😇

请添加图片描述

3.5. Understanding the role of “gates” intuitively.请添加图片描述

3.6. Why can LSTM mitigate gradient vanishing?

请添加图片描述
Because when we solve for the gradient, we avoid the appearance of the kth power of “Whh”, and then because there are three “gates”, we need to expand three equations when solving for the gradient, and the three gates constrain each other so that the probability of a large or small value is much smaller.

4. How to implement LSTM with PyTorch? 😎

请添加图片描述
The PyTorch documentation on LSTM is from here.

import torch
from torch import nn

lstm = nn.LSTM(input_size=100, hidden_size=20, num_layers=4)
print(lstm)
x = torch.randn(10, 3, 100)
out, (h, c) = lstm(x)
print(out[-1])
print(h[-1])
print(out[-1].shape)
print(h[-1].shape)
print(out[-1]==h[-1])
print(f"out.shape:{out.shape}nh.shape:{h.shape}nc.shape:{c.shape}")

⚠️: The out of the LSTM is the value of the last time step h of all time steps h. 🤠
请添加图片描述

import torch
from torch import nn

lstm = nn.LSTM(input_size=100, hidden_size=20, num_layers=4)
print(lstm)
x = torch.randn(10, 3, 100)
out, (h, c) = lstm(x)
print(out[-1])
print(h[-1])
print(out[-1].shape)
print(h[-1].shape)
print(out[-1]==h[-1])
print(f"out.shape:{out.shape}nh.shape:{h.shape}nc.shape:{c.shape}")

print('one layer lstm')
cell = nn.LSTMCell(input_size=100, hidden_size=20)
h = torch.zeros(3, 20)
c = torch.zeros(3, 20)
for xt in x:
    h, c = cell(xt, [h, c])
print(f"h.shape:{h.shape}")
print(f"c.shape:{c.shape}")

print("two layer lstm")
cell1 = nn.LSTMCell(input_size=100, hidden_size=30)
cell2 = nn.LSTMCell(input_size=30, hidden_size=20)
h1 = torch.zeros(3, 30)
c1 = torch.zeros(3, 30)
h2 = torch.zeros(3, 20)
c2 = torch.zeros(3, 20)
for xt in x:
    h1, c1 = cell1(xt, [h1, c1])
    h2, c2 = cell2(h1, [h2, c2])
print(f"h2.shape:{h2.shape}")
print(f"c2.shape:{c2.shape}")

请添加图片描述

Finally 🤩

Thank you for the current age of knowledge sharing and the people willing to share it, thank you! The knowledge on this blog is what I’ve learned on this site, thanks for the support! 😇

本图文内容来源于网友网络收集整理提供,作为学习参考使用,版权属于原作者。
THE END
分享
二维码
< <上一篇
下一篇>>