GithubHelp home page GithubHelp logo

capsule_text_classification's Introduction

Text Classification with Capsule Network

Implementation of our paper "Investigating Capsule Networks with Dynamic Routing for Text Classification" which is accepted by EMNLP18.

Requirements: Code is written in Python (2.7) and requires Tensorflow (1.4.1).

Link to our recent capsule project: https://github.com/andyweizhao/NLP-Capsule

ACL19 preprint: "Towards Scalable and Reliable Capsule Networks for Challenging NLP Applications"

Data Preparation

The reuters_process.py provides functions to clean the raw data and generate Reuters-Multilabel and Reuters-Full datasets. For quick start, please refer to Reuters for the Reuters-Multilabel dataset. For other datasets, please access to Others.

More explanation

The utils.py includes several wrapped and fundamental functions such as _conv2d_wrapper, _separable_conv2d_wrapper and _get_variable_wrapper etc.

The layers.py implements capsule network including Primary Capsule Layer, Convolutional Capsule Layer, Capsule Flatten Layer and FC Capsule Layer.

The network.py provides the implementation of two kinds of capsule network as well as baseline models for the comparison.

The loss.py provides the implementation of three kinds of loss function: cross entropy, margin loss and spread loss.

Quick start

python ./main.py --model_type CNN --learning_rate 0.0005

python ./main.py --model_type capsule-A --learning_rate 0.001

Performance on Reuters-Multilabel dataset

Capsule-A
Epoch: 1  Val accuracy: 82.9%  Loss: 0.1149
ER: 0.015 Precision: 0.236 Recall: 0.362 F1: 0.255
Epoch: 2  Val accuracy: 88.8%  Loss: 0.0748
ER: 0.172 Precision: 0.466 Recall: 0.500 F1: 0.459
Epoch: 3  Val accuracy: 89.3%  Loss: 0.0601
ER: 0.495 Precision: 0.765 Recall: 0.751 F1: 0.734
Epoch: 4  Val accuracy: 90.5%  Loss: 0.0560
ER: 0.578 Precision: 0.829 Recall: 0.817 F1: 0.802
Epoch: 5  Val accuracy: 90.1%  Loss: 0.0530
ER: 0.609 Precision: 0.841 Recall: 0.838 F1: 0.822
Epoch: 6  Val accuracy: 90.9%  Loss: 0.0505
ER: 0.600 Precision: 0.850 Recall: 0.854 F1: 0.831
Epoch: 7  Val accuracy: 92.0%  Loss: 0.0474
ER: 0.600 Precision: 0.873 Recall: 0.837 F1: 0.833

Capsule-B
Epoch: 1  Val accuracy: 82.7%  Loss: 0.0867
ER: 0.031 Precision: 0.257 Recall: 0.226 F1: 0.235
Epoch: 2  Val accuracy: 90.9%  Loss: 0.0586
ER: 0.458 Precision: 0.752 Recall: 0.663 F1: 0.692
Epoch: 3  Val accuracy: 93.9%  Loss: 0.0431
ER: 0.612 Precision: 0.943 Recall: 0.792 F1: 0.841

CNN:
ER: 0.028 Precision: 0.307 Recall: 0.199 F1: 0.234
Epoch: 2  Val accuracy: 92.0%  Loss: 0.0462
ER: 0.200 Precision: 0.687 Recall: 0.492 F1: 0.555
Epoch: 3  Val accuracy: 94.7%  Loss: 0.0346
ER: 0.265 Precision: 0.876 Recall: 0.589 F1: 0.683
Epoch: 4  Val accuracy: 95.2%  Loss: 0.0310
ER: 0.255 Precision: 0.890 Recall: 0.581 F1: 0.683
Epoch: 5  Val accuracy: 95.4%  Loss: 0.0298
ER: 0.262 Precision: 0.887 Recall: 0.581 F1: 0.682
Epoch: 6  Val accuracy: 95.2%  Loss: 0.0295
ER: 0.262 Precision: 0.884 Recall: 0.577 F1: 0.679
Epoch: 7  Val accuracy: 95.8%  Loss: 0.0294
ER: 0.246 Precision: 0.881 Recall: 0.566 F1: 0.671

Notes: Val accuracy and loss are evaluated on dev (single-label), the metrics such as ER and Precision are evaluated on test (multi-label).

The main functions are already in this repository. For any questions, you can report issue here.

Reference

If you find our source code useful, please consider citing our work.

@inproceedings{zhao2018capsule,
  year = {2018},
  author = {Wei Zhao and Jianbo Ye and Min Yang and Zeyang Lei and Suofei Zhang and Zhou Zhao},
  month = {September},
  title = {Investigating Capsule Networks with Dynamic Routing for Text Classification},
  booktitle = {Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing},
  url = {https://www.aclweb.org/anthology/D18-1350}
}

@inproceedings{zhao2019capsule,
    title = "Towards Scalable and Reliable Capsule Networks for Challenging {NLP} Applications",
    author = "Zhao, Wei and Peng, Haiyun and Eger, Steffen and Cambria, Erik and Yang, Min",
    booktitle = "Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics",
    month = jul,
    year = "2019",
    address = "Florence, Italy",
    publisher = "Association for Computational Linguistics",
    url = "https://www.aclweb.org/anthology/P19-1150",
    doi = "10.18653/v1/P19-1150",
    pages = "1549--1559"
}

@article{DBLP:Zhang2018capsule,
  author    = {Suofei Zhang and Wei Zhao and Xiaofu Wu and Quan Zhou},
  title     = {Fast Dynamic Routing Based on Weighted Kernel Density Estimation},
  journal   = {CoRR},
  volume    = {abs/1805.10807},
  year      = {2018},
  url       = {http://arxiv.org/abs/1805.10807},
  archivePrefix = {arXiv},
  eprint    = {1805.10807},
}

capsule_text_classification's People

Contributors

andyweizhao avatar danielhers avatar sleepybag avatar wanglilian avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

capsule_text_classification's Issues

代码问题

您好,
image请问您的这部分什么意思呢,hk_offsets和wk_offsets分别代表什么呢,谢谢~

next update

may I ask you when the experiment will update,please

About CapsuleConv, FullyConnected layers.

Hello,
I was interested to read the paper.
I would like to clarify the following.
Even though
vec_transformationByMat supposed to be used in the layers according to the paper, in the code vec_transformationByConv is applied
instead. It seems that vec_trandformationByMat is never used in the code.
Thanks in advance.

got an error

Thank you for sharing, but when I run your code, there is an error: ValueError: Dimensions must be equal, but are 84840 and 16 for 'capsule_3/conv2/add_2' (op: 'Add') with input shapes : [84840,16,48], [84840,16,16,16,48]. I changed the input data, but nothing else changed. Could you give me some suggestions?

数据集

你好,其他的数据集可以共享下吗

Library requirements

Please list your requirements in a working environment by running pip freeze. I tried with the following but I'm getting theano.tensor.var.AsTensorError: ('Cannot convert Tensor("capsule_3/primary/Reshape:0", shape=(25, 99, 1, 16, 16), dtype=float32) to TensorType', <class 'tensorflow.python.framework.ops.Tensor'>):

absl-py==0.6.0
astor==0.7.1
backports.weakref==1.0.post1
bleach==1.5.0
enum34==1.1.6
funcsigs==1.0.2
futures==3.2.0
gast==0.2.0
grpcio==1.16.0
h5py==2.8.0
html5lib==0.9999999
Keras==2.2.4
Keras-Applications==1.0.6
Keras-Preprocessing==1.0.5
Markdown==3.0.1
mock==2.0.0
numpy==1.15.3
pbr==5.1.0
protobuf==3.6.1
PyYAML==3.13
scikit-learn==0.17.1
scipy==1.1.0
six==1.11.0
tensorboard==1.11.0
tensorflow==1.4.1
tensorflow-tensorboard==0.4.0
termcolor==1.1.0
Theano==0.8.0
Werkzeug==0.14.1

代码一些参数和论文对不上

对于capsule A 模型
1 N-gram convolutional Layer 卷积步长 论文说是1 但是代码是2
2 primary layer 关于C的值,代码中是32 但是我的输出为什么是prim poses dimension:(25, 99, 1, 16, 16)
C是16吗?

为什么我得不到你论文里面的结果?

为什么我得不到你论文里面的结果?而且最终结果一直不会收敛,你是取得最大值作为最终结果吗?或者是你修改了一些超参数,可以交流一下吗?谢谢

capsule-B F1 85.8?

In the paper, the capsule-B F1 score on Reuters-Multilabel data set is 85.8, but the best score I can get is 83.7

python ./main.py -- model_type capsule-A --learning_rate 0.001

Epoch: 2 Val accuracy: 89.9% Loss: 0.0612
ER: 0.095 Precision: 0.635 Recall: 0.575 F1: 0.566
Epoch: 3 Val accuracy: 93.3% Loss: 0.0391
ER: 0.594 Precision: 0.912 Recall: 0.770 F1: 0.816
Epoch: 4 Val accuracy: 94.7% Loss: 0.0326
ER: 0.615 Precision: 0.939 Recall: 0.788 F1: 0.837
Epoch: 5 Val accuracy: 95.8% Loss: 0.0299
ER: 0.428 Precision: 0.948 Recall: 0.692 F1: 0.777
Epoch: 6 Val accuracy: 96.0% Loss: 0.0272
ER: 0.348 Precision: 0.958 Recall: 0.661 F1: 0.759

Orphan Category

在一些任务上,直接跑capsnet,相比于textcnn效果会差一些,考虑到background-noise的影响,您提出了3种策略,包括Orphan Category,Leaky-Softmax,和Coefficients Amendment。代码中好像只有Coefficients Amendment部分代码。请问其他两种方法的代码您还会更新上来吗?

What is leaky-softmax

hello, I have read your paper, but I do not understand leaky-softmax.
Can you give me equation, thanks !

文本分类相关问题

您好,我发现了一些问,运行您的代码的时候出现了一些,发现维度的错误(你的代码模型部分我都没有改变)在routing 部分的b=b+K.batch_dot(outputs,u_hat_vecs,[2,3])计算它的耦合系数的时候,其中output维度是[224,16,16,16],u_hat_ves的维度是[226,16,48,16]报错是: Dimensions must be equal, but are 224 and 16 for 'capsule_3/conv2/add_2' (op: 'Add') with input shapes: [224,16,48], [224,16,16,16,48],第二个是:胶囊网络的动态路由迭代是一个迭代过程,您好像没有进行反向传播截断,这个地方是否需要进行反向传播截断.
我使用tf的矩阵相乘[224,16,16,16] x[224,16,16,48]得到[224,16,16,48] 我然后在第2个维度上进行相加变成[224,16,48]与加b,解决了routing 部分的问题,然后后面的poses的reshape 又出现了相关问题,
我使用的环境是python3.6和tf.1.14.0,这个应该会是环境配置问题吧,像请问下相关的问题,想拿你的模型做一个baseline 模型,作为引文,你的代码是应该没有写笔误吧,还是python3.6和python2.的问题

ValueError: num_outputs should be int or long, got 9.

Hello
i need help

Traceback (most recent call last):
File "./main.py", line 167, in
poses, activations = baseline_model_cnn(X_embedding, args.num_classes)
File "D:\PycharmProjects\pythonProject\capsule_text_classification-master\network.py", line 18, in baseline_model_cnn
activations = tf.sigmoid(slim.fully_connected(nets, num_classes, scope='final_layer', activation_fn=None))
File "D:\Anaconda\envs\py27\lib\site-packages\tensorflow\contrib\framework\python\ops\arg_scope.py", line 183, in func_with_args
return func(*args, **current_args)
File "D:\Anaconda\envs\py27\lib\site-packages\tensorflow\contrib\layers\python\layers\layers.py", line 1822, in fully_connected
(num_outputs,))
ValueError: num_outputs should be int or long, got 9.

About MR dataset

when I run your code in MR dataset, I found can't get results as your paper.please tell how you set the experiment and how to process the original data

layer.py 101行 的 激活函数怎么理解?

感谢您的分享,我在学习代码时有一处不理解,如下:
beta_a = _get_weights_wrapper(
name='beta_a', shape=[1, shape[-1]]
)
activations = K.sqrt(K.sum(K.square(poses), axis=-1)) + beta_a
我理解activations指的是vector的强度, 那beta_a是一个随机生成的变量,为什么要加在activations中呢?
还望请您有时间指点一下~

For a single label task, how do you handle the output of the model ?

您好,在您的代码中在多标签数据上实验,模型输出的胶囊向量模大于0.5的标签设置为1,这样的设置很显然不适用于单标签的任务,我想知道对于单标签任务您是如何设置输出的?非常期待您的回答,感谢!

The loss don't change

Hello, after I change the weight_sharing from true to false. the loss don't change until 100 iterations. Then the model work properly, but the decaying rapid of the loss get quite slow. Can you give me some suggestions? I believe the key issues lays in the Squash function. but i don't know how to amend it.
Thank you!

The format of the input data

Would please tell us what is the format of the input data, i.e., how to use your code on user's own data? Thank you very much.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.