DCGAN(深度卷积对抗网络)案例

反射流沙
• 阅读 13512

介绍

DCGAN(深度卷积对抗网络)案例

如图所示,GAN网络会同时训练两个模型。生成器:负责生成数据(比如:照片);判别器:判别所生成照片的真假。训练过程中,生成器生成的照片会越来越接近真实照片,直到判别器无法区分照片真假。

DCGAN(深度卷积对抗网络)案例

DCGAN(深度卷积对抗生成网络)是GAN的变体,是一种将卷积引入模型的网络。特点是:

  • 判别器使用strided convolutions来替代空间池化,生成器使用反卷积
  • 使用BN稳定学习,有助于处理初始化不良导致的训练问题
  • 生成器输出层使用Tanh激活函数,其它层使用Relu激活函数。判别器上使用Leaky Relu激活函数。

本次案例我们将使用mnist作为数据集训练DCGAN网络,程序最后将使用GIF的方式展示训练效果。

数据导入

import tensorflow as tf
import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow.contrib as tcon
import PIL
import time
from IPython import display

# shape:(60000,28,28)
(train_images,train_labels),(_,_)=tf.keras.datasets.mnist.load_data()
# shape:[batch_size,height,width,channel]
train_images_reshape=tf.reshape(train_images,shape=(train_images.shape[0],28,28,1)).astype(tf.float32)
# 缩放图片[-1,1]
train_images_nor=(train_images-127.5)/127.5

dataset加载数据

BUFFER_SIZE=60000
BATCH_SIZE=256

# 优化输入管道需要从:读取,转换,加载三方面考虑。
train_dataset=tf.data.Dataset.from_tensor_slices(train_images).shuffle(buffer_size=BUFFER_SIZE).batch(BATCH_SIZE)

生成模型

该生成模型将使用反卷积层,我们首先创建全连接层然后通过两次上采样将图片分辨率扩充至28x28x1。我们将逐步提升分辨率降低depth,除最后一层使用tanh激活函数,其它层都使用Leaky Relu激活函数。

def make_generator_model():
    # 反卷积,从后往前
    model=tf.keras.Sequential()
    model.add(
        tf.keras.layers.Dense(
            input_dim=7*7*256,
            
            # 不使用bias的原因是我们使用了BN,BN会抵消掉bias的作用。
            # bias的作用:
            # 提升网络拟合能力,而且计算简单(只要一次加法)。
            # 能力的提升源于调整输出的整体分布
            use_bias=False,
            # noise dim
            input_shape=(100,)
        )
    )
    """
    随着神经网络的训练,网络层的输入分布会发生变动,逐渐向激活函数取值两端靠拢,如:sigmoid激活函数,
    此时会进入饱和状态,梯度更新缓慢,对输入变动不敏感,甚至梯度消失导致模型难以训练。
    BN,在网络层输入激活函数输入值之前加入,可以将分布拉到均值为0,标准差为1的正态分布,从而
    使激活函数处于对输入值敏感的区域,从而加快模型训练。此外,BN还能起到类似dropout的正则化作用,由于我们会有
    ‘强拉’操作,所以对初始化要求没有那么高,可以使用较大的学习率。
    """
    model.add(tf.keras.layers.BatchNormalization())
    """
    relu 激活函数在输入为负值的时候,激活值为0,此时神经元无法学习
    leakyrelu 激活函数在输入为负值的时候,激活值不为0(但值很小),神经元可以继续学习
    """
    model.add(tf.keras.layers.LeakyReLU())

    model.add(tf.keras.layers.Reshape(input_shape=(7,7,256)))
    assert model.output_shape == (None,7,7,256)

    model.add(tf.keras.layers.Conv2DTranspose(
        filters=128,
        kernel_size=5,
        strides=1,
        padding='same',
        use_bias='False'
    ))
    assert model.output_shape == (None,7,7,128)
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.LeakyReLU())

    # 卷积核为奇数:图像两边可以对称padding 00xxxx00
    model.add(tf.keras.layers.Conv2DTranspose(
        filters=64,
        kernel_size=5,
        strides=2,
        padding='same',
        use_bias='False'
    ))
    assert model.output_shape == (None,14,14,64)
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.LeakyReLU())

    model.add(tf.keras.layers.Conv2DTranspose(
        filters=1,
        kernel_size=5,
        strides=2,
        padding='same',
        use_bias='False',
        
        # tanh激活函数值区间[-1,1],均值为0关于原点中心对称。、
        # sigmoid激活函数梯度在反向传播过程中会出全正数或全负数,导致权重更新出现Z型下降。
        activation='tanh'
    ))
    assert model.output_shape == (None,28,28,1)

    return model

判别模型

判别器使用strided convolutions来替代空间池化,比如这里strided=2。卷积层使用LeakyReLU替代Relu,并使用Dropout为全连接层提供加噪声的输入。

def make_discriminator_model():
    # 常规卷积操作
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same'))
    model.add(tf.keras.layers.LeakyReLU())
    
    # dropout常见于全连接层,其实卷积层也是可以使用的。
    # 这里简单翻译下dropout论文观点:
    """
    可能很多人认为因为卷积层参数较少,过拟合发生概率较低,所以dropout作用并不大。
    但是,dropout在前面几层依然有帮助,因为它为后面的全连接层提供了加噪声的输入,从而防止过拟合。
    """
    model.add(tf.keras.layers.Dropout(0.3))
      
    model.add(tf.keras.layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(tf.keras.layers.LeakyReLU())
    model.add(tf.keras.layers.Dropout(0.3))
       
    model.add(tf.keras.layers.Flatten())
    model.add(tf.keras.layers.Dense(1))
     
    return model

损失函数

获取模型:

generator = make_generator_model()
discriminator = make_discriminator_model()

生成器损失函数:

损失函数使用sigmoid cross entropylabels使用值全为1的数组。

def generator_loss(generator_output):
    return tf.losses.sigmoid_cross_entropy(
        multi_class_labels=tf.ones_like(generator_output),
        logits=generator_output
    )

判别器损失函数

判别器损失函数接受两种输入,生成器生成的图像和数据集中的真实图像,损失函数计算方法如下:

  • 使用sigmoid cross entropy损失函数计算数据集中真实图像的损失,labels使用值全为1的数组。
  • 使用sigmoid cross entropy损失函数计算生成器图像的损失,labels使用值全为0的数组。
  • 将以上损失相加得到判别器损失。
def discriminator_loss(real_output, generated_output):
    # real:[1,1,...,1] 
    real_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.ones_like(real_output), logits=real_output)
    #:generated:[0,0,...,0] 
    generated_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.zeros_like(generated_output), logits=generated_output)
    
    # 总损失为两者相加
    total_loss = real_loss + generated_loss
    return total_loss

模型保存:

# 两种模型同时训练,自然需要使用两种优化器,学习率为:0.0001
generator_optimizer = tf.train.AdamOptimizer(1e-4)
discriminator_optimizer = tf.train.AdamOptimizer(1e-4)
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")

# checkpoint配置
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

模型训练

训练参数配置:

# 数据集迭代次数
EPOCHS = 50
# 生成器噪声维度
noise_dim = 100

# 可视化效果数量设置
num_examples_to_generate = 16
random_vector_for_generation = tf.random_normal([num_examples_to_generate,
                                                 noise_dim])

生成器将我们设定的正态分布的噪声向量作为输入,用来生成图像。判别器将同时显示数据集真实图像和生成器生成的图像用于判别。随后,我们计算生成器和判断器损失函数对参数的梯度,然后使用梯度下降进行更新。

def train_step(images):
      # 正态分布噪声作为生成器输入
      noise = tf.random_normal([BATCH_SIZE, noise_dim])
      
      # tf.GradientTape进行记录
      with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)
        
        # 判别器中真实图像和生成器的假图像
        real_output = discriminator(images, training=True)
        generated_output = discriminator(generated_images, training=True)
        
        gen_loss = generator_loss(generated_output)
        disc_loss = discriminator_loss(real_output, generated_output)
        
      gradients_of_generator = gen_tape.gradient(gen_loss, generator.variables)
      gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.variables)
      
      generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.variables))
      discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.variables))

开始训练:

加速计算节约内存,但是不可以使用'pdb','print'。

train_step = tf.contrib.eager.defun(train_step)
def train(dataset, epochs):  
  for epoch in range(epochs):
    start = time.time()
    
    # 迭代数据集
    for images in dataset:
      train_step(images)

    display.clear_output(wait=True)
    
    # 保存图像用于后面的可视化
    generate_and_save_images(generator,
                               epoch + 1,
                               random_vector_for_generation)
    
    # 每迭代15次数据集保存一次模型
    # 如需部署至tensorflow serving需要使用savemodel
    if (epoch + 1) % 15 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)
    
    print ('Time taken for epoch {} is {} sec'.format(epoch + 1,
                                                      time.time()-start))
  display.clear_output(wait=True)
  generate_and_save_images(generator,
                           epochs,
                           random_vector_for_generation)

可视化生成器图像:

def generate_and_save_images(model, epoch, test_input):
  # training:False 不训练BN
  predictions = model(test_input, training=False)

  fig = plt.figure(figsize=(4,4))
  
  for i in range(predictions.shape[0]):
      plt.subplot(4, 4, i+1)
      plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
      plt.axis('off')
        
  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
  plt.show()

train(train_dataset, EPOCHS)

可视化模型训练结果

展示照片:

def display_image(epoch_no):
  return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))

动画展示训练结果:

with imageio.get_writer('dcgan.gif', mode='I') as writer:
  filenames = glob.glob('image*.png')
  filenames = sorted(filenames)
  last = -1
  for i,filename in enumerate(filenames):
    frame = 2*(i**0.5)
    if round(frame) > round(last):
      last = frame
    else:
      continue
    image = imageio.imread(filename)
    writer.append_data(image)
  image = imageio.imread(filename)
  writer.append_data(image)
    
os.system('cp dcgan.gif dcgan.gif.png')
display.Image(filename="dcgan.gif.png")

总结

DCGAN中生成器判别器都使用卷积网络来提升生成和判别能力,其中生成器利用反卷积,判别器利用常规卷积。生成器用随机噪声向量作为输入来生成假图像,判别器通过对真实样本的学习判断生成器图像真伪,如果判断为假,生成器重新调校训练,直到判别器无法区分真实样本图像和生成器的图像。

本文代码部分参考Yash Katariya,在此表示感谢。

点赞
收藏
评论区
推荐文章
blmius blmius
3年前
MySQL:[Err] 1292 - Incorrect datetime value: ‘0000-00-00 00:00:00‘ for column ‘CREATE_TIME‘ at row 1
文章目录问题用navicat导入数据时,报错:原因这是因为当前的MySQL不支持datetime为0的情况。解决修改sql\mode:sql\mode:SQLMode定义了MySQL应支持的SQL语法、数据校验等,这样可以更容易地在不同的环境中使用MySQL。全局s
皕杰报表之UUID
​在我们用皕杰报表工具设计填报报表时,如何在新增行里自动增加id呢?能新增整数排序id吗?目前可以在新增行里自动增加id,但只能用uuid函数增加UUID编码,不能新增整数排序id。uuid函数说明:获取一个UUID,可以在填报表中用来创建数据ID语法:uuid()或uuid(sep)参数说明:sep布尔值,生成的uuid中是否包含分隔符'',缺省为
Easter79 Easter79
3年前
swap空间的增减方法
(1)增大swap空间去激活swap交换区:swapoff v /dev/vg00/lvswap扩展交换lv:lvextend L 10G /dev/vg00/lvswap重新生成swap交换区:mkswap /dev/vg00/lvswap激活新生成的交换区:swapon v /dev/vg00/lvswap
Jacquelyn38 Jacquelyn38
4年前
2020年前端实用代码段,为你的工作保驾护航
有空的时候,自己总结了几个代码段,在开发中也经常使用,谢谢。1、使用解构获取json数据let jsonData  id: 1,status: "OK",data: 'a', 'b';let  id, status, data: number   jsonData;console.log(id, status, number )
Wesley13 Wesley13
3年前
mysql设置时区
mysql设置时区mysql\_query("SETtime\_zone'8:00'")ordie('时区设置失败,请联系管理员!');中国在东8区所以加8方法二:selectcount(user\_id)asdevice,CONVERT\_TZ(FROM\_UNIXTIME(reg\_time),'08:00','0
Stella981 Stella981
3年前
Google研究人员推出了一种用于生成文本到图像的新框架(TReCS)
!(https://oscimg.oschina.net/oscnet/faedcb264a1c43969f2f5a2e6b9dbd2e.png)基于生成对抗网络(GAN)的深度神经网络促进了端到端可训练的照片级逼真的文本到图像的生成。许多方法还使用中间场景图表示法来改善图像合成。使用基于对话的交互的方法允许用户提供指令,以逐步改进和调整生成
Stella981 Stella981
3年前
Django中Admin中的一些参数配置
设置在列表中显示的字段,id为django模型默认的主键list_display('id','name','sex','profession','email','qq','phone','status','create_time')设置在列表可编辑字段list_editable
Wesley13 Wesley13
3年前
MySQL部分从库上面因为大量的临时表tmp_table造成慢查询
背景描述Time:20190124T00:08:14.70572408:00User@Host:@Id:Schema:sentrymetaLast_errno:0Killed:0Query_time:0.315758Lock_
Python进阶者 Python进阶者
1年前
Excel中这日期老是出来00:00:00,怎么用Pandas把这个去除
大家好,我是皮皮。一、前言前几天在Python白银交流群【上海新年人】问了一个Pandas数据筛选的问题。问题如下:这日期老是出来00:00:00,怎么把这个去除。二、实现过程后来【论草莓如何成为冻干莓】给了一个思路和代码如下:pd.toexcel之前把这
生成对抗网络GAN简介
生成对抗网络(GenerativeAdversarialNetworks,GAN)是一种深度敏感词模型,用于生成具有高度逼真度的新数据,如图像、音频、文本等。GAN是由IanGoodfellow等人在2014年提出的,其核心思想是通过两个神经网络,即生成器和判别器,相互竞争和协作来实现数据生成的目的。GAN的基本框架和训练过程如下图所示:
美凌格栋栋酱 美凌格栋栋酱
4个月前
Oracle 分组与拼接字符串同时使用
SELECTT.,ROWNUMIDFROM(SELECTT.EMPLID,T.NAME,T.BU,T.REALDEPART,T.FORMATDATE,SUM(T.S0)S0,MAX(UPDATETIME)CREATETIME,LISTAGG(TOCHAR(