GithubHelp home page GithubHelp logo

yuki-lin / tag-suggestion Goto Github PK

View Code? Open in Web Editor NEW

This project forked from hit-computer/tag-suggestion

0.0 1.0 0.0 279 KB

Implement CNN-GRU hierarchical model in Python/Theano for text tag suggestion

Python 100.00%

tag-suggestion's Introduction

Tag-Suggestion

这部分实验模型代码参考了hed-dlg-truncated。不过他们的代码非常复杂,实现的功能也很多,我最终所借鉴的部分其实并不多。由于刚开始学习Theano时就接触到这份代码,当时也是一边参考theano官方文档一边研究这份代码的,因而最终自己的模型代码中借鉴一些他们已有的功能模块。

========================================================

模型

我所做的任务是社会媒体上的文本标签注释(简单的说就是给文本标注语义标签),比如知乎,对于每个问题用户都会给出1~5个话题标签,而我希望通过构建一个模型能够自动根据用户的问题描述给出话题标签。知乎上大部分问题都会有问题的详细描述,我们把它视为是一个文本,由于是对文本进行建模,所以采用了层次化的深度神经网络(即先利用CNN模型从词学习得到句子的向量表示,再通过GRU模型从句子学习得到文本表示)。具体做法是,输入词向量利用CNN模型学习得到一个句子向量,然后再将句子向量输入到GRU中,将GRU的最后一个状态向量作为文本向量表示,最后接入一个sigmoid层(相当于构建多个sigmoid分类器),预测每个标签词的置信度。模型结构如下图所示,

model

事实上,除了上述用CNN学习句子向量表示以外,我们还尝试了用RNN(GRU)、char-rnn(可以避免分词带来的错误以及解决未登录词的问题)学习句子表示,以及NAACL2016年的一篇论文中的模型Hierarchical Attention

实验

我们总共爬取了50W知乎数据,随机抽取了10万数据作为训练语料,5千作为测试语料,采用了两种评价指标PRF和R-Precision。目前我们评价的是模型预测生成的Top-3的结果,以下是实验结果,

模型 P R F R-Precision
CNN+GRU(this code) 0.2828 0.3247 0.2828 0.2998
GRU+GRU 0.2762 0.3173 0.2766 0.2957
char-rnn+GRU 0.2763 0.3190 0.2772 0.2976
Hierarchical Attention 0.2732 0.3142 0.2732 0.2880

实验结论:

    1. 实验结果中F并不高,这是该任务普遍存在的问题,主要有以下两个原因:(数据方面)由于数据来源于网络,所有数据难免会有噪声;(评价方法方面)对于标签,可能存在多种合理的情况,既模型给出的推荐是合理但和原始数据的标签不一致,这也会PRF平价值很低
    1. CNN+GRU模型的效果最好,这说明CNN在文本标签注释任务上能更好的学到句子表示,分析原因可能是CNN能够更好的捕获局部信息,而这些信息对于模型给出正确的语义标签是有帮助的
    1. Char-rnn+GRU模型和GRU+GRU模型差不多,没有显著提高的原因可能是任务是标签注释,即对文本的语义理解粒度并不要求非常细,所以即使存在一些未登录词以及分词错误,对这个层次(或者说这个粒度)的语义理解影响不大。但Char-rnn的优势还是有的,比如在一个无法提供分词工具的环境下,基于词的RNN就没法做了。不过Char-rnn也有劣势,就是训练时间更长,因为循环次数更多了

豆瓣数据集上的实验

为了更好的比较模型的性能,我们与之前做标签推荐任务的模型进行比较,主要选了TAM(X Si, 2010)和WTM(Z Liu, 2011)这两个模型。TAM和WTM两个模型的作者都是豆瓣语料数据上做的实验,所以为了和这两个模型进行比较,我也同样在豆瓣数据上进行了相应的实验。我们用49050个豆瓣语料作为训练数据,用12132规模的语料作为测试数据。Top-3实验结果如下

模型PRF
包含标题TAM0.30760.34050.2809
WTM0.36830.45130.3541
不包含标题TAM0.29710.32300.2676
WTM0.34980.41820.3311
GRU+GRU0.36800.40520.3337
CNN+GRU0.38350.42130.3480
Hierarchical Attention0.37660.41130.3406

由于我们的模型没有对标题进行建模,所以只在不包含标题的豆瓣数据进行了实验。从实验结果来看,基于深度神经网络的三个模型均优于WTM,并且远远优于TAM,证明了基于深度神经网络的模型能够更好的学习到文本语义信息。

数据预处理

Data文件夹下面有两个py文件处理数据。首先需要将数据处理成以下格式:文本需要分词(用空格隔开),每个句子用‘</ss>’分割,文本结尾用‘</d>’分割,最后是文本标签(用空格隔开),例如

</ss> 众所周知 , 由于 ** 在 相当 长 的 时间 里 掌握 着 造丝 技术 与 造瓷 技术 , 因此 ** 古代 的 丝绸 与 瓷 在 国际 上 有 很 大 的 影响力 。 </ss> 而 近代 以后 国外 也 拥有 了 这样 的 技术 , ** 肯定 不 是 一 家 独 大 了 , 我 想 问 的 是 , 从 实用性 与 艺术性 的 角度 上 分开 谈谈 ** 丝绸 与 瓷器 现在 在 国际 上 的 地位 。 </ss> </d> 艺术 瓷器 </ss>

运行category.py文件对原始数据类别标签进行过滤,根据设定的类别数目(程序中的NofC值)筛选出重要的NofC个类别标签。由于筛选策略中会用到词向量,所以需要预先训练词向量(替换程序中的vec100.txt文件)

preprocess.py文件将训练数据和测试数据处理成符合模型输入的格式。将数据按照实验设置划分成trainingdata.txt和testdata.txt两部分。运行preprocess.py会得到4个文件:MT_WordEmb.pkl、ttrain.dialogues.pkl、tvalid.dialogues.pkl和ttrain.dict.pkl,这四个文件分别是词向量,训练数据,测试数据和词表,用于设置state.py中的对应参数

运行说明

在命令行中输入(使用GPU进行训练时)

THEANO_FLAGS='mode=FAST_RUN,device=gpu,floatX=float32' python train.py --prototype prototype_zhifu

如果没有GPU,则输入

THEANO_FLAGS='mode=FAST_RUN,device=cpu,floatX=float32' python train.py --prototype prototype_zhifu

tag-suggestion's People

Contributors

hit-computer avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.