GithubHelp home page GithubHelp logo

lancopku / text-autoaugment Goto Github PK

View Code? Open in Web Editor NEW
125.0 4.0 16.0 12.78 MB

[EMNLP 2021] Text AutoAugment: Learning Compositional Augmentation Policy for Text Classification

Home Page: https://arxiv.org/abs/2109.00523

License: MIT License

Python 99.02% Shell 0.98%
text-classification automl data-augmentation

text-autoaugment's Introduction

Text-AutoAugment (TAA)

This repository contains the code for our paper Text AutoAugment: Learning Compositional Augmentation Policy for Text Classification (EMNLP 2021 main conference).

Overview of IAIS

Updates

  • [22.02.23]: We add an example on how to use TAA for your custom (local) dataset.
  • [21.10.27]: We make taa installable as a package and adapt to huggingface/transformers. Now you can search augmentation policy for the huggingface dataset with TWO lines of code.

Quick Links

Overview

  1. We present a learnable and compositional framework for data augmentation. Our proposed algorithm automatically searches for the optimal compositional policy, which improves the diversity and quality of augmented samples.

  2. In low-resource and class-imbalanced regimes of six benchmark datasets, TAA significantly improves the generalization ability of deep neural networks like BERT and effectively boosts text classification performance.

Getting Started

Prepare environment

Install pytorch and other small additional dependencies. Then, install this repo as a python package. Note that cudatoolkit=10.2 should match the CUDA version on your machine.

# Clone this repo
git clone https://github.com/lancopku/text-autoaugment.git
cd text-autoaugment

# Create a conda environment
conda create -n taa python=3.6
conda activate taa

# Install dependencies
pip install torch==1.10.1+cu102 -f https://download.pytorch.org/whl/cu102/torch_stable.html
pip install git+https://github.com/wbaek/theconf
pip install git+https://github.com/ildoonet/pystopwatch2.git
pip install -r requirements.txt

# Install this library (**no need to re-build if the source code is modified**)
python setup.py develop

# Download the models in NLTK
python -c "import nltk; nltk.download('wordnet'); nltk.download('averaged_perceptron_tagger'); nltk.download('omw-1.4')"

Please make sure your Torch supports GPU, check it with the command python -c "import torch; print(torch.cuda.is_available())" (should output True).

Use TAA with Huggingface

1. Get augmented training dataset with TAA policy

Option 1: Search for the optimal policy

You can search for the optimal policy on classification datasets supported by huggingface/datasets:

from taa.search_and_augment import search_and_augment

# return the augmented train dataset in the form of torch.utils.data.Dataset
augmented_train_dataset = search_and_augment(configfile="/path/to/your/config.yaml")

The configfile (YAML file) contains all the arguments including path, model, dataset, optimization hyper-parameter, etc. To successfully run the code, please carefully preset these arguments:

show details
  • model:

    • type: backbone model
  • dataset:

    • path: Path or name of the dataset
    • name: Defining the name of the dataset configuration
    • data_dir: Defining the data_dir of the dataset configuration
    • data_files: Path(s) to source data file(s)

    ATTENTION: All the augments above are used for the load_dataset() function in huggingface/datasets. Please refer to link for details.

    • text_key: Used to get text from a data instance (dict form in huggingface/datasets. See this IMDB example.)
  • abspath: Your working directory

  • aug: Pre-searched policy. Now we support IMDB, SST5, TREC, YELP2 and YELP5. See archive.py.

  • per_device_train_batch_size: Batch size per device for training

  • per_device_eval_batch_size: Batch size per device for evaluation

  • epoch: Training epoch

  • lr: Learning rate

  • max_seq_length

  • n_aug: Augment each text sample n_aug times

  • num_op: Number of operations per sub-policy

  • num_policy: Number of sub-policy per policy

  • method: Search method (taa)

  • topN: Ensemble topN sub-policy to get final policy

  • ir: Imbalance rate

  • seed: Random seed

  • trail: Trail under current random seed

  • train:

    • npc: Number of examples per class in the training dataset
  • valid:

    • npc: Number of examples per class in the val dataset
  • test:

    • npc: Number of examples per class in the test dataset
  • num_search: Number of optimization iteration

  • num_gpus: Number of GPUs used in RAY

  • num_cpus: Number of CPUs used in RAY

configfile example 1: TAA for huggingface dataset

bert_sst2_example.yaml is a configfile example for BERT model and SST2 dataset. You can follow this example to create your own configfile for other huggingface dataset.

For instance, if you only want to change the dataset from sst2 to imdb, just delete the sst2 in the 'path' argument, modify the 'name' to imdb and modity the 'text_key' to text. The result should be like bert_imdb_example.yaml.

configfile example 2: TAA for custom (local) dataset

bert_custom_data_example.yaml is a configfile example for BERT model and custom (local) dataset. The custom dataset should be in the CSV format, and the column name of the data table should be text and label. custom_data.csv is an example of the custom dataset.

WARNING: The policy optimization framework is based on ray. By default we use 4 GPUs and 40 CPUs for policy optimization. Make sure your computing resources meet this condition, or you will need to create a new configuration file. And please specify the gpus, e.g., CUDA_VISIBLE_DEVICES=0,1,2,3 before using the above code. TPU does not seem to be supported now.

Option 2: Use our pre-searched policy

To train a model on the datasets augmented by our pre-searched policy, please use (Take IMDB as an example):

from taa.search_and_augment import augment_with_presearched_policy

# return the augmented train dataset in the form of torch.utils.data.Dataset
augmented_train_dataset = augment_with_presearched_policy(configfile="/path/to/your/config.yaml")

Now we support IMDB, SST5, TREC, YELP2 and YELP5. See archive.py for details.

This table lists the test accuracy (%) of pre-searched TAA policy on full datasets:

Dataset IMDB SST-5 TREC YELP-2 YELP-5
No Aug 88.77 52.29 96.40 95.85 65.55
TAA 89.37 52.55 97.07 96.04 65.73
n_aug 4 4 4 2 2

More pre-searched policies and their performance will be COMING SOON.

2. Fine-tune a new model on the augmented training dataset

After getting augmented_train_dataset, you can load it to the huggingface trainer directly. Please refer to search_augment_train.py for details.

Reproduce results in the paper

Please see examples/reproduce_experiment.py, and run script/huggingface_lowresource.sh or script/huggingface_imbalanced.sh.

Contact

If you have any questions related to the code or the paper, feel free to open an issue.

Acknowledgments

Code refers to: fast-autoaugment.

Citation

If you find this code useful for your research, please consider citing:

@inproceedings{ren2021taa,
    title = "Text {A}uto{A}ugment: Learning Compositional Augmentation Policy for Text Classification",
    author = "Ren, Shuhuai and Zhang, Jinchao and Li, Lei and Sun, Xu and Zhou, Jie",
    booktitle = "Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing",
    year = "2021",
}

License

MIT

text-autoaugment's People

Contributors

markussagen avatar renshuhuai-andy avatar wolvecap avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

text-autoaugment's Issues

如何复现代码?

作者您好!我是一名在读硕士研究生,对您的论文很感兴趣,想复现论文代码进行研究,在复现过程中遇到一些阻力,恳请作者解答一下我的一些疑问。
1.在论文中您提到用的实验环境为8张Tesla P40,我的实验环境为:显卡 1080 * 2,显存 16G,cuda 版本10.2,并且参照您的readme文档配置好实验环境了,这样的实验配置能否跑动您的代码?
2.readme文档中提到复现代码主要是运行 reproduce_experiment.py,但是在运行到该第46行代码的时候报错,提示taa/models 文件下并没有所需的文件,我尝试在search.py中去运行生成所需的policy文件,但是不知道从何切入,对于整个跑通调试流程逻辑有点混乱,如果我要完整复现结果,应该是以怎样的步骤、流程去跑通代码程序?

在代码这方面我还是一个小白,非常感谢作者能够在百忙之中给我解答!

安装taa库后,按照readme文件使用出现错误

使用pip命令安装taa后,按照readme中的说明使用时,出现如下错误:
image

使用方式如下:
`from taa.search_and_augment import search_and_augment

return the augmented train dataset in the form of torch.utils.data.Dataset

augmented_train_dataset = search_and_augment(configfile="/path/to/your/config.yaml")
`

问题求教

你好,非常感谢你能分享这个代码,请教一下from theconf import Config as C 这个语句,我也没找到文件,或者是python库呢,这个是参数json文件吗?可是这样导不进来呢

The return in augmentation.py cannot be usead as data source

The return of each transform function in augmentation.py should be Str, instead of List, which is generated by default.
e.g.:
def random_word_delete(text, m):
return aug.augment(text)
⬇⬇⬇⬇⬇⬇⬇⬇⬇
def random_word_delete(text, m):
return aug.augment(text)[0]

module 'ray.tune' has no attribute 'trial_runner'

Traceback (most recent call last):
File "reproduce_experiment.py", line 10, in
from taa.search import get_path, search_policy, train_model_parallel
File "/root/yxyanyi/xiaozhu/text-autoaugment-main/taa/search.py", line 51, in
patch = gorilla.Patch(ray.tune.trial_runner.TrialRunner, 'step', step_w_log, settings=gorilla.Settings(allow_hit=True))
AttributeError: module 'ray.tune' has no attribute 'trial_runner'
作者您好,我按你的github进行安装环境后,便直接运行reproduce_experiment.py,然后报这个错误,恳望指正

使用custom dataset时,在 load_dataset报错。

你好,在使用自定义数据集(与示例数据集一致)时,按照示例Config文件进行运行,在load_dataset函数有报错,具体信息如下:
Traceback (most recent call last):
File "/opt/anaconda3/envs/forRL/lib/python3.8/site-packages/ray/tune/ray_trial_executor.py", line 901, in get_next_executor_event
future_result = ray.get(ready_future)
File "/opt/anaconda3/envs/forRL/lib/python3.8/site-packages/ray/_private/client_mode_hook.py", line 105, in wrapper
return func(*args, **kwargs)
File "/opt/anaconda3/envs/forRL/lib/python3.8/site-packages/ray/worker.py", line 1809, in get
raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(TuneError): �[36mray::ImplicitFunc.train()�[39m (pid=51539, ip=10.10.25.1, repr=objective)
File "/opt/anaconda3/envs/forRL/lib/python3.8/site-packages/ray/tune/trainable.py", line 349, in train
result = self.step()
File "/opt/anaconda3/envs/forRL/lib/python3.8/site-packages/ray/tune/function_runner.py", line 403, in step
self._report_thread_runner_error(block=True)
File "/opt/anaconda3/envs/forRL/lib/python3.8/site-packages/ray/tune/function_runner.py", line 567, in _report_thread_runner_error
raise TuneError(
ray.tune.error.TuneError: Trial raised an exception. Traceback:
�[36mray::ImplicitFunc.train()�[39m (pid=51539, ip=10.10.25.1, repr=objective)
File "/opt/anaconda3/envs/forRL/lib/python3.8/site-packages/ray/tune/function_runner.py", line 272, in run
self._entrypoint()
File "/opt/anaconda3/envs/forRL/lib/python3.8/site-packages/ray/tune/function_runner.py", line 348, in entrypoint
return self._trainable_func(
File "/opt/anaconda3/envs/forRL/lib/python3.8/site-packages/ray/tune/function_runner.py", line 640, in _trainable_func
output = fn()
File "/mnt/zhzhang_hdd/implicit-hate-corpus/text_autoaugment/taa/search.py", line 85, in objective
result = train_and_eval(config['tag'], policy_opt=True, save_path=save_path, only_eval=False)
File "/mnt/zhzhang_hdd/implicit-hate-corpus/text_autoaugment/taa/train.py", line 49, in train_and_eval
train_dataset, val_dataset, test_dataset = get_datasets(dataset_type, policy_opt=policy_opt)
File "/mnt/zhzhang_hdd/implicit-hate-corpus/text_autoaugment/taa/data.py", line 58, in get_datasets
test_dataset = load_dataset(path=path, name=dataset, data_dir=data_dir, data_files=data_files, split='test')
File "/opt/anaconda3/envs/forRL/lib/python3.8/site-packages/datasets/load.py", line 1714, in load_dataset
ds = builder_instance.as_dataset(split=split, ignore_verifications=ignore_verifications, in_memory=keep_in_memory)
File "/opt/anaconda3/envs/forRL/lib/python3.8/site-packages/datasets/builder.py", line 763, in as_dataset
datasets = utils.map_nested(
File "/opt/anaconda3/envs/forRL/lib/python3.8/site-packages/datasets/utils/py_utils.py", line 250, in map_nested
return function(data_struct)
File "/opt/anaconda3/envs/forRL/lib/python3.8/site-packages/datasets/builder.py", line 794, in _build_single_dataset
ds = self._as_dataset(
File "/opt/anaconda3/envs/forRL/lib/python3.8/site-packages/datasets/builder.py", line 862, in _as_dataset
dataset_kwargs = ArrowReader(self._cache_dir, self.info).read(
File "/opt/anaconda3/envs/forRL/lib/python3.8/site-packages/datasets/arrow_reader.py", line 211, in read
files = self.get_file_instructions(name, instructions, split_infos)
File "/opt/anaconda3/envs/forRL/lib/python3.8/site-packages/datasets/arrow_reader.py", line 184, in get_file_instructions
file_instructions = make_file_instructions(
File "/opt/anaconda3/envs/forRL/lib/python3.8/site-packages/datasets/arrow_reader.py", line 107, in make_file_instructions
absolute_instructions = instruction.to_absolute(name2len)
File "/opt/anaconda3/envs/forRL/lib/python3.8/site-packages/datasets/arrow_reader.py", line 618, in to_absolute
return [_rel_to_abs_instr(rel_instr, name2len) for rel_instr in self._relative_instructions]
File "/opt/anaconda3/envs/forRL/lib/python3.8/site-packages/datasets/arrow_reader.py", line 618, in
return [_rel_to_abs_instr(rel_instr, name2len) for rel_instr in self._relative_instructions]
File "/opt/anaconda3/envs/forRL/lib/python3.8/site-packages/datasets/arrow_reader.py", line 433, in _rel_to_abs_instr
raise ValueError(f'Unknown split "{split}". Should be one of {list(name2len)}.')
ValueError: Unknown split "test". Should be one of ['train'].
请问需要在数据文件夹中将数据先进行划分吗?并如何修改config文件。谢谢。

Class-imbalanced Regime

作者你好,我想复现小样本下IMDB在贵方法的效果,论文里面写在IMDB的训练集取80条,验证集取60条,然后我就在IMDB里的训练集,把标签为0的取40条,标签为1的取40条,然后再验证集同理不同类各取30条,请问这样做正确吗?

"Exception: This class is a singleton!"

Hi there, when I try to use search and augment like :
augmented_train_dataset = search_and_augment(configfile="./text-autoaugment/taa/confs/bert_sst2_example.yaml")
I have the error "Exception: This class is a singleton!". I did not modify anything in the file. How can I fix it ?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.