GithubHelp home page GithubHelp logo

nvlabs / relvit Goto Github PK

View Code? Open in Web Editor NEW
63.0 6.0 2.0 265 KB

[ICLR 2022] RelViT: Concept-guided Vision Transformer for Visual Relational Reasoning

License: Other

Python 99.92% Shell 0.08%
iclr2022 pytorch vqa hico-det visual-reasoning

relvit's Introduction

RelViT

This repository hosts the code for the paper:

RelViT: Concept-guided Vision Transformer for Visual Relational Reasoning (ICLR 2022)

by Xiaojian Ma, Weili Nie, Zhiding Yu, Huaizu Jiang, Chaowei Xiao, Yuke Zhu and Anima Anandkumar

arXiv | Poster | Slides

News

  • πŸ”₯πŸ”₯ 09/10/2022: Pre-trained models on GQA are now released.

Abstract

Reasoning about visual relationships is central to how humans interpret the visual world. This task remains challenging for current deep learning algorithms since it requires addressing three key technical problems jointly: 1) identifying object entities and their properties, 2) inferring semantic relations between pairs of entities, and 3) generalizing to novel object-relation combinations, i.e., systematic generalization. In this work, we use vision transformers (ViTs) as our base model for visual reasoning and make better use of concepts defined as object entities and their relations to improve the reasoning ability of ViTs. Specifically, we introduce a novel concept-feature dictionary to allow flexible image feature retrieval at training time with concept keys. This dictionary enables two new concept-guided auxiliary tasks: 1) a global task for promoting relational reasoning, and 2) a local task for facilitating semantic object-centric correspondence learning. To examine the systematic generalization of visual reasoning models, we introduce systematic splits for the standard HICO and GQA benchmarks. We show the resulting model, Concept-guided Vision Transformer (or RelViT for short) significantly outperforms prior approaches on HICO and GQA by 16% and 13% in the original split, and by 43% and 18% in the systematic split. Our ablation analyses also reveal our model's compatibility with multiple ViT variants and robustness to hyper-parameters.

Installation

  • Install PyTorch:

    conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
  • Install the necessary packages with requirements.txt:

    pip install -r requirements.txt

The code has been tested with Python 3.8, PyTorch 1.11.0 and CUDA 11.6 on Ubuntu 20.04

Data Preparation

Please refer to data preparation

Training

HICO
bash scripts/train_hico_image.sh configs/train_hico.yaml

Note

  • In configs/train_hico.yaml you may find some configurable options:

    • use eval_mode to run different experiments: original or systematic generalization test
    • use model_args.encoder_args.encoder and load_encoder to select the vision backbone. There are five options available: pvtv2_b2, pvtv2_b3, swin_small, swin_base and vit_small_16.
    • use relvit to turn on/off RelViT auxillary loss
    • use relvit_weight to adjust the coefficient of RelViT auxillary loss
    • use relvit_local_only to control if you only use RelViT local/global task
    • use relvit_mode to control if you want to include EsViT loss.
    • use relvit_sample_uniform to choose from uniform or "most-recent" concept sampling
    • use relvit_concept_use and relvit_num_concepts to choose the concept used by RelViT among HOI, verb and object

    In general, we don't recommend modifying other parameters.

  • All the GPUs will be used by default. To run with the recommended batch size, you may need 1 V100 32G GPU.

GQA
bash scripts/train_gqa_image.sh configs/train_gqa.yaml

Note

  • In configs/train_gqa.yaml you may find some configurable options:

    • use eval_mode to run different experiments: original or systematic generalization test
    • use model_args.encoder_args.encoder and load_encoder to select the vision backbone. There are five options available: pvtv2_b2, pvtv2_b3, swin_small, swin_base and vit_small_16.
    • use relvit to turn on/off RelViT auxillary loss
    • use relvit_weight to adjust the coefficient of RelViT auxillary loss
    • use relvit_local_only to control if you only use RelViT local/global task
    • use relvit_mode to control if you want to include EsViT loss
    • use relvit_sample_uniform to choose from uniform or "most-recent" concept sampling

    In general, we don't recommend modifying other parameters.

  • All the GPUs will be used by default. To run with the recommended batch size, you may need up to 64 V100 32G GPUs. This is because we need to fine-tune the vision backbone during training.

Testing

HICO
bash scripts/train_hico_image.sh configs/train_hico.yaml --test_only --test_model <path to best_model.pth>
GQA
bash scripts/train_gqa_image.sh configs/train_gqa.yaml --test_only --test_model <path to best_model.pth>

Pre-trained models

tag encoder experiment result URL
swin-small-relvit swin_small GQA (val) 61.38 link
swin-base-relvit swin_base GQA (val) 65.54 link

License

Please check the LICENSE file for both the code and the released pre-trained models. This work may be used non-commercially, meaning for research or evaluation purposes only. For business inquiries, please contact [email protected].

Acknowledgement

The authors have referred the following projects:

SimCLR

DenseCL

EsViT

Swin-Transformer

PVT

HICODet

MCAN

Citation

Please consider citing our paper if you find our work helpful for your research:

@inproceedings{ma2022relvit,
    title={RelViT: Concept-guided Vision Transformer for Visual Relational Reasoning},
    author={Xiaojian Ma and Weili Nie and Zhiding Yu and Huaizu Jiang and Chaowei Xiao and Yuke Zhu and Song-Chun Zhu and Anima Anandkumar},
    booktitle={International Conference on Learning Representations},
    year={2022},
    url={https://openreview.net/forum?id=afoV8W3-IYp}
}

relvit's People

Contributors

jeasinema avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

Forkers

simrit1 aaronhd

relvit's Issues

training time

Hello, could you please tell me the training time of hico and gqa in which gpu, thanks!

Load checkpoint to inference for image embeddings

Hi, could you please provide the instruction to load the provided checkpoint? I want to inference the model for image embeddings.

I have tried as below but it’s wrong:

import torch
from models.swin_transformer import swin_base

model = swin_base()
sd = torch.load('./swin_base_original_gqa.pth', map_location='cpu')['state_dict']
model.load_state_dict(sd, strict=False)

# _IncompatibleKeys(missing_keys=['patch_embed'patch_embed.proj.weight', 'patch_embed.proj.bias', 'patch_embed.norm.weight', 'patch_embed.norm.bias', 'layers.0.blocks.0.norm1.weight'. ...]

Pretrained model

Can trained models be provided, especially on the GQA dataset.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.