GithubHelp home page GithubHelp logo

dssm's Introduction

dssm

A BiGRU-Attention DSSM implementation with tensorflow estimator.

对应博客:https://blog.csdn.net/cdj0311/article/details/107634795

之前使用Keras和paddlepaddle实现过DSSM文本表示模型,(https://github.com/cdj0311/keras_bert_classification/blob/master/bert_dssm.py, https://github.com/cdj0311/paddledssm) 由于Keras做分布式计算比较麻烦,而paddlepaddle早已弃用,现在用tensorflow的高级API tf.estimator重写一遍,其中表示层使用双向GRU+Attention,最终输出为64维的向量。

python == 3.6

tensorflow == 1.13.1

训练步骤如下:

  1. 将文本数据转换为tfrecord格式:

    python convert_data.py

    data目录的data.txt中包含了10000条训练数据,数据为某新闻网站上的标题和对应的内容,格式为:title\tcontent,train.tfrecord是转换完成的tfrecord数据。

  2. 模型训练:

    sh train_local.sh

    模型训练完后会分别导出query和doc的pb格式模型,可根据需要进行选择。

  3. 模型预测:

    python predict.py

    给定一个句子得到向量,并获取最相似的N个句子,例如:

    输入: 赵丽颖冯绍峰在拍女儿国的时候真的超级甜了

    输出:

       0.801103	女神赵丽颖李沁都爱穿黄毛衣,但差距真的蛮大的
       0.744942	街拍:喜欢第二位俏皮可爱的小姐姐,和她在一起不会觉得无聊!
       0.722599	杜江霍思燕夫妇甜蜜现身 牵手依偎恩爱甜到发腻
       0.719018	还在情侣穿搭烦恼,看街拍情侣都是怎么搭配的
       0.707306	赵丽颖,应是绿肥红瘦,剧照
       0.701783	她的闺蜜则穿了一件白色的蕾丝连衣裙,尽显女人味
       0.70024	国民妖精十元女神可爱撩人瞬间合集!出色的不只是时尚穿搭
       0.691073	图集:#杨幂#赵丽颖暗斗时尚穿同款婚纱谁更美
       0.687201	赵丽颖 路人抓拍下的颖宝,这颜值可以说是完美的纯天然美女了~
    

    输入: 祝考研的女士们先生们都顺利考进自己理想的学校

    输出:

     0.890815	祝考研的女士们先生们都顺利考进自己理想的学校!实在考不上就滚tm的,当代...
     0.758741	硕士研究生招生考试22日开考
     0.701588	加油高考!祝你们顺利考上心仪的大学!
     0.660756	中考,你准备好了吗?
     0.654576	这些考研复试面试小技巧收好,导师的心就抓住了!
     0.63505	高考生作弊被抓飞踹监考老师:你知道我爸是谁?
     0.626651	高考倒计时30天,祝所有今年参加高考的小伙伴们心想事成,高考必胜
     0.590912	各位同学请注意,第一季期末考试现在开始~请认真阅读仔细答题
     0.585147	航班延误艺考生妈妈痛哭 浙传:可提供证明安排考试
     0.575564	当女儿带男同学回家写作业的时候,爸爸都在想什么
    
  4. 分布式训练

    设置run_on_cluster=True, 提交到job中即可训练,由于每个公司的分布式训练提交命令不一样,这里就不贴出来了。

该项目是基于字符做Embedding,实际使用中我们一般会将字和词同时作为输入进行训练。

dssm's People

Contributors

cdj0311 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

dssm's Issues

代码44行,应该把tf.Variable换成tf.get_variable

第59行:
char_embed = word_embedding(char_input, None, FLAGS.char_vocab_size, FLAGS.char_embedding_size, "char_embedding")
设置resue=None
以及第71行:
char_embed = word_embedding(char_input, True, FLAGS.char_vocab_size, FLAGS.char_embedding_size, "char_embedding")
设置resue=True。
这样设置 并不会对 使tf.Variable创建的变量共享,也就是query和doc的embedding table 并不相同。但在59和71行的resue设置却表明作者想共享变量。
下面是测试用例,证明query和doc的embedding table并不共享。
with tf.variable_scope('layertest', reuse=None): embedding_matrix = tf.Variable(tf.truncated_normal((3, 2)),name="embedding") print("1",embedding_matrix.name) #system out : 1 layertest/embedding:0
with tf.variable_scope('layertest', reuse=True): embedding_matrix = tf.Variable(tf.truncated_normal((3, 2)),name="embedding") print("2",embedding_matrix.name) #system out : 2 layertest_1/embedding:0

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.