基于torch.nn.functional.conv2d实现CNN
在我们之前的实验中,我们一直用torch.nn.Conv2D来实现卷积神经网络,但是torch.nn.Conv2D在实现中是以torch.nn.functional.conv2d为基础的,这两者的区别是什么呢?
torch.nn.Conv2D
源码如下:
torch.nn.Conv2dCLASS torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')可以发现,函数参数包括输入的通道数、输出的通道数、卷积核大小等。在输入中,我们不需要输入卷积核的权重,但是如果在实验中,我们需要用自己的卷积核,那么这种方式就不适用了。
torch.nn.functional.conv2d
源码如下:
torch.nn.functionaltorch.nn.functional.conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) → Tensor参数的具体意义:
input代表输入图像的大小(minibatch,in_channels,H,W),是一个四维tensor
filters代表卷积核的大小(out_channels,in_channe/groups,H,W),是一个四维tensor
bias代表每一个channel的bias,是一个维数等于out_channels的tensor
stride是一个数或者一个二元组(SH,SW),代表纵向和横向的步长
padding是一个数或者一个二元组(PH,PW ),代表纵向和横向的填充值
dilation是一个数,代表卷积核内部每个元素之间间隔元素的数目(不常用,默认为0)
groups是一个数,代表分组卷积时分的组数,特别的当groups = in_channel时,就是在做逐层卷积(depth-wise conv).
二者区别
torch.nn.Conv2D是一个类,而torch.nn.functional.conv2d是一个函数,在Sequential里面只能放nn.xxx,而nn.functional.xxx是不能放入Sequential里面的。
nn.Module 实现的 layer 是由 class Layer(nn.Module) 定义的特殊类,nn.functional 中的函数是纯函数,由 def function(input) 定义。
nn.functional.xxx 需要自己定义 weight,每次调用时都需要手动传入 weight,而 nn.xxx 则不用。
如果需要自己定义卷积核,那么就只能使用nn.functional.conv2d。但是在使用时,需要注意BatchNormalization和Dropout的使用方式。参考以下链接
接下来我们使用torch.nn.functional.conv2d来定义CNN实现Mnist数据集的识别,CNN定义如下所示:
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv_1_weight = nn.Parameter(torch.randn(16,1,3,3))self.bias_1_weight = nn.Parameter(torch.randn(16))self.bn1 = nn.BatchNorm2d(16)self.conv_2_weight = nn.Parameter(torch.randn(32,16,3,3))self.bias_2_weight = nn.Parameter(torch.randn(32))self.bn2 = nn.BatchNorm2d(32)self.Linear_weight = nn.Parameter(torch.randn(10,32*32*32))self.bias_weight = nn.Parameter(torch.randn(10))def forward(self,x):x = F.conv2d(x,self.conv_1_weight,self.bias_1_weight,stride=1,padding=1)x = F.relu(self.bn1(x),inplace=True)x = F.conv2d(x,self.conv_2_weight,self.bias_2_weight,stride=1,padding=1)x = F.relu(self.bn2(x),inplace=True)x = x.view(-1,32*32*32)x = F.linear(x,self.Linear_weight,self.bias_weight)return x实验结果如下所示,最终的模型准确率为97%:
Epoch: 29 | Train Loss: 0.1965 | Test Accuracy: 0.97 Epoch: 29 | Train Loss: 0.1152 | Test Accuracy: 0.97 Epoch: 29 | Train Loss: 0.0702 | Test Accuracy: 0.97 Epoch: 29 | Train Loss: 0.0971 | Test Accuracy: 0.97 Epoch: 29 | Train Loss: 0.1620 | Test Accuracy: 0.97全部代码如下所示:
import torch import torch.nn as nn import torch.nn.functional as F import torchvision from data import Getdata from torch import optimdata_train_loader,data_test_loader = Getdata() class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv_1_weight = nn.Parameter(torch.randn(16,1,3,3))self.bias_1_weight = nn.Parameter(torch.randn(16))self.bn1 = nn.BatchNorm2d(16)self.conv_2_weight = nn.Parameter(torch.randn(32,16,3,3))self.bias_2_weight = nn.Parameter(torch.randn(32))self.bn2 = nn.BatchNorm2d(32)self.Linear_weight = nn.Parameter(torch.randn(10,32*32*32))self.bias_weight = nn.Parameter(torch.randn(10))def forward(self,x):x = F.conv2d(x,self.conv_1_weight,self.bias_1_weight,stride=1,padding=1)x = F.relu(self.bn1(x),inplace=True)x = F.conv2d(x,self.conv_2_weight,self.bias_2_weight,stride=1,padding=1)x = F.relu(self.bn2(x),inplace=True)x = x.view(-1,32*32*32)x = F.linear(x,self.Linear_weight,self.bias_weight)return x model = CNN()optimizer = torch.optim.Adam(model.parameters(),lr=1e-3) loss_func = nn.CrossEntropyLoss() epoch = 30for i in range(epoch):for step,(train_x,train_y) in enumerate(data_train_loader):model.train()output = model(train_x)loss = loss_func(output,train_y)optimizer.zero_grad()loss.backward()optimizer.step()if step % 50 == 0:model.eval()with torch.no_grad():test_acc = 0num = 0for s,(test_x,test_y) in enumerate(data_test_loader):output = model(test_x)output = output.int()pred_y = torch.max((output),dim=1)[1]test_acc += test_y.eq_(pred_y).sum().item()num += test_y.size(0)print('Epoch: ',i,'| Train Loss: %.4f'% loss.item(),'| Test Accuracy: %.2f' % float(test_acc / num))努力加油a啊
创作挑战赛新人创作奖励来咯,坚持创作打卡瓜分现金大奖总结
以上是生活随笔为你收集整理的基于torch.nn.functional.conv2d实现CNN的全部内容,希望文章能够帮你解决所遇到的问题。
- 上一篇: 蓝牙耳机有爆炸风险吗(蓝牙无线技术)
- 下一篇: 统计学、数据分析、机器学习常用数据特征汇