通过python来生成手写数字(代码详解)(python手写数字识别代码简单)

前言

本篇文章主要介绍如何来设置一个GAN网络利用MNIST手写数字图片进行训练来生成手写数字图片,代码主要参考github的实现,在原来的基础上做了一些修改和新增了一些功能。

github参考地址:https://github.com/znxlwm/tensorflow-MNIST-GAN-DCGAN

GAN简介

生成对抗式网络(GAN,Generative Adversarial NetWorks):是深度学习中的一种模型,属于无监督学习算法。模型主要包括两个模块,生成模型(Generative Model)判别模型(Adversarial Model),通过两个模型的互相博弈使得生成模型产生接近于真实样本。举个简单一点的例子,方便理解,

例如:制作假钞团伙和真钞的鉴别专家,对于假钞制作团伙来说他们为了能让假钞顺利使用,那么他们就需要让他们制作的假钞像真钞一样,那样他们制作的假钞就不会被鉴别专家发现。所以假钞的制作团伙就需要不断的模仿真钞,不断的使得假钞越来越想真钞,从而骗过钞票的鉴别专家。在这个例子中假钞制作团伙就相对于GAN中的生成模型,钞票鉴别专家就是判别模型。生成模型就是通过不断的和判别模型进行博弈不断学习,最终使得生成模型达到以假乱真的目的。

在GAN中有一个非常重要的公式,如果理解了,就代表你弄懂了GAN,公式如下:

通过python来生成手写数字(代码详解)(python手写数字识别代码简单)

GAN论文原文:http://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf

公式解读:

对于G(生成模型)来说要想模型生成的结果最好,要使得上式取得最小值。对于生成模型我们只需要看后面一部分的式子log(1−D(G(z))) ,z表示的是输入的噪声,生成模型就是通过这个输入噪声来产生一个输出D(判别模型)对于真实样本输出为1,对于虚假样本输出0,D(x)表示判别模型的输出,x表示输入数据(真实样本或虚假样本)。要想使得生成模型G生成的接近真实样本,也就是要让判别模型D的输出为1,也就是要让D(G(z))尽量接近于1,那么1−D(G(z)) 就会接近于0。当x接近于0时,log(x) 会趋于负无穷,所以当上式趋于无穷小时,D(G(z)) 会趋于1,此时的生成模型生成的数据接近于真实数据,生成模型达到最优。

对于D(判别模型)来说要想模型的性能最好,就要使得上式取得最大值。上式中的x~pdata表示真实样本,z~pz表示虚假样本。所以,要想使得D的性能最优,对于真实样本的输入D(x) 应该输出1,对于虚假样本的输入D(G(z))应该输出0,所以此时V(D,G)取得最大值。

GAN生成手写数字

主要利用MNIST数据集进行训练来生成手写数字

软件环境

  • 系统:win10
  • python:3.6.4
  • tensorflow-gpu:1.8.0
  • matplotlib:2.1.2
  • numpy:1.14.0
  • github地址:https://github.com/steelOneself/GAN_tensorflow/tree/master/GANNet

GAN架构

通过python来生成手写数字(代码详解)(python手写数字识别代码简单)

上图主要展示了G(生成模型)和D(判别模型)的网络架构,可以发现G和D的结构刚好是一个逆的过程,GAN主要利用多层感知器来实现的,所以最终生成的手写数字图片会比卷积结构的生成模型效果要差一些,后面会介绍使用DCGAN来实现手写数字图片的生成。

代码介绍

  • GAN结构设计

将G和D网络每层隐藏节点数定义在了字典中

通过python来生成手写数字(代码详解)(python手写数字识别代码简单)

  • 损失函数的定义

为了避免出现log(0),所以我们需要在后面加一个极小的数。为了便于训练生成网络和判别网络,我们需要通过求loss最小化问题来求解网络的参数。式中的D_real表示判别网络对于真实样本的输出,D_fake表示对虚假样本的输出。当D最优时,对所有的D_real都输出1,对所有的D_fake都输出0,所以此时的D_loss最小,也就是判别网络D的损失达到最小值。当G最优时,D_fake的输出应该为1(因为此时G生成的样本已经接近于真实样本,此时D已经无法判断,认为G生成的样本就是真实样本),此时G_loss最小,也就是生成网络G的损失达到最小值。

通过python来生成手写数字(代码详解)(python手写数字识别代码简单)

  • 模型的训练

训练的时候我们先训练D,然后再训练G,训练D和G的顺序关系不大,需要保持在训练时,两个网络的参数都能够更新,从而达到相互博弈的效果。

通过python来生成手写数字(代码详解)(python手写数字识别代码简单)

  • 模型生成样本

通过加载训练好的模型,直接生成手写数字的图片。github项目中包含已经训练好的ckpt文件,加载模型可以直接生成。

通过python来生成手写数字(代码详解)(python手写数字识别代码简单)

结果展示

  • 训练结果

使用的是GTX1060训练的每个batch大约需要花5.7s左右。

通过python来生成手写数字(代码详解)(python手写数字识别代码简单)

  • loss变化

随着epoch的增大,G_loss越来越小。

通过python来生成手写数字(代码详解)(python手写数字识别代码简单)

  • 随epoch变化生成的手写数字

通过python来生成手写数字(代码详解)(python手写数字识别代码简单)

  • 最终生成的手写数字图片

通过python来生成手写数字(代码详解)(python手写数字识别代码简单)

可以发现迭代100个epoch之后,GAN生成的手写数字图片效果并不是特别好,下一篇文章介绍使用DCGAN来生成手写数字图片,效果会比这个结果好很多。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发表回复

您的电子邮箱地址不会被公开。