GithubHelp home page GithubHelp logo

Comments (4)

huuquan1994 avatar huuquan1994 commented on April 28, 2024 1

@ChandanVerma
Here is an example of how to load the mPLUG-Large-V2 model

Load the mPLUG from checkpoint

import torch.nn as nn
import yaml
from models.model_caption_mplug import MPLUG
from models.tokenization_bert import BertTokenizer
from models.vit import interpolate_pos_embed, resize_pos_embed

# Load the config file
config_path = 'AliceMind/mPLUG/configs/caption_mplug_large.yaml'
config = yaml.load(open(config_path, 'r'), Loader=yaml.Loader)

config["min_length"] = 8
config["max_length"] = 25
config["add_object"] = True
config["beam_size"] = 5
config['text_encoder'] = 'bert-base-uncased'
config['text_decoder'] = 'bert-base-uncased'

# initilize mPLUG-Large-V2 model
mplug_tokenizer = BertTokenizer.from_pretrained(config['text_encoder'])
mPLUG_model = MPLUG(config=config, tokenizer=mplug_tokenizer)

mPLUG_ckpt_path = '/path/to/mplug_large_v2.pth' # replace the path of your ckpt here
checkpoint = torch.load(mPLUG_ckpt_path, map_location='cpu')
state_dict = checkpoint['model']

# reshape positional embedding to accomodate for image resolution change
# ref: https://github.com/alibaba/AliceMind/blob/main/mPLUG/caption_mplug.py#L227
if config["clip_name"] == "ViT-B-16":
    num_patches = int(config["image_res"] * config["image_res"]/(16*16))
elif config["clip_name"] == "ViT-L-14":
   num_patches = int(config["image_res"] * config["image_res"]/(14*14))

pos_embed = nn.Parameter(torch.zeros(num_patches + 1, 768).float())

pos_embed = resize_pos_embed(state_dict['visual_encoder.visual.positional_embedding'].unsqueeze(0),
                                         pos_embed.unsqueeze(0))
state_dict['visual_encoder.visual.positional_embedding'] = pos_embed

# load and move the mPLUG-Large model to GPU
mPLUG_model.load_state_dict(state_dict, strict=False); # adding ; to avoid printing long messages
mPLUG_model = mPLUG_model.to('cuda').eval();

Inference from single input

from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode

# define preprocess function
mPLUG_transform = transforms.Compose([
    transforms.Resize((config['image_res'], config['image_res']), 
                      interpolation=InterpolationMode.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize((0.48145466, 0.4578275, 0.40821073), 
                         (0.26862954, 0.26130258, 0.27577711)),
])

# load image
img_path = 'dog.jpg' # path to your image here
image = Image.open(img_path).convert('RGB')

# run inference
mplug_image = mPLUG_transform(image).unsqueeze(0) # [1, 3, image_res, image_res]
mplug_image = mplug_image.to('cuda')
    
topk_ids, topk_probs = mPLUG_model(mplug_image, None, train=False)
    
output_caption = mplug_tokenizer.decode(topk_ids[0][0]).replace("[SEP]", "").replace("[CLS]", "").replace("[PAD]", "").strip()

print(output_caption)

from alicemind.

ChandanVerma avatar ChandanVerma commented on April 28, 2024

hi @huuquan1994 is it possible to generate video captions using a single video file as well?
Thanks in advance.

from alicemind.

huuquan1994 avatar huuquan1994 commented on April 28, 2024

@ChandanVerma
I don't have time to test on videos but you can try to refer to https://github.com/alibaba/AliceMind/blob/main/mPLUG/videocap_mplug.py and see how to load both videos and pre-trained model.
Take a look at the video dataset that uses to load video inputs (here: https://github.com/alibaba/AliceMind/blob/main/mPLUG/dataset/video_dataset.py)

from alicemind.

ChandanVerma avatar ChandanVerma commented on April 28, 2024

Sure.. thanks for the heads up 👍

from alicemind.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.