VGG16学习笔记

Wesley13
• 阅读 523

转载自:http://deanhan.com/2018/07/26/vgg16/

摘要

本文对图片分类任务中经典的深度学习模型VGG16进行了简要介绍,分析了其结构,并讨论了其优缺点。调用Keras中已有的VGG16模型测试其分类性能,结果表明VGG16对三幅测试图片均能正确分类。

前言

VGG是由Simonyan 和Zisserman在文献《Very Deep Convolutional Networks for Large Scale Image Recognition》中提出卷积神经网络模型,其名称来源于作者所在的牛津大学视觉几何组(Visual Geometry Group)的缩写。

该模型参加2014年的 ImageNet图像分类与定位挑战赛,取得了优异成绩:在分类任务上排名第二,在定位任务上排名第一。

结构

VGG中根据卷积核大小卷积层数目的不同,可分为AA-LRN,B,C,D,E共6个配置(ConvNet Configuration),其中以D,E两种配置较为常用,分别称为VGG16VGG19

下图给出了VGG的六种结构配置:

VGG16学习笔记

上图中,每一列对应一种结构配置。例如,图中绿色部分即指明了VGG16所采用的结构。

我们针对VGG16进行具体分析发现,VGG16共包含:

  • 13个卷积层(Convolutional Layer),分别用conv3-XXX表示
  • 3个全连接层(Fully connected Layer),分别用FC-XXXX表示
  • 5个池化层(Pool layer),分别用maxpool表示

其中,卷积层和全连接层具有权重系数,因此也被称为权重层,总数目为13+3=16,这即是

VGG16中16的来源。(池化层不涉及权重,因此不属于权重层,不被计数)。

特点

VGG16的突出特点是简单,体现在:

  1. 卷积层均采用相同的卷积核参数

    卷积层均表示为conv3-XXX,其中conv3说明该卷积层采用的卷积核的尺寸(kernel size)是3,即宽(width)和高(height)均为3,3*3很小的卷积核尺寸,结合其它参数(步幅stride=1,填充方式padding=same),这样就能够使得每一个卷积层(张量)与前一层(张量)保持相同的宽和高。XXX代表卷积层的通道数。

  2. 池化层均采用相同的池化核参数

    池化层的参数均为2××2,步幅stride=2,max的池化方式,这样就能够使得每一个池化层(张量)的宽和高是前一层(张量)的1212。

  3. 模型是由若干卷积层和池化层堆叠(stack)的方式构成,比较容易形成较深的网络结构(在2014年,16层已经被认为很深了)。

综合上述分析,可以概括VGG的优点为: Small filters, Deeper networks.

VGG16学习笔记

块结构

我们注意图1右侧,VGG16的卷积层和池化层可以划分为不同的块(Block),从前到后依次编号为Block1~block5。每一个块内包含若干卷积层一个池化层。例如:Block4包含:

  • 3个卷积层,conv3-512
  • 1个池化层,maxpool

并且同一块内,卷积层的通道(channel)数是相同的,例如:

  • block2中包含2个卷积层,每个卷积层用conv3-128表示,即卷积核为:3x3x3,通道数都是128

  • block3中包含3个卷积层,每个卷积层用conv3-256表示,即卷积核为:3x3x3,通道数都是256

下面给出按照块划分的VGG16的结构图,可以结合图2进行理解:

VGG16学习笔记

VGG的输入图像是 224x224x3 的图像张量(tensor),随着层数的增加,后一个块内的张量相比于前一个块内的张量:

  • 通道数翻倍,由64依次增加到128,再到256,直至512保持不变,不再翻倍
  • 高和宽变减半,由 $224 \rightarrow 112\rightarrow 56\rightarrow 28\rightarrow 14\rightarrow 7$

权重参数

尽管VGG的结构简单,但是所包含的权重数目却很大,达到了惊人的139,357,544个参数。这些参数包括卷积核权重全连接层权重

  • 例如,对于第一层卷积,由于输入图的通道数是3,网络必须学习大小为3x3,通道数为3的的卷积核,这样的卷积核有64个,因此总共有(3x3x3)x64 = 1728个参数

  • 计算全连接层的权重参数数目的方法为:前一层节点数×本层的节点数前一层节点数×本层的节点数。因此,全连接层的参数分别为:

    • 7x7x512x4096 = 1027,645,444
    • 4096x4096 = 16,781,321
    • 4096x1000 = 4096000

FeiFei Li在CS231的课件中给出了整个网络的全部参数的计算过程(不考虑偏置),如下图所示:

VGG16学习笔记

图中蓝色是计算权重参数数量的部分;红色是计算所需存储容量的部分。

VGG16具有如此之大的参数数目,可以预期它具有很高的拟合能力;但同时缺点也很明显:

  • 即训练时间过长,调参难度大。
  • 需要的存储容量大,不利于部署。例如存储VGG16权重值文件的大小为500多MB,不利于安装到嵌入式系统中。

实践

下面,我们应用Keras对VGG16的图像分类能力进行试验。

Keras是一个高层神经网络API,_Keras_由纯Python编写 ,是tensorflow和Theano等底层深度学习库的高级封装 。使用Keras时,我们不需要直接调用底层API构建深度学习网络,仅调用keras已经封装好的函数即可。

本次试验平台:python 3.6 + tensorflow 1.8 + keras 2.2,Google Colab

源代码如下:

 1 # -*- coding: utf-8 -*-
 2 """
 3 Spyder Editor
 4 
 5 This is a temporary script file.
 6 """
 7 import matplotlib.pyplot as plt
 8 
 9 from keras.applications.vgg16 import VGG16
10 from keras.preprocessing import image
11 from keras.applications.vgg16 import preprocess_input, decode_predictions
12 import numpy as np
13 
14 def percent(value):
15     return '%.2f%%' % (value * 100)
16 
17 # include_top=True,表示會載入完整的 VGG16 模型,包括加在最後3層的卷積層
18 # include_top=False,表示會載入 VGG16 的模型,不包括加在最後3層的卷積層,通常是取得 Features
19 # 若下載失敗,請先刪除 c:\<使用者>\.keras\models\vgg16_weights_tf_dim_ordering_tf_kernels.h5
20 model = VGG16(weights='imagenet', include_top=True)
21 
22 
23 # Input:要辨識的影像
24 img_path = 'frog.jpg'
25 
26 #img_path = 'tiger.jpg' 并转化为224*224的标准尺寸
27 img = image.load_img(img_path, target_size=(224, 224))
28 
29 
30 x = image.img_to_array(img) #转化为浮点型
31 x = np.expand_dims(x, axis=0)#转化为张量size为(1, 224, 224, 3)
32 x = preprocess_input(x)
33 
34 # 預測,取得features,維度為 (1,1000)
35 features = model.predict(x)
36 
37 # 取得前五個最可能的類別及機率
38 pred=decode_predictions(features, top=5)[0]
39 
40 
41 #整理预测结果,value
42 values = []
43 bar_label = []
44 for element in pred:
45     values.append(element[2])
46     bar_label.append(element[1])
47 
48 #绘图并保存
49 fig=plt.figure(u"Top-5 预测结果")
50 ax = fig.add_subplot(111) 
51 ax.bar(range(len(values)), values, tick_label=bar_label, width=0.5, fc='g')
52 ax.set_ylabel(u'probability') 
53 ax.set_title(u'Top-5') 
54 for a,b in zip(range(len(values)), values):
55     ax.text(a, b+0.0005, percent(b), ha='center', va = 'bottom', fontsize=7)
56 
57 fig = plt.gcf()
58 plt.show()
59 
60 name=img_path[0:-4]+'_pred'
61 fig.savefig(name, dpi=200)

上述程序的基本流程是:

  1. 载入相关模块,keras ,matplotlib,numpy
  2. 下载已经训练好的模型文件:
  3. 导入测试图像
  4. 应用模型文件对图像分类

需要额外说明的是:

  • 程序运行过程中,语句model = VGG16(weights='imagenet', include_top=True)会下载已经训练好的文件到c:\<使用者>\.keras\models文件夹下,模型的文件名为vgg16_weights_tf_dim_ordering_tf_kernels.h5,大小为527MB

  • 语句pred=decode_predictions(features, top=5)[0]会下载分类信息文件到c:\<使用者>\.keras\models文件夹下,模型的文件名为imagenet_class_index.json,该文件指明了ImageNet大赛所用的1000个图像类的信息。(由于下载地址在aws上,梯子请自备)

  • 程序运行结束,会在工作目录下生成测试图片的预测图,给出了最有可能的前5个类列。名称为:测试文件名_pred.png

  • 在程序中还可以查看模型的结构,语句为:model.summary(),命令行输出模型的结构配置为:

    1 _________________________________________________________________ 2 Layer (type) Output Shape Param #
    3 ================================================================= 4 input_12 (InputLayer) (None, 224, 224, 3) 0
    5 _________________________________________________________________ 6 block1_conv1 (Conv2D) (None, 224, 224, 64) 1792
    7 _________________________________________________________________ 8 block1_conv2 (Conv2D) (None, 224, 224, 64) 36928
    9 _________________________________________________________________ 10 block1_pool (MaxPooling2D) (None, 112, 112, 64) 0
    11 _________________________________________________________________ 12 block2_conv1 (Conv2D) (None, 112, 112, 128) 73856
    13 _________________________________________________________________ 14 block2_conv2 (Conv2D) (None, 112, 112, 128) 147584
    15 _________________________________________________________________ 16 block2_pool (MaxPooling2D) (None, 56, 56, 128) 0
    17 _________________________________________________________________ 18 block3_conv1 (Conv2D) (None, 56, 56, 256) 295168
    19 _________________________________________________________________ 20 block3_conv2 (Conv2D) (None, 56, 56, 256) 590080
    21 _________________________________________________________________ 22 block3_conv3 (Conv2D) (None, 56, 56, 256) 590080
    23 _________________________________________________________________ 24 block3_pool (MaxPooling2D) (None, 28, 28, 256) 0
    25 _________________________________________________________________ 26 block4_conv1 (Conv2D) (None, 28, 28, 512) 1180160
    27 _________________________________________________________________ 28 block4_conv2 (Conv2D) (None, 28, 28, 512) 2359808
    29 _________________________________________________________________ 30 block4_conv3 (Conv2D) (None, 28, 28, 512) 2359808
    31 _________________________________________________________________ 32 block4_pool (MaxPooling2D) (None, 14, 14, 512) 0
    33 _________________________________________________________________ 34 block5_conv1 (Conv2D) (None, 14, 14, 512) 2359808
    35 _________________________________________________________________ 36 block5_conv2 (Conv2D) (None, 14, 14, 512) 2359808
    37 _________________________________________________________________ 38 block5_conv3 (Conv2D) (None, 14, 14, 512) 2359808
    39 _________________________________________________________________ 40 block5_pool (MaxPooling2D) (None, 7, 7, 512) 0
    41 _________________________________________________________________ 42 flatten (Flatten) (None, 25088) 0
    43 _________________________________________________________________ 44 fc1 (Dense) (None, 4096) 102764544 45 _________________________________________________________________ 46 fc2 (Dense) (None, 4096) 16781312
    47 _________________________________________________________________ 48 predictions (Dense) (None, 1000) 4097000
    49 ================================================================= 50 Total params: 138,357,544 51 Trainable params: 138,357,544 52 Non-trainable params: 0 53 _________________________________________________________________

可以看到总的训练参数为 $138,357,544$。

代码及图片文件全部放在我的github

结果

分别对虎(tiger),猫(cat),卷纸(paper_towel)三张图片进行分类:

VGG16学习笔记

VGG16学习笔记


VGG16学习笔记

VGG16学习笔记


VGG16学习笔记
VGG16学习笔记 网上随便下的图,效果还行.

点赞
收藏
评论区
推荐文章
blmius blmius
2年前
MySQL:[Err] 1292 - Incorrect datetime value: ‘0000-00-00 00:00:00‘ for column ‘CREATE_TIME‘ at row 1
文章目录问题用navicat导入数据时,报错:原因这是因为当前的MySQL不支持datetime为0的情况。解决修改sql\mode:sql\mode:SQLMode定义了MySQL应支持的SQL语法、数据校验等,这样可以更容易地在不同的环境中使用MySQL。全局s
Jacquelyn38 Jacquelyn38
2年前
2020年前端实用代码段,为你的工作保驾护航
有空的时候,自己总结了几个代码段,在开发中也经常使用,谢谢。1、使用解构获取json数据let jsonData  id: 1,status: "OK",data: 'a', 'b';let  id, status, data: number   jsonData;console.log(id, status, number )
皕杰报表之UUID
​在我们用皕杰报表工具设计填报报表时,如何在新增行里自动增加id呢?能新增整数排序id吗?目前可以在新增行里自动增加id,但只能用uuid函数增加UUID编码,不能新增整数排序id。uuid函数说明:获取一个UUID,可以在填报表中用来创建数据ID语法:uuid()或uuid(sep)参数说明:sep布尔值,生成的uuid中是否包含分隔符'',缺省为
Wesley13 Wesley13
2年前
PPDB:今晚老齐直播
【今晚老齐直播】今晚(本周三晚)20:0021:00小白开始“用”飞桨(https://www.oschina.net/action/visit/ad?id1185)由PPDE(飞桨(https://www.oschina.net/action/visit/ad?id1185)开发者专家计划)成员老齐,为深度学习小白指点迷津。
Wesley13 Wesley13
2年前
mysql设置时区
mysql设置时区mysql\_query("SETtime\_zone'8:00'")ordie('时区设置失败,请联系管理员!');中国在东8区所以加8方法二:selectcount(user\_id)asdevice,CONVERT\_TZ(FROM\_UNIXTIME(reg\_time),'08:00','0
Wesley13 Wesley13
2年前
00:Java简单了解
浅谈Java之概述Java是SUN(StanfordUniversityNetwork),斯坦福大学网络公司)1995年推出的一门高级编程语言。Java是一种面向Internet的编程语言。随着Java技术在web方面的不断成熟,已经成为Web应用程序的首选开发语言。Java是简单易学,完全面向对象,安全可靠,与平台无关的编程语言。
Stella981 Stella981
2年前
Android蓝牙连接汽车OBD设备
//设备连接public class BluetoothConnect implements Runnable {    private static final UUID CONNECT_UUID  UUID.fromString("0000110100001000800000805F9B34FB");
Stella981 Stella981
2年前
Django中Admin中的一些参数配置
设置在列表中显示的字段,id为django模型默认的主键list_display('id','name','sex','profession','email','qq','phone','status','create_time')设置在列表可编辑字段list_editable
Wesley13 Wesley13
2年前
MySQL部分从库上面因为大量的临时表tmp_table造成慢查询
背景描述Time:20190124T00:08:14.70572408:00User@Host:@Id:Schema:sentrymetaLast_errno:0Killed:0Query_time:0.315758Lock_
Python进阶者 Python进阶者
3个月前
Excel中这日期老是出来00:00:00,怎么用Pandas把这个去除
大家好,我是皮皮。一、前言前几天在Python白银交流群【上海新年人】问了一个Pandas数据筛选的问题。问题如下:这日期老是出来00:00:00,怎么把这个去除。二、实现过程后来【论草莓如何成为冻干莓】给了一个思路和代码如下:pd.toexcel之前把这