GithubHelp home page GithubHelp logo

xinyu1205 / recognize-anything Goto Github PK

View Code? Open in Web Editor NEW
2.4K 2.4K 226.0 26.59 MB

Open-source and strong foundation image recognition models.

Home Page: https://recognize-anything.github.io/

License: Apache License 2.0

Python 2.81% Jupyter Notebook 97.19%
recognize-anything tag2text-iclr2024

recognize-anything's Introduction

recognize-anything's People

Contributors

0x4007 avatar amorporkian avatar coler1994 avatar crazycth avatar demoulinv avatar dnth avatar eltociear avatar fcakyon avatar ganymedesky avatar guillaume-rochette-oxb avatar majinyu666 avatar mhd-medfa avatar mitpitt avatar positive666 avatar tuofeilunhifi avatar xinyu1205 avatar zhaoyangli-nju avatar zylo117 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

recognize-anything's Issues

how can i get det location

我在使用inference_ram.py 进行识别的时候,我如何在获取到标签的同时,能把标签对应的位置信息获取到呢。

Not Issue, but can generate caption and tags

Thanks for great work!
I can generate caption and tags with RAM.
Please change the following 2 code

models/tag2text.py

'''
 * The Recognize Anything Model (RAM) & Tag2Text Model
 * Written by Xinyu Huang
'''
import numpy as np
import json
import torch
import warnings

from torch import nn
from models.bert import BertConfig, BertModel, BertLMHeadModel
from models.vit import VisionTransformer
from models.swin_transformer import SwinTransformer
from data.ram_tag_list_threshold import ram_class_threshold

from models.utils import *

warnings.filterwarnings("ignore")

#####################################

class MyRAM(nn.Module):
    def __init__(self,
                 med_config=f'{CONFIG_PATH}/configs/med_config.json',
                 image_size=384,
                 vit='base',
                 vit_grad_ckpt=False,
                 vit_ckpt_layer=0,
                 prompt='a picture of ',
                 threshold=0.68,
                 delete_tag_index=[],
                 tag_list=f'{CONFIG_PATH}/data/ram_tag_list.txt',
                 tag_list_chinese=f'{CONFIG_PATH}/data/ram_tag_list_chinese.txt'):
        r""" The Recognize Anything Model (RAM) inference module.
        RAM is a strong image tagging model, which can recognize any common category with high accuracy.
        Described in the paper " Recognize Anything: A Strong Image Tagging Model" https://recognize-anything.github.io/
        
        Args:
            med_config (str): path for the mixture of encoder-decoder model's configuration file
            image_size (int): input image size
            vit (str): model size of vision transformer
            threshold (int): tagging threshold
            delete_tag_index (list): delete some tags that may disturb captioning
        """
        super().__init__()

        # create image encoder
        if vit == 'swin_b':
            if image_size == 224:
                vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_224.json'
            elif image_size == 384:
                vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_384.json'
            vision_config = read_json(vision_config_path)
            assert image_size == vision_config['image_res']
            # assert config['patch_size'] == 32
            vision_width = vision_config['vision_width']

            self.visual_encoder = SwinTransformer(
                img_size=vision_config['image_res'],
                patch_size=4,
                in_chans=3,
                embed_dim=vision_config['embed_dim'],
                depths=vision_config['depths'],
                num_heads=vision_config['num_heads'],
                window_size=vision_config['window_size'],
                mlp_ratio=4.,
                qkv_bias=True,
                drop_rate=0.0,
                drop_path_rate=0.1,
                ape=False,
                patch_norm=True,
                use_checkpoint=False)

        elif vit == 'swin_l':
            if image_size == 224:
                vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_224.json'
            elif image_size == 384:
                vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_384.json'
            vision_config = read_json(vision_config_path)
            assert image_size == vision_config['image_res']
            # assert config['patch_size'] == 32
            vision_width = vision_config['vision_width']

            self.visual_encoder = SwinTransformer(
                img_size=vision_config['image_res'],
                patch_size=4,
                in_chans=3,
                embed_dim=vision_config['embed_dim'],
                depths=vision_config['depths'],
                num_heads=vision_config['num_heads'],
                window_size=vision_config['window_size'],
                mlp_ratio=4.,
                qkv_bias=True,
                drop_rate=0.0,
                drop_path_rate=0.1,
                ape=False,
                patch_norm=True,
                use_checkpoint=False)

        else:
            self.visual_encoder, vision_width = create_vit(
                vit, image_size, vit_grad_ckpt, vit_ckpt_layer)

        # create tokenzier
        self.tokenizer = init_tokenizer()

        # Tag2Text employ encoder-decoder architecture for image-tag-text generation: image-tag interaction encoder and image-tag-text decoder
        # create image-tag interaction encoder
        encoder_config = BertConfig.from_json_file(med_config)
        encoder_config.encoder_width = 512
        self.tag_encoder = BertModel(config=encoder_config,
                                     add_pooling_layer=False)

        # create image-tag-text decoder
        decoder_config = BertConfig.from_json_file(med_config)
        self.text_decoder = BertLMHeadModel(config=decoder_config)

        self.delete_tag_index = delete_tag_index
        self.prompt = prompt
        self.prompt_length = len(self.tokenizer(self.prompt).input_ids) - 1

        # load tag list
        self.tag_list = self.load_tag_list(tag_list)
        self.tag_list_chinese = self.load_tag_list(tag_list_chinese)

        # create image-tag recognition decoder
        self.threshold = threshold
        self.num_class = len(self.tag_list)
        q2l_config = BertConfig.from_json_file(f'{CONFIG_PATH}/configs/q2l_config.json')
        q2l_config.encoder_width = 512
        self.tagging_head = BertModel(config=q2l_config,
                                      add_pooling_layer=False)
        self.tagging_head.resize_token_embeddings(len(self.tokenizer))
        self.label_embed = nn.Embedding(self.num_class, q2l_config.hidden_size)

        if q2l_config.hidden_size != 512:
            self.wordvec_proj = nn.Linear(512, q2l_config.hidden_size)
        else:
            self.wordvec_proj = nn.Identity()

        self.fc = nn.Linear(q2l_config.hidden_size, 1)

        self.del_selfattention()

        # share weights of the lowest 2-layer of "image-tag interaction encoder" with the "image-tag recogntion decoder"
        tie_encoder_decoder_weights(self.tag_encoder, self.tagging_head, '',
                                    ' ')
        self.image_proj = nn.Linear(vision_width, 512)
        self.label_embed = nn.Parameter(torch.load('data/textual_label_embedding.pth',map_location='cpu').float())

        # adjust thresholds for some tags
        self.class_threshold = torch.ones(self.num_class) * self.threshold
        for key,value in enumerate(ram_class_threshold):
            self.class_threshold[key] = value

    def load_tag_list(self, tag_list_file):
        with open(tag_list_file, 'r') as f:
            tag_list = f.read().splitlines()
        tag_list = np.array(tag_list)
        return tag_list

    # delete self-attention layer of image-tag recognition decoder to reduce computation, follower Query2Label
    def del_selfattention(self):
        del self.tagging_head.embeddings
        for layer in self.tagging_head.encoder.layer:
            del layer.attention
            
    def generate(self,
                 image,
                 sample=False,
                 num_beams=3,
                 max_length=30,
                 min_length=10,
                 top_p=0.9,
                 repetition_penalty=1.0,
                 threshold=0.68,
                 tag_input=None,
                 return_tag_predict=False):
        label_embed = torch.nn.functional.relu(self.wordvec_proj(self.label_embed))

        image_embeds = self.image_proj(self.visual_encoder(image))
        image_atts = torch.ones(image_embeds.size()[:-1],
                                dtype=torch.long).to(image.device)

        # recognized image tags using image-tag recogntiion decoder
        image_cls_embeds = image_embeds[:, 0, :]
        image_spatial_embeds = image_embeds[:, 1:, :]

        bs = image_spatial_embeds.shape[0]
        label_embed = label_embed.unsqueeze(0).repeat(bs, 1, 1)
        tagging_embed = self.tagging_head(
            encoder_embeds=label_embed,
            encoder_hidden_states=image_embeds,
            encoder_attention_mask=image_atts,
            return_dict=False,
            mode='tagging',
        )

        logits = self.fc(tagging_embed[0]).squeeze(-1)

        targets = torch.where(
            torch.sigmoid(logits) > self.class_threshold.to(image.device),
            torch.tensor(1.0).to(image.device),
            torch.zeros(self.num_class).to(image.device))

        tag = targets.cpu().numpy()
        tag[:,self.delete_tag_index] = 0
        tag_output = []
        
        for b in range(bs):
            index = np.argwhere(tag[b] == 1)
            token = self.tag_list[index].squeeze(axis=1)
            tag_output.append(' | '.join(token))

        tag_input = tag_output
            
        # beam search for text generation(default)
        if not sample:
            image_embeds = image_embeds.repeat_interleave(num_beams, dim=0)
            tag_input_temp = []
            for tag in tag_input:
                for i in range(num_beams):
                    tag_input_temp.append(tag)
            tag_input = tag_input_temp

        image_atts = torch.ones(image_embeds.size()[:-1],
                                dtype=torch.long).to(image.device)

        # tokenizer input tags
        tag_input_tokenzier = self.tokenizer(tag_input,
                                             padding='max_length',
                                             truncation=True,
                                             max_length=40,
                                             return_tensors="pt").to(
                                                 image.device)
        encoder_input_ids = tag_input_tokenzier.input_ids
        encoder_input_ids[:, 0] = self.tokenizer.enc_token_id

        # put input tag into image-tag interaction encoder to interact with image embeddings
        output_tagembedding = self.tag_encoder(
            encoder_input_ids,
            attention_mask=tag_input_tokenzier.attention_mask,
            encoder_hidden_states=image_embeds,
            encoder_attention_mask=image_atts,
            return_dict=True,
        )

        # prompt trick for better captioning, followed BLIP
        prompt = [self.prompt] * image.size(0)
        input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(
            image.device)
        input_ids[:, 0] = self.tokenizer.bos_token_id
        input_ids = input_ids[:, :-1]
        
        if sample:
            # nucleus sampling
            model_kwargs = {
                "encoder_hidden_states": output_tagembedding.last_hidden_state,
                "encoder_attention_mask": None
            }
            outputs = self.text_decoder.generate(
                input_ids=input_ids,
                max_length=max_length,
                min_length=min_length,
                do_sample=True,
                top_p=top_p,
                num_return_sequences=1,
                eos_token_id=self.tokenizer.sep_token_id,
                pad_token_id=self.tokenizer.pad_token_id,
                repetition_penalty=1.1,
                **model_kwargs)
        else:
            # beam search (default)
            model_kwargs = {
                "encoder_hidden_states": output_tagembedding.last_hidden_state,
                "encoder_attention_mask": None
            }
            outputs = self.text_decoder.generate(
                input_ids=input_ids,
                max_length=max_length,
                min_length=min_length,
                num_beams=num_beams,
                eos_token_id=self.tokenizer.sep_token_id,
                pad_token_id=self.tokenizer.pad_token_id,
                repetition_penalty=repetition_penalty,
                **model_kwargs)

        captions = []
        for output in outputs:
            caption = self.tokenizer.decode(output, skip_special_tokens=True)
            captions.append(caption[len(self.prompt):])
        if return_tag_predict == True:
            return  captions, tag_output
        return captions
        

    def generate_tag(self,
                 image,
                 threshold=0.68,
                 tag_input=None,
                 ):
            
        label_embed = torch.nn.functional.relu(self.wordvec_proj(self.label_embed))

        image_embeds = self.image_proj(self.visual_encoder(image))
        image_atts = torch.ones(image_embeds.size()[:-1],
                                dtype=torch.long).to(image.device)

        # recognized image tags using image-tag recogntiion decoder
        image_cls_embeds = image_embeds[:, 0, :]
        image_spatial_embeds = image_embeds[:, 1:, :]

        bs = image_spatial_embeds.shape[0]
        label_embed = label_embed.unsqueeze(0).repeat(bs, 1, 1)
        tagging_embed = self.tagging_head(
            encoder_embeds=label_embed,
            encoder_hidden_states=image_embeds,
            encoder_attention_mask=image_atts,
            return_dict=False,
            mode='tagging',
        )

        logits = self.fc(tagging_embed[0]).squeeze(-1)

        targets = torch.where(
            torch.sigmoid(logits) > self.class_threshold.to(image.device),
            torch.tensor(1.0).to(image.device),
            torch.zeros(self.num_class).to(image.device))

        tag = targets.cpu().numpy()
        tag[:,self.delete_tag_index] = 0
        tag_output = []
        tag_output_chinese = []
        for b in range(bs):
            index = np.argwhere(tag[b] == 1)
            token = self.tag_list[index].squeeze(axis=1)
            tag_output.append(' | '.join(token))
            token_chinese = self.tag_list_chinese[index].squeeze(axis=1)
            tag_output_chinese.append(' | '.join(token_chinese))
            
        


        return tag_output, tag_output_chinese



#####################################

class RAM(nn.Module):
    def __init__(self,
                 med_config=f'{CONFIG_PATH}/configs/med_config.json',
                 image_size=384,
                 vit='base',
                 vit_grad_ckpt=False,
                 vit_ckpt_layer=0,
                 prompt='a picture of ',
                 threshold=0.68,
                 delete_tag_index=[],
                 tag_list=f'{CONFIG_PATH}/data/ram_tag_list.txt',
                 tag_list_chinese=f'{CONFIG_PATH}/data/ram_tag_list_chinese.txt'):
        r""" The Recognize Anything Model (RAM) inference module.
        RAM is a strong image tagging model, which can recognize any common category with high accuracy.
        Described in the paper " Recognize Anything: A Strong Image Tagging Model" https://recognize-anything.github.io/
        
        Args:
            med_config (str): path for the mixture of encoder-decoder model's configuration file
            image_size (int): input image size
            vit (str): model size of vision transformer
            threshold (int): tagging threshold
            delete_tag_index (list): delete some tags that may disturb captioning
        """
        super().__init__()

        # create image encoder
        if vit == 'swin_b':
            if image_size == 224:
                vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_224.json'
            elif image_size == 384:
                vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_384.json'
            vision_config = read_json(vision_config_path)
            assert image_size == vision_config['image_res']
            # assert config['patch_size'] == 32
            vision_width = vision_config['vision_width']

            self.visual_encoder = SwinTransformer(
                img_size=vision_config['image_res'],
                patch_size=4,
                in_chans=3,
                embed_dim=vision_config['embed_dim'],
                depths=vision_config['depths'],
                num_heads=vision_config['num_heads'],
                window_size=vision_config['window_size'],
                mlp_ratio=4.,
                qkv_bias=True,
                drop_rate=0.0,
                drop_path_rate=0.1,
                ape=False,
                patch_norm=True,
                use_checkpoint=False)

        elif vit == 'swin_l':
            if image_size == 224:
                vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_224.json'
            elif image_size == 384:
                vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_384.json'
            vision_config = read_json(vision_config_path)
            assert image_size == vision_config['image_res']
            # assert config['patch_size'] == 32
            vision_width = vision_config['vision_width']

            self.visual_encoder = SwinTransformer(
                img_size=vision_config['image_res'],
                patch_size=4,
                in_chans=3,
                embed_dim=vision_config['embed_dim'],
                depths=vision_config['depths'],
                num_heads=vision_config['num_heads'],
                window_size=vision_config['window_size'],
                mlp_ratio=4.,
                qkv_bias=True,
                drop_rate=0.0,
                drop_path_rate=0.1,
                ape=False,
                patch_norm=True,
                use_checkpoint=False)

        else:
            self.visual_encoder, vision_width = create_vit(
                vit, image_size, vit_grad_ckpt, vit_ckpt_layer)

        # create tokenzier
        self.tokenizer = init_tokenizer()

        # Tag2Text employ encoder-decoder architecture for image-tag-text generation: image-tag interaction encoder and image-tag-text decoder
        # create image-tag interaction encoder
        encoder_config = BertConfig.from_json_file(med_config)
        encoder_config.encoder_width = 512
        self.tag_encoder = BertModel(config=encoder_config,
                                     add_pooling_layer=False)

        # create image-tag-text decoder
        decoder_config = BertConfig.from_json_file(med_config)
        self.text_decoder = BertLMHeadModel(config=decoder_config)

        self.delete_tag_index = delete_tag_index
        self.prompt = prompt
        self.prompt_length = len(self.tokenizer(self.prompt).input_ids) - 1

        # load tag list
        self.tag_list = self.load_tag_list(tag_list)
        self.tag_list_chinese = self.load_tag_list(tag_list_chinese)

        # create image-tag recognition decoder
        self.threshold = threshold
        self.num_class = len(self.tag_list)
        q2l_config = BertConfig.from_json_file(f'{CONFIG_PATH}/configs/q2l_config.json')
        q2l_config.encoder_width = 512
        self.tagging_head = BertModel(config=q2l_config,
                                      add_pooling_layer=False)
        self.tagging_head.resize_token_embeddings(len(self.tokenizer))
        self.label_embed = nn.Embedding(self.num_class, q2l_config.hidden_size)

        if q2l_config.hidden_size != 512:
            self.wordvec_proj = nn.Linear(512, q2l_config.hidden_size)
        else:
            self.wordvec_proj = nn.Identity()

        self.fc = nn.Linear(q2l_config.hidden_size, 1)

        self.del_selfattention()

        # share weights of the lowest 2-layer of "image-tag interaction encoder" with the "image-tag recogntion decoder"
        tie_encoder_decoder_weights(self.tag_encoder, self.tagging_head, '',
                                    ' ')
        self.image_proj = nn.Linear(vision_width, 512)
        self.label_embed = nn.Parameter(torch.load('data/textual_label_embedding.pth',map_location='cpu').float())

        # adjust thresholds for some tags
        self.class_threshold = torch.ones(self.num_class) * self.threshold
        for key,value in enumerate(ram_class_threshold):
            self.class_threshold[key] = value

    def load_tag_list(self, tag_list_file):
        with open(tag_list_file, 'r') as f:
            tag_list = f.read().splitlines()
        tag_list = np.array(tag_list)
        return tag_list

    # delete self-attention layer of image-tag recognition decoder to reduce computation, follower Query2Label
    def del_selfattention(self):
        del self.tagging_head.embeddings
        for layer in self.tagging_head.encoder.layer:
            del layer.attention

    def generate_tag(self,
                 image,
                 threshold=0.68,
                 tag_input=None,
                 ):
            
        label_embed = torch.nn.functional.relu(self.wordvec_proj(self.label_embed))

        image_embeds = self.image_proj(self.visual_encoder(image))
        image_atts = torch.ones(image_embeds.size()[:-1],
                                dtype=torch.long).to(image.device)

        # recognized image tags using image-tag recogntiion decoder
        image_cls_embeds = image_embeds[:, 0, :]
        image_spatial_embeds = image_embeds[:, 1:, :]

        bs = image_spatial_embeds.shape[0]
        label_embed = label_embed.unsqueeze(0).repeat(bs, 1, 1)
        tagging_embed = self.tagging_head(
            encoder_embeds=label_embed,
            encoder_hidden_states=image_embeds,
            encoder_attention_mask=image_atts,
            return_dict=False,
            mode='tagging',
        )

        logits = self.fc(tagging_embed[0]).squeeze(-1)

        targets = torch.where(
            torch.sigmoid(logits) > self.class_threshold.to(image.device),
            torch.tensor(1.0).to(image.device),
            torch.zeros(self.num_class).to(image.device))

        tag = targets.cpu().numpy()
        tag[:,self.delete_tag_index] = 0
        tag_output = []
        tag_output_chinese = []
        for b in range(bs):
            index = np.argwhere(tag[b] == 1)
            token = self.tag_list[index].squeeze(axis=1)
            tag_output.append(' | '.join(token))
            token_chinese = self.tag_list_chinese[index].squeeze(axis=1)
            tag_output_chinese.append(' | '.join(token_chinese))


        return tag_output, tag_output_chinese


class Tag2Text_Caption(nn.Module):

    def __init__(self,
                 med_config=f'{CONFIG_PATH}/configs/med_config.json',
                 image_size=384,
                 vit='base',
                 vit_grad_ckpt=False,
                 vit_ckpt_layer=0,
                 prompt='a picture of ',
                 threshold=0.68,
                 delete_tag_index=[127,2961, 3351, 3265, 3338, 3355, 3359],
                 tag_list=f'{CONFIG_PATH}/data/tag_list.txt'):
        r""" Tag2Text inference module, both captioning and tagging are included.
        Tag2Text is an efficient and controllable vision-language pre-training framework.
        Described in the paper "Tag2Text: Guiding Vision-Language Model via Image Tagging" https://arxiv.org/abs/2303.05657

        Args:
            med_config (str): path for the mixture of encoder-decoder model's configuration file
            image_size (int): input image size
            vit (str): model size of vision transformer
            threshold (int): tagging threshold
            delete_tag_index (list): delete some tags that may disturb captioning
        """
        super().__init__()

        # create image encoder
        if vit == 'swin_b':
            if image_size == 224:
                vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_224.json'
            elif image_size == 384:
                vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_384.json'
            vision_config = read_json(vision_config_path)
            assert image_size == vision_config['image_res']
            # assert config['patch_size'] == 32
            vision_width = vision_config['vision_width']

            self.visual_encoder = SwinTransformer(
                img_size=vision_config['image_res'],
                patch_size=4,
                in_chans=3,
                embed_dim=vision_config['embed_dim'],
                depths=vision_config['depths'],
                num_heads=vision_config['num_heads'],
                window_size=vision_config['window_size'],
                mlp_ratio=4.,
                qkv_bias=True,
                drop_rate=0.0,
                drop_path_rate=0.1,
                ape=False,
                patch_norm=True,
                use_checkpoint=False)

        else:
            self.visual_encoder, vision_width = create_vit(
                vit, image_size, vit_grad_ckpt, vit_ckpt_layer)

        # create tokenzier
        self.tokenizer = init_tokenizer()

        # Tag2Text employ encoder-decoder architecture for image-tag-text generation: image-tag interaction encoder and image-tag-text decoder
        # create image-tag interaction encoder
        encoder_config = BertConfig.from_json_file(med_config)
        encoder_config.encoder_width = vision_width
        self.tag_encoder = BertModel(config=encoder_config,
                                     add_pooling_layer=False)

        # create image-tag-text decoder
        decoder_config = BertConfig.from_json_file(med_config)
        self.text_decoder = BertLMHeadModel(config=decoder_config)

        # delete some tags that may disturb captioning
        # 127: "quarter"; 2961: "back"; 3351: "two"; 3265: "three"; 3338: "four"; 3355: "five"; 3359: "one"
        self.delete_tag_index = delete_tag_index
        self.prompt = prompt
        self.prompt_length = len(self.tokenizer(self.prompt).input_ids) - 1

        # load tag list
        self.tag_list = self.load_tag_list(tag_list)

        # create image-tag recognition decoder
        self.threshold = threshold
        self.num_class = len(self.tag_list)
        q2l_config = BertConfig.from_json_file(f'{CONFIG_PATH}/configs/q2l_config.json')
        q2l_config.encoder_width = vision_width
        self.tagging_head = BertModel(config=q2l_config,
                                      add_pooling_layer=False)
        self.tagging_head.resize_token_embeddings(len(self.tokenizer))
        self.label_embed = nn.Embedding(self.num_class, q2l_config.hidden_size)
        self.fc = GroupWiseLinear(self.num_class,
                                  q2l_config.hidden_size,
                                  bias=True)
        self.del_selfattention()

        # share weights of the lowest 2-layer of "image-tag interaction encoder" with the "image-tag recogntion decoder"
        tie_encoder_decoder_weights(self.tag_encoder, self.tagging_head, '',
                                    ' ')

        # adjust thresholds for some tags
        # default threshold: 0.68
        # 2701: "person"; 2828: "man"; 1167: "woman"; 
        tag_thrshold = {2701:0.7, 2828: 0.7, 1167: 0.7}
        self.class_threshold = torch.ones(self.num_class) * self.threshold
        for key,value in tag_thrshold.items():
            self.class_threshold[key] = value

    def load_tag_list(self, tag_list_file):
        with open(tag_list_file, 'r') as f:
            tag_list = f.read().splitlines()
        tag_list = np.array(tag_list)
        return tag_list

    # delete self-attention layer of image-tag recognition decoder to reduce computation, follower Query2Label
    def del_selfattention(self):
        del self.tagging_head.embeddings
        for layer in self.tagging_head.encoder.layer:
            del layer.attention

    def generate(self,
                 image,
                 sample=False,
                 num_beams=3,
                 max_length=30,
                 min_length=10,
                 top_p=0.9,
                 repetition_penalty=1.0,
                 tag_input=None,
                 return_tag_predict=False):

        image_embeds = self.visual_encoder(image)
        image_atts = torch.ones(image_embeds.size()[:-1],
                                dtype=torch.long).to(image.device)

        # if not user specified tags, recognized image tags using image-tag recogntiion decoder
        if tag_input == None:
            image_cls_embeds = image_embeds[:, 0, :]
            image_spatial_embeds = image_embeds[:, 1:, :]

            bs = image_spatial_embeds.shape[0]
            label_embed = self.label_embed.weight.unsqueeze(0).repeat(bs, 1, 1)
            tagging_embed = self.tagging_head(
                encoder_embeds=label_embed,
                encoder_hidden_states=image_embeds,
                encoder_attention_mask=image_atts,
                return_dict=False,
                mode='tagging',
            )

            logits = self.fc(tagging_embed[0])

            targets = torch.where(
                torch.sigmoid(logits) > self.class_threshold,
                torch.tensor(1.0).to(image.device),
                torch.zeros(self.num_class).to(image.device))

            tag = targets.cpu().numpy()

            # delete some tags that may disturb captioning
            tag[:, self.delete_tag_index] = 0

            tag_input = []
            for b in range(bs):
                index = np.argwhere(tag[b] == 1)
                token = self.tag_list[index].squeeze(axis=1)
                tag_input.append(' | '.join(token))
                
        tag_output = tag_input

        # beam search for text generation(default)
        if not sample:
            image_embeds = image_embeds.repeat_interleave(num_beams, dim=0)
            tag_input_temp = []
            for tag in tag_input:
                for i in range(num_beams):
                    tag_input_temp.append(tag)
            tag_input = tag_input_temp

        image_atts = torch.ones(image_embeds.size()[:-1],
                                dtype=torch.long).to(image.device)

        # tokenizer input tags
        tag_input_tokenzier = self.tokenizer(tag_input,
                                             padding='max_length',
                                             truncation=True,
                                             max_length=40,
                                             return_tensors="pt").to(
                                                 image.device)
        encoder_input_ids = tag_input_tokenzier.input_ids
        encoder_input_ids[:, 0] = self.tokenizer.enc_token_id

        # put input tag into image-tag interaction encoder to interact with image embeddings
        output_tagembedding = self.tag_encoder(
            encoder_input_ids,
            attention_mask=tag_input_tokenzier.attention_mask,
            encoder_hidden_states=image_embeds,
            encoder_attention_mask=image_atts,
            return_dict=True,
        )

        # prompt trick for better captioning, followed BLIP
        prompt = [self.prompt] * image.size(0)
        input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(
            image.device)
        input_ids[:, 0] = self.tokenizer.bos_token_id
        input_ids = input_ids[:, :-1]

        if sample:
            # nucleus sampling
            model_kwargs = {
                "encoder_hidden_states": output_tagembedding.last_hidden_state,
                "encoder_attention_mask": None
            }
            outputs = self.text_decoder.generate(
                input_ids=input_ids,
                max_length=max_length,
                min_length=min_length,
                do_sample=True,
                top_p=top_p,
                num_return_sequences=1,
                eos_token_id=self.tokenizer.sep_token_id,
                pad_token_id=self.tokenizer.pad_token_id,
                repetition_penalty=1.1,
                **model_kwargs)
        else:
            # beam search (default)
            model_kwargs = {
                "encoder_hidden_states": output_tagembedding.last_hidden_state,
                "encoder_attention_mask": None
            }
            outputs = self.text_decoder.generate(
                input_ids=input_ids,
                max_length=max_length,
                min_length=min_length,
                num_beams=num_beams,
                eos_token_id=self.tokenizer.sep_token_id,
                pad_token_id=self.tokenizer.pad_token_id,
                repetition_penalty=repetition_penalty,
                **model_kwargs)

        captions = []
        for output in outputs:
            caption = self.tokenizer.decode(output, skip_special_tokens=True)
            captions.append(caption[len(self.prompt):])
        if return_tag_predict == True:
            return  captions, tag_output
        return captions


# load Tag2Text pretrained model parameters
def tag2text_caption(pretrained='', **kwargs):
    model = Tag2Text_Caption(**kwargs)
    if pretrained:
        if kwargs['vit'] == 'swin_b':
            model, msg = load_checkpoint_swinbase(model, pretrained, kwargs)
        else:
            model, msg = load_checkpoint(model, pretrained)
        print('vit:', kwargs['vit'])
        print('msg', msg)
    return model

def ram_own(pretrained="", **kwargs):
  model = MyRAM(**kwargs)
  if pretrained:
      if kwargs['vit'] == 'swin_b':
          model, msg = load_checkpoint_swinbase(model, pretrained, kwargs)
      elif kwargs['vit'] == 'swin_l':
          model, msg = load_checkpoint_swinlarge(model, pretrained, kwargs)
      else:
          model, msg = load_checkpoint(model, pretrained)
      print('vit:', kwargs['vit'])
      print('msg', msg)
  return model

# load RAM pretrained model parameters
def ram(pretrained='', **kwargs):
    model = RAM(**kwargs)
    if pretrained:
        if kwargs['vit'] == 'swin_b':
            model, msg = load_checkpoint_swinbase(model, pretrained, kwargs)
        elif kwargs['vit'] == 'swin_l':
            model, msg = load_checkpoint_swinlarge(model, pretrained, kwargs)
        else:
            model, msg = load_checkpoint(model, pretrained)
        print('vit:', kwargs['vit'])
        print('msg', msg)
    return model

inference_ram.py

'''
 * The Recognize Anything Model (RAM)
 * Written by Xinyu Huang
'''
import argparse
import numpy as np
import random

import torch
import torchvision.transforms as transforms

from PIL import Image
from models.tag2text import ram, ram_own

parser = argparse.ArgumentParser(
    description='Tag2Text inferece for tagging and captioning')
parser.add_argument('--image',
                    metavar='DIR',
                    help='path to dataset',
                    default='images/1641173_2291260800.jpg')
parser.add_argument('--pretrained',
                    metavar='DIR',
                    help='path to pretrained model',
                    default='pretrained/tag2text_swin_14m.pth')
parser.add_argument('--image-size',
                    default=384,
                    type=int,
                    metavar='N',
                    help='input image size (default: 448)')

def inference(image, model):

    with torch.no_grad():
        tags, tags_chinese = model.generate_tag(image)

    return tags[0],tags_chinese[0]

def inference_own(image, model):
  with torch.no_grad():
    caption, tags = model.generate(image, 
                                              tag_input=None,
                                              max_length=50,
                                              return_tag_predict=True)
  return caption[0], tags[0]


if __name__ == "__main__":

    args = parser.parse_args()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    transform = transforms.Compose([
        transforms.Resize((args.image_size, args.image_size)),
        transforms.ToTensor(), normalize
    ])

    #######load model
    model = ram_own(pretrained=args.pretrained,
                             image_size=args.image_size,
                             vit='swin_l')
    model.eval()

    model = model.to(device)
    raw_image = Image.open(args.image).convert("RGB").resize(
        (args.image_size, args.image_size))
    image = transform(raw_image).unsqueeze(0).to(device)

    # res = inference(image, model)
    # print("Image Tags: ", res[0])
    # print("图像标签: ", res[1])
    res = inference_own(image, model)
    print("Image Tags: ", res[1])
    print("Image Caption:", res[0])

inference command

python inference_ram.py  --image images/1641173_2291260800.jpg --pretrained pretrained/ram_swin_large_14m.pth

Please enjoy the Tag2Text Life!

about transformers version compatibility

Dose it compatible with higher transformers version?

if run it in my env of verison : transformers 4.28.0
then it occurs an error:

./models/bert.py", line 229, in forward
RuntimeError: The size of tensor a (3) must match the size of tensor b (9) at non-singleton dimension 0

requests.exceptions.ConnectionError

hello, thanks for your great work.
when i run
python batch_inference.py --pretrained ../../recognize-anything/pretrain/ram_swin_large_14m.pth --image-dir image_dir --model-type ram, i got the following error, could you please give some hint to solve it? thanks so much!

Traceback (most recent call last):
  File "/opt/conda/envs/*/lib/python3.8/site-packages/urllib3/connectionpool.py", line 790, in urlopen
    response = self._make_request(
  File "/opt/conda/envs/*/lib/python3.8/site-packages/urllib3/connectionpool.py", line 491, in _make_request
    raise new_e
  File "/opt/conda/envs/*/lib/python3.8/site-packages/urllib3/connectionpool.py", line 467, in _make_request
    self._validate_conn(conn)
  File "/opt/conda/envs/*/lib/python3.8/site-packages/urllib3/connectionpool.py", line 1092, in _validate_conn
    conn.connect()
  File "/opt/conda/envs/*/lib/python3.8/site-packages/urllib3/connection.py", line 635, in connect
    sock_and_verified = _ssl_wrap_socket_and_match_hostname(
  File "/opt/conda/envs/*/lib/python3.8/site-packages/urllib3/connection.py", line 774, in _ssl_wrap_socket_and_match_hostname
    ssl_sock = ssl_wrap_socket(
  File "/opt/conda/envs/*/lib/python3.8/site-packages/urllib3/util/ssl_.py", line 459, in ssl_wrap_socket
    ssl_sock = _ssl_wrap_socket_impl(sock, context, tls_in_tls, server_hostname)
  File "/opt/conda/envs/*/lib/python3.8/site-packages/urllib3/util/ssl_.py", line 503, in _ssl_wrap_socket_impl
    return ssl_context.wrap_socket(sock, server_hostname=server_hostname)
  File "/opt/conda/envs/*/lib/python3.8/ssl.py", line 500, in wrap_socket
    return self.sslsocket_class._create(
  File "/opt/conda/envs/*/lib/python3.8/ssl.py", line 1040, in _create
    self.do_handshake()
  File "/opt/conda/envs/*/lib/python3.8/ssl.py", line 1309, in do_handshake
    self._sslobj.do_handshake()
ConnectionResetError: [Errno 104] Connection reset by peer

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/opt/conda/envs/*/lib/python3.8/site-packages/requests/adapters.py", line 486, in send
    resp = conn.urlopen(
  File "/opt/conda/envs/*/lib/python3.8/site-packages/urllib3/connectionpool.py", line 844, in urlopen
    retries = retries.increment(
  File "/opt/conda/envs/*/lib/python3.8/site-packages/urllib3/util/retry.py", line 470, in increment
    raise reraise(type(error), error, _stacktrace)
  File "/opt/conda/envs/*/lib/python3.8/site-packages/urllib3/util/util.py", line 38, in reraise
    raise value.with_traceback(tb)
  File "/opt/conda/envs/*/lib/python3.8/site-packages/urllib3/connectionpool.py", line 790, in urlopen
    response = self._make_request(
  File "/opt/conda/envs/*/lib/python3.8/site-packages/urllib3/connectionpool.py", line 491, in _make_request
    raise new_e
  File "/opt/conda/envs/*/lib/python3.8/site-packages/urllib3/connectionpool.py", line 467, in _make_request
    self._validate_conn(conn)
  File "/opt/conda/envs/*/lib/python3.8/site-packages/urllib3/connectionpool.py", line 1092, in _validate_conn
    conn.connect()
  File "/opt/conda/envs/*/lib/python3.8/site-packages/urllib3/connection.py", line 635, in connect
    sock_and_verified = _ssl_wrap_socket_and_match_hostname(
  File "/opt/conda/envs/*/lib/python3.8/site-packages/urllib3/connection.py", line 774, in _ssl_wrap_socket_and_match_hostname
    ssl_sock = ssl_wrap_socket(
  File "/opt/conda/envs/*/lib/python3.8/site-packages/urllib3/util/ssl_.py", line 459, in ssl_wrap_socket
    ssl_sock = _ssl_wrap_socket_impl(sock, context, tls_in_tls, server_hostname)
  File "/opt/conda/envs/*/lib/python3.8/site-packages/urllib3/util/ssl_.py", line 503, in _ssl_wrap_socket_impl
    return ssl_context.wrap_socket(sock, server_hostname=server_hostname)
  File "/opt/conda/envs/*/lib/python3.8/ssl.py", line 500, in wrap_socket
    return self.sslsocket_class._create(
  File "/opt/conda/envs/*/lib/python3.8/ssl.py", line 1040, in _create
    self.do_handshake()
  File "/opt/conda/envs/*/lib/python3.8/ssl.py", line 1309, in do_handshake
    self._sslobj.do_handshake()
urllib3.exceptions.ProtocolError: ('Connection aborted.', ConnectionResetError(104, 'Connection reset by peer'))

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "batch_inference.py", line 232, in <module>
    main()
  File "batch_inference.py", line 199, in main
    model = initialize_model(
  File "batch_inference.py", line 81, in initialize_model
    model = ram(pretrained=pretrained,
  File "/mnt/csi-data-aly/user/congcongli/recognize-anything/ram/models/ram.py", line 263, in ram
    model = RAM(**kwargs)
  File "/mnt/csi-data-aly/user/congcongli/recognize-anything/ram/models/ram.py", line 103, in __init__
    self.tokenizer = init_tokenizer()
  File "/mnt/csi-data-aly/user/congcongli/recognize-anything/ram/models/utils.py", line 131, in init_tokenizer
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
  File "/opt/conda/envs/*/lib/python3.8/site-packages/transformers/tokenization_utils_base.py", line 1654, in from_pretrained
    fast_tokenizer_file = get_fast_tokenizer_file(
  File "/opt/conda/envs/*/lib/python3.8/site-packages/transformers/tokenization_utils_base.py", line 3486, in get_fast_tokenizer_file
    all_files = get_list_of_files(
  File "/opt/conda/envs/*/lib/python3.8/site-packages/transformers/file_utils.py", line 2103, in get_list_of_files
    return list_repo_files(path_or_repo, revision=revision, token=token)
  File "/opt/conda/envs/*/lib/python3.8/site-packages/huggingface_hub/utils/_deprecation.py", line 103, in inner_f
    return f(*args, **kwargs)
  File "/opt/conda/envs/*/lib/python3.8/site-packages/huggingface_hub/utils/_validators.py", line 118, in _inner_fn
    return fn(*args, **kwargs)
  File "/opt/conda/envs/*/lib/python3.8/site-packages/huggingface_hub/hf_api.py", line 2087, in list_repo_files
    return [
  File "/opt/conda/envs/*/lib/python3.8/site-packages/huggingface_hub/hf_api.py", line 2087, in <listcomp>
    return [
  File "/opt/conda/envs/*/lib/python3.8/site-packages/huggingface_hub/hf_api.py", line 2053, in list_files_info
    for subpath_info in paginate(path=tree_url, headers=headers, params={"recursive": True, "expand": expand}):
  File "/opt/conda/envs/*/lib/python3.8/site-packages/huggingface_hub/utils/_pagination.py", line 35, in paginate
    r = session.get(path, params=params, headers=headers)
  File "/opt/conda/envs/*/lib/python3.8/site-packages/requests/sessions.py", line 600, in get
    return self.request("GET", url, **kwargs)
  File "/opt/conda/envs/*/lib/python3.8/site-packages/requests/sessions.py", line 587, in request
    resp = self.send(prep, **send_kwargs)
  File "/opt/conda/envs/*/lib/python3.8/site-packages/requests/sessions.py", line 701, in send
    r = adapter.send(request, **kwargs)
  File "/opt/conda/envs/*/lib/python3.8/site-packages/requests/adapters.py", line 501, in send
    raise ConnectionError(err, request=request)
requests.exceptions.ConnectionError: ('Connection aborted.', ConnectionResetError(104, 'Connection reset by peer'))

`UnpicklingError: invalid load key, 'v'.` when loading the model

I try to run this code:

import torch
import torchvision.transforms as transforms
from models.tag2text import tag2text_caption, ram

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
IMAGE_SIZE = 384
CHECKPOINT_RAM = "ram_swin_large_14m.pth"

normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225]
)
transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    normalize
])

model_ram = ram(pretrained=CHECKPOINT_RAM, image_size=IMAGE_SIZE, vit='swin_l' )

model_ram.eval()
model_ram = model_ram.to(DEVICE)

And get this exception:

/encoder/layer/0/crossattention/self/query is tied
/encoder/layer/0/crossattention/self/key is tied
/encoder/layer/0/crossattention/self/value is tied
/encoder/layer/0/crossattention/output/dense is tied
/encoder/layer/0/crossattention/output/LayerNorm is tied
/encoder/layer/0/intermediate/dense is tied
/encoder/layer/0/output/dense is tied
/encoder/layer/0/output/LayerNorm is tied
/encoder/layer/1/crossattention/self/query is tied
/encoder/layer/1/crossattention/self/key is tied
/encoder/layer/1/crossattention/self/value is tied
/encoder/layer/1/crossattention/output/dense is tied
/encoder/layer/1/crossattention/output/LayerNorm is tied
/encoder/layer/1/intermediate/dense is tied
/encoder/layer/1/output/dense is tied
/encoder/layer/1/output/LayerNorm is tied
--------------
ram_swin_large_14m.pth
--------------
---------------------------------------------------------------------------
UnpicklingError                           Traceback (most recent call last)
[<ipython-input-16-8a868b36a5c6>](https://localhost:8080/#) in <cell line: 8>()
      6 # model_tag2text = model_tag2text.to(DEVICE)
      7 
----> 8 model_ram = ram(pretrained=CHECKPOINT_RAM, image_size=IMAGE_SIZE, vit='swin_l' )
      9 
     10 model_ram.eval()

3 frames
[/usr/local/lib/python3.10/dist-packages/torch/serialization.py](https://localhost:8080/#) in _legacy_load(f, map_location, pickle_module, **pickle_load_args)
   1031             "functionality.")
   1032 
-> 1033     magic_number = pickle_module.load(f, **pickle_load_args)
   1034     if magic_number != MAGIC_NUMBER:
   1035         raise RuntimeError("Invalid magic number; corrupt file?")

UnpicklingError: invalid load key, 'v'.

I'm running my code in Google Colab. Here is the link to my code: https://colab.research.google.com/drive/155jbQL31PrKxRrEq0V8TSs8KdreYitZC?usp=sharing

Would be awesome if you could help me to set it up in a notebook. I'm thinking of making a tutorial about it.

specified tag for ram in batch inference

Hey
I hope this message finds you well. Firstly, I would like to express my appreciation for your hard work and your prompt and helpful responses to our queries.I am writing to report an issue that I have encountered while using the batch inference feature of your software. Specifically, when specifying the tag argument and using RAM, the function fails to work properly and returns null.I have included a sample output below for your reference:

{'filepath': '/media/mmohseni/ubuntu/storage/ImageTaggerDataLake/2049.jpg', 'model_identified_tags': 'wall | ceiling | ceiling fan | pillar | equipment | fill | floor | gym | room', 'user_specified_tags': None, 'image_caption': None}
I would be grateful if you could look into this matter and provide any guidance or support to resolve the issue.

Thank you for your time and assistance.

Best regards,

RAM vs Multi Label classification

Could you clarify whether your model functions as a multi-label classifier with a large number of classes? If not, would you mind elaborating on the advantages of your model compared to a multi-label classifier?

Request for Code: Text Semantic Parser for Tag Extraction

Hi there! First of all, thank you for sharing your code and models. I've been reading your paper and came across Section 3.2, where you mentioned using a text semantic parser to extract tags from raw captions. It would be immensely helpful if you could also share the code for the semantic parser. With access to this code, I'll be able to test your models on other datasets effectively. Thank you once again for your contribution!

I meet an error

File "/Users/chin/miniconda3/envs/python3.8/lib/python3.8/site-packages/torch/serialization.py", line 815, in load
return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
File "/Users/chin/miniconda3/envs/python3.8/lib/python3.8/site-packages/torch/serialization.py", line 1033, in _legacy_load
magic_number = pickle_module.load(f, **pickle_load_args)
_pickle.UnpicklingError: invalid load key, '<'.

关于聚类

我们使用ground-dino 分割后,对于每个tag 去使用kmeans++ 聚类,聚类的内容是什么,图像特征(使用什么抽取)?聚类个数设置?

Training Code

Thanks for your fascinating work , Can't wait for the training code .

Example for recognising on an open-set of tags

Hi there,

Firstly, fantastic work and thank you for sharing!

Second, do you mind providing a small code example to recognize an open-set of tags?

Eg.

from ram import get_transform, inference_ram, inference_tag2text
from ram.models import ram, tag2text_caption

ram_model = ram(pretrained=ram_checkpoint, image_size=image_size, vit='swin_l').eval().to(device)

# Setup image
image = ... 

# Custom tags
tags = ["house", "car", "pig"] 

# Perform inference
result = inference_ram(model, image, custom_tags)

how to modify the results to be displayed in Chinese?

Hello, thank you for your work. May I ask how to modify the results to be displayed in Chinese? I change the bert-base-uncased file in init_tokenizer and change the Cased file to bert-base-chinese. The result is wrong

图像文本特征相似度比对

你好,作者,我想利用该项目进行图像和文本的相似度比对,但我发现图像特征维度是1x1024,文本特征是1x768,两者尺寸不一样,不知道我理解的对不对?期待你的答复

Some questions about grad-CAM showing in fig7 in paper Tag2Text.

I use gradCAM to visualize the same image as the paper, but I get a weird result different from the fig7.
When I use the word "cat" to calculate the heatmap: the result is like this:
image
when I change the word to "siamese"(a kind of cat), the result looks ok
image
sorry, when I raise this issue, I find I use the weights of ram rather T2T, however, the gradCAM fig is in the T2T paper. But it still seems weird in ram, can you tell me what caused this?
I don't remember whether swin-transformer is also used in t2t. If it is not used in t2t, I suspect it is caused by this; other than that, all I can think of is CLIP text encoder and text encoder trained by yourself in T2T.

Please try again or make sure your Internet connection is on.

When I use the RAM Inference demo. Something wrong. How can I fix this error.

Traceback (most recent call last):
File "inference_ram.py", line 52, in
model = ram(pretrained=args.pretrained,
File "/home/developments/Recognize_Anything-Tag2Text/models/tag2text.py", line 476, in ram
model = RAM(**kwargs)
File "/home/developments/Recognize_Anything-Tag2Text/models/tag2text.py", line 103, in init
self.tokenizer = init_tokenizer()
File "/home/developments/Recognize_Anything-Tag2Text/models/utils.py", line 131, in init_tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
File "/root/anaconda3/envs/image2label/lib/python3.8/site-packages/transformers/tokenization_utils_base.py", line 1695, in from_pretrained
resolved_vocab_files[file_id] = cached_path(
File "/root/anaconda3/envs/image2label/lib/python3.8/site-packages/transformers/file_utils.py", line 1776, in cached_path
output_path = get_from_cache(
File "/root/anaconda3/envs/image2label/lib/python3.8/site-packages/transformers/file_utils.py", line 2000, in get_from_cache
raise ValueError(
ValueError: Connection error, and we cannot find the requested files in the cached path. Please try again or make sure your Internet connection is on.

A question about the loss in t2t

loss = loss_t2t + loss_tag/(loss_tag/loss_t2t).detach()

I think the result of this loss is always equal to 2*loss_t2t, which will cause the result to be irrelevant to the loss_tag. What am I doing wrong? I did think that since “(loss_tag/loss_t2t).detach()” doesn’t make sense in backward, it can be regarded as a constant, but this constant is not fixed, but can be changed, which leads to the result that is essentially only related to loss_t2t . Is this what we want?
It seems that you want to realize the idea of dynamically changing the two loss scales, but is it really possible to do it in this way? Hope to get your reply, this is so confusing to me.

About training code

Halo guys,
Thanks for your great work.
Will your training code be released on time(7.8) and will it include the training code of RAM?

Training code and datasets

Thanks for your wonderful work. When can you release the training code and datasets? Can't wait to use it😊

Image sizes

Hello,

what image sizes are supported by ram and tag2text models?

missing dependency for clip

(.env) ➜  recognize-anything git:(main) ✗ python inference_ram_openset.py --image test_photos/dog.jpg --pretrained pretrained/ram_swin_large_14m.pth
Traceback (most recent call last):
  File "/Users/Arseny/dev/recognize-anything/inference_ram_openset.py", line 15, in <module>
    from models.openset_utils import build_openset_label_embedding
  File "/Users/Arseny/dev/recognize-anything/models/openset_utils.py", line 6, in <module>
    from clip import clip
ModuleNotFoundError: No module named 'clip'

Installing CLIP from git+https://github.com/openai/CLIP.git helps though.

pre-trained checkpoint with 4M data

First of all, thanks for sharing great work!

I'm trying to use your pre-trained network for my research.
But could you also share pre-trained networks with a 4M dataset (without CC12M)?
(If possible, please share both ViT and Swin-transformer versions)

关于gif图的报错

用一个gif图在huggingface上面测试,一切正常. 当时使用inference.py脚本执行的时候,会出现错误. 错误信息如下:

D:\work\Tag2Text>python D:/work/Tag2Text/inference.py --image C:/Users/gaoyo/Desktop/9346691egw1fb9c0fxebug20zk0jg4bl.gif
Traceback (most recent call last):
  File "D:/work/Tag2Text/inference.py", line 92, in <module>
    image = transform(raw_image).unsqueeze(0).to(device)
  File "C:\Users\gaoyo\.conda\envs\base\lib\site-packages\torchvision\transforms\transforms.py", line 95, in __call__
    img = t(img)
  File "C:\Users\gaoyo\.conda\envs\base\lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\Users\gaoyo\.conda\envs\base\lib\site-packages\torchvision\transforms\transforms.py", line 277, in forward
    return F.normalize(tensor, self.mean, self.std, self.inplace)
  File "C:\Users\gaoyo\.conda\envs\base\lib\site-packages\torchvision\transforms\functional.py", line 363, in normalize
    return F_t.normalize(tensor, mean=mean, std=std, inplace=inplace)
  File "C:\Users\gaoyo\.conda\envs\base\lib\site-packages\torchvision\transforms\_functional_tensor.py", line 928, in normalize
    return tensor.sub_(mean).div_(std)
RuntimeError: output with shape [1, 384, 384] doesn't match the broadcast shape [3, 384, 384]

测试使用的git图:
9346691egw1fb9c0fxebug20zk0jg4bl

huggingface上面执行正常的截图:
20230606110204

期望是:inference.py 里面也可以工作正常

@xinyu1205 麻烦看一下,是哪里出问题了,多谢.

question about Figure3

Hi,

Great work! I just got a quick question: what does the "Cross Attention" mean in your figure 3 in your paper?
Does it means it inputs the cross attention maps or values into latter process?

Thanks!

Error in running openset example

Hello,

I am trying to run inference with openset. I get the following error:

checkpoint, the shape in current model is torch.Size([768, 512]).
size mismatch for tagging_head.encoder.layer.1.crossattention.self.value.weight: copying a param with shape torch.Size([768, 1024]) from checkpoint, the shape in current model is torch.Size([768, 512]).

File "/Users/csv610/Projects/SegmentAnything/RAM/recognize-anything/ram/models/ram.py", line 77, in init
vision_config = read_json(vision_config_path)
^^^^^^^^^^^^^^^^^^
UnboundLocalError: cannot access local variable 'vision_config_path' where it is not associated with a value

_IncompatibleKeys

Thanks for developing such an amazing project!

While trying to run the 'RAM Inference' and 'RAM Inference on Unseen Categories (Open-Set)' I met some problems.

When I tried to run 'RAM Inference', it showed 'msg _IncompatibleKeys ...' but after that it still printed right Image Tags.

(RAM) apple@appledeMacBook-Pro-7 recognize-anything % python inference_ram.py --image images/1641173_2291260800.jpg \ --pretrained pretrained/ram_swin_large_14m.pth /encoder/layer/0/crossattention/self/query is tied /encoder/layer/0/crossattention/self/key is tied /encoder/layer/0/crossattention/self/value is tied /encoder/layer/0/crossattention/output/dense is tied /encoder/layer/0/crossattention/output/LayerNorm is tied /encoder/layer/0/intermediate/dense is tied /encoder/layer/0/output/dense is tied /encoder/layer/0/output/LayerNorm is tied /encoder/layer/1/crossattention/self/query is tied /encoder/layer/1/crossattention/self/key is tied /encoder/layer/1/crossattention/self/value is tied /encoder/layer/1/crossattention/output/dense is tied /encoder/layer/1/crossattention/output/LayerNorm is tied /encoder/layer/1/intermediate/dense is tied /encoder/layer/1/output/dense is tied /encoder/layer/1/output/LayerNorm is tied "--------------" pretrained/ram_swin_large_14m.pth "--------------" load checkpoint from pretrained/ram_swin_large_14m.pth vit: swin_l msg _IncompatibleKeys(missing_keys=['visual_encoder.layers.0.blocks.0.attn.relative_position_index', 'visual_encoder.layers.0.blocks.1.attn_mask', 'visual_encoder.layers.0.blocks.1.attn.relative_position_index', 'visual_encoder.layers.1.blocks.0.attn.relative_position_index', 'visual_encoder.layers.1.blocks.1.attn_mask', 'visual_encoder.layers.1.blocks.1.attn.relative_position_index', 'visual_encoder.layers.2.blocks.0.attn.relative_position_index', 'visual_encoder.layers.2.blocks.1.attn_mask', 'visual_encoder.layers.2.blocks.1.attn.relative_position_index', 'visual_encoder.layers.2.blocks.2.attn.relative_position_index', 'visual_encoder.layers.2.blocks.3.attn_mask', 'visual_encoder.layers.2.blocks.3.attn.relative_position_index', 'visual_encoder.layers.2.blocks.4.attn.relative_position_index', 'visual_encoder.layers.2.blocks.5.attn_mask', 'visual_encoder.layers.2.blocks.5.attn.relative_position_index', 'visual_encoder.layers.2.blocks.6.attn.relative_position_index', 'visual_encoder.layers.2.blocks.7.attn_mask', 'visual_encoder.layers.2.blocks.7.attn.relative_position_index', 'visual_encoder.layers.2.blocks.8.attn.relative_position_index', 'visual_encoder.layers.2.blocks.9.attn_mask', 'visual_encoder.layers.2.blocks.9.attn.relative_position_index', 'visual_encoder.layers.2.blocks.10.attn.relative_position_index', 'visual_encoder.layers.2.blocks.11.attn_mask', 'visual_encoder.layers.2.blocks.11.attn.relative_position_index', 'visual_encoder.layers.2.blocks.12.attn.relative_position_index', 'visual_encoder.layers.2.blocks.13.attn_mask', 'visual_encoder.layers.2.blocks.13.attn.relative_position_index', 'visual_encoder.layers.2.blocks.14.attn.relative_position_index', 'visual_encoder.layers.2.blocks.15.attn_mask', 'visual_encoder.layers.2.blocks.15.attn.relative_position_index', 'visual_encoder.layers.2.blocks.16.attn.relative_position_index', 'visual_encoder.layers.2.blocks.17.attn_mask', 'visual_encoder.layers.2.blocks.17.attn.relative_position_index', 'visual_encoder.layers.3.blocks.0.attn.relative_position_index', 'visual_encoder.layers.3.blocks.1.attn.relative_position_index'], unexpected_keys=[]) Image Tags: brush | dirt road | flower | path | hillside | lake | lead to | mountain | mountain path | road | trail | tree | water | yellow 图像标签: 刷子 | 泥土路 | 花 | 小路 | 山坡 | 湖泊 | 通向 | 山 | 山路 | 路 | 小道 | 树 | 水 | 黄色

But when I tried 'RAM Inference on Unseen Categories (Open-Set)', same thing happened but printed no tags:

(RAM) apple@appledeMacBook-Pro-7 recognize-anything % python inference_ram_openset.py --image images/openset_example.jpg \ --pretrained pretrained/ram_swin_large_14m.pth /encoder/layer/0/crossattention/self/query is tied /encoder/layer/0/crossattention/self/key is tied /encoder/layer/0/crossattention/self/value is tied /encoder/layer/0/crossattention/output/dense is tied /encoder/layer/0/crossattention/output/LayerNorm is tied /encoder/layer/0/intermediate/dense is tied /encoder/layer/0/output/dense is tied /encoder/layer/0/output/LayerNorm is tied /encoder/layer/1/crossattention/self/query is tied /encoder/layer/1/crossattention/self/key is tied /encoder/layer/1/crossattention/self/value is tied /encoder/layer/1/crossattention/output/dense is tied /encoder/layer/1/crossattention/output/LayerNorm is tied /encoder/layer/1/intermediate/dense is tied /encoder/layer/1/output/dense is tied /encoder/layer/1/output/LayerNorm is tied "--------------" pretrained/ram_swin_large_14m.pth "--------------" load checkpoint from pretrained/ram_swin_large_14m.pth vit: swin_l msg _IncompatibleKeys(missing_keys=['visual_encoder.layers.0.blocks.0.attn.relative_position_index', 'visual_encoder.layers.0.blocks.1.attn_mask', 'visual_encoder.layers.0.blocks.1.attn.relative_position_index', 'visual_encoder.layers.1.blocks.0.attn.relative_position_index', 'visual_encoder.layers.1.blocks.1.attn_mask', 'visual_encoder.layers.1.blocks.1.attn.relative_position_index', 'visual_encoder.layers.2.blocks.0.attn.relative_position_index', 'visual_encoder.layers.2.blocks.1.attn_mask', 'visual_encoder.layers.2.blocks.1.attn.relative_position_index', 'visual_encoder.layers.2.blocks.2.attn.relative_position_index', 'visual_encoder.layers.2.blocks.3.attn_mask', 'visual_encoder.layers.2.blocks.3.attn.relative_position_index', 'visual_encoder.layers.2.blocks.4.attn.relative_position_index', 'visual_encoder.layers.2.blocks.5.attn_mask', 'visual_encoder.layers.2.blocks.5.attn.relative_position_index', 'visual_encoder.layers.2.blocks.6.attn.relative_position_index', 'visual_encoder.layers.2.blocks.7.attn_mask', 'visual_encoder.layers.2.blocks.7.attn.relative_position_index', 'visual_encoder.layers.2.blocks.8.attn.relative_position_index', 'visual_encoder.layers.2.blocks.9.attn_mask', 'visual_encoder.layers.2.blocks.9.attn.relative_position_index', 'visual_encoder.layers.2.blocks.10.attn.relative_position_index', 'visual_encoder.layers.2.blocks.11.attn_mask', 'visual_encoder.layers.2.blocks.11.attn.relative_position_index', 'visual_encoder.layers.2.blocks.12.attn.relative_position_index', 'visual_encoder.layers.2.blocks.13.attn_mask', 'visual_encoder.layers.2.blocks.13.attn.relative_position_index', 'visual_encoder.layers.2.blocks.14.attn.relative_position_index', 'visual_encoder.layers.2.blocks.15.attn_mask', 'visual_encoder.layers.2.blocks.15.attn.relative_position_index', 'visual_encoder.layers.2.blocks.16.attn.relative_position_index', 'visual_encoder.layers.2.blocks.17.attn_mask', 'visual_encoder.layers.2.blocks.17.attn.relative_position_index', 'visual_encoder.layers.3.blocks.0.attn.relative_position_index', 'visual_encoder.layers.3.blocks.1.attn.relative_position_index'], unexpected_keys=[]) Image Tags:

May I ask if it is caused by some problems with my environment or something else and how I could fix it? Thank you again for this amazing project, and appreciate your assistance in resolving this issue.

Benchmark on RAM (Evalution metric)

Hi,
I am currently working on batch inferring Ram for COCO Test data and Open-imagesV6, using your repository as a reference. While the repository provides an excellent implementation for the RAM model, I noticed that there is no mention of an evaluation metric for benchmarking the detection method.

To ensure accurate performance assessment and enable meaningful comparisons, it would be immensely helpful to have an evaluation metric specifically tailored for the RAM model. I understand that evaluation metrics can vary depending on the specific task and dataset, but having a standard metric would greatly benefit users like me who are interested in benchmarking the detection capabilities of the RAM model.

Could you please provide guidance on the recommended evaluation metric for assessing the performance of the RAM model in the context of detection tasks? Any insights you can offer on how to achieve a benchmark using the RAM model would be greatly appreciated.

Thank you

运行的结果没有示例图中好

非常棒的工作,标注效果相比Blip的有了很大的提升!nice!

ram_grounded_sam
主业的这张图中RAM的结果中如你展示和提醒的是有lamp和door标签的,但是我跑出来的结果中却没有
image
是什么原因导致的呢?

Tag to class names for downstream applications?

Hi,

Thanks for this amazing project! I'm planning to use this to generate class names for open-vocabulary segmentation. However, I found a small issue. When I pass an image of a room into RAM, it will generate tags like "room, building, floor, ceiling, wall" etc, however, I only need "floor ceiling, walls", since these are the most granular classes, and "room" and "building" are not good prompts for segmentation.

Is there a way to filter the RAM output to this kind of granular class? Any suggestion is appreciated.

Best,

DX

Tag2Text prediction error occured

RAM model works fine, but Tag2Text prediction throws an error

This is error context (recognize_anything_demo.ipynb)

You selected Tag2Text
You selected one image
/root/anaconda3/lib/python3.9/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.16.5 and <1.23.0 is required for this version of SciPy (detected version 1.25.0
  warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
/encoder/layer/0/crossattention/self/query is tied
/encoder/layer/0/crossattention/self/key is tied
/encoder/layer/0/crossattention/self/value is tied
/encoder/layer/0/crossattention/output/dense is tied
/encoder/layer/0/crossattention/output/LayerNorm is tied
/encoder/layer/0/intermediate/dense is tied
/encoder/layer/0/output/dense is tied
/encoder/layer/0/output/LayerNorm is tied
/encoder/layer/1/crossattention/self/query is tied
/encoder/layer/1/crossattention/self/key is tied
/encoder/layer/1/crossattention/self/value is tied
/encoder/layer/1/crossattention/output/dense is tied
/encoder/layer/1/crossattention/output/LayerNorm is tied
/encoder/layer/1/intermediate/dense is tied
/encoder/layer/1/output/dense is tied
/encoder/layer/1/output/LayerNorm is tied
--------------
pretrained/tag2text_swin_14m.pth
--------------
load checkpoint from pretrained/tag2text_swin_14m.pth
vit: swin_b
Traceback (most recent call last):
  File "/ext_hdd_02/yhkim/copyright_project/expirments/recognize-anything/inference_tag2text.py", line 71, in <module>
    res = inference(image, model, args.specified_tags)
  File "/ext_hdd_02/yhkim/copyright_project/expirments/recognize-anything/ram/inference.py", line 11, in inference_tag2text
    caption, tag_predict = model.generate(image,
  File "/ext_hdd_02/yhkim/copyright_project/expirments/recognize-anything/ram/models/tag2text.py", line 259, in generate
    caption = self.tokenizer.decode(output, skip_special_tokens=True)
  File "/root/anaconda3/lib/python3.9/site-packages/transformers/tokenization_utils_base.py", line 3507, in decode
    token_ids = to_py_obj(token_ids)
  File "/root/anaconda3/lib/python3.9/site-packages/transformers/utils/generic.py", line 207, in to_py_obj
    elif is_tf_tensor(obj):
  File "/root/anaconda3/lib/python3.9/site-packages/transformers/utils/generic.py", line 166, in is_tf_tensor
    return False if not is_tf_available() else _is_tensorflow(x)
  File "/root/anaconda3/lib/python3.9/site-packages/transformers/utils/generic.py", line 157, in _is_tensorflow
    import tensorflow as tf
  File "/root/anaconda3/lib/python3.9/site-packages/tensorflow/__init__.py", line 41, in <module>
    from tensorflow.python.tools import module_util as _module_util
  File "/root/anaconda3/lib/python3.9/site-packages/tensorflow/python/__init__.py", line 46, in <module>
    from tensorflow.python import data
  File "/root/anaconda3/lib/python3.9/site-packages/tensorflow/python/data/__init__.py", line 25, in <module>
    from tensorflow.python.data import experimental
  File "/root/anaconda3/lib/python3.9/site-packages/tensorflow/python/data/experimental/__init__.py", line 97, in <module>
    from tensorflow.python.data.experimental import service
  File "/root/anaconda3/lib/python3.9/site-packages/tensorflow/python/data/experimental/service/__init__.py", line 353, in <module>
    from tensorflow.python.data.experimental.ops.data_service_ops import distribute
  File "/root/anaconda3/lib/python3.9/site-packages/tensorflow/python/data/experimental/ops/data_service_ops.py", line 26, in <module>
    from tensorflow.python.data.experimental.ops import compression_ops
  File "/root/anaconda3/lib/python3.9/site-packages/tensorflow/python/data/experimental/ops/compression_ops.py", line 20, in <module>
    from tensorflow.python.data.util import structure
  File "/root/anaconda3/lib/python3.9/site-packages/tensorflow/python/data/util/structure.py", line 26, in <module>
    from tensorflow.python.data.util import nest
  File "/root/anaconda3/lib/python3.9/site-packages/tensorflow/python/data/util/nest.py", line 40, in <module>
    from tensorflow.python.framework import sparse_tensor as _sparse_tensor
  File "/root/anaconda3/lib/python3.9/site-packages/tensorflow/python/framework/sparse_tensor.py", line 28, in <module>
    from tensorflow.python.framework import constant_op
  File "/root/anaconda3/lib/python3.9/site-packages/tensorflow/python/framework/constant_op.py", line 29, in <module>
    from tensorflow.python.eager import execute
  File "/root/anaconda3/lib/python3.9/site-packages/tensorflow/python/eager/execute.py", line 27, in <module>
    from tensorflow.python.framework import dtypes
  File "/root/anaconda3/lib/python3.9/site-packages/tensorflow/python/framework/dtypes.py", line 585, in <module>
    np.object,
  File "/root/anaconda3/lib/python3.9/site-packages/numpy/__init__.py", line 313, in __getattr__
    raise AttributeError(__former_attrs__[attr])
AttributeError: module 'numpy' has no attribute 'object'.
`np.object` was a deprecated alias for the builtin `object`. To avoid this error in existing code, use `object` by itself. Doing this will not modify any behavior and is safe. 
The aliases was originally deprecated in NumPy 1.20; for more details and guidance see the original release note at:
    https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations

Implementation of negative tags

Hi, is there any way to embed negative tags in the project?
These tags should work the opposite of the usual ones, that is, what will be indicated as a negative tag will be ignored in the description of the images.

I made myself a mini GUI and would like to add negative tags to it, but I don't know if it's really possible to pull it off.
image

Can you confirm that I understand the overall system architecture (Figure 3) of the thesis?

In paper, figure 3
image

Can you confirm that I understand the overall system architecture (Figure 3) of the thesis?
The architecture structure of the thesis I understand is as follows.

  1. Image Encoder

    • Extract features from the input data of the image.
  2. Image-Tag Recognition Decoder

    • Predict tags from image feature data.
    • Tag data obtains tags through parsing in given data (image-text)
  3. Image-Tag Interaction Encoder and Image-Tag-Text Generation Decoder

    • It receives the feature data and tag data of the image as input and creates a sentence that describes the image.
    • Example:
      • inputs: [cat, lay, suitcase, pllow] + image-feature
      • output: A cat laying in a suitcase next to the pillow
  4. CLIP Text Encoder

    • image feature + tag list embedding?....(not sure)
  5. Textual Label Queries

    • It seems to support to learning Image-Tag Recognition Decoder
    • And only used training step

I'm reading the thesis, but I don't quite understand the contents of #4 and #5.
Could you please explain it easily?

Questions about Open Images V6

I found that you released two datasets OpenImages-common and OpenImages-rare, I would like to know how to get these two datasets on Open Images. Or do I need to download all the pictures of Open Images V6 to the local, and then filter the data according to your file names in ram_annots.txt?

What are the prompts during training?

Hi, Thank you for your excellent work.
I am curious about the prompt templates (eg, 'a photo of a {}') for the tags during training. Are these templates similar to those utilized in CLIP? However, the prompt templates used in CLIP seem to be more appropriate for noun tags, as opposed to adjective or verb tags (eg, 'red' or 'play').
Thanks.

How to generate own textual_label_embedding?

Hello, thank you for your great work!

I am trying to implement zero-shot inference myself. But I have no idea about the textual label embeding generation. I have tested a few model like "openai/clip-vit-large-patch14", "openai/clip-vit-base-patch32", "laion/CLIP-ViT-B-16-laion2B-s34B-b88K", "laion/CLIP-ViT-H-14-laion2B-s32B-b79K". But I fail to reproduce the embedding of "ram_tag_list".

P.s. When will zero-shot inference open-source?

有没有方法可以提升下执行效率

用10多张图片试了下,效率有点低,在个人电脑上处理一张图片大约需要几十秒(40秒)左右,如果图片数量比较多的话,处理起来,就比较吃力了。 有没法什么方法,可以提示一下效率?如果在1,2,3秒内,能执行结束的话,就很棒了.

Inference_ram.py example using Tag2Text

Hello,

I am confused. It seems ithe nference_ram.py example is using

default='pretrained/tag2text_swin_14m.pth'

which is the same as the model used in the inferecen_tag2text.py.

Can someone clarify if this was not a typo?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.