(pytorch-深度学习)实现稠密连接网络(DenseNet)
生活随笔
收集整理的这篇文章主要介绍了
(pytorch-深度学习)实现稠密连接网络(DenseNet)
小编觉得挺不错的,现在分享给大家,帮大家做个参考.
稠密连接网络(DenseNet)
ResNet中的跨层连接设计引申出了数个后续工作。稠密连接网络(DenseNet)与ResNet的主要区别在于在跨层连接上的主要区别:
- ResNet使用相加
- DenseNet使用连结
ResNet(左)与DenseNet(右):
图中将部分前后相邻的运算抽象为模块AAA和模块BBB。
- DenseNet里模块BBB的输出不是像ResNet那样和模块AAA的输出相加,而是在通道维上连结。
- 这样模块AAA的输出可以直接传入模块BBB后面的层。在这个设计里,模块AAA相当于直接跟模块BBB后面的所有层直接连接在了一起。这也是它被称为“稠密连接”的原因。
DenseNet的主要构建模块是稠密块(dense block)和过渡层(transition layer)。
- 稠密块定义了输入和输出是如何连结的
- 过渡层用来控制通道数,控制其大小
稠密块
DenseNet使用了ResNet改良版的“批量归一化、激活和卷积”结构:
import time import torch from torch import nn, optim import torch.nn.functional as Fdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')def conv_block(in_channels, out_channels):blk = nn.Sequential(nn.BatchNorm2d(in_channels), nn.ReLU(),nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))return blk- 稠密块由多个conv_block组成,每块使用相同的输出通道数。
- 在前向计算时,我们将每块的输入和输出在通道维上连结。
定义一个有2个输出通道数为10的卷积块。
- 使用通道数为3的输入时,我们会得到通道数为3+2×10=233+2\times 10=233+2×10=23的输出。
- 卷积块的通道数控制了输出通道数相对于输入通道数的增长,因此也被称为增长率(growth rate)。
过渡层
- 每个稠密块都会带来通道数的增加,使用过多则会带来过于复杂的模型。
- 过渡层用来控制模型复杂度。它通过1×11\times11×1卷积层来减小通道数,并使用步幅为2的平均池化层减半高和宽,从而进一步降低模型复杂度。
对上例中稠密块的输出,使用通道数为10的过渡层。此时输出的通道数减为10,高和宽均减半。
blk = transition_block(23, 10) blk(Y).shape # torch.Size([4, 10, 4, 4])DenseNet模型
DenseNet首先使用和ResNet一样的单卷积层和最大池化层。
net = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),nn.BatchNorm2d(64), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))- 接着使用4个稠密块。
- 同ResNet一样,我们可以设置每个稠密块使用多少个卷积层(这里设成4)。
- 稠密块里的卷积层通道数(即增长率)设为32,所以每个稠密块将增加128个通道。
ResNet里通过步幅为2的残差块在每个模块之间减小高和宽。DenseNet则使用过渡层来减半高和宽,并减半通道数。
num_channels, growth_rate = 64, 32 # num_channels为当前的通道数 num_convs_in_dense_blocks = [4, 4, 4, 4]for i, num_convs in enumerate(num_convs_in_dense_blocks):DB = DenseBlock(num_convs, num_channels, growth_rate)net.add_module("DenseBlosk_%d" % i, DB)# 上一个稠密块的输出通道数num_channels = DB.out_channels# 在稠密块之间加入通道数减半的过渡层if i != len(num_convs_in_dense_blocks) - 1:net.add_module("transition_block_%d" % i, transition_block(num_channels, num_channels // 2))num_channels = num_channels // 2- 最后接上全局池化层和全连接层来输出。
- 打印每个子模块的输出维度
- 获取数据
训练模型
def train(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs):net = net.to(device)print("training on ", device)loss = torch.nn.CrossEntropyLoss()for epoch in range(num_epochs):train_l_sum, train_acc_sum, n, batch_count, start = 0.0, 0.0, 0, 0, time.time()for X, y in train_iter:X = X.to(device)y = y.to(device)y_hat = net(X)l = loss(y_hat, y)optimizer.zero_grad()l.backward()optimizer.step()train_l_sum += l.cpu().item()train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()n += y.shape[0]batch_count += 1test_acc = evaluate_accuracy(test_iter, net)print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'% (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start)) lr, num_epochs = 0.001, 5 optimizer = torch.optim.Adam(net.parameters(), lr=lr) train(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)《动手学深度学习》
总结
以上是生活随笔为你收集整理的(pytorch-深度学习)实现稠密连接网络(DenseNet)的全部内容,希望文章能够帮你解决所遇到的问题。
- 上一篇: 这款堪称完美的PDF编辑器,帮你节省50
- 下一篇: (pytorch-深度学习系列)使用so