GithubHelp home page GithubHelp logo

slei109 / tart Goto Github PK

View Code? Open in Web Editor NEW
4.0 2.0 0.0 37.98 MB

PyTorch Implementation of TART: Improved Few-shot Text Classification Using Task-Adaptive Reference Transformation, ACL 2023

License: MIT License

Shell 1.43% Python 98.57%

tart's Introduction

TART: Improved Few-shot Text Classification Using Task-Adaptive Reference Transformation

This repository contains the code and data for our ACL 2023 paper:

TART: Improved Few-shot Text Classification Using Task-Adaptive Reference Transformation

If you find this work useful and use it in your own research, please cite our paper.

@inproceedings{lei-etal-2023-tart,
    title = "{TART}: Improved Few-shot Text Classification Using Task-Adaptive Reference Transformation",
    author = "Lei, Shuo  and
      Zhang, Xuchao  and
      He, Jianfeng  and
      Chen, Fanglan  and
      Lu, Chang-Tien",
    booktitle = "Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)",
    month = jul,
    year = "2023",
    address = "Toronto, Canada",
    publisher = "Association for Computational Linguistics",
    url = "https://aclanthology.org/2023.acl-long.617",
    doi = "10.18653/v1/2023.acl-long.617",
    pages = "11014--11026",
}

Overview

Meta-learning has emerged as a trending technique for tackling few-shot text classification and achieving state-of-the-art performance. However, the performance of existing approaches heavily depends on the inter-class variance of the support set. As a result, it can perform well on tasks when the semantics of sampled classes are distinct while failing to differentiate classes with similar semantics. In this paper, we propose a novel Task-Adaptive Reference Transformation (TART) network, aiming to enhance the generalization by transforming the class prototypes to per-class fixed reference points in task-adaptive metric spaces. To further maximize divergence between transformed prototypes in task-adaptive metric spaces, TART introduces a discriminative reference regularization among transformed prototypes. Extensive experiments are conducted on four benchmark datasets, and our method demonstrates clear superiority over the state-of-the-art models in all the datasets.

The figure gives an overview of our model.

Data

We ran experiments on a total of 4 datasets. You may unzip our processed data file data.zip and put the data files under data/ folder.

Dataset Notes
20 Newsgroups (link) Processed data available. We used the 20news-18828 version, available at the link provided.
Reuters-21578 (link) Processed data available.
Amazon reviews (link) We used a subset of the product review data. Processed data available.
HuffPost headlines (link) Processed data available.

Please download pre-trained word embedding file wiki.en.vec from here and put it under pretrain_wordvec/ folder.

Pre-trained_Weights

We released the pre-trained weights for each task. Feel free to download pre-trained weights via google drive.

Quickstart

After you have finished configuring the data/ folder and the pretrain_wordvec/ folder, you can run our model with the following commands.

cd bin
sh tart.sh

You can also adjust the model by modifying the parameters in the tart.sh file.

Dependencies

  • Python 3.7.11
  • PyTorch 1.11.0
  • numpy 1.19.1
  • torchtext 0.13.0
  • termcolor 1.1.0
  • tqdm 4.62.3
  • CUDA 11.1

Acknowledgement

The implementation is based on MLADA.

tart's People

Contributors

slei109 avatar

Stargazers

Jin Cui avatar Jianfeng He avatar  avatar  avatar

Watchers

Kostas Georgiou avatar  avatar

tart's Issues

Hyperparameter setting

Hello, author! Can the parameter be set using tart.sh? The data I reproduced were much more skewed than those in the paper, especially at HuffPost. The 1-shot of 20newsgroup is similar, and the 5-shot is also slightly off. Other data sets also have some biases

10-shot
0.6716720322668552,20newsgroup,0.98,,[],False,,2,../data/20news.json,20newsgroup,0.1,bilstm,False,1,0.001,0.001,,train,7,8,5,10,False,reuters_False_data.json,20,,25,result/5-way_1-shot_20newsgroup_result.csv,False,3,1,,1000,100,1000,,100,5,../pretrain_wordvec/wiki.en.vec,../pretrain_wordvec

5-shot
0.7930320369005204,20newsgroup,0.98,,[],False,,2,../data/20news.json,20newsgroup,0.1,bilstm,False,1,0.001,0.001,,train,7,8,5,10,False,reuters_False_data.json,20,,25,result/5-way_5-shot_20newsgroup_result.csv,False,3,5,,1000,100,1000,,100,5,../pretrain_wordvec/wiki.en.vec,../pretrain_wordvec

1-shot
0.39820801581442355,0.08724249734104635,huffpost,0.98,,[],False,,2,../data/huffpost.json,huffpost,0.1,bilstm,False,1,0.001,0.001,,train,16,20,5,10,False,reuters_False_data.json,20,,25,result/5-way_1-shot_huffpost_result.csv,False,3,1,,1000,100,1000,,100,5,../pretrain_wordvec/wiki.en.vec,../pretrain_wordvec
5-shot
0.5700000269412995,0.07295171397582322,huffpost,0.98,,[],False,,2,../data/huffpost.json,huffpost,0.1,bilstm,False,1,0.001,0.001,,train,16,20,5,10,False,reuters_False_data.json,20,,25,result/5-way_5-shot_huffpost_result.csv,False,3,5,,1000,100,1000,,100,5,../pretrain_wordvec/wiki.en.vec,../pretrain_wordvec

20news.json

Hi, sorry for bothering, would you mind sharing this "../data/20news.json"?

how to use BERT as encoder?

I try to run your model for my research.
for fair comparison, I try to use BERT as encoder. In your paper, you also did an experiment of BERT as encoder.
BUT, I cannot find how to use BERT as encoder for your model.

I found that an argument (--bert) in shell script. however, I cannot find any source code to use BERT as an encoder in the source.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.