Tensorflow2数据集过大,GPU内存不够

前言:
在我们平时使用tensorflow训练模型时,有时候可能因为数据集太大(比如VOC数据集等等)导致GPU内存不够导致终止,可以自制一个数据生成器来解决此问题。

代码如下:

def train_generator(train_path,train_labels,batch):
    over=len(train_path)%batch
    while True:
        for i in range(0,len(train_path)-over,batch):
            train_data=read_img(train_path[i:i+batch])
            train_label=train_labels[i:i+batch]
            yield (np.array(train_data), np.array(train_label))

方法就是将数据集图片的路径保存到一个列表之中,然后使用while循环在训练时进行不断读取,这里over的作用是防止图片长度不是batch整数倍,导致label的数据长度不等于batch,我在训练时出现了这样的问题,这是我的猜测。然后yield与return的不同是,return是在函数执行到return就会退出函数,而yield则不会退出函数,所以使用yield
最后一句话也可以改成:

yield ({'input':np.array(train_data)}, {'output':np.array(train_label)})

'input’是你网络第一层的名字.。
'output’是你网络最后一层的名字。

接下来是使用代码:

history=model.fit(train_generator(train_data,train_label,batch=Yolo_param.Batch_size),
          batch_size=Yolo_param.Batch_size,
          epochs=10,
          steps_per_epoch=1024,
          validation_steps=32,
          callbacks=[callback],
          validation_data=train_generator(test_data,test_label,batch=Yolo_param.Batch_size))

steps_per_epoch这个参数是每个epoch的数据大小,如果不给进度就能难显示。

最后就是显存设置:

gpus = tf.config.list_physical_devices('GPU')
if gpus:
  try:
    tf.config.set_logical_device_configuration(
        gpus[0],
        [tf.config.LogicalDeviceConfiguration(memory_limit=4096)])
  except RuntimeError as e:
    print(e)

4096就是你限制显卡内存的大小,可以根据自己显卡实际情况来进行设置

本图文内容来源于网友网络收集整理提供,作为学习参考使用,版权属于原作者。
THE END
分享
二维码
< <上一篇

)">
下一篇>>