欢迎访问 生活随笔!

生活随笔

当前位置: 首页 > 编程资源 > 编程问答 >内容正文

编程问答

PyTorch实战GANs

发布时间:2025/3/15 编程问答 71 豆豆
生活随笔 收集整理的这篇文章主要介绍了 PyTorch实战GANs 小编觉得挺不错的,现在分享给大家,帮大家做个参考.

GANs简介

GANs(Generative Adversarial Networks ),全名又叫做生成式对抗网络,设计者使用的是一种类似于“左右手互博”的思想,所以GANs的作者周伯通(英文名:lan Goodfellow)在设计的时候遵循的就是这个原则。“左右手”分别指代的是GANs中的生成器(Generator)和判别器(Discriminator)。

图片来源于网络

生成器的主要作用就是随机生成一个指定格式的图片,判别器的主要作用是能够对输入的图片真假进行判断,下图就是GANs最原始的网络架构。

图片来源于网络

所以在GANs中重点需要实现的就是生成器和判别器,下面我们通过两种不同的方式对GANs进行实现,方法一中的生成器和判别器由简单的神经网络构成,方法二中生成器和判别器由卷积神经网络构成。

简单神经网络

这里我们重点介绍生成器、判别器的实现以及如何定义模型的损失和优化,完整代码会在最后贴出来。首先是判别器,这里使用的网络架构比较简单,是输入层-隐藏层-输出层的三层结构。输入图像我们都知道MINST数据集的图片是28*28的,激活函数使用的LeakyReLU。

class Discriminator(torch.nn.Module):def __init__(self):super(Discriminator,self).__init__()self.discriminator = torch.nn.Sequential(torch.nn.Linear(28*28,128),torch.nn.LeakyReLU(),torch.nn.Linear(128,1))def forward(self, input):output = self.discriminator(input)return output

然后是生成器,生成器通过输入一个指定大小的随机数生成出28*28的图片,最后我们生成器生成的图片越接近真实图片说明生成器的效果越好。

class Generator(torch.nn.Module):def __init__(self):super(Generator,self).__init__() self.generator = torch.nn.Sequential(torch.nn.Linear(100,128),torch.nn.LeakyReLU(),torch.nn.Linear(128,28*28),torch.nn.Tanh())def forward(self,input):output = self.generator(input)return output

生成我们生成器需要用到的随机数我们使用一个函数来定义。

def rand_img(batchsize,output_size):Z = np.random.uniform(-1.,1., size=(batchsize, output_size))Z = np.float32(Z)Z = torch.from_numpy(Z) Z = Variable(Z.cuda())return Z

接下来是损失的定义,我们只要把握住两个原则,我们希望判别器对输入的真实图片全部判断为1,输入的虚假图片全部判断为0,同时对于生成器我们要求生产的图片输入到判别器后能够被判断为1。这就是GAN是的精髓,具体实现如下。

model_discriminator = Discriminator_conv().cuda() model_generator = Generator_conv().cuda()X_gen = model_generator(Z) X_gen = X_gen.view(-1,1,28,28) X_train = X_train.view(-1,1,28,28)logits_real = model_discriminator(X_train) logits_fake = model_discriminator(X_gen)d_loss = loss_f(logits_real, torch.ones_like(logits_real))+loss_f(logits_fake, torch.zeros_like(logits_fake))Z = rand_img(batchsize=batchsize, output_size=100) X_gen = model_generator(Z) X_gen = X_gen.view(-1,1,28,28) logits_fake = model_discriminator(X_gen) g_loss = loss_f(logits_fake,torch.ones_like(logits_fake))

我们通过训练减小d_loss来提升判别器的能力,同时又在训练减小g_loss来提升生产器的能力,这两个看似矛盾的方向却可以让整个模型取得非常好的效果。

卷积神经网络

使用卷积方式实现的GANs也被称作为DCGANs,卷积的实现最大的不同就是在模型的结构中加入了卷积的成分,当然最后效果相对前者会更加理想。

判别器,使用的是非常常用的卷积神经网络结构。

class Discriminator_conv(torch.nn.Module):def __init__(self):super(Discriminator_conv,self).__init__()self.conv = torch.nn.Sequential(torch.nn.Conv2d(1,32,kernel_size=5,stride=1),torch.nn.LeakyReLU(),torch.nn.MaxPool2d(kernel_size=2,stride=2),torch.nn.Conv2d(32,64,kernel_size=5,stride=1),torch.nn.LeakyReLU(),torch.nn.MaxPool2d(kernel_size=2,stride=2))self.dense = torch.nn.Sequential(torch.nn.Linear(64*4*4,64*4*4),torch.nn.LeakyReLU(),torch.nn.Linear(64*4*4,1))def forward(self, input):output = self.conv(input)output = output.view(-1,64*4*4)output = self.dense(output)return output

生成器,其中用到的一个逆向卷积的方法,公式如下:

class Generator_conv(torch.nn.Module):def __init__(self):super(Generator_conv,self).__init__()self.conv_dense = torch.nn.Sequential(torch.nn.Linear(100,1024),torch.nn.LeakyReLU(),torch.nn.BatchNorm1d(num_features=1024),torch.nn.Linear(1024,7*7*128),torch.nn.BatchNorm1d(num_features=7*7*128))self.transpose_conv = torch.nn.Sequential(torch.nn.ConvTranspose2d(128,64,kernel_size=4,stride=2,padding=1),torch.nn.ReLU(),torch.nn.BatchNorm2d(num_features=64),torch.nn.ConvTranspose2d(64,1,kernel_size=4,stride=2,padding=1),torch.nn.Tanh())def forward(self, input):output = self.conv_dense(input)output = output.view(-1,128,7,7)output = self.transpose_conv(output)return output

最后我把模型训练1个epoch、10个epoch和20个epoch后得到的结果贴出来,可以看出我们的生成器已经可以生成同MINIST数据类似的图片了。

1个epoch

10个epoch

20个epoch

总结

最后说几点小的诀窍。

1、我们可以将原来的d_loss改成如下形式。

d_loss = loss_f(logits_real, torch.ones_like(logits_real)*(1-smooth))+loss_f(logits_fake, torch.zeros_like(logits_fake))

通过乘上一个(1-smooth)的参数(其中smooth可以设为0.1-0.9)来防止判别器模型的过拟合。

2、通过改变降低优化函数的初始学习速率来降低生成器的g_loss。

optimizer_dis = torch.optim.Adam(model_discriminator.parameters(),lr=0.0001) optimizer_gen = torch.optim.Adam(model_generator.parameters(),lr=0.0001)

3、构建更加深度的网络结构能够取得更好的结果,当然也会开销更多的训练时间。

资源

非常全的GANs衍生模型

完整代码

import torch import torchvision from torch.autograd import Variable from torchvision import datasets,models,transforms import matplotlib.pyplot as plt import numpy as np%matplotlib inline %config InlineBackend.figure_format="retina"epoch_n =20 batchsize = 128 smooth = 0.1train_transform=transforms.ToTensor()train_data = datasets.MNIST(root="data",download=True,train=True,transform=train_transform) train_load = torch.utils.data.DataLoader(dataset=train_data,shuffle=True,batch_size=batchsize)def plot_img(img):img = torchvision.utils.make_grid(img)img = img.numpy().transpose(1,2,0)plt.figure(figsize=(12,9))plt.imshow(img)class Discriminator_conv(torch.nn.Module):def __init__(self):super(Discriminator_conv,self).__init__()self.conv = torch.nn.Sequential(torch.nn.Conv2d(1,32,kernel_size=5,stride=1),torch.nn.LeakyReLU(),torch.nn.MaxPool2d(kernel_size=2,stride=2),torch.nn.Conv2d(32,64,kernel_size=5,stride=1),torch.nn.LeakyReLU(),torch.nn.MaxPool2d(kernel_size=2,stride=2))self.dense = torch.nn.Sequential(torch.nn.Linear(64*4*4,64*4*4),torch.nn.LeakyReLU(),torch.nn.Linear(64*4*4,1))def forward(self, input):output = self.conv(input)output = output.view(-1,64*4*4)output = self.dense(output)return outputclass Generator_conv(torch.nn.Module):def __init__(self):super(Generator_conv,self).__init__()self.conv_dense = torch.nn.Sequential(torch.nn.Linear(100,1024),torch.nn.LeakyReLU(),torch.nn.BatchNorm1d(num_features=1024),torch.nn.Linear(1024,7*7*128),torch.nn.BatchNorm1d(num_features=7*7*128))self.transpose_conv = torch.nn.Sequential(torch.nn.ConvTranspose2d(128,64,kernel_size=4,stride=2,padding=1),torch.nn.ReLU(),torch.nn.BatchNorm2d(num_features=64),torch.nn.ConvTranspose2d(64,1,kernel_size=4,stride=2,padding=1),torch.nn.Tanh())def forward(self, input):output = self.conv_dense(input)output = output.view(-1,128,7,7)output = self.transpose_conv(output)return outputdef initialize_weights(m):if isinstance(m,torch.nn.Linear) or isinstance(m,torch.nn.Conv2d):torch.nn.init.xavier_uniform_(m.weight.data)model_discriminator = Discriminator_conv().cuda() model_discriminator.apply(initialize_weights) model_generator = Generator_conv().cuda() model_generator.apply(initialize_weights)loss_f = torch.nn.BCEWithLogitsLoss()optimizer_dis = torch.optim.Adam(model_discriminator.parameters(),lr=0.0001) optimizer_gen = torch.optim.Adam(model_generator.parameters(),lr=0.0001)samples = [] losses = []def rand_img(batchsize,output_size):Z = np.random.uniform(-1.,1., size=(batchsize, output_size))Z = np.float32(Z)Z = torch.from_numpy(Z) Z = Variable(Z.cuda())return Zfor epoch in range(epoch_n):for batch in train_load:X_train,y_train = batchX_train,y_train = Variable(X_train.cuda()),Variable(y_train.cuda())#X_train,y_train = Variable(X_train),Variable(y_train)Z = rand_img(batchsize=batchsize, output_size=100)optimizer_dis.zero_grad() X_gen = model_generator(Z)X_gen = X_gen.view(-1,1,28,28)X_train = X_train.view(-1,1,28,28)logits_real = model_discriminator(X_train)logits_fake = model_discriminator(X_gen)d_loss = loss_f(logits_real, torch.ones_like(logits_real)*(1-smooth))+loss_f(logits_fake, torch.zeros_like(logits_fake))d_loss.backward(retain_graph=True)optimizer_dis.step()optimizer_gen.zero_grad() Z = rand_img(batchsize=batchsize, output_size=100)X_gen = model_generator(Z)X_gen = X_gen.view(-1,1,28,28)logits_fake = model_discriminator(X_gen)g_loss = loss_f(logits_fake,torch.ones_like(logits_fake)) g_loss.backward() optimizer_gen.step()print("Epoch{}/{}...".format(epoch+1, epoch_n),"Discriminator Loss:{:.4f}...".format(d_loss),"Generator Loss:{:.4f}...".format(g_loss))losses.append((d_loss, g_loss))fake_img = model_generator(Z)samples.append(fake_img)fig, ax = plt.subplots() losses = np.array(losses) plt.plot(losses.T[0], label='Discriminator') plt.plot(losses.T[1], label='Generator') plt.title("Training Losses") plt.legend()def to_img(img):img = img.detach().cpu().dataimg = img.clamp(0,1)img = img.view(-1,1,28,28)return imgfor i in range(len(samples)):img = to_img(samples[i])plot_img(img)

https://zhuanlan.zhihu.com/p/40393929

总结

以上是生活随笔为你收集整理的PyTorch实战GANs的全部内容,希望文章能够帮你解决所遇到的问题。

如果觉得生活随笔网站内容还不错,欢迎将生活随笔推荐给好友。