GithubHelp home page GithubHelp logo

kevinmicha / mcpl Goto Github PK

View Code? Open in Web Editor NEW

This project forked from astrazeneca/mcpl

0.0 0.0 0.0 47.24 MB

Official implementation for "An image is worth multiple words: discovering object level concepts using multi-concepts prompts learning" [ICML 2024]]

License: Apache License 2.0

Shell 0.50% Python 99.50%

mcpl's Introduction

An Image is Worth Multiple Words: Discovering Object Level Concepts using Multi-Concept Prompt Learning (ICML 2024)

Hugging Face Spaces Maturity level-0

teaser

An Image is Worth Multiple Words: Discovering Object Level Concepts using Multi-Concept Prompt Learning (ICML 2024)

Chen Jin1 Ryutaro Tanno2 Amrutha Saseendran1 Tom Diethe1 Philip Teare1

Multi-Concept Prompt Learning (MCPL) pioneers mask-free text-guided learning for multiple prompts from one scene. Our approach not only enhances current methodologies but also paves the way for novel applications, such as facilitating knowledge discovery through natural language-driven interactions between humans and machines.

Motivation

We use Textural Inversion (T.I.) to learn concepts from both masked (left-first) or cropped (left-second) images; MCPL-one, learning both concepts jointly from the full image with a single string; and MCPL-diverse accounting for per-image specific relationships

Naive learning multiple text embeddings from single image-sentence pair without imagery guidence lead to miss-alignment in per-word cross attention (top). We propose three regularisation terms to enhance the accuracy of prompt-object level correlation (bottom).

Method

Input images from our natural_2_concepts dataset.

Applications

Multiple concepts from single image

Input images from our natural_2_concepts dataset.

Per-image different multiple concepts

Input images from P2P demo images.

Out-of-Distribution concept discovery and hypothesis generation

Input images from LGE CMR and MIMIC-CXR dataset.

Dataset

We generate and collected a Multi-Concept-Dataset including a total of ~ 1400 images and masked objects/concepts as follows

/ (370 images) /natural_2_concepts
/natural_345_concepts
/real_natural_concepts

Data file name Size # of images
medical_2_concepts 2.5M 370
natural_2_concepts 36M 415
natural_345_concepts 13M 525
real_natural_concepts 5.6M 137

Setup

Our code builds on, and shares requirements with Latent Diffusion Models (LDM). To set up their environment, please run:

conda env create -f environment.yaml
conda activate ldm
pip install -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
cd ./src/taming-transformers
pip install -e .

You will also need the official LDM text-to-image checkpoint, available through the LDM project page.

Currently, the model can be downloaded by running:

mkdir -p models/ldm/text2img-large/
wget -O models/ldm/text2img-large/model.ckpt https://ommer-lab.com/files/latent-diffusion/nitro/txt2img-f8-large/model.ckpt

Learning

MCPL-all: a naive approach that learns em-beddings for all prompts in the string (including adjectives, prepositions and nouns. etc.)

  • specify the placeholder_string to describe your multi-concept images;
  • in presudo_words we specify to learn every word in the placeholder_string;
python main.py --base configs/latent-diffusion/txt2img-1p4B-finetune.yaml \
                -t \
                --actual_resume </path/to/pretrained/model.ckpt> \
                -n <run_name> \
                --gpus 0, \
                --data_root </path/to/directory/with/images> \
                --init_word <initialization_word> \
                --placeholder_string 'green * and orange @' \
                --presudo_words 'green,*,and,orange,@'

MCPL-one: which simplifies the objective by learning single prompt (nouns) per concept

  • in this case, in presudo_words we specify to learn only a subset of words in the placeholder_string;
python main.py --base configs/latent-diffusion/txt2img-1p4B-finetune.yaml \
                -t \
                --actual_resume </path/to/pretrained/model.ckpt> \
                -n <run_name> \
                --gpus 0, \
                --data_root </path/to/directory/with/images> \
                --init_word <initialization_word> \
                --placeholder_string 'green * and orange @' \
                --presudo_words '*,@'

MCPL-diverse: where different strings are learned per image to observe variances among examples

  • before start, name each training image using single word representing relation;
  • e.g. in the ball and box exp, we train with: <'front.jpg, next.jpg, on.jpg, under.jpg'>;
  • in placeholder_string we describe the multi-concept, and use 'RELATE' as placeholder of relationship between multi-concepts;
  • in presudo_words, we specify all presudo_words include relations to be learnt, the per-image relation will be injected via replace 'RELATE' with the relation specified by each image's name;
python main.py --base configs/latent-diffusion/txt2img-1p4B-finetune.yaml \
                -t \
                --actual_resume </path/to/pretrained/model.ckpt> \
                -n <run_name> \
                --gpus 0, \
                --data_root </path/to/directory/with/images> \
                --init_word <initialization_word> \
                --placeholder_string 'green * RELATE orange @' \
                --presudo_words '*,@,on,under,next,front'

Regularisation-1: adding PromptCL and Bind adjective (teddybear skateboard example)

python main.py --base configs/latent-diffusion/txt2img-1p4B-finetune.yaml \
                -t \
                --actual_resume </path/to/pretrained/model.ckpt> \
                -n <run_name> \
                --gpus 0, \
                --data_root </path/to/directory/with/images> \
                --init_word <initialization_word> \
                --placeholder_string 'a brown @ on a rolling * at times square' \
                --presudo_words 'a,brown,on,rolling,at,times,square,@,*' \
                --attn_words 'brown,rolling,@,*' \
                --presudo_words_softmax '@,*' \
                --presudo_words_infonce '@,*' \
                --infonce_temperature 0.2 \
                --infonce_scale 0.0005 \
                --adj_aug_infonce 'brown,rolling' \
                --attn_mask_type 'skip'

Regularisation-2: adding PromptCL, Bind adjective and Attention Mask (teddybear skateboard example)

python main.py --base configs/latent-diffusion/txt2img-1p4B-finetune.yaml \
                -t \
                --actual_resume </path/to/pretrained/model.ckpt> \
                -n <run_name> \
                --gpus 0, \
                --data_root </path/to/directory/with/images> \
                --init_word <initialization_word> \
                --placeholder_string 'a brown @ on a rolling * at times square' \
                --presudo_words 'a,brown,on,rolling,at,times,square,@,*' \
                --attn_words 'brown,rolling,@,*' \
                --presudo_words_softmax '@,*' \
                --presudo_words_infonce '@,*' \
                --infonce_temperature 0.3 \
                --infonce_scale 0.00075 \
                --adj_aug_infonce 'brown,rolling'

Generation

To generate new images of the learned concept, run:

python scripts/txt2img.py --ddim_eta 0.0 
            --n_samples 8 
            --n_iter 2 
            --scale 10.0 
            --ddim_steps 50 
            --embedding_path /path/to/logs/trained_model/checkpoints/embeddings_gs-6099.pt 
            --ckpt_path /path/to/pretrained/model.ckpt 
            --prompt "a photo of green * and orange @"

where * and @ is the placeholder string used during inversion.

Code scructure

Our code is builds on the code from the Textural Inversion (MIT licence) library as well as the Prompt-to-Prompt (Apache-2.0 licence) codebase.

The mainjority modifications are performed in the following files, where we provide docstrings for all functions:

./main.py
./src/p2p/p2p_ldm_utils.py
./src/p2p/ptp_utils.py
./ldm/modules/embedding_manager.py
./ldm/models/diffusion/ddpm.py

The rest lib files are mostly unchanged and inherent from prior work.

FAQ

bert tokenizer error Sometimes one may get the following error due to the intrinsic error of tokenizer, simply try a different word with similar meaning. For example in the error below, replace 'peachy' in your prompt with 'splendid' would resolve the issue.

File "/YOUR-HOME-PATH/MCPL/ldm/modules/embedding_manager.py", line 22, in get_bert_token_for_string
    assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string"
AssertionError: String 'peachy' maps to more than a single token. Please use another string

Citation

If you make use of our work, please cite our paper:

@inproceedings{
anonymous2024an,
title={An Image is Worth Multiple Words: Discovering Object Level Concepts using Multi-Concept Prompt Learning},
author={Anonymous},
booktitle={Forty-first International Conference on Machine Learning},
year={2024},
url={https://openreview.net/forum?id=F3x6uYILgL}
}

mcpl's People

Contributors

lxasqjc avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.