PyTorch手写实现卷积神经网络

卷积神经网络

深度学习处理的数据一般是[0,1],而图片数据的取值范围是[0,255],所以一般导入图片数据的时候做一次预处理/255。

彩色图片有3个通道,每个通道上都是以[0,255]的数值进行图片对应通道上的存储。

全连接神经网络参数众多,难以训练,而且容易过拟合,因因此引入了卷积神经网络。

特征图上的每个数值都是通过上一层的输入通过卷积核计算得到的。

将卷积层进行层层堆叠,就形成了深度卷积网络。在目前的硬件条件下,可以实现几百上千层的卷积层结构。卷积层越多,对复杂结构的特征识别能力越强。

激活函数一般是ReLU(把小于零的值剔除)。

池化层可以放大主要特征,忽略掉几个像素的偏差,抓住主要矛盾,忽略次要矛盾。可以减少数据维度,减少训练参数,避免过拟合。

批归一化的好处:①收敛速度更快(避免的了梯度离散现象,能更好的的更新参数)②更容易得到更好的解 ③训练更稳定

Lenet5_main.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# 主程序 1, 3, 4
# 1 导入数据 2 创建网络 3 训练网络 4 测试验证
import torch
# 导入数据集
from torchvision import datasets
# 对数据做变形,变换
from torchvision import transforms
# 加载数据
from torch.utils.data import DataLoader
# 常见的神经网络相关模块
from torch import nn, optim
# 导入自己构建的模型
from Lenet5 import Lenet5
# 损失函数
from torch.nn import functional as F

batch_size = 128
# 导入数据
def main():
cifar_train = datasets.CIFAR10(
root="/media/D/dataset/CIFAR10",
train=True,
transform=transforms.Compose([
# 将照片转化成32*32的特征图
transforms.Resize((32, 32)),
# 将数据转化成tensor
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
]), download=False)
cifar_train = DataLoader(
dataset=cifar_train,
batch_size=batch_size,
shuffle=True,
)
cifar_test = datasets.CIFAR10(
root="/media/D/dataset/CIFAR10",
train=False,
transform=transforms.Compose([
# 将照片转化成32*32的特征图
transforms.Resize((32, 32)),
# 将数据转化成tensor
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
]), download=False)
cifar_test = DataLoader(
dataset=cifar_test,
batch_size=batch_size,
shuffle=True,
)
# 对数据预览 iter()方法得到数据集的迭代器,next()生成取出数据
x, label = iter(cifar_train).next()
print(x.shape, label.shape)

# 训练网络
device = torch.device("cuda")
# 实例化网络
model = Lenet5().to(device)
# 构建CrossEntropyLoss损失函数(包含了Softmax), 需要构建函数名, 不能直接使用
criterion = nn.CrossEntropyLoss().to(device)
# 构建优化器:优化器自己选择需要梯度的参数
optimizer = optim.SGD(model.parameters(), lr=0.001)
for epoch in range(10000):
model.train()
for batch_idx, (x, label) in enumerate(cifar_train):
# 将训练集放入模型
x, label = x.to(device), label.to(device)
logits = model(x)
# 计算loss:交叉熵损失函数
loss = criterion(logits, label)
# 优化迭代老三样
# 梯度清零
optimizer.zero_grad()
# 梯度计算
loss.backward()
# 梯度迭代更新
optimizer.step()
print(epoch, loss.item())
# 测试数据, 测试集的数据拿来作验证
model.eval()
with torch.no_grad():
total_correct = 0
total_num = 0
for x, label in cifar_test:
x, label = x.to(device), label.to(device)
logits = model(x)
# logits = F.softmax(logits, dim=1)
pred = logits.argmax(dim=1)
correct = torch.eq(pred, label).float().sum().item()

total_correct += correct
total_num += x.size(0)
acc = total_correct / total_num
print(epoch, acc)


if __name__ == '__main__':
main()

Lenet5.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# 构建Lenet5神经网络
# 1986年设计得到的, 最简单的卷积神经网络的结构
import torch
from torch import nn


class Lenet5(nn.Module):
# 网络结构的初始化
def __init__(self):
super(Lenet5, self).__init__()
# 神经网络结构构造
# 卷积单元 Sequential可以帮助我们快速构建神经网络的结构
self.conv_unit = nn.Sequential(
# 卷积层C1
# 输入的是彩图:in = 3, 输出:out = 6
# kernel_size:卷积核大小, stride:卷积核移动步长, padding:补0的层数
nn.Conv2d(3, 6, kernel_size=5, stride=1, padding=0),
# 池化层(平均池化), 池化不改变通道的数量, 改变图片的尺寸大小
nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
# 卷积层C2
nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),
# 池化层(平均池化)
nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
)
# 线性结构单元
self.fc_unit = nn.Sequential(
nn.Linear(400, 120), # 输入是多少,可以通过测试决定
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, 10)
)

def forward(self, x):
x = self.conv_unit(x)
x = x.view(x.size(0), -1) # x.size(0)就是batch_size view()方法展平多维张量
output = self.fc_unit(x)
return output

def main():
# 实例化神经网络
net = Lenet5()
# 放入测试图片
tmp = torch.randn(8, 3, 32, 32)
out = net(tmp)
print(out.shape)

if __name__ == '__main__':
main()