GithubHelp home page GithubHelp logo

00mjk / vl-t5 Goto Github PK

View Code? Open in Web Editor NEW

This project forked from j-min/vl-t5

0.0 0.0 0.0 1.09 MB

PyTorch code for "Unifying Vision-and-Language Tasks via Text Generation" (ICML 2021)

License: MIT License

Shell 1.26% Python 74.32% Jupyter Notebook 24.42%

vl-t5's Introduction

Unifying Vision-and-Language Tasks via Text Generation

teaser image

Setup

# Create python environment (optional)
conda create -n vlt5 python=3.7
source activate vlt5

# Install python dependencies
pip install -r requirements.txt

# Download T5/BART backbone checkpoint
python download_backbones.py

# For MSCOCO captioning evaluation (optional; for captioning only)
python -c "import language_evaluation; language_evaluation.download('coco')"

Code structure

# Store images, features, and annotations
./datasets
    COCO/
        images/
        featuers/
    VG/
        images/
        features/
    GQA/
        images/
        features/
    nlvr/
        images/
        features/
    RefCOCO/

    ...

# Run feature extraction
./feature_extraction

# Train VL-T5
./VL-T5/
    src/
        modeling_t5.py modeling_bart.py                       <= VL-T5/VL-BART model classes
        pretrain.py, pretrain_data.py, pretrain_model.py      <= pretraining
        vqa.py, vqa_data.py vqa_model.py ...                  <= fine-tuning on downstream tasks (ex. VQA, GQA, NLVR2)
        multitask.py, multitask_data.py multiask_model.py     <= multitask learning on 7 downstream tasks
        param.py                                              <= (argparse) configuration
        tokenization.py                                       <= custom tokenizer
        utils.py, dist_utils.py                               <= utility functions
    snap/                                                     <= store weight checkpoints
    scripts/                                                  <= bash scripts for pretraining and finetuning

API

import sys
sys.path.append('./VL-T5/src')

# Parse configuration
from param import parse_args
args = parse_args(
    backbone='t5-base' # Backbone architecture
    load='./snap/pretrain/VLT5/Epoch30' # Pretrained checkpoint
    parse=False, # False for interactive env (ex. jupyter)
)
# Assign GPU
args.gpu = 0

# Load data loaders
from vqa_data import get_loader
train_loader = get_loader(
    args,
    split=args.train,
    ...
)
val_loader = get_loader(
    args,
    split=args.valid,
    ...
)
test_loader = get_loader(
    args,
    split=args.test,
    ...
)

# Import trainer
from vqa import Trainer
trainer = Trainer(
    args,
    train_loader=train_loader
    val_loader=val_loader
    test_loader=test_loader,
)

# model is attached to trainer
model = trainer.model

# Each task-specific model class is inherited from VLT5/VLBart classes, which are inherited from Huggingface transformers T5/BART classes
print(model)
>>> VLT5VQA(
    (shared): Embedding(...)
    (encoder): JointEncoder(...)
    ...
)

# Training
train_batch = next(iter(train_loader))
model.train_step(train_batch)
>>> {'loss': ... }

# Inference
test_batch = next(iter(test_loader))
model.test_step(test_batch)
>>> {'pred_ans': ... }

To add a new task, you can start with writing 3 files by editing from existing ones.

NEW_TASK_model.py # Define a VLT5NewTask/VLBartNewTask model which inherits VLT5/VLBart class
NEW_TASK_data.py # Define Dataset/DataLoader/Evaluator
NEW_TASK.py # Define a trainer which inherits TrainerBase (trainer_base.py)

Download Pre-trained models / Pre-extracted features

We host model checkpoints and features via google drive. We recommend using gdrive to download them.

Pretrained Models

gdrive download 1_SBj4sZ0gUqfBon1gFBiNRAmfHv5w_ph --recursive

COCO+VG pretraining (default)

  • VL-T5/snap/pretrain/VLT5/Epoch30.pth: VL-T5 pretrained for 30 epochs on COCO+VG
  • VL-T5/snap/pretrain/VLBart/Epoch30.pth: VL-BART pretrained for 30 epochs on COCO+VG

VCR pretraining (2nd stage)

  • VL-T5/snap/vcr_pretrain/VLT5/Epoch20.pth: VL-T5 further pretrained for 20 epochs on VCR
  • VL-T5/snap/vcr_pretrain/VLBart/Epoch20.pth: VL-BART further pretrained for 20 epochs on VCR

Dataset Preparation / Feature extraction

gdrive download 1MBBhlkP83VMKS2Qe0SmFfzkHhMpIG5wf --recursive
  • Multi30K only
    • git clone --recursive https://github.com/multi30k/dataset ./datasets/multi30k-dataset
    • unzip train.en.gz, val.en.gz, test_2017_flickr.en.gz, test_2018_flickr.en.gz in ./datasets/multi30k-dataset/data/task1/raw/
    • unzip train.de.gz, val.de.gz, test_2017_flickr.de.gz, test_2018_flickr.de.gz in ./datasets/multi30k-dataset/data/task1/raw/
  • For manual feature extraction, please checkout ./feature_extraction

Pretraining on COCO+VG

# Pretraining with 4 gpus
cd VL-T5/
bash scripts/COCOVG_pretrain_VLT5.sh 4
bash scripts/COCOVG_pretrain_VLBart.sh 4

Downstream tasks

# Finetuning with 4 gpus
cd VL-T5/
bash scripts/VQA_VLT5.sh 4
bash scripts/VQA_VLBart.sh 4
# Finetuning with 4 gpus
cd VL-T5/
bash scripts/GQA_VLT5.sh 4
bash scripts/GQA_VLBart.sh 4
# Finetuning with 4 gpus
cd VL-T5/
bash scripts/NLVR_VLT5.sh 4
bash scripts/NLVR_VLBart.sh 4
# Finetuning with 4 gpus
cd VL-T5/
bash scripts/RefCOCOg_VLT5.sh 4
bash scripts/RefCOCOG_VLBart.sh 4
# Pretraining on VCR with 4 gpus (optional)
cd VL-T5/
bash scripts/VCR_pretrain_VLT5.sh 4
bash scripts/VCR_pretrain_VLBart.sh 4

# Finetuning with 4 gpus
cd VL-T5/
bash scripts/VCR_VLT5.sh 4
bash scripts/VCR_VLBart.sh 4
# Finetuning with 4 gpus
cd VL-T5/
bash scripts/COCOCaption_VLT5.sh 4
bash scripts/COCOCaption_VLBart.sh 4
# Finetuning with 4 gpus
cd VL-T5/
bash scripts/Multi30K_VLT5.sh 4
bash scripts/Multi30K_VLBart.sh 4

Reference

Please cite our paper if you use our models in your works:

@inproceedings{cho2021vlt5,
  title     = {Unifying Vision-and-Language Tasks via Text Generation},
  author    = {Jaemin Cho and Jie Lei and Hao Tan and Mohit Bansal},
  booktitle = {ICML},
  year      = {2021}
}

vl-t5's People

Contributors

j-min avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.