TensorFlow学习笔记(10):读取文件

逻辑寻云使
• 阅读 13543

简介

TensorFlow读取数据共有三种方法:

  • Feeding:当TensorFlow运行每步计算的时候,从Python获取数据。在Graph的设计阶段,用placeholder占住Graph的位置,完成Graph的表达;当Graph传给Session后,在运算时再把需要的数据从Python传过来。

  • Preloaded data:数据直接预加载到TensorFlow的Graph中,再把Graph传入Session运行。只适用于小数据。

  • Reading from file:在Graph中定义好文件读取的运算节点,把Graph传入Session运行时,执行读取文件的运算,这样可以避免在Python和TensorFlow C++执行环境之间反复传递数据。

本文讲解Reading from file的代码。

其他关于TensorFlow的学习笔记,请点击入门教程

实现

#!/usr/bin/env python
# -*- coding=utf-8 -*-
# @author: 陈水平
# @date: 2017-02-19
# @description: modified program to illustrate reading from file based on TF offitial tutorial
# @ref: https://www.tensorflow.org/programmers_guide/reading_data

def read_my_file_format(filename_queue):
  """从文件名队列读取一行数据
  
  输入:
  -----
  filename_queue:文件名队列,举个例子,可以使用`tf.train.string_input_producer(["file0.csv", "file1.csv"])`方法创建一个包含两个CSV文件的队列
  
  输出:
  -----
  一个样本:`[features, label]`
  """
  reader = tf.SomeReader()  # 创建Reader
  key, record_string = reader.read(filename_queue)  # 读取一行记录
  example, label = tf.some_decoder(record_string)  # 解析该行记录
  processed_example = some_processing(example)  # 对特征进行预处理
  return processed_example, label

def input_pipeline(filenames, batch_size, num_epochs=None):
  """ 从一组文件中读取一个批次数据
  
  输入:
  -----
  filenames:文件名列表,如`["file0.csv", "file1.csv"]`
  batch_size:每次读取的样本数
  num_epochs:每个文件的读取次数
  
  输出:
  -----
  一批样本,`[[example1, label1], [example2, label2], ...]`
  """
  filename_queue = tf.train.string_input_producer(filenames, num_epochs=num_epochs, shuffle=True)  # 创建文件名队列
  example, label = read_my_file_format(filename_queue)  # 读取一个样本
  # 将样本放进样本队列,每次输出一个批次样本
  #   - min_after_dequeue:定义输出样本后的队列最小样本数,越大随机性越强,但start up时间和内存占用越多
  #   - capacity:队列大小,必须比min_after_dequeue大
  min_after_dequeue = 10000
  capacity = min_after_dqueue + 3 * batch_size
  example_batch, label_batch = tf.train.shuffle_batch(
    [example, label], batch_size=batch_size, capacity=capacity,
    min_after_dequeue=min_after_dequeue)
  return example_batch, label_batch
  
def main(_):
  x, y = input_pipeline(['file0.csv', 'file1.csv'], 1000, 5)
  train_op = some_func(x, y)
  init_op = tf.global_variables_initializer()
  local_init_op = tf.local_variables_initializer()  # local variables like epoch_num, batch_size
  sess = tf.Session()
  
  sess.run(init_op)
  sess.run(local_init_op)
  
  # `QueueRunner`用于创建一系列线程,反复地执行`enqueue` op
  # `Coordinator`用于让这些线程一起结束
  # 典型应用场景:
  #   - 多线程准备样本数据,执行enqueue将样本放进一个队列
  #   - 一个训练线程从队列执行dequeu获取一批样本,执行training op
  # `tf.train`的许多函数会在graph中添加`QueueRunner`对象,如`tf.train.string_input_producer`
  # 在执行training op之前,需要保证Queue里有数据,因此需要先执行`start_queue_runners`
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(sess=sess, coord=coord)
  
  try:
    while not coord.should_stop():
      sess.run(train_op)
  except tf.errors.OutOfRangeError:
    print 'Done training -- epoch limit reached'
  finally:
    coord.request_stop()
  
  # Wait for threads to finish  
  coord.join(threads)
  sess.close()
  
if __name__ == '__main__':
  tf.app.run()
点赞
收藏
评论区
推荐文章
blmius blmius
4年前
MySQL:[Err] 1292 - Incorrect datetime value: ‘0000-00-00 00:00:00‘ for column ‘CREATE_TIME‘ at row 1
文章目录问题用navicat导入数据时,报错:原因这是因为当前的MySQL不支持datetime为0的情况。解决修改sql\mode:sql\mode:SQLMode定义了MySQL应支持的SQL语法、数据校验等,这样可以更容易地在不同的环境中使用MySQL。全局s
美凌格栋栋酱 美凌格栋栋酱
7个月前
Oracle 分组与拼接字符串同时使用
SELECTT.,ROWNUMIDFROM(SELECTT.EMPLID,T.NAME,T.BU,T.REALDEPART,T.FORMATDATE,SUM(T.S0)S0,MAX(UPDATETIME)CREATETIME,LISTAGG(TOCHAR(
Jacquelyn38 Jacquelyn38
4年前
2020年前端实用代码段,为你的工作保驾护航
有空的时候,自己总结了几个代码段,在开发中也经常使用,谢谢。1、使用解构获取json数据let jsonData  id: 1,status: "OK",data: 'a', 'b';let  id, status, data: number   jsonData;console.log(id, status, number )
Dive into TensorFlow系列(3)- 揭开Tensor的神秘面纱
TensorFlow计算图是由op和tensor组成,那么tensor一般都用来代表什么呢?显然,像模型的输入数据、网络权重、输入数据经op处理后的输出结果都需要用张量或特殊张量进行表达。既然tensor在TensorFlow体系架构中如此重要,因此本
Easter79 Easter79
3年前
Tensorflow2.0
Tensorflow2.01.Tensorflow简介1.Tensorflow是什么1.Google开源软件库1.采用数据流图,用于数值计算2.支
Easter79 Easter79
3年前
Tensorflow开篇:环境安装2—tensorflow1.10.0
前提已经安装pyhon环境,具体可以参考Tensorflow开篇环境安装(python3.5.2)(https://my.oschina.net/zhys513/blog/edit/863579)tensorflow官网:http://tensorflow.google.cn/(https://www.oschina.net/action/Go
Stella981 Stella981
3年前
Graphviz图片显示中文乱码问题
1\.报错详情¶现象:graph.view()展示的图形显示中文为乱码。In \40\:fromsklearnimporttreefromsklearn.datasetsimportload_winefromsklearn.model_selectionimporttrain
Easter79 Easter79
3年前
Tensorflow2.0全网最新教程来啦
Tensorflow2.0来啦,废话不多说,直接介绍Tensorflow2.0介绍:tensorflow是GOOGLE在2015年底发布的一款深度学习框架,也是目前全世界用得最多,发展最好的深度学习框架。2019年3月8日,GOOGLE发布最新tensorflow2版本。新版本的tensorflow有很多新特征,更快
Easter79 Easter79
3年前
TensorFlow从1到2(十四)评估器的使用和泰坦尼克号乘客分析
!(http://files.17study.com.cn/201904/tensorFlow2/tflogocard2.png)三种开发模式使用TensorFlow2.0完成机器学习一般有三种方式:使用底层逻辑这种方式使用Python函数自定义学习模型,把数学公式转化为可执行的程序逻辑。接着在训练循环
Easter79 Easter79
3年前
Tensorflow.cifar_数据下载过程(数据输出)
1、环境:Win7x64、python3.7x64、tensorflow1.14、CPUi59400F2、3、 3.1、cifar10,没有数据,全新下载,下到默认目录(C:\\Users\\Administrator\\tensorflow\_datasets),全过程控制台输出:(20190903)"C:\ProgramF
从源代码构建TensorFlow流程记录
通常情况下,直接安装构建好的.whl即可。不过,当需要一些特殊配置(或者闲来无事想体会TensorFlow构建过程到底有多麻烦)的时候,则需要选择从源代码构建TensorFlow。万幸文档混乱的TensorFlow还是好心地为我们提供了一整页的文档供参考https://www.tensorflow.org/install/source?hlzhcn,个人认为其中最需要关注的部分莫过于经过测试供参考的源配置(列于文末)。
逻辑寻云使
逻辑寻云使
Lv1
忽闻歌古调,归思欲沾巾。
文章
4
粉丝
0
获赞
0