基于yoloV5-v6分类多检测头模型修改(多国车牌检测)

一 修改背景

基于yoloV5系列越来越强大,适用面越来越广泛,主要是由于训练简单,模型适配性好,推理速度快等优点,yoloV5系列适用非常广泛。
但随着越发强大的系统,导致模型堆叠问题越发严重,输入相同的图片检测的内容不同,或者输入不同的图片检测类似的内容。这些都需要使用多个模型来完成,导致设备负载大,推理堆叠。实际运用场景可能有:多国车牌,使用不同的国家字符,需要用多个对应国家的模型来完成车牌文字检测识别,又比如:ADAS系统,输入相同的图像,不仅仅要检测前方的车辆类型,交通标志,车道线(YOLOP)等等。诸如需求比比皆是,故此在官方的模型上使其共用backbone,使用不同的检测头来完成相对于效果。

二 修改思路

共用backbone,使用多个检测头来分别检测不同国家的车牌。
比如我们定义第一个头是:大陆车牌,第二个头是:港澳车牌,第三个头是:老挝车牌等等。
重点 : 我们创建了多头,但是每次我们输入的图片只是其中一个头的,如果每个头都运行,会很浪费时间,所以我们只运行对应的一个头,这里就需要后期建立一个多头的列表,选择我们数据输入的对应头就OK了。

# 网络结构如下:
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license

# Parameters

nc1 : 20
nc2 : 30
nc3 : 40

nc: [nc1,nc2,nc3]  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple
anchors:
  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32

# YOLOv5 v6.0 backbone
backbone:
  # [from, number, module, args]
  [[9, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
   [-1, 3, C3, [128]],
   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
   [-1, 6, C3, [256]],
   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
   [-1, 9, C3, [512]],
   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
   [-1, 3, C3, [1024]],
   [-1, 1, SPPF, [1024, 5]],  # 9
  ]

# YOLOv5 v6.0 head
head1:
  [[-1, 1, Conv, [512, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
   [-1, 3, C3, [512, False]],  # 13

   [-1, 1, Conv, [256, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
   [-1, 3, C3, [256, False]],  # 17 (P3/8-small)

   [-1, 1, Conv, [256, 3, 2]],
   [[-1, 14], 1, Concat, [1]],  # cat head P4
   [-1, 3, C3, [512, False]],  # 20 (P4/16-medium)

   [-1, 1, Conv, [512, 3, 2]],
   [[-1, 10], 1, Concat, [1]],  # cat head P5
   [-1, 3, C3, [1024, False]],  # 23 (P5/32-large)

   [[17, 20, 23], 1, Detect, [nc1, anchors]],  # Detect(P3, P4, P5)   24

  ]

head2:
  [[9, 1, Conv, [512, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
   [-1, 3, C3, [512, False]],  # 28

   [-1, 1, Conv, [256, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
   [-1, 3, C3, [256, False]],  # 32 (P3/8-small)

   [-1, 1, Conv, [256, 3, 2]],
   [[-1, 29], 1, Concat, [1]],  # cat head P4
   [-1, 3, C3, [512, False]],  # 35 (P4/16-medium)

   [-1, 1, Conv, [512, 3, 2]],
   [[-1, 25], 1, Concat, [1]],  # cat head P5
   [-1, 3, C3, [1024, False]],  # 38 (P5/32-large)

   [[32, 35, 38], 1, Detect, [nc2, anchors]],  # Detect(P3, P4, P5)  39
  ]

head3:
  [ [ 9, 1, Conv, [ 512, 1, 1 ] ],
    [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ],
    [ [ -1, 6 ], 1, Concat, [ 1 ] ],  # cat backbone P4
    [ -1, 3, C3, [ 512, False ] ],  # 43

    [ -1, 1, Conv, [ 256, 1, 1 ] ],
    [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ],
    [ [ -1, 4 ], 1, Concat, [ 1 ] ],  # cat backbone P3
    [ -1, 3, C3, [ 256, False ] ],  # 47 (P3/8-small)

    [ -1, 1, Conv, [ 256, 3, 2 ] ],
    [ [ -1, 44 ], 1, Concat, [ 1 ] ],  # cat head P4
    [ -1, 3, C3, [ 512, False ] ],  # 50 (P4/16-medium)

    [ -1, 1, Conv, [ 512, 3, 2 ] ],
    [ [ -1, 40 ], 1, Concat, [ 1 ] ],  # cat head P5
    [ -1, 3, C3, [ 1024, False ] ],  # 53 (P5/32-large)

    [ [ 47, 50, 53 ], 1, Detect, [ nc3, anchors ] ],  # Detect(P3, P4, P5)
  ]

注意: 每一层的连接方式需要修正,需要看是层的索引值。
在这里插入图片描述

# 网络各层的参数值如下:


                 from  n    params  module                                  arguments                     
  0                -1  1      3520  models.common.Conv                      [3, 32, 6, 2, 2]              
  1                -1  1     18560  models.common.Conv                      [32, 64, 3, 2]                
  2                -1  1     18816  models.common.C3                        [64, 64, 1]                   
  3                -1  1     73984  models.common.Conv                      [64, 128, 3, 2]               
  4                -1  2    115712  models.common.C3                        [128, 128, 2]                 
  5                -1  1    295424  models.common.Conv                      [128, 256, 3, 2]              
  6                -1  3    625152  models.common.C3                        [256, 256, 3]                 
  7                -1  1   1180672  models.common.Conv                      [256, 512, 3, 2]              
  8                -1  1   1182720  models.common.C3                        [512, 512, 1]                 
  9                -1  1    656896  models.common.SPPF                      [512, 512, 5]                 
 10                -1  1    131584  models.common.Conv                      [512, 256, 1, 1]              
 11                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']          
 12           [-1, 6]  1         0  models.common.Concat                    [1]                           
 13                -1  1    361984  models.common.C3                        [512, 256, 1, False]          
 14                -1  1     33024  models.common.Conv                      [256, 128, 1, 1]              
 15                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']          
 16           [-1, 4]  1         0  models.common.Concat                    [1]                           
 17                -1  1     90880  models.common.C3                        [256, 128, 1, False]          
 18                -1  1    147712  models.common.Conv                      [128, 128, 3, 2]              
 19          [-1, 14]  1         0  models.common.Concat                    [1]                           
 20                -1  1    296448  models.common.C3                        [256, 256, 1, False]          
 21                -1  1    590336  models.common.Conv                      [256, 256, 3, 2]              
 22          [-1, 10]  1         0  models.common.Concat                    [1]                           
 23                -1  1   1182720  models.common.C3                        [512, 512, 1, False]          
 24      [17, 20, 23]  1     67425  Detect                                  [20, [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]], [128, 256, 512]]
 25                 9  1    131584  models.common.Conv                      [512, 256, 1, 1]              
 26                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']          
 27           [-1, 6]  1         0  models.common.Concat                    [1]                           
 28                -1  1    361984  models.common.C3                        [512, 256, 1, False]          
 29                -1  1     33024  models.common.Conv                      [256, 128, 1, 1]              
 30                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']          
 31           [-1, 4]  1         0  models.common.Concat                    [1]                           
 32                -1  1     90880  models.common.C3                        [256, 128, 1, False]          
 33                -1  1    147712  models.common.Conv                      [128, 128, 3, 2]              
 34          [-1, 29]  1         0  models.common.Concat                    [1]                           
 35                -1  1    296448  models.common.C3                        [256, 256, 1, False]          
 36                -1  1    590336  models.common.Conv                      [256, 256, 3, 2]              
 37          [-1, 25]  1         0  models.common.Concat                    [1]                           
 38                -1  1   1182720  models.common.C3                        [512, 512, 1, False]          
 39      [32, 35, 38]  1     94395  Detect                                  [30, [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]], [128, 256, 512]]
 40                 9  1    131584  models.common.Conv                      [512, 256, 1, 1]              
 41                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']          
 42           [-1, 6]  1         0  models.common.Concat                    [1]                           
 43                -1  1    361984  models.common.C3                        [512, 256, 1, False]          
 44                -1  1     33024  models.common.Conv                      [256, 128, 1, 1]              
 45                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']          
 46           [-1, 4]  1         0  models.common.Concat                    [1]                           
 47                -1  1     90880  models.common.C3                        [256, 128, 1, False]          
 48                -1  1    147712  models.common.Conv                      [128, 128, 3, 2]              
 49          [-1, 44]  1         0  models.common.Concat                    [1]                           
 50                -1  1    296448  models.common.C3                        [256, 256, 1, False]          
 51                -1  1    590336  models.common.Conv                      [256, 256, 3, 2]              
 52          [-1, 40]  1         0  models.common.Concat                    [1]                           
 53                -1  1   1182720  models.common.C3                        [512, 512, 1, False]          
 54      [47, 50, 53]  1    121365  Detect                                  [40, [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]], [128, 256, 512]]
self.headi 24
self.headi 39
self.headi 54
Model Summary: 508 layers, 12958705 parameters, 12958705 gradients

三 模型修改

1 网络结构修改

修改结构的时候主要需要注意,这里是多个头,我们创建一个多头列表,输入对应头数据来完成模型训练即可。

  • 修改模型初始化
    在这里插入图片描述
    主要是需要记录头的数量,骨干网络的层数,不同头的层数(列表)
    头的数量可以根据索引来进行输入数据,训练对应的头,推理的时候也是对应输入头的索引即可,
    骨干是共用的,所以记录数量,后期好用于网络结构拼接。
    不同的头可以使用不同的层数,针对难度大的数据可以使用较多的卷积,默认是15层。

  • 初始化detect层
    detect层的m.stride值,默认是[8,16,32]。由于都有不同的头,anchor对应的下采样比例可能出现不一样,可能需要使用不同的anchor来进行初始化,所以这里每个头的m.stride 都需要进行初始化。用一个循环完成。
    在这里插入图片描述

  • 网络拼接
    网络拼接的时候需要主要,共用主干后,对应的值会有一些变话,都可以更具传入的头和对应头的层数进行查询,这里的计算大家可以自己算一下,需要注意的是P4,P5拼接的层数是头数量15的倍数,_forward_once函数中,新加代码乘以15的由来。
    在这里插入图片描述

在这里插入图片描述

还有一些小的修改,大家可以自己查看yolo_plate.py文件。基本都是和输入头索引对应的detect层的位置,也就是前面计算的self.headi_forward

2数据读取修改

  • 修改数据读取配置文件
    添加 headnum 头的数量,用于数据读取的循环值。
    依次写个头的对应的数据路径,类别,以及类别名称即可。

  • 数据读取成dataloader
    这里是多个头的数据,所以创建的时候使用列表来进行存储
    在这里插入图片描述
    修改create_dataloader方法,返回列表值即可
    在这里插入图片描述
    这里需要记录类别数量,名字等等对应即可,修改较为简单,省去,不清楚的可以去查看源码。

  • 数据训练数据读取
    这里我们创建了多个头的数据dataloader,我们训练的时候是同时进行训练的,所以每次从一个dataloader中读取相同张数的数据,进行一个batch训练,然后将loss进行相加然后回传。
    由于数据长短不同,所以我们按照最长的数据进行设置一个epoch的长度,如果短的读取完了,再次创建train_loader来进行重复读取训练。
    在这里插入图片描述
    数据运行逻辑:
    在这里插入图片描述

3 训练工程常见问题修改

  • 根据检测头的数量修改读取数据的路径:
    在这里插入图片描述

  • general.py文件,修改读取数量路径,修改为列表形式。
    在这里插入图片描述

  • 根据数量dataloader 读取对应的bar数据读取器,列表形式
    在这里插入图片描述

四 模型训练

我使用416大小训练了2个头的内容,map涨点很快,训练速度和之前的训练过程相当,稍微慢一丢丢

# 训练指令
python train_plate.py --data data/mydata.yaml --batch 256 --epochs 400 --weights weights/yolov5s.pt   --imgsz 416  --device '0,1'  --cfg models/yolov5s_plate.yaml  --hyp data/hyps/palte_head.yaml --name car_plate_head_size416

在这里插入图片描述
模型收敛的比单个头训练的更快一些。

五 模型开源

目前还有一些内容没有更新完成,完成后上传github

本图文内容来源于网友网络收集整理提供,作为学习参考使用,版权属于原作者。
THE END
分享
二维码
< <上一篇
下一篇>>