GithubHelp home page GithubHelp logo

hlblade / mmgl Goto Github PK

View Code? Open in Web Editor NEW

This project forked from minjiyoon/mmgl

0.0 0.0 0.0 72 KB

Multimodal Graph Learning: how to encode multiple multimodal neighbors with their relations into LLMs

Home Page: https://arxiv.org/abs/2310.07478

Shell 0.56% Python 99.44%

mmgl's Introduction

Multimodal Graph Learning

Most multimodal learning algorithms focus on modeling simple one-to-one pairs of data from two modalities, such as image-caption pairs, or audiotext pairs. However, in most real-world settings, entities of different modalities interact with each other in more complex and multifaceted ways, going beyond one-to-one mappings.

We propose Multimodal Graph Learning (MMGL), a systematic framework for capturing information from multiple multimodal neighbors with relational structures among them. In particular, we focus on MMGL for generative tasks, building upon pretrained Language Models (LMs), aiming to augment their text generation with multimodal neighbor contexts.

The original paper can be found at MMGL

Setup

Create a new conda environment, install PyTorch and the remaining requirements:

conda create python==3.7 -n mmgl
conda activate mmgl
pip install -r requirements.txt
pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117

The code is implemented on PyTorch DistributedDataParallel. The code supports the WikiWeb2M dataset.

Data preprocessing

First, make a folder to download the WikiWeb2M dataset: mkdir wikiweb2m/raw. Then download all Train/Validation/Test files from the WikiWeb2M into wikiweb2m/raw. Next, make a folder to download images: mkdir wikiweb2m/raw/images. Finally, run preprocess_data.py to convert the WikiWeb2M dataset into pytorch format.

python preprocess_data.py

The output training/validation/test set sizes for section summarization is as follows:

Number of Train Validation Test
Sections 680K 170K 170K

Training

Script

In script/train_generation.sh, you can specify the base model (MODEL_NAME), the task (TASK; currently we support only section summarization 'section'), the neighbor context (CONTEXT). For CONTEXT, there are four options as follows:

CONTEXT description
section_only use only text in the target section
section_all use text and images in the target section
text_only use only text in the all page
all use text and images in the all page

You can set how to encode text neighbors using NEIGHBOR_MODE. There are two options as follows:

NEIGHBOR_MODE description
raw concatenate text neighbors as raw text into the input text
embedding embed text neighbors using text_model and concatenate embeddings into the input text

You can set the parameter-efficient fine-tuning (PEFT) option in the script using PEFT_TYPE. There are four PEFT options.

CONTEXT description
none full finetune
prompt prompt tuning
prefix prefix tuning
lora LoRA
flamingo fine-tune only newly added cross-attention; can be used on decode-only models with neighbor_mode = embedding

In the script, you can change max_input_length and max_output_length in addition to other optimization hyperparameters (e.g., epochs, learning_rate, per_device_train_batch_size). You can set which models to encode text and image neighbors using text_model and visual_model. All arguments you can set are defined under Argument class in language_modelling/run_generation.py.

File description

We provide brief descriptions for each file as follows:

Directory/File description
wikiweb2m/ codes related to WikiWeb2M dataset
wikiweb2m/cider compute CIDEr scores
wikiweb2m/data.py prepare each training point based on context and neighbor_mode
wikiweb2m/preprocess_data.py codes to preprocess WikiWeb2M dataset and download images
script/ codes to run MMGL
script/train_generation.sh set hyperparameters
language_modelling/ main directory
language_modelling/run_generation.py prepare models, read datasets, train/validation loops
language_modelling/utils.py utility functions
model/ language models
model/modelling_self_attention.py LMs only with self-attention; including encoder-decoder and decoder-only models
model/modelling_cross_attention.py LMs with cross-attention to encode neighbor information; decoder-only models

Citation

If you find this work or our code useful, please consider citing:

@article{yoon2023multimodal,
  title={Multimodal Graph Learning for Generative Tasks},
  author={Yoon, Minji and Koh, Jing Yu and Hooi, Bryan and Salakhutdinov, Ruslan},
  journal={arXiv preprint arXiv:2310.07478},
  year={2023}
}

mmgl's People

Contributors

minjiyoon avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.