简单介绍长短期记忆网络 – LSTM

一、引言

1.1 什么是LSTM

首先看看百科的解释。
长短期记忆(英语:Long Short-Term Memory,LSTM)是一种时间循环神经网络(RNN),论文首次发表于1997年。由于独特的设计结构,LSTM适合于处理和预测时间序列中间隔和延迟非常长的重要事件。1

为了更好地理解长短期记忆网络 - LSTM(下文简称LSTM),可以先了解循环神经网络-RNN(下文简称RNN)的相关知识,这里有一些相关的文章。LSTM只是RNN的一个变种,LSTM是为了解决RNN中的梯度消失的问题而提出的。

二、循环神经网络RNN

2.1 为什么需要RNN

人的思想是有记忆延续性。比如当你在阅读这篇文章,你会根据你曾经对每个字的理解来理解这篇文章的字,而不是每次都要思考一个字在这篇文章的语境下到底如何理解(从一个字或词的多种解释来选择一个符合当下语境的解释)。

举个例子:要识别这么一个句子:
The cat, which already ate cakes, () full.2

假设对其中的单词从左到右一个一个地处理,前面已经cat的识别结果是一个单数名词,到后边()里的内容,到底是填were 还是 was,那么就需要根据前边cat的识别结果进行判断。这就是RNN需要做的。

使用神经网络来预测句子中下一个字的解释。传统的神经网络在模型训练好了以后,在输入层给定一个x,通过网络之后就能在输出层得到特定的y。利用这个模型可以通过训练拟合任意函数,但是只能单独的取处理一个个的输入,前一个输出和后一个输出是完全没有关系的

神经网络的结构如下:
Alt
但是,在理解一句话的意思的时候,一个字的意思是跟前面的字相关联的,即前面的输出和后面的输出是有关系的。所以仅仅利用这样的模型是不够的的,为了解决这个问题,有人提出了RNN。
RNN模型构造:
传统RNN模型

RNN神经网络示意图:
RNN模型
蓝色部分的是隐藏层,RNN利用隐藏层将信息向后传递。
我们来看看RNN隐藏层里发生了什么,将上图按时间线展开3

隐藏层

符号 意义
X 一个向量,输入层的值
S 一个向量,隐藏层的值
O 一个向量,输出层的值
U 输入层到隐藏层的权重矩阵
V 隐藏层到输出层的权重矩阵
W 隐藏层上一次的值作为这一次输入的权重

再给出一个更具体的图,给出各层元素的对应关系
具体图
现在看上去就比较清楚了,这个网络在 t 时刻接收到输入

x

t

x_t

xt 之后,隐藏层的值是

s

t

s_t

st ,输出值是

o

t

o_t

ot 。关键一点是,

s

t

s_t

st 的值不仅仅取决于

x

t

x_t

xt ,还取决于

s

t

1

s_{t-1}

st1 我们可以用下面的公式来表示RNN的计算方法:
用公式表示如下:

O

t

=

g

(

V

S

t

)

O_t = g(V·S_t)

Ot=g(VSt)

S

t

=

f

(

U

X

t

+

W

S

t

1

S_t = f(U·X_t + W ·S_{t-1})

St=f(UXt+WSt1
注意:为了简单说明问题,偏置都没有包含在公式里面。

这样,就可以做到的在一个序列中根据前面的输出来影响后面的输出。

三、长短时记忆神经网络LSTM

3.1 为什么需要LSTM

回到我们的例子:
The cat, which already ate …, () full.

这个例子与之前的例子稍微有一些不同,这里的cat 和()之间已经相隔了较长的一段距离,这时候用RNN来处理这样的长期信息就不太合适。

因为RNN在反向传播阶段有梯度消失等问题不能处理长依赖问题,这里的梯度消失是由于RNN在计算过程中使用链式法则

具体来说,RNN使用覆盖的方式来计算状态:

S

t

=

f

(

S

t

1

,

x

t

)

S_t = f(S_{t-1},x_t)

St=f(St1,xt),这类似于复合函数,根据链式求导的法则,复合函数求导:设

f

f

f

g

g

g

x

x

x 的可导函数,则

(

f

g

)

(

x

)

=

f

(

g

(

x

)

)

g

(

x

)

(f circ g)'(x) = f'(g(x))g'(x)

(fg)(x)=f(g(x))g(x),这是一种连乘的方式,如果导数小于或大于1,会发生梯度下降以及梯度爆炸。梯度爆炸可以通过剪枝算法解决,但是梯度消失却没办法解决。

梯度消失可能不太好理解,可以简单理解为RNN中后边输入的数据影响越大,前面的数据的影响小,因此不能处理长期信息。后来,有学者在一篇论文Long Short-Term Memory 4 提出了LSTM,LSTM通过选择性地保留信息,有效地缓解了梯度消失以及梯度下降的问题,可以说LSTM正是为了适合学习长期依赖而产生的。

3.2 LSTM结构分析

回顾一下RNN的模型构造:

RNN模型构造
可以看到,RNN循环网络模型的链式结构非常简单,通常仅含有一个tanh层。

LSTM模型构造:
LSTM
而LSTM的链式结构中,循环单元结构不同,里边有四个神经网络层。

先来解释一下图中符号含义:
符号含义

符号 含义
黄色矩形 神经网络层
粉色圆 结点操作,比如向量相加
箭头 从一个结点的输出到另外的结点的输入
箭头合并 链接
箭头分叉 内容复制后副本流向不同的位置

LSTM结构(图右)和普通RNN的主要输入输出区别如下所示:
LSTM对比RNN
相比RNN只有一个传递状态

h

t

h^t

ht , LSTM有两个传输状态,一个

c

t

c^t

ct (cell state), 和一个

h

t

h^t

ht (hidden state)。(RNN中的

h

t

h^t

ht 对应LSTM中的

C

t

C^t

Ct

3.3 LSTM背后的核心思想

LSTM的核心思想,LSTM的关键是细胞状态(cell state),即下图中上边的水平线。cell state像是一条传送带,它贯穿整条链,其中只发生一些小的线性作用。信息流过这条线而不改变是非常容易的。5 改变cell state需要三个门的相互配合。

如下图所示:
细胞状态
LSTM删除或添加信息到cell state,是由被称为门的结构控制的。LSTM中有三个门,“遗忘门” “输入门” 以及“输出门”,用来保护和更新cell的状态。
门是筛选信息的方法,由一个sigmoid网络层和一个点乘操作组成。
如下图:
门
sigmoid层作为激活函数,将输出控制在(0,1)区间内,Sigmoid的函数图形如下:
Sigmoid
可以看到,绝大多数的值都是接近0或者接近1的。利用这一个性质,0 表示不允许任何通过,1 表示允许一切通过。

3.4 LSTM的运行机制

第一步,需要决定从cell state中丢弃什么样的信息,这个由“遗忘门”的sigmoid层决定。根据输入

h

t

1

h_{t-1}

ht1

x

t

x_t

xt,得到的输出是0和1之间的数。0 代表“完全保留这个值”,1代表“完全丢弃这个值”。

回到开始的例子,原来的主语是"cat",之后遇到了一个新的主语"cats"。这时需要把之前的"cat"给忘掉,以便确定接下来是要使用"were",而不是"was"。如下图:
遗忘门
第二步,需要决定在cell state里存储什么样的信息。这一步划分为两个部分,一是称为“输入门”的sigmoid层决定哪些数据需要更新。然后,tanh层创建一个新的候选值向量

C

~

t

widetilde{C}_t

C

t,这些值能加入state中。第二部分,需要将这两个部分合并以实现对state的更新。

在例子中,这里对应于把新的"cats"加入到"cell state"中,以替代需要遗忘的"cat"。如下图:
input gate
在决定好需要遗忘的以及需要加入的记忆之后,就可以把旧的cell state

C

t

1

C_{t-1}

Ct1更新到新的cell state

C

t

C_t

Ct。 这一步中,把旧的state

C

t

1

C_{t-1}

Ct1

f

t

f_t

ft 相乘,遗忘先前决定遗忘的东西,之后加上新的记忆信息

i

t

C

~

t

i_t ast widetilde{C}_t

itC

t。这里为了体现对状态值的更新度是有限制的,可以把

i

t

i_t

it当成一个权重。如下图:
更新
最后,需要决定输出。这个输出将会基于cell state ,这是一个过滤后的值。首先,使用“输出门”的sigmoid层决定输出cell state的哪些部分的。然后,将cell state放入tanh(将数值限制在-1到1),最后将结果与sigmoid门的输出相乘,这样就可以只输出需要的部分。如下图:
输出门

3.5 LSTM如何避免梯度下降

上边提到了RNN中的梯度下降以及梯度爆炸问题,是是因为在计算过程中使用链式法则,使用了乘积。而在LSTM中,状态是通过累加的方式来计算,

S

t

=

τ

=

1

t

Δ

S

τ

S_t = sum_{tau =1}^t Delta S_{tau}

St=τ=1tΔSτ。这样的计算,就不是复合函数的形式,它的导数也就不是乘积的形式,就不会发生梯度消失的情况。

四、入门例子

下面给出LSTM的一个入门实例-根据前9年的数据预测后3年的客流6,感谢原作者的代码,完整的代码见GithubYonv1943。这里简单说一下这个代码实例的结果,需要了解更加详细的代码细节可以看看原作者的原文详解。

考虑有一组某机场1949年~1960年12年共144个月的客流量数据。使用这个数据中的前9年的客流量来预测后3年的客流量,再和实际的数据进行比对,可以看出LSTM的对这类具有时序关系的拟合效果。

结果图:
结果图

  • 数据:机场1949~1960年12年共144个月的客流量数据。数据具有三个维度[客运量,年份,月份]。其中前75%(前9年)的数据作为训练集,后25%(后3年)的数据作为测试集。
  • 纵坐标:标准化处理:变量值与平均数的差除以标准差,给出数值的相对位置。横坐标为月数。
  • 图解释:竖直黑线左边是训练集(前9年)。右边(后3年)红色的是预测数值,蓝色的是实际数值。

可以看到在这个LSTM对这个数据集的拟合效果是比较好的,在这样的实际场景中,可以利用LSTM这样的工具来对客流量做一个预测,以便对客运高峰等情况做好预备方案。

五、总结

  • RNN的计算中存在多个偏导数连乘,导致梯度消失或梯度爆炸,难以处理长依赖的信息。
  • LSTM通过三个选择性地保留信息,可以选择最近的信息或者很久之前的信息。
  • LSTM更新cell state是采用了线性求和的计算,因此不会出现梯度消失问题,可以处理长期依赖的信息。

六、参考资料


  1. 长短期记忆 ↩︎

  2. 吴恩达深度学习课程 ↩︎

  3. 一文搞懂RNN(循环神经网络)基础篇 ↩︎

  4. Long Short-Term Memory ↩︎

  5. Understanding LSTM Networks ↩︎

  6. LSTM入门例子:根据前9年的数据预测后3年的客流(PyTorch实现) ↩︎

本图文内容来源于网友网络收集整理提供,作为学习参考使用,版权属于原作者。
THE END
分享
二维码
< <上一篇
下一篇>>