计算复杂度

提示:计算复杂度的简单理解(第一次写博客)

计算复杂度

计算复杂度

我们以Vicinity Vision Transformer论文中的图为例。
在这里插入图片描述图注:标准自注意力(左)和线性化自注意力(右)的图示。

N

N

N表示输入图像的

p

a

t

c

h

patch

patch数,

d

d

d是特征维度。使

N

d

Ngg d

Nd,线性化自注意力的计算复杂度相对于输入长度线性增长,而标准自注意力的计算复杂度是二次的。

从输入到输出可以这样计算:

(

N

×

d

)

×

(

d

×

N

)

=

N

×

N

×

(

d

×

N

)

×

(

N

×

N

)

=

d

×

N

(Ntimes d)times (dtimes N)=Ntimes Ntimes (dtimes N)times (Ntimes N)=dtimes N

(N×d)×(d×N)=N×N×(d×N)×(N×N)=d×N

(

d

×

N

)

×

(

N

×

d

)

=

d

×

d

×

(

d

×

d

)

×

(

d

×

N

)

=

d

×

N

(dtimes N)times (Ntimes d)=dtimes dtimes (dtimes d)times (dtimes N)=dtimes N

(d×N)×(N×d)=d×d×(d×d)×(d×N)=d×N

关于计算复杂度:其实可以认为是乘法次数。我们给出最直观的解释。

假设有两个矩阵做乘法,如下:

[

1

2

3

4

5

6

]

×

[

1

2

3

4

5

6

]

=

[

1

2

3

4

5

6

7

8

9

]

left[begin{matrix}1&2\3&4\5&6\end{matrix}right]timesleft[begin{matrix}1&2&3\4&5&6\end{matrix}right]=left[begin{matrix}1&2&3\4&5&6\7&8&9\end{matrix}right]

135246×[142536]=147258369,其中行数为

N

N

N,列数为

d

d

d

(

3

×

2

)

×

(

2

×

3

)

=

(

3

×

3

)

×

(

N

×

d

)

×

(

d

×

N

)

=

(

N

×

N

)

(3times 2)times (2times 3)=(3times 3)times (Ntimes d)times (dtimes N)=(Ntimes N)

(3×2)×(2×3)=(3×3)×(N×d)×(d×N)=(N×N)

3

×

3

3times 3

3×3矩阵第一个元素涉及的乘法次数:

1

×

1

+

2

×

4

=

9

1times 1+2times 4=9

1×1+2×4=9 共2次乘法;其它元素是一样的。最后可以得到

2

×

9

=

2

×

3

×

3

=

d

×

N

×

N

=

N

2

d

2times 9=2times 3times 3=dtimes Ntimes N=N^{2}d

2×9=2×3×3=d×N×N=N2d.

假设又有两个矩阵做乘法,如下:

[

1

2

3

4

5

6

]

×

[

1

2

3

4

5

6

]

=

[

1

2

3

4

]

left[begin{matrix}1&2&3\4&5&6\end{matrix}right]timesleft[begin{matrix}1&2\3&4\5&6\end{matrix}right]=left[begin{matrix}1&2\3&4\end{matrix}right]

[142536]×135246=[1324],其中行数为

d

d

d,列数为

N

N

N

(

2

×

3

)

×

(

3

×

2

)

=

(

2

×

2

)

×

(

d

×

N

)

×

(

N

×

d

)

=

(

d

×

d

)

(2times 3)times (3times 2)=(2times 2)times (dtimes N)times (Ntimes d)=(dtimes d)

(2×3)×(3×2)=(2×2)×(d×N)×(N×d)=(d×d)

2

×

2

2times 2

2×2矩阵第一个元素涉及的乘法次数:

1

×

1

+

2

×

3

+

2

×

5

=

17

1times 1+2times 3+2times 5=17

1×1+2×3+2×5=17 共3次乘法;其它元素是一样的。最后可以得到

3

×

4

=

3

×

2

×

2

=

N

×

d

×

d

=

N

d

2

3times 4=3times 2times 2=Ntimes dtimes d=Nd^2

3×4=3×2×2=N×d×d=Nd2 .

为什么会有这种情况呢?以第二个例子为例,可以观察到,所得结果的一个元素的乘法数量和消失的维度大小有关,也就是列数

N

N

N,或者说,列数

N

N

N就是所得结果一个元素的乘法次数。那么多少个元素呢?元素个数就要看你是如何进行的乘法操作,其实就是矩阵大小。比如

(

2

×

3

)

×

(

3

×

2

)

=

(

2

×

2

)

×

(

d

×

N

)

×

(

N

×

d

)

=

(

d

×

d

)

(2times 3)times (3times 2)=(2times 2)times (dtimes N)times (Ntimes d)=(dtimes d)

(2×3)×(3×2)=(2×2)×(d×N)×(N×d)=(d×d),那么就是

d

2

d^2

d2个元素,最后乘法次数就是

N

d

2

Nd^2

Nd2

乘法次数=消失的维度 × 所得矩阵大小

那么计算复杂度呢?我们不要去管

O

(

)

O(bullet)

O()具体代表什么,这不重要。
以第一个图为例,乘法次数1:

(

N

×

d

)

×

(

d

×

N

)

=

N

2

d

(Ntimes d)times (dtimes N)=N^{2}d

(N×d)×(d×N)=N2d;乘法次数

2

2

2

(

N

×

d

)

×

(

d

×

N

)

=

N

2

d

(Ntimes d)times (dtimes N)=N^{2}d

(N×d)×(d×N)=N2d

O

(

N

2

d

+

N

2

d

)

=

O

(

N

2

)

O(N^{2}d+N^{2}d)=O(N^2)

O(N2d+N2d)=O(N2)。因为

N

d

Ngg d

Nd,所以

d

d

d(还有常数

2

2

2)被省略了,即

O

(

N

2

)

O(N^2)

O(N2)
以第二个图为例,乘法次数1:

(

d

×

N

)

×

(

N

×

d

)

=

N

d

2

(dtimes N)times (Ntimes d)=Nd^2

(d×N)×(N×d)=Nd2;乘法次数2:

(

d

×

d

)

×

(

d

×

N

)

=

N

d

2

(dtimes d)times (dtimes N)=Nd^2

(d×d)×(d×N)=Nd2

O

(

N

d

2

+

N

d

2

)

=

O

(

N

)

O(Nd^2+Nd^2)=O(N)

O(Nd2+Nd2)=O(N)。因为

N

d

Ngg d

Nd,所以

d

d

d(还有常数2)被省略了,即

O

(

N

)

O(N)

O(N)

事实告诉我们,我们两个的结果一样,但是我们可以通过控制中间过程减少计算复杂度。

本图文内容来源于网友网络收集整理提供,作为学习参考使用,版权属于原作者。
THE END
分享
二维码
< <上一篇
下一篇>>