GithubHelp home page GithubHelp logo

yongyehuang / tensorflow-tutorial Goto Github PK

View Code? Open in Web Editor NEW
902.0 52.0 372.0 477.84 MB

Some interesting TensorFlow tutorials for beginners.

Jupyter Notebook 98.26% Python 1.74% Shell 0.01%
tensorflow tensorflow-tutorials lstm cnn

tensorflow-tutorial's Introduction

Tensorflow-Tutorial

2018-04 更新说明

时间过去一年,TensorFlow 已经从 1.0 版本更新到了 1.8 版本,而且最近更新的非常频繁。最烦的就是每次更新很多 API 都改了,一些老版本的代码就跑不通了。因为本项目关注的人越来越多了,所以自己也感觉到非常有必要更新并更正一些之前的错误,否则误人子弟就不好了。这里不少内容可以直接在官方的教程中找到,官方文档也在不断完善中,我也是把里边的例子跑一下,加深理解而已,更多的还是要自己在具体任务中去搭模型,训模型才能很好地掌握。

这一次更新主要内容如下:

  • 使用较新版本的 tfmaster
  • 所有的代码改成 python3.5
  • 重新整理了基础用例
  • 添加实战例子

因为工作和学习比较忙,所以这些内容也没办法一下子完成。和之前的版本不同,之前我是作为一个入门菜鸟一遍学一边做笔记。虽然现在依然还是理解得不够,但是比之前掌握的知识应该多了不少,希望能够整理成一个更好的教程。

之前的代码我放在了另外一个分支上: https://github.com/yongyehuang/Tensorflow-Tutorial/tree/1.2.1

如果有什么问题或者建议,欢迎开issue或者邮件与我联系:[email protected]

运行环境

  • python 3.5
  • tensorflow master (gpu version)

文件结构

|- Tensorflow-Tutorial
|  |- example-notebook     # 入门教程 notebook 版
|  |- example-python      # 入门教程 .py 版
|  |- utils                # 一些工具函数(logging, tf.flags)
|  |- models               # 一些实战的例子(BN, GAN, 序列标注,seq2seq 等,持续更新)
|  |- data               # 数据
|  |- doc             # 相关文档

1.入门例子

T_01.TensorFlow 的基本用法

介绍 TensorFlow 的变量、常量和基本操作,最后介绍了一个非常简单的回归拟合例子。

T_02.实现一个两层的全连接网络对 MNIST 进行分类

T_03.TensorFlow 变量命名管理机制

  • notebook1 介绍 tf.Variable() 和 tf.get_variable() 创建变量的区别;介绍如何使用 tf.name_scope() 和 tf.variable_scope() 管理命名空间。

  • notebook2 除了使用变量命名来管理变量之外,还经常用到 collection 的方式来聚合一些变量或者操作。

T_04.实现一个两层的卷积神经网络(CNN)对 MNIST 进行分类

构建一个非常简单的 CNN 网络,同时输出中间各个核的可视化来理解 CNN 的原理。

第一层卷积核可视化

在上一个例子的基础上,加入 BN 层。在 CNN 中,使用 BN 层可以加速收敛速度,同时也能够减小初始化方式的影响。在使用 BN 层的时候要注意训练时用的是 mini-batch 的均值方差,测试时用的是指数平均的均值方差。所以在训练的过程中,一定要记得更新并保存均值方差。

在这个小网络中:迭代 10000 步,batch_size=100,大概耗时 45s;添加了 BN 层之后,迭代同样的次数,大概耗时 90s.

T_05.实现多层的 LSTM 和 GRU 网络对 MNIST 进行分类

字符 8

lstm 对字符 8 的识别过程

T_06.tensorboard 的简单用法

简单的 tensorboard 可视化

T_07.使用 tf.train.Saver() 来保存模型

T_08.【迁移学习】往一个已经保存好的 模型添加新的变量

T_09.使用 tfrecord 打包不定长的序列数据

T_10.使用 tf.data.Dataset 和 tfrecord 给 numpy 数据构建数据集

下面是对 MNIST 数据训练集 55000 个样本 读取的一个速度比较,统一 batch_size=128,主要比较 one-shotinitializable 两种迭代方式:

iter_mode buffer_size 100 batch(s)
one-shot 2000 125
one-shot 5000 149
initializable 2000 0.7
initializable 5000 0.7

可以看到,使用 initializable 方式的速度明显要快很多。因为使用 one-shot 方式会把整个矩阵放在图中,计算非常非常慢。

T_11.使用 tf.data.Dataset 和 tfrecord 给 图片数据 构建数据集

对于 png 数据的读取,我尝试了 3 组不同的方式: one-shot 方式, tf 的队列方式(queue), tfrecord 方式. 同样是在机械硬盘上操作, 结果是 tfrecord 方式明显要快一些。(batch_size=128,图片大小为256*256,机械硬盘)

iter_mode buffer_size 100 batch(s)
one-shot 2000 75
one-shot 5000 86
tf.queue 2000 11
tf.queue 5000 11
tfrecord 2000 5.3
tfrecord 5000 5.3

如果是在 SSD 上面的话,tf 的队列方式应该也是比较快的.打包成 tfrecord 格式只是减少了小文件的读取,其实现也是使用队列的。

T_12.TensorFlow 高级API tf.layers 的使用

使用 TensorFlow 原生的 API 能够帮助自己很好的理解网络的细节,但是往往比较低效。 tf.layers 和 tf.keras 一样,是一个封装得比较好的一个高级库,接口用着挺方便的。所以在开发的时候,可以使用高级的接口能够有效的提高工作效率。

2.TensorFlow 实战(持续更新)

下面的每个例子都是相互独立的,每个文件夹下面的代码都是可以单独运行的,不依赖于其他文件夹。

参考:tensorflow中batch normalization的用法

参考:

这里的 notebook 和 .py 文件的内容是一样的。本例子和下面的 GAN 模型用的数据集也是用了GAN学习指南:从原理入门到制作生成Demo 的二次元头像,感觉这里例子比较有意思。如果想使用其他数据集的话,只需要把数据集换一下就行了。

下载链接: https://pan.baidu.com/s/1HBJpfkIFaGh0s2nfNXJsrA 密码: x39r

下载后把所有的图片解压到一个文件夹中,比如本例中是: data_path = '../../data/anime/'

运行: python dcgan.py

这里的生成器和判别器我只实现了 DCGAN,没有实现 MLP. 如果想实现的话可以参考下面的两个例子。 参考:

原版的 wgan: python wgan.py

改进的 wgan-gp: python wgan_gp.py

代码来自:affinelayer/pix2pix-tensorflow

tensorflow-tutorial's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

tensorflow-tutorial's Issues

mnist = input_data.read_data_sets('MNIST_data', one_hot=True) 导入数据集mnist报错

http://blog.csdn.net/jerr__y/article/details/57084077
导入数据集的时候报错了: urllib.error.URLError: <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed (_ssl.c:749)>
原因是:from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True) 源代码里面默认的source_url是:DEFAULT_SOURCE_URL = 'https://storage.googleapis.com/cvdf-datasets/mnist/' 谷歌无法访问:
换成:mnist = input_data.read_data_sets('MNIST_data', one_hot=True, source_url='http://yann.lecun.com/exdb/mnist/')

[问题]运行至第十五部分报错了,怎么解决?

我是导出成py文件用ipython这个命令执行的文件

acc, _cost, _ = sess.run(fetches, feed_dict)

`---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/Users/finup/gitlab/poll-parrot/demo/a.py in ()
511 X_batch, y_batch = data_train.next_batch(tr_batch_size)
512 feed_dict = {X_inputs:X_batch, y_inputs:y_batch, lr:_lr, batch_size:tr_batch_size, keep_prob:0.5}
--> 513 _acc, _cost, _ = sess.run(fetches, feed_dict) # the cost is the mean cost of one batch
514 _accs += _acc
515 _costs += _cost

/Users/finup/gitlab/poll-parrot/env/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in run(self, fetches, feed_dict, options, run_metadata)
887 try:
888 result = self._run(None, fetches, feed_dict, options_ptr,
--> 889 run_metadata_ptr)
890 if run_metadata:
891 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/Users/finup/gitlab/poll-parrot/env/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _run(self, handle, fetches, feed_dict, options, run_metadata)
1094 'Cannot feed value of shape %r for Tensor %r, '
1095 'which has shape %r'
-> 1096 % (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape())))
1097 if not self.graph.is_feedable(subfeed_t):
1098 raise ValueError('Tensor %s may not be fed.' % subfeed_t)

ValueError: Cannot feed value of shape (128,) for Tensor u'Inputs/X_input:0', which has shape '(?, 32)`

Bi-directional LSTM Model :in[2] def的函数是什么意思?

def clean(s):
if u'“/s' not in s: # 句子中间的引号不应去掉
return s.replace(u' ”/s', '')
elif u'”/s' not in s:
return s.replace(u'“/s ', '')
elif u'‘/s' not in s:
return s.replace(u' ’/s', '')
elif u'’/s' not in s:
return s.replace(u'‘/s ', '')
else:
return s

仔细琢磨了半天,这个函数是不是没有意义啊?是写错了,还是我脑袋烧了。

OutOfRange Error

OutOfRangeError (see above for traceback): RandomShuffleQueue '_1_shuffle_batch/random_shuffle_queue' is closed and has insufficient elements (requested 100, current size 0)
[[Node: shuffle_batch = QueueDequeueManyV2[component_types=[DT_FLOAT, DT_INT32], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](shuffle_batch/random_shuffle_queue, shuffle_batch/n)]]

when i run the tfrecord-3-image-reader.py,the error occurred.How to solve it?

Received a label value of -2147483648 which is outside the valid range of [0, 5).

Bi-directional lstm中文分词里,报错tensorflow.python.framework.errors_impl.InvalidArgumentError: Received a label value of -2147483648 which is outside the valid range of [0, 5). Label values: -2147483648 -2147483648 2 3 -2147483648 0 0 0 0 0 0 0 0 0 0 -2147483648 -2147483648 -2147483648 -2147483648 -2147483648 2 3 -2147483648 0 0 0 0 0 0 0 -2147483648 -2147483648 -2147483648 -2147483648 -2147483648 2 3 -2147483648 0 0 0 0 0 0 0 -2147483648 -2147483648 -2147483648 -2147483648 -2147483648 -2147483648 -2147483648 2 3 -2147483648 0 0 0 0 0 -2147483648 -2147483648 -2147483648 -2147483648 ...等等等].我用的是自己的数据集,处理的跟样例数据集一样的形式(今/B 天/M是/M个/M好/E3天/E2气/E),结果报这个错,请问是否是我的数据集中的句子长度过长?该如何解决?

关于LSTM的一点问题

你好,请教下两个问题:
1.hidden_size = 256是不是就是输出状态向量的长度?
2.layer_num = 2这个是什么意思,这个和理解lstm中的哪个概念是对应的?

谢谢 :)

Bi-directional LSTM 这个demo中,没有看明白字向量怎么弄的

如题,看代码好像没有使用到字向量
self.embedding = tf.get_variable("embedding", [vocab_size, self.embedding_size], dtype=tf.float32)
这一步只是是不是可以理解为创建了一个空的字向量?
还有如果字不在集合中的话,查找字直接报错,请问这块如何优化

没有找到wd_1_1_cnn_concat和wd_1_2_cnn_max

你好,当我实现你的模型联合时发现无法导入如下两个文件:
import wd_1_1_cnn_concat.network as network1
import wd_1_2_cnn_max.network as network2
你能告诉我这两个文件在哪地方吗,非常感谢。

关于RNNcell内部的variable sharing

Notebook “Tutorial_05 - An understandable example to implement Multi-LSTM for MNIST”有这样一段代码。

with tf.variable_scope('RNN'):
    for timestep in range(timestep_size):
        if timestep > 0:
            tf.get_variable_scope().reuse_variables()

这个issue是关于tf.get_variable_scope().reuse_variables()合理性的猜测。希望同博主一起讨论。

首先,我发现tensorflow新旧版本在定义RNNcell的__call__方法时有不同的处理。旧版本直接定义__call__方法,新版本则要先继承_LayerRNNCell再定义callbuild 方法(而非直接定义__call__)。

为何这么处理?个人认为,使用RNNcell分为两个步骤:第一,实例化一个RNNcell;第二,调用声明的RNNcell实例进行计算。定义__call__方法就是为了简化用RNNcell的实例进行运算时的API调用。另外,大部分关于variable sharing的考虑和决策都发生在第一步。

但是,在声明RNNcell时,我们只指定了num_units。而将inputs转换为state的运算,涉及到根据input的shape来声明一组tf Variable。根据1.4.0版本的implementation,这组tf Variable的声明并没写在__init__方法中。个人猜测,在第一次使用RNNcell的某个实例进行计算时,先调用该实例的build方法,根据input的shape声明所需的tf Variables,然后再调用该实例的call方法进行计算。而build方法似乎只执行一次。

那么就产生一个问题。假设我们有两个不同shape的inputs,分别传递给同一个RNNcell的实例做计算,会发生什么?当然这是题外话。

关于是否应该使用tf.get_variable_scope().reuse_variables(),个人认为至少在1.4.0中不必。因为代码的for loop中,我们是重复"call"一个已经声明的mlstm_cell,而不是每次循环都声明一个mlstm_cell。另外,将input转换为state所需的tf Variables在第一次call mlstm_cell时得到定义,后续的call应该会自动重复使用这组tf Variables,即build方法只执行一次。

bi_lstm执行报错

y_pred = bi_lstm(X_inputs) 运行这一步时,将X_inputs当成普通参数运行了,导致出错,这该如何解决?

ValueError Traceback (most recent call last)
in ()
121
122
--> 123 y_pred = bi_lstm(X_inputs)
124 # adding extra statistics to monitor
125 # y_inputs.shape = [batch_size, timestep_size]

in bi_lstm(X_inputs)
49 """build the bi-LSTMs network. Return the y_pred"""
50 # ** 0.char embedding
---> 51 embedding = tf.get_variable("embedding", [vocab_size, embedding_size], dtype=tf.float32)
52 # X_inputs.shape = [batchsize, timestep_size] -> inputs.shape = [batchsize, timestep_size, embedding_size]
53 inputs = tf.nn.embedding_lookup(embedding, X_inputs)

d:\set up\python352\lib\site-packages\tensorflow\python\ops\variable_scope.py in get_variable(name, shape, dtype, initializer, regularizer, trainable, collections, caching_device, partitioner, validate_shape, use_resource, custom_getter)
1063 collections=collections, caching_device=caching_device,
1064 partitioner=partitioner, validate_shape=validate_shape,
-> 1065 use_resource=use_resource, custom_getter=custom_getter)
1066 get_variable_or_local_docstring = (
1067 """%s

d:\set up\python352\lib\site-packages\tensorflow\python\ops\variable_scope.py in get_variable(self, var_store, name, shape, dtype, initializer, regularizer, reuse, trainable, collections, caching_device, partitioner, validate_shape, use_resource, custom_getter)
960 collections=collections, caching_device=caching_device,
961 partitioner=partitioner, validate_shape=validate_shape,
--> 962 use_resource=use_resource, custom_getter=custom_getter)
963
964 def _get_partitioned_variable(self,

d:\set up\python352\lib\site-packages\tensorflow\python\ops\variable_scope.py in get_variable(self, name, shape, dtype, initializer, regularizer, reuse, trainable, collections, caching_device, partitioner, validate_shape, use_resource, custom_getter)
365 reuse=reuse, trainable=trainable, collections=collections,
366 caching_device=caching_device, partitioner=partitioner,
--> 367 validate_shape=validate_shape, use_resource=use_resource)
368
369 def _get_partitioned_variable(

d:\set up\python352\lib\site-packages\tensorflow\python\ops\variable_scope.py in _true_getter(name, shape, dtype, initializer, regularizer, reuse, trainable, collections, caching_device, partitioner, validate_shape, use_resource)
350 trainable=trainable, collections=collections,
351 caching_device=caching_device, validate_shape=validate_shape,
--> 352 use_resource=use_resource)
353
354 if custom_getter is not None:

d:\set up\python352\lib\site-packages\tensorflow\python\ops\variable_scope.py in _get_single_variable(self, name, shape, dtype, initializer, regularizer, partition_info, reuse, trainable, collections, caching_device, validate_shape, use_resource)
662 " Did you mean to set reuse=True in VarScope? "
663 "Originally defined at:\n\n%s" % (
--> 664 name, "".join(traceback.format_list(tb))))
665 found_var = self._vars[name]
666 if not shape.is_compatible_with(found_var.get_shape()):

ValueError: Variable embedding already exists, disallowed. Did you mean to set reuse=True in VarScope? Originally defined at:

File "d:\set up\python352\lib\site-packages\tensorflow\python\framework\ops.py", line 1269, in init
self._traceback = _extract_stack()
File "d:\set up\python352\lib\site-packages\tensorflow\python\framework\ops.py", line 2506, in create_op
original_op=self._default_original_op, op_def=op_def)
File "d:\set up\python352\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 767, in apply_op
op_def=op_def)

Multi—LSTM在Android端准确率下降很多

我正在做一个文本分类的项目,用到的是Multi-LSTM,在PC端的准确率大概在88%,而在android端的准确率则只有30%,用CNN的话,PC端和Android端的准确率都有99%,相差不大,想知道您是否遇到过类似的情况?

你好

pandas您用的那个版本?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.