Tensorflow2数据集过大,GPU内存不够
前言:
在我们平时使用tensorflow训练模型时,有时候可能因为数据集太大(比如VOC数据集等等)导致GPU内存不够导致终止,可以自制一个数据生成器来解决此问题。
代码如下:
def train_generator(train_path,train_labels,batch):
over=len(train_path)%batch
while True:
for i in range(0,len(train_path)-over,batch):
train_data=read_img(train_path[i:i+batch])
train_label=train_labels[i:i+batch]
yield (np.array(train_data), np.array(train_label))
方法就是将数据集图片的路径保存到一个列表之中,然后使用while循环在训练时进行不断读取,这里over的作用是防止图片长度不是batch整数倍,导致label的数据长度不等于batch,我在训练时出现了这样的问题,这是我的猜测。然后yield与return的不同是,return是在函数执行到return就会退出函数,而yield则不会退出函数,所以使用yield。
最后一句话也可以改成:
yield ({'input':np.array(train_data)}, {'output':np.array(train_label)})
'input’是你网络第一层的名字.。
'output’是你网络最后一层的名字。
接下来是使用代码:
history=model.fit(train_generator(train_data,train_label,batch=Yolo_param.Batch_size),
batch_size=Yolo_param.Batch_size,
epochs=10,
steps_per_epoch=1024,
validation_steps=32,
callbacks=[callback],
validation_data=train_generator(test_data,test_label,batch=Yolo_param.Batch_size))
steps_per_epoch这个参数是每个epoch的数据大小,如果不给进度就能难显示。
最后就是显存设置:
gpus = tf.config.list_physical_devices('GPU')
if gpus:
try:
tf.config.set_logical_device_configuration(
gpus[0],
[tf.config.LogicalDeviceConfiguration(memory_limit=4096)])
except RuntimeError as e:
print(e)
4096就是你限制显卡内存的大小,可以根据自己显卡实际情况来进行设置
本图文内容来源于网友网络收集整理提供,作为学习参考使用,版权属于原作者。
THE END
二维码