# 如何估算transformer模型的显存大小

`````` total_memory = memory_modal + memory_activations + memory_gradients
``````

`````` total_memory = memory_modal + 2 * memory_activations
``````

## 估算模型的内存

`````` multi_headed_attention --> layer_normalization --> MLP -->layer_normalization
``````

`````` memory_modal = memory of multi_headed_attention + memory of MLP
= memory of value  + memory of key + memory of query + memory of MLP
``````

`````` memory_modal = 4*n_tr_blocks*square_of(n_head * dim)
``````

## 估算中间变量的内存

`````` multi_headed_attention = softmax(query * key * sequence_length) * value
``````

k,q,v的维度是：

`````` [batch_size, n_head, sequence_length, dim]
``````

`````` [batch_size, n_head, sequence_length, sequence_length]
``````

`````` memory_softmax  = batch_size * n_head * square_of(sequence_length)
``````

q* k * sequence_length操作乘以value的形状为[batch_size, n_head, sequence_length, dim]。MLP也有相同的维度：

`````` memory of MLP  = batch_size * n_head * sequence_length * dim
memory of value = batch_size * n_head * sequence_length * dim
``````

`````` memory_activations = memory_softmax + memory_value + memory_MLP
= batch_size * n_head * square_of(sequence_length)
+ batch_size * n_head * sequence_length * dim
+ batch_size * n_head * sequence_length * dim
= batch_size * n_head * sequence_length * (sequence_length + 2*dim)
``````

`````` n_tr_blocks * (batch_size * n_head * sequence_length * (sequence_length + 2*dim))
``````

## 整合在一起

`````` total_memory = memory_modal + 2 * memory_activations
``````

`````` 4*n_tr_blocks*square_of(n_head * dim)
``````

`````` n_tr_blocks * (batch_size * n_head * sequence_length * (sequence_length + 2*dim))
``````

`````` R = n_tr_blocks = transformer层堆叠的数量
D = dim = 注意力头的维度
B = batch_size = 批大小
S = sequence_length =输入序列的长度

memory modal = 4 * R * N^2 * D^2

memory activations = RBNS(S + 2D)
``````

`````` M = (4 * R * N^2 * D^2) + RBNS(S + 2D)
``````

`````` M = (4 * R * N^2 * D^2) + RBNS(S) = 4*R*N^2*D^2 + RBNS^2
``````

`````` 总内存 = ((4 * R * N^2 * D^2) + RBNS(S + 2D)) * float64(以字节为单位)

``````

https://avoid.overfit.cn/post/6724eec842b740d482f73386b1b8b012

THE END