# ConvLSTM时空预测实战代码详解

## 二、数据集的选取和下载

``````import numpy as np
from tensorflow import keras
fpath = keras.utils.get_file(
"moving_mnist.npy",
"http://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy",
)
print(dataset.shape)
``````

## 三、数据集预处理与数据集划分

``````# 转换数据集的seq和samples维度，便于输入我们的模型
dataset = np.swapaxes(dataset, 0, 1)
# 10000个样本太多，我们只选取1000个
dataset = dataset[:1000, ...]
# 我们此时是二维灰度图片，因此要增加一维，代表单通道，如果是彩色，则为3
dataset = np.expand_dims(dataset, axis=-1)
print(dataset.shape)
``````

``````indexes = np.arange(dataset.shape[0])
np.random.shuffle(indexes)  # 打乱索引顺序
# 训练集：测试集=9：1
train_index = indexes[: int(0.9 * dataset.shape[0])]
val_index = indexes[int(0.9 * dataset.shape[0]):]
train_dataset = dataset[train_index]
val_dataset = dataset[val_index]
print(train_dataset.shape)
print(val_dataset.shape)
``````

``````# 归一化,除255就是把3基色都调到0-1区间，得到绝对色彩信息
train_dataset = train_dataset / 255
val_dataset = val_dataset / 255
``````

``````# 分离x和y,注意，此时的y是下一帧图像，既最后一个片子，我们用前20帧预测后20帧，既序号0-19
def create_shifted_frames(data):
x = data[:, 0: data.shape[1] - 1, :, :]
y = data[:, 1: data.shape[1], :, :]
return x, y
x_train, y_train = create_shifted_frames(train_dataset)
x_val, y_val = create_shifted_frames(val_dataset)
``````

## 四、模型构建

``````# 模型构建核心代码，这里我们修改超参数与keras官方超参数一致
model = Sequential([
keras.layers.ConvLSTM2D(filters=64, kernel_size=(5, 5),
input_shape=(None, 64, 64, 1),
keras.layers.BatchNormalization(),
keras.layers.ConvLSTM2D(filters=64, kernel_size=(3, 3),
keras.layers.BatchNormalization(),
keras.layers.ConvLSTM2D(filters=64, kernel_size=(1, 1),
keras.layers.Conv3D(filters=1, kernel_size=(3, 3, 3),
activation='sigmoid',
])
model.summary()
``````

## 五、模型训练

``````# 定义回调
early_stopping = keras.callbacks.EarlyStopping(monitor="val_loss", patience=10)
reduce_lr = keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=5)

# 设置训练参数
epochs = 50
batch_size = 2

# 拟合模型.
model.fit(
x_train,
y_train,
batch_size=batch_size,
epochs=epochs,
validation_data=(x_val, y_val),
callbacks=[early_stopping, reduce_lr],
)
model.save('model.h5')
``````

THE END

)">