使用Lenet-5识别手写数字(含简单GUI测试,简单详细版)

写在前面:欢迎来到「湫歌」的博客。我是秋秋,一名普通的在校大学生。在学习之余,用博客来记录我学习过程中的点点滴滴,也希望我的博客能够更给同样热爱学习热爱技术的你们带来收获!希望大家多多关照,我们一起成长一起进步。也希望大家多多支持我鸭,喜欢我就给我一个关注吧!

1、LeNet-5的搭建与训练

from tensorflow.keras.datasets import mnist
from tensorflow.keras import models
from tensorflow.keras import layers
import tensorflow as tf
import cv2
import numpy as np
import pandas as pd

(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
train_images=train_images.reshape(60000,28,28,1)
test_images=test_images.reshape(10000,28,28,1)
# 归一化
train_images=train_images/255
test_images=test_images/255
train_labels=np.array(pd.get_dummies(train_labels))
test_labels=np.array(pd.get_dummies(test_labels))

model =models.Sequential()
model.add(layers.Conv2D(filters=6,kernel_size=(5,5),input_shape=(28,28,1),padding='same',activation='swish'))
model.add(layers.AveragePooling2D(pool_size=(2,2)))
model.add(layers.Conv2D(filters=16,kernel_size=(5,5),activation='swish'))
model.add(layers.AveragePooling2D(pool_size=(2,2)))
model.add(layers.Conv2D(filters=120,kernel_size=(5,5),activation='swish'))
model.add(layers.Flatten())
model.add(layers.Dense(84,activation='swish'))
model.add(layers.Dense(10,activation='softmax'))
model.summary()


model.compile(optimizer='rmsprop',loss='categorical_crossentropy',metrics=['acc'])
history=model.fit(train_images,train_labels,epochs=10,validation_data=(test_images,test_labels))
model.evaluate(test_images,test_labels)
model.save('model_mnist.h5')
# print(model.evaluate(test_images,test_labels))
# network=models.Sequential()
# network.add(layers.Dense(512,activation='relu',input_shape=(28*28,)))
# network.add(layers.Dense(10,activation='softmax'))
# print(network.summary())

网络模型结果 

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (None, 28, 28, 6)         156       
_________________________________________________________________
average_pooling2d (AveragePo (None, 14, 14, 6)         0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 10, 10, 16)        2416      
_________________________________________________________________
average_pooling2d_1 (Average (None, 5, 5, 16)          0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 1, 1, 120)         48120     
_________________________________________________________________
flatten (Flatten)            (None, 120)               0         
_________________________________________________________________
dense (Dense)                (None, 84)                10164     
_________________________________________________________________
dense_1 (Dense)              (None, 10)                850       
=================================================================

2、测试test

import tensorflow as tf
from tensorflow.keras import models
from tensorflow.keras import layers
import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# 调用模型
newmodel=models.load_model('model_mnist.h5')
# 读取图片
img=cv2.imread('2.png',0)
plt.imshow(img)
# print(img.shape)
img=cv2.resize(img,(28,28))
img=img.reshape(1,28,28,1)
img=img/255 #归一化
print(img.shape)

predict=newmodel.predict(img)
# predict
print(predict)
np.argmax(predict)
print("预测图像中的数字为:" + str(np.argmax(predict)))

3、GUI页面设计

import platform
import sys
from PyQt5.QtCore import *
from PyQt5.QtGui import *
from PyQt5.QtWidgets import *
from PyQt5.QtWidgets import QApplication
from PyQt5.QtWidgets import QWidget
from PyQt5.Qt import QPixmap, QPainter, QPoint, QPaintEvent, QMouseEvent, QPen, QColor, QSize
from PyQt5.QtCore import Qt
from PyQt5.Qt import QWidget, QColor, QPixmap, QIcon, QSize, QCheckBox
from PyQt5.QtWidgets import QHBoxLayout, QVBoxLayout, QPushButton, QSplitter, QComboBox, QLabel, QSpinBox, QFileDialog
import tensorflow as tf
from tensorflow.keras import models
from tensorflow.keras import layers
import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


def main():
    app = QApplication(sys.argv)

    mainWidget = MainWidget() #新建一个主界面
    mainWidget.show()    #显示主界面

    exit(app.exec_()) #进入消息循环


class PaintBoard(QWidget):

    def __init__(self, Parent=None):
        '''
        Constructor
        '''
        super().__init__(Parent)

        self.__InitData()  # 先初始化数据,再初始化界面
        self.__InitView()
        self.setWindowTitle("画笔")

    def __InitData(self):

        self.__size = QSize(280, 280)

        # 新建QPixmap作为画板,尺寸为__size
        self.__board = QPixmap(self.__size)
        self.__board.fill(Qt.black)  # 用黑色填充画板

        self.__IsEmpty = True  # 默认为空画板

        self.__lastPos = QPoint(0, 0)  # 上一次鼠标位置
        self.__currentPos = QPoint(0, 0)  # 当前的鼠标位置

        self.__painter = QPainter()  # 新建绘图工具

        self.__thickness = 10  # 默认画笔粗细为10px
        self.__penColor = QColor("white")  # 设置默认画笔颜色为白色

    def __InitView(self):
        # 设置界面的尺寸为__size
        self.setFixedSize(self.__size)

    def Clear(self):
        # 清空画板
        self.__board.fill(Qt.black)
        self.update()
        self.__IsEmpty = True


    def IsEmpty(self):
        # 返回画板是否为空
        return self.__IsEmpty

    def GetContentAsQImage(self):
        # 获取画板内容(返回QImage)
        image = self.__board.toImage()
        return image

    def paintEvent(self, paintEvent):
        # 绘图事件
        # 绘图时必须使用QPainter的实例,此处为__painter
        # 绘图在begin()函数与end()函数间进行
        # begin(param)的参数要指定绘图设备,即把图画在哪里
        # drawPixmap用于绘制QPixmap类型的对象
        self.__painter.begin(self)
        # 0,0为绘图的左上角起点的坐标,__board即要绘制的图
        self.__painter.drawPixmap(0, 0, self.__board)
        self.__painter.end()

    def mousePressEvent(self, mouseEvent):
        # 鼠标按下时,获取鼠标的当前位置保存为上一次位置
        self.__currentPos = mouseEvent.pos()
        self.__lastPos = self.__currentPos

    def mouseMoveEvent(self, mouseEvent):
        # 鼠标移动时,更新当前位置,并在上一个位置和当前位置间画线
        self.__currentPos = mouseEvent.pos()
        self.__painter.begin(self.__board)
        self.__painter.setPen(QPen(self.__penColor, self.__thickness))  # 设置画笔颜色,粗细
        # 画线
        self.__painter.drawLine(self.__lastPos, self.__currentPos)
        self.__painter.end()
        self.__lastPos = self.__currentPos

        self.update()  # 更新显示

    def mouseReleaseEvent(self, mouseEvent):
        self.__IsEmpty = False  # 画板不再为空


class MainWidget(QWidget):

    def __init__(self, Parent=None):
        '''
        Constructor
        '''
        super().__init__(Parent)

        self.__InitData()  # 先初始化数据,再初始化界面
        self.__InitView()

    def __InitData(self):
        '''
                  初始化成员变量
        '''
        self.__paintBoard = PaintBoard(self)

    def __InitView(self):
        '''
                  初始化界面
        '''
        self.setFixedSize(550, 300)
        self.setWindowTitle("Predictive handwritten digits")

        # 新建一个水平布局作为本窗体的主布局
        main_layout = QHBoxLayout(self)
        # 设置主布局内边距以及控件间距为10px
        main_layout.setSpacing(10)

        # 在主界面左侧放置画板
        main_layout.addWidget(self.__paintBoard)

        # 新建垂直子布局用于放置按键
        sub_layout = QVBoxLayout()

        # 设置此子布局和内部控件的间距为10px
        sub_layout.setContentsMargins(10, 10, 10, 10)

        self.__btn_Clear = QPushButton("清空画板")
        self.__btn_Clear.setParent(self)  # 设置父对象为本界面

        # 将按键按下信号与画板清空函数相关联
        self.__btn_Clear.clicked.connect(self.__paintBoard.Clear)
        sub_layout.addWidget(self.__btn_Clear)


        self.__btn_Save = QPushButton("保存作品")
        self.__btn_Save.setParent(self)
        self.__btn_Save.clicked.connect(self.on_btn_Save_Clicked)
        sub_layout.addWidget(self.__btn_Save)

        self.__btn_Predict = QPushButton("预测")
        self.__btn_Predict.setParent(self)  # 设置父对象为本界面
        self.__btn_Predict.clicked.connect(self.Predict)
        sub_layout.addWidget(self.__btn_Predict)

        self.__btn_Quit = QPushButton("退出")
        self.__btn_Quit.setParent(self)  # 设置父对象为本界面
        self.__btn_Quit.clicked.connect(self.Quit)
        sub_layout.addWidget(self.__btn_Quit)

        self.__text_browser = QTextBrowser(self)
        self.__text_browser.setParent(self)
        sub_layout.addWidget(self.__text_browser)

        splitter = QSplitter(self)  # 占位符
        sub_layout.addWidget(splitter)

        main_layout.addLayout(sub_layout)  # 将子布局加入主布局

    def __fillColorList(self, comboBox):

        index_black = 0
        index = 0
        for color in self.__colorList:
            if color == "black":
                index_black = index
            index += 1
            pix = QPixmap(70, 20)
            pix.fill(QColor(color))
            comboBox.addItem(QIcon(pix), None)
            comboBox.setIconSize(QSize(70, 20))
            comboBox.setSizeAdjustPolicy(QComboBox.AdjustToContents)

        comboBox.setCurrentIndex(index_black)

    def on_PenColorChange(self):
        color_index = self.__comboBox_penColor.currentIndex()
        color_str = self.__colorList[color_index]
        self.__paintBoard.ChangePenColor(color_str)

    def on_PenThicknessChange(self):
        penThickness = self.__spinBox_penThickness.value()
        self.__paintBoard.ChangePenThickness(penThickness)

    def on_btn_Save_Clicked(self):
        image = self.__paintBoard.GetContentAsQImage()
        image.save('1.png')

    def Predict(self):
        # 调用模型
        newmodel = models.load_model('model_mnist.h5')
        # 读取图片
        # img = cv2.imread('1.png', 0)
        img = cv2.imread('1.png', 0)
        plt.imshow(img)
        # print(img.shape)
        img = cv2.resize(img, (28, 28))

        rows = img.shape[0]
        cols = img.shape[1]
        for i in range(rows):
            for j in range(cols):
                if (img[i, j] > 150):
                    img[i, j] = 255;
                else:
                    img[i, j] = 0;

        img = img.reshape(1, 28, 28, 1)
        img = img / 255  # 归一化
        # print(img.shape)


        predict = newmodel.predict(img)
        predict
        # print(predict)
        np.argmax(predict)
        # print("预测图像中的数字为:" + str(np.argmax(predict)))
        self.__text_browser.append("预测图像中的数字为:" + str(np.argmax(predict)))
        self.cursot = self.__text_browser.textCursor()
        self.__text_browser.moveCursor(self.cursot.End)



    def Quit(self):
        self.close()


if __name__ == '__main__':
    main()

本次Lenet-5识别手写数字学习结束啦!喜欢的小伙伴可以关注秋秋,一起探讨一起学习! 

本图文内容来源于网友网络收集整理提供,作为学习参考使用,版权属于原作者。
THE END
分享
二维码

)">
< <上一篇
下一篇>>