多分类-手写识别体

1.分析数据集

数据集:

链接:https://pan.baidu.com/s/1YY9HuDqCSr3-CHWON3NdKg
提取码:15eq

mnist_train.csv 数据集一共 (60000, 785) 行列 数据。 已知 28 * 28 = 784

  • 第一列的值为标签值。范围(0, 9), 我们希望神经网络能够预测得到正确的标签值。
  • 剩下的 784 = 28*28 列数据 是手写识别体的数字的像素值。

因此 我们可以把第一列作为标签值,剩下的 28*28 列 作为 变量。

import pandas as pd
import numpy as np

path = r'datamnist_train.csv'
df = pd.read_csv(path, header=None)
df.head()
0 1 2 3 4 5 6 7 8 9 ... 775 776 777 778 779 780 781 782 783 784
0 5 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0
1 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0
2 4 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0
3 1 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0
4 9 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0

5 rows × 785 columns

df.shape
(60000, 785)
df.describe()
0 1 2 3 4 5 6 7 8 9 ... 775 776 777 778 779 780 781 782 783 784
count 60000.000000 60000.0 60000.0 60000.0 60000.0 60000.0 60000.0 60000.0 60000.0 60000.0 ... 60000.000000 60000.000000 60000.000000 60000.000000 60000.000000 60000.0000 60000.0 60000.0 60000.0 60000.0
mean 4.453933 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.200433 0.088867 0.045633 0.019283 0.015117 0.0020 0.0 0.0 0.0 0.0
std 2.889270 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 6.042472 3.956189 2.839845 1.686770 1.678283 0.3466 0.0 0.0 0.0 0.0
min 0.000000 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.000000 0.000000 0.000000 0.000000 0.000000 0.0000 0.0 0.0 0.0 0.0
25% 2.000000 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.000000 0.000000 0.000000 0.000000 0.000000 0.0000 0.0 0.0 0.0 0.0
50% 4.000000 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.000000 0.000000 0.000000 0.000000 0.000000 0.0000 0.0 0.0 0.0 0.0
75% 7.000000 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.000000 0.000000 0.000000 0.000000 0.000000 0.0000 0.0 0.0 0.0 0.0
max 9.000000 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 254.000000 254.000000 253.000000 253.000000 254.000000 62.0000 0.0 0.0 0.0 0.0

8 rows × 785 columns

读取数据的另外一种方法:

  • open(路径, 读取方式)

注:这种方法不常用,直接使用 pd.read_csv() 即可,非常方便。

data_file = open(r"datamnist_train.csv", 'r') # _io.TextIOWrapper
data_list = data_file.readlines()              # .csv 文件的数据 存放到 list 对象里
print(len(data_list))
data_file.close()
# data_list  # 每个 list 对应 csv 文件的每一行, 类型是str 
60000
data_list[0]     # csv 文件 中 的第一行, 列表返回的是 一串 str 。
'5,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3,18,18,18,126,136,175,26,166,255,247,127,0,0,0,0,0,0,0,0,0,0,0,0,30,36,94,154,170,253,253,253,253,253,225,172,253,242,195,64,0,0,0,0,0,0,0,0,0,0,0,49,238,253,253,253,253,253,253,253,253,251,93,82,82,56,39,0,0,0,0,0,0,0,0,0,0,0,0,18,219,253,253,253,253,253,198,182,247,241,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,80,156,107,253,253,205,11,0,43,154,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,14,1,154,253,90,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,139,253,190,2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,11,190,253,70,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,35,241,225,160,108,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,81,240,253,253,119,25,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,45,186,253,253,150,27,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,16,93,252,253,187,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,249,253,249,64,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,46,130,183,253,253,207,2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,39,148,229,253,253,253,250,182,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,24,114,221,253,253,253,253,201,78,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,23,66,213,253,253,253,253,198,81,2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,18,171,219,253,253,253,253,195,80,9,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,55,172,226,253,253,253,253,244,133,11,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,136,253,253,253,212,135,132,16,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0n'

提取训练集的前100条数据

测试集的前10条数据

df.iloc[:100].to_csv(r"datamnist_train_100.csv", index=False, header=False)
test = pd.read_csv(r'datamnist_test.csv', header=None)
test.iloc[:10].to_csv(r'datamnist_test_10.csv', index=False, header=False)

观察数据

train_path = r"datamnist_train_100.csv"
test_path = r"datamnist_test_10.csv"
data_file = open(train_path, 'r')
data_list = data_file.readlines()
data_file.close()
len(data_list)
100
data_list[0]
'5,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3,18,18,18,126,136,175,26,166,255,247,127,0,0,0,0,0,0,0,0,0,0,0,0,30,36,94,154,170,253,253,253,253,253,225,172,253,242,195,64,0,0,0,0,0,0,0,0,0,0,0,49,238,253,253,253,253,253,253,253,253,251,93,82,82,56,39,0,0,0,0,0,0,0,0,0,0,0,0,18,219,253,253,253,253,253,198,182,247,241,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,80,156,107,253,253,205,11,0,43,154,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,14,1,154,253,90,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,139,253,190,2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,11,190,253,70,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,35,241,225,160,108,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,81,240,253,253,119,25,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,45,186,253,253,150,27,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,16,93,252,253,187,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,249,253,249,64,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,46,130,183,253,253,207,2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,39,148,229,253,253,253,250,182,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,24,114,221,253,253,253,253,201,78,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,23,66,213,253,253,253,253,198,81,2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,18,171,219,253,253,253,253,195,80,9,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,55,172,226,253,253,253,253,244,133,11,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,136,253,253,253,212,135,132,16,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0n'

处理数据

  • 第一列的标签提取出来
  • 其它像素值,以 逗号 分割开
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

# data_list[0] 是一串字符
all_values = data_list[0].split(',')         # 字符串 以 逗号 分割,生成 列表
len(all_values)
# all_values                                   # 一共785, 第一个为标签,其余为 28*28的像素值
image_array = np.asfarray(all_values[1:]).reshape(28, 28)     # asfarray():把 列表 转化成 浮点型 numpy 数组
# print(image_array)
plt.imshow(image_array, cmap='Greys', interpolation=None)
<matplotlib.image.AxesImage at 0x12dcd8c0ca0>

png

inputs = (np.asfarray(all_values[1:]) / 255 * 0.99) + 0.01
inputs.shape
(784,)
# 转换 inputs 为 二维数组
inputs = np.array(inputs, ndmin=2).T
print(inputs)
inputs.shape  # (784, 1)
[[0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.02164706]
 [0.07988235]
 [0.07988235]
 [0.07988235]
 [0.49917647]
 [0.538     ]
 [0.68941176]
 [0.11094118]
 [0.65447059]
 [1.        ]
 [0.96894118]
 [0.50305882]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.12647059]
 [0.14976471]
 [0.37494118]
 [0.60788235]
 [0.67      ]
 [0.99223529]
 [0.99223529]
 [0.99223529]
 [0.99223529]
 [0.99223529]
 [0.88352941]
 [0.67776471]
 [0.99223529]
 [0.94952941]
 [0.76705882]
 [0.25847059]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.20023529]
 [0.934     ]
 [0.99223529]
 [0.99223529]
 [0.99223529]
 [0.99223529]
 [0.99223529]
 [0.99223529]
 [0.99223529]
 [0.99223529]
 [0.98447059]
 [0.37105882]
 [0.32835294]
 [0.32835294]
 [0.22741176]
 [0.16141176]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.07988235]
 [0.86023529]
 [0.99223529]
 [0.99223529]
 [0.99223529]
 [0.99223529]
 [0.99223529]
 [0.77870588]
 [0.71658824]
 [0.96894118]
 [0.94564706]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.32058824]
 [0.61564706]
 [0.42541176]
 [0.99223529]
 [0.99223529]
 [0.80588235]
 [0.05270588]
 [0.01      ]
 [0.17694118]
 [0.60788235]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.06435294]
 [0.01388235]
 [0.60788235]
 [0.99223529]
 [0.35941176]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.54964706]
 [0.99223529]
 [0.74764706]
 [0.01776471]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.05270588]
 [0.74764706]
 [0.99223529]
 [0.28176471]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.14588235]
 [0.94564706]
 [0.88352941]
 [0.63117647]
 [0.42929412]
 [0.01388235]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.32447059]
 [0.94176471]
 [0.99223529]
 [0.99223529]
 [0.472     ]
 [0.10705882]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.18470588]
 [0.73211765]
 [0.99223529]
 [0.99223529]
 [0.59235294]
 [0.11482353]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.07211765]
 [0.37105882]
 [0.98835294]
 [0.99223529]
 [0.736     ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.97670588]
 [0.99223529]
 [0.97670588]
 [0.25847059]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.18858824]
 [0.51470588]
 [0.72047059]
 [0.99223529]
 [0.99223529]
 [0.81364706]
 [0.01776471]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.16141176]
 [0.58458824]
 [0.89905882]
 [0.99223529]
 [0.99223529]
 [0.99223529]
 [0.98058824]
 [0.71658824]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.10317647]
 [0.45258824]
 [0.868     ]
 [0.99223529]
 [0.99223529]
 [0.99223529]
 [0.99223529]
 [0.79035294]
 [0.31282353]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.09929412]
 [0.26623529]
 [0.83694118]
 [0.99223529]
 [0.99223529]
 [0.99223529]
 [0.99223529]
 [0.77870588]
 [0.32447059]
 [0.01776471]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.07988235]
 [0.67388235]
 [0.86023529]
 [0.99223529]
 [0.99223529]
 [0.99223529]
 [0.99223529]
 [0.76705882]
 [0.32058824]
 [0.04494118]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.22352941]
 [0.67776471]
 [0.88741176]
 [0.99223529]
 [0.99223529]
 [0.99223529]
 [0.99223529]
 [0.95729412]
 [0.52635294]
 [0.05270588]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.538     ]
 [0.99223529]
 [0.99223529]
 [0.99223529]
 [0.83305882]
 [0.53411765]
 [0.52247059]
 [0.07211765]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]
 [0.01      ]]





(784, 1)
output_nodes = 10
targets = np.zeros(output_nodes) + 0.01
targets[int(all_values[0])] = 0.99       # 标签值刚好为 输出值 targets 的下标
print(targets.shape, targets)

targets = np.array(targets, ndmin=2).T
print(targets.shape, 'n', targets)
(10,) [0.01 0.01 0.01 0.01 0.01 0.99 0.01 0.01 0.01 0.01]
(10, 1) 
 [[0.01]
 [0.01]
 [0.01]
 [0.01]
 [0.01]
 [0.99]
 [0.01]
 [0.01]
 [0.01]
 [0.01]]

下一条数据

# data_list[0] 是一串字符
all_values = data_list[1].split(',')         # 字符串 以 逗号 分割,生成 列表
len(all_values)
# all_values                                   # 一共785, 第一个为标签,其余为 28*28的像素值
image_array = np.asfarray(all_values[1:]).reshape(28, 28)     # asfarray():把 列表 转化成 浮点型 numpy 数组
# print(image_array)
plt.imshow(image_array, cmap='Greys', interpolation=None)
<matplotlib.image.AxesImage at 0x12dce0cccd0>


png

以上功能可以使用 numpy 简化

# 训练集路径
train_path = r"datamnist_train_100.csv"
# 使用 read_csv() 读取文件
data = pd.read_csv(train_path, header=None)

# 标签
label = data[0].to_numpy(dtype=np.float64)

# 28*28 = 784 像素值 --> 二维 numpy 对象 
all_variabels = data.iloc[:, 1:]
all_variabels = all_variabels.to_numpy(dtype=np.float64)

plt.imshow(image_array, cmap='Greys', interpolation=None)
<matplotlib.image.AxesImage at 0x12dce122310>


png

把 训练数据 映射到指定的 区间

观察数据:我们知道 像素值 在 0~255 之间,在使用 神经网络 训练之前,我们把 该值 缩放到 0.01~1之间

注意:最小值为0.01,而不是为0,防止 像素值为0 后期 权重 更新失败。

# 缩放:这一步 相当于 归一化 或者 标准化
# np.asfarray(all_values[1:] 与 all_variabels 等价
scaled_input = (np.asfarray(all_values[1:])) / 255.5 * 0.99 + 0.01     # (0 ~ 255) --> (0.01 ~ 1) 
print(all_variabels)
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]

输出层

我们如何设置 输出层哪?

首先,输出结果是一个数字,数字的范围是 0~9,因此,该问题归纳为 多分类问题,输出神经元的个数设置为10.

# 输出层 神经元个数 为 10个
onodes = 10
# 生成 (10,)的 一维numpy数组
targets = np.zeros(onodes)+0.01
targets[int(all_values[0])] = 0.99
print(targets)
type(targets), targets.shape
[0.99 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01]





(numpy.ndarray, (10,))
all_values[0]
'0'

正态分布随机生成矩阵

hnodes, inodes = 3, 3
np.random.normal(0.0, pow(hnodes, -0.5), (hnodes, inodes))
array([[ 0.5685311 , -0.17127778,  0.67140503],
       [ 0.61448826,  0.29478324,  0.38356441],
       [-0.26157523, -0.43210937, -0.76723949]])
pow(16, 2)
256
mu, sigma = 0, 0.1 # mean and standard deviation
s = np.random.normal(mu, sigma, 3)  # 根据 平均值 和 标准差 生成 正态分布的数据
s
array([-0.1519728 ,  0.1522495 , -0.08011677])

激活函数

import scipy.special as ss

activation_function = lambda x: ss.expit(x)      # expit(x) = 1/(1+exp(-x))

x = np.arange(-10, 10, 0.1)
y = activation_function(x)

plt.plot(x, y)
plt.show()         # 值域:[0, 1], x(0) = 0.5
activation_function(0)


png

0.5

2.框架代码

  • 初始化函数 :设定 输入层、隐藏层、输出层
  • 训练 : 学习给定训练集样本后,优化权重
  • 查询 : 给定输入,从输出节点给出答案
# 整体框架如下:
class NeuralNetwork:
#     初始化神经网络
    def __init__():
        pass
    
#     训练
    def train():
        pass
    
#     查询
    def query():
        pass

初始化方法如下:

    # 初始化
    def __init__(self, inputnodes, hiddennodes, outputnodes, learningrate):
        # set number of nodes in each input, hidden, output layer
        """:arg
            inodes : 输入层 神经元个数
            hnodes : 隐藏层 神经元个数
            onodes : 输出层 神经元个数
            lr     : 学习率
        """
        self.inodes = inputnodes
        self.hnodes = hiddennodes
        self.onodes = outputnodes

        # learning rate
        self.lr = learningrate
        pass

权重

权重,刚开始进行 随机生成,然后根据每次训练的结果,我们计算 损失值,进而更新权重,以便下次训练时,损失值更小。

import numpy as np

a = np.zeros([3, 2])
a[0, 0] = 1
a[0, 1] = 2 
a[1, 0] = 9
a[2, 1] = 12
a.shape
(3, 2)
%matplotlib inline
import matplotlib.pyplot as plt

# a = np.random.rand(4*4).reshape(4, 4)
plt.imshow(a, interpolation="nearest")   # 把二维数组(m, n) 转换成 (m, n) 个正方形表示
<matplotlib.image.AxesImage at 0x12dce5c76a0>


png

我们设计的网络结构:3层神经元。

包含 : 输入层、隐藏层、输出层
注:hidden_nodes :表示隐藏层神经元个数, input_nodes :表示输入层 神经元个数, output_nodes : 输出层神经元个数

  • 设置输入层和隐藏层之间的连接权重矩阵为 $ W_{input_hidden}$, 大小:(hidden_nodes, input_nodes)
  • 设隐藏层和输出层之间的连接权重矩阵 为

    W

    h

    i

    d

    d

    e

    n

    o

    u

    t

    p

    u

    t

    W_{hidden_output}

    Whiddenoutput, 大小:{output_nodes, hidden_nodes}

初始化权重值

我们设置初始的权重的要求:

  • 权重值较小
  • 随机

方法一

# 生成 rows 行,columns列 的范围在(0, 1) 的随机值
# np.random.rand(rows, columns)
np.random.rand(3, 3) # 每个值都在(0, 1)之间
array([[0.56577557, 0.4548674 , 0.37397099],
       [0.7145576 , 0.37282089, 0.39927602],
       [0.25294715, 0.93189228, 0.19266301]])

一般权重的范围在(-1.0, 1.0) 之间,我了简单起见,我们设置 权重范围(-0.5, 0.5)

np.random.rand(3, 3) - 0.5
array([[ 0.18886128,  0.14154925,  0.25791405],
       [-0.49226421, -0.35166701, -0.20517272],
       [-0.32607314,  0.41795557,  0.01489836]])
inodes, hnodes = 3, 3
wih = np.random.rand(hnodes, inodes) - 0.5
wih
array([[ 0.30452585,  0.49108575,  0.13111538],
       [ 0.46734034,  0.1032404 ,  0.02777094],
       [ 0.04864065, -0.07211676, -0.48808873]])
onodes = 2
who = np.random.rand(onodes, hnodes) - 0.5
who
array([[0.18630746, 0.17767593, 0.47013643],
       [0.4316961 , 0.22181227, 0.44532566]])

这里我们初始化权重时,使用正态分布函数,在 神经网络 类 实例化时进行初始化:

注意:

  • 权重的维度
  • 权重的值

方法二

mu, sigma = 0, 0.1 # mean and standard deviation 分别表示平均值和标准差
s = np.random.normal(mu, sigma, size=(2, 2))
s
# abs(sigma - np.std(s, ddof=1))
# abs(mu - np.mean(s))
array([[ 0.07580048, -0.16289481],
       [-0.03706102,  0.01746931]])

正态分布函数的 平均值mu 设为0, 标准差sigma 设置为 传入链接输入的开方,即 $ frac{1}{sqrt(传入链接数目)} $

# 初始化输入层和隐藏层之间的权重
inodes, hnodes = 3, 3
onodes = 3
wih = np.random.normal(0.0, pow(hnodes, -0.5), (hnodes, inodes))
wih
who = np.random.normal(0.0, pow(onodes, -0.5), (onodes, hnodes))
who
array([[ 0.90693664,  0.45453867, -0.42490091],
       [-0.15520574, -0.11698295,  0.23764649],
       [-0.21942348,  0.1452418 ,  0.06733478]])

train() 训练函数的编写

a = np.arange(10).reshape(2, 5)
print(a)
np.transpose(a)                # 转秩
[[0 1 2 3 4]
 [5 6 7 8 9]]





array([[0, 5],
       [1, 6],
       [2, 7],
       [3, 8],
       [4, 9]])
# 提取出一条样本数据 : all_values[0] 是标签; all_values[1:] 是训练数据的像素数据 
print(all_values)
['0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '51', '159', '253', '159', '50', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '48', '238', '252', '252', '252', '237', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '54', '227', '253', '252', '239', '233', '252', '57', '6', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '10', '60', '224', '252', '253', '252', '202', '84', '252', '253', '122', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '163', '252', '252', '252', '253', '252', '252', '96', '189', '253', '167', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '51', '238', '253', '253', '190', '114', '253', '228', '47', '79', '255', '168', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '48', '238', '252', '252', '179', '12', '75', '121', '21', '0', '0', '253', '243', '50', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '38', '165', '253', '233', '208', '84', '0', '0', '0', '0', '0', '0', '253', '252', '165', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '7', '178', '252', '240', '71', '19', '28', '0', '0', '0', '0', '0', '0', '253', '252', '195', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '57', '252', '252', '63', '0', '0', '0', '0', '0', '0', '0', '0', '0', '253', '252', '195', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '198', '253', '190', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '255', '253', '196', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '76', '246', '252', '112', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '253', '252', '148', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '85', '252', '230', '25', '0', '0', '0', '0', '0', '0', '0', '0', '7', '135', '253', '186', '12', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '85', '252', '223', '0', '0', '0', '0', '0', '0', '0', '0', '7', '131', '252', '225', '71', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '85', '252', '145', '0', '0', '0', '0', '0', '0', '0', '48', '165', '252', '173', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '86', '253', '225', '0', '0', '0', '0', '0', '0', '114', '238', '253', '162', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '85', '252', '249', '146', '48', '29', '85', '178', '225', '253', '223', '167', '56', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '85', '252', '252', '252', '229', '215', '252', '252', '252', '196', '130', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '28', '199', '252', '252', '253', '252', '252', '233', '145', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '25', '128', '252', '253', '252', '141', '37', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0n']
# 输入数据 0-255 --> 0-0.99 --> 0.01~1
input_list = np.asfarray(all_values[1:])/255*0.99 + 0.01
input_list.shape    # (784,)
(784,)
input_list
array([0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.208     , 0.62729412, 0.99223529,
       0.62729412, 0.20411765, 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.19635294,
       0.934     , 0.98835294, 0.98835294, 0.98835294, 0.93011765,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.21964706, 0.89129412, 0.99223529, 0.98835294,
       0.93788235, 0.91458824, 0.98835294, 0.23129412, 0.03329412,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.04882353, 0.24294118, 0.87964706,
       0.98835294, 0.99223529, 0.98835294, 0.79423529, 0.33611765,
       0.98835294, 0.99223529, 0.48364706, 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.64282353, 0.98835294, 0.98835294, 0.98835294, 0.99223529,
       0.98835294, 0.98835294, 0.38270588, 0.74376471, 0.99223529,
       0.65835294, 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.208     , 0.934     , 0.99223529,
       0.99223529, 0.74764706, 0.45258824, 0.99223529, 0.89517647,
       0.19247059, 0.31670588, 1.        , 0.66223529, 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.19635294,
       0.934     , 0.98835294, 0.98835294, 0.70494118, 0.05658824,
       0.30117647, 0.47976471, 0.09152941, 0.01      , 0.01      ,
       0.99223529, 0.95341176, 0.20411765, 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.15752941, 0.65058824, 0.99223529, 0.91458824,
       0.81752941, 0.33611765, 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.99223529, 0.98835294,
       0.65058824, 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.03717647, 0.70105882,
       0.98835294, 0.94176471, 0.28564706, 0.08376471, 0.11870588,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.99223529, 0.98835294, 0.76705882, 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.23129412, 0.98835294, 0.98835294, 0.25458824,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.99223529,
       0.98835294, 0.76705882, 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.77870588,
       0.99223529, 0.74764706, 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 1.        , 0.99223529, 0.77094118,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.30505882, 0.96505882, 0.98835294, 0.44482353,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.99223529, 0.98835294, 0.58458824, 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.34      ,
       0.98835294, 0.90294118, 0.10705882, 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.03717647, 0.53411765, 0.99223529, 0.73211765,
       0.05658824, 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.34      , 0.98835294, 0.87576471,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.03717647, 0.51858824,
       0.98835294, 0.88352941, 0.28564706, 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.34      , 0.98835294, 0.57294118, 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.19635294, 0.65058824, 0.98835294, 0.68164706, 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.34388235, 0.99223529,
       0.88352941, 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.45258824, 0.934     , 0.99223529,
       0.63894118, 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.34      , 0.98835294, 0.97670588, 0.57682353,
       0.19635294, 0.12258824, 0.34      , 0.70105882, 0.88352941,
       0.99223529, 0.87576471, 0.65835294, 0.22741176, 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.34      ,
       0.98835294, 0.98835294, 0.98835294, 0.89905882, 0.84470588,
       0.98835294, 0.98835294, 0.98835294, 0.77094118, 0.51470588,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.11870588, 0.78258824, 0.98835294,
       0.98835294, 0.99223529, 0.98835294, 0.98835294, 0.91458824,
       0.57294118, 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.10705882, 0.50694118, 0.98835294, 0.99223529,
       0.98835294, 0.55741176, 0.15364706, 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      ])

定义查询函数

接受神经网络的输入,返回网络的输出。

  1. 输入层和隐藏层的关系
  • 链接权重和输入层相乘得到隐藏层的输入信号X

公式:

X

h

i

d

d

e

n

=

W

i

n

p

u

t

_

h

i

d

d

e

n

I

X_{hidden} = W_{input_hidden} cdot I

Xhidden=Winput_hiddenI

代码:

hidden_inputs = np.dot(self.wih, inputs)
  • 输入信号X通过激活函数得到隐藏层的输出O

公式:

O

h

i

d

d

e

n

=

s

i

g

m

o

i

d

(

X

h

i

d

d

e

n

)

O_{hidden} = sigmoid(X_{hidden})

Ohidden=sigmoid(Xhidden)

代码:

hidden_outputs = self.activation_function(hidden_inputs)
  1. 隐藏层和输出层之间权重和输入的处理
  • 链接权重 X 隐藏层的输出值得到 输出层的输入信号 X

X

o

u

t

p

u

t

=

W

h

i

d

d

e

n

_

o

u

t

p

u

t

O

h

i

d

d

e

n

X_{output} = W_{hidden_output} cdot O_{hidden}

Xoutput=Whidden_outputOhidden

  • 输出层的结果通过激活函数得到 输出层的 输出

O

o

u

t

p

u

t

=

s

i

g

m

o

i

d

(

X

o

u

t

p

u

t

)

O_{output} = sigmoid(X_{out_put})

Ooutput=sigmoid(Xoutput)

final_inputs = np.dot(self.who, hidden_outputs)
final_outputs = self.activation_function(final_inputs)
import numpy as np

inputs = np.array(input_list, ndmin=2).T 
inputs.shape
(784, 1)
np.zeros(10) + 0.01
array([0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01])

定义 train 函数

训练函数的功能:

  • 前向传播 : 使用 权重 X 输入神经元的值
  • 反向传播 : 根据 前向传播的 预测值,计算 误差,并使用梯度下降法反向更新权重

query(input_list) 函数已经把前向传播实现了,这里我们主要关注反向传播(backpropagation) 的实现。

反向传播稍微复杂,推导需要用到求导、链式法则等

def train(inputs_list, targets_list)
    pass

我们需要传入:

  1. 要训练的样本 inputs_list
  2. 标签值 targets_list

标签值用来求误差,进而反向传播,更新权重,再前向传播得到优化后的值。

  • 求误差 : 实际值 - 预测值
output_errors = targets - final_outputs

难点

  • 隐藏层节点反向传播的误差:

e

r

r

o

r

s

h

i

d

d

e

n

=

w

e

i

g

h

t

s

h

i

d

d

e

n

_

o

u

t

p

u

t

T

e

r

r

o

r

s

o

u

t

p

u

t

errors_{hidden} = weights^{T}_{hidden_output} cdot errors_{output}

errorshidden=weightshidden_outputTerrorsoutput

hidden_errors = np.dot(self.who.T, output_errors)

因此, 对于 隐藏层输出层 之间的权重,我们使用 output_errors 进行优化,

对于 输入层隐藏层 之间的权重,我们使用 计算得到的 hidden_errors 进行优化。

我们根据 梯度下降算法 得到 更新节点 j 与 下一个 节点 k 之间 链接权重 的矩阵形式的表达式如下:

Δ

W

j

,

k

=

α

E

k

s

i

g

m

o

i

d

(

j

w

j

,

k

O

j

)

(

1

s

i

g

m

o

i

d

(

j

w

j

,

k

O

j

)

)

O

j

T

Delta W_{j, k} = alpha * E_{k} sigmoid(sum limits _{j}w_{j,k} cdot O_j) * (1 - sigmoid(sum limits _{j}w_{j,k} cdot O_j) ) cdot O_j^T

ΔWj,k=αEksigmoid(jwj,kOj)(1sigmoid(jwj,kOj))OjT

注:

α

alpha

α 是学习率, sigmoid 是 激活函数 ,注意 * 表示正常的乘法,

cdot

表示的是 矩阵点积。

Python 代码实现:

# 更新 隐藏层-输出层 权重
self.who += self.lr * np.dot((output_errors * final_outputs * (1.0 - final_outputs)), np.transpose(hidden_outputs)) 
# 更新 输入层-隐藏层权重
self.wih += self.lr * np.dot((hidden_errors * hidden_outputs * (1.0 - hidden_outputs)), np.transpose(inputs))

梯度下降公式:

n

e

w

W

j

,

k

=

o

l

d

W

j

,

k

α

E

w

j

,

k

new W_{j, k} = old W_{j, k} - alpha cdot frac{partial E}{partial w_{j,k}}

newWj,k=oldWj,kαwj,kE

前面的公式实际上求的就是 偏导数的值。

a = np.arange(4).reshape(-1, 2)
b = a
a, b
(array([[0, 1],
        [2, 3]]),
 array([[0, 1],
        [2, 3]]))
np.dot(a, b)
array([[ 2,  3],
       [ 6, 11]])
a = np.arange(10).reshape(-1, 2)

print(a)
# 求 ndarray 中的最大值(只有参数数组a)

max_value = np.argmax(a) 
max_value
[[0 1]
 [2 3]
 [4 5]
 [6 7]
 [8 9]]





9

需要全部代码可以关注微信公众号哈。(学长杨小杨)。
在这里插入图片描述
全部代码原文

本图文内容来源于网友网络收集整理提供,作为学习参考使用,版权属于原作者。
THE END
分享
二维码
< <上一篇
下一篇>>