GithubHelp home page GithubHelp logo

chendelong1999 / remoteclip Goto Github PK

View Code? Open in Web Editor NEW
239.0 4.0 14.0 4.93 MB

🛰️ Official repository of paper "RemoteCLIP: A Vision Language Foundation Model for Remote Sensing" (IEEE TGRS)

Home Page: https://arxiv.org/abs/2306.11029

License: Apache License 2.0

Jupyter Notebook 97.34% Python 2.66%
remote-sensing vision-language contrastive-language-image-pretraining

remoteclip's Issues

code and dataset release

Hi, thanks for your excellent work, I wonder if you would release the code and the dataset you use?

Fine tuning?

Great paper and work! Thank you for publishing this.

In your github repo I don't see training or fine tuning code or details. Are those available?

Thanks!

Fine tuning training time?

I'd like to know more about how you all fine tuned your model using the base OpenCLIP weights. How long did it take and what GPUs did you end up using? We are thinking about fine tuning RemoteCLIP itself with some more domain specific imagery and want to get a general sense of the cost and time it took you all to do that yourselves. Thanks :)

Test set for Zero-shot Benchmark Datasets

Hi,

For zero-shot classification results, some datasets don't have specific test sets. In order for us to compare RemoteCLIP and reproduce the results, do you plan to provide more information on the test sets of various datasets? It would definitely be helpful. Thanks.

Expected state_dict to be dict-like, got <class 'torch.jit._script.RecursiveScriptModule'>.

`import torch, open_clip
from PIL import Image
from IPython.display import display

model_name = 'RN50' # 'RN50' or 'ViT-B-32' or 'ViT-L-14'
model, _, preprocess = open_clip.create_model_and_transforms(model_name)
tokenizer = open_clip.get_tokenizer(model_name)

path_to_your_checkpoints = 'checkpoints'

ckpt = torch.load(f"{path_to_your_checkpoints}/{model_name}.pt", map_location="cpu")
message = model.load_state_dict(ckpt)
print(message)
model = model.cuda().eval()i run the code , and then:n Module.load_state_dict(self, state_dict, strict)
1971 r"""Copies parameters and buffers from :attr:state_dict into
1972 this module and its descendants. If :attr:strict is True, then
1973 the keys of :attr:state_dict must exactly match the keys returned
(...)
1991 RuntimeError.
1992 """
1993 if not isinstance(state_dict, Mapping):
-> 1994 raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict)))
1996 missing_keys: List[str] = []
1997 unexpected_keys: List[str] = []

TypeError: Expected state_dict to be dict-like, got <class 'torch.jit._script.RecursiveScriptModule'>.`

data bias

hi! thank you for your excellent work. How did you overcome data bias during your training process? The number of samples in your generated data far exceeds that of RET3, and theoretically, the model parameters would tend to favor the generated data. Have you encountered this issue? How did you resolve it?thank you!

For Zero-shot inference

For Zero-shot inference in the paper, the results in your paper reported 68.62% and 77.96% for raw CLIP and Remote CLIP results on the AID dataset under the ViT-B-32 backbone, respectively. Using the same template-based prompting as you (a satellite photo of {class name}) my result is only 0.195% when using raw CLIP for inference. It's a big difference from your results, so I would like to ask if the CLIP in your paper is the CLIP after continuous training or the original model posted on the OpenAI website?

RET-3 Deduplication

Hi,

Deduplication: Remove the (almost) same image presence in both the test set of RSICD (one of the evaluation sets) and the training set of RSITMD (part of RET-3), and vice versa.

In the paper you mentioned, "We generate p-Hash values for all images and used these values to detect duplicate images. If the number of different digits between two images is less than threshold 2, they are considered duplicates. Finally, the number of removed duplicated samples ranges from 40 to 3k in different datasets"

Would you mind providing a filename list of the de-duplicated images from RET-3 (which not appear in the test set of RSICD and RSITMD)?

Thanks

T-SNE visualization?

Hi, I would like to know which TSNE tool you utilize to visualize your image samples like Fig. 6 in your paper? It takes a large amount of time when I apply the TSNE tool from Scikit-learn to large-scale data. It will be appreciated if you can answer my question!

Pretrained model loading

When I tried to load the pre-trained model provided by the author using the ITRA method, I found that it can only be loaded through the '--resume' parameter for existing checkpoints. Moreover, the author's model could not be loaded successfully due to inconsistencies in some layers. I believe that in the 'openclip' loading method within model.py, code to load the model should be added.

一图多caption的相似度矩阵计算干扰问题

您好,您的工作非常具有创新性,给caption生成带来了一套简单实用的方法,非常感谢你们的工作。
我有一个问题想要请教一下,当CLIP计算相似度矩阵的时候,一个图像有多个caption的话,矩阵计算的时候会将其他四个视为负样本,那么就会对结果造成干扰。请问你们是如何处理模型训练和验证过程中的“一图多文相互干扰”的问题的呢?谢谢!

The retrieval evaluation code is not activated accurately in RSICD dataset.

My evaluation code with your model (ViT-L-14) is below.

from huggingface_hub import hf_hub_download
import open_clip
import numpy as np
import torchvision
import os
import json
import torch
from PIL import Image
from tqdm import tqdm
import torch.nn.functional as F
from clip_benchmark.metrics.zeroshot_retrieval import recall_at_k, batchify, dataloader_with_indices

class Dataset(torch.utils.data.Dataset):

    def __init__(self, base_path, transforms, filename):

        self.base_path = base_path
        self.transforms = transforms
        self.filename = filename

        self.data = self.load_annotation()

    def __getitem__(self, idx):
        image = self.data[idx]
        filename = os.path.join(self.base_path, image['filename'])
        raws = [i['raw'].replace(' .', '') for i in image['sentences']]
        return self.transforms(Image.open(filename)), raws

    def __len__(self):
        return len(self.data)

    def load_annotation(self):
        with open(self.filename, 'r') as f:
            data = json.load(f)
        images = data['images']
        test_images = list()
        for image in images:
            if image['split'] == 'test':
                test_images.append(image)
        return test_images

def main():

    device = torch.device('cuda')
    model_name = 'ViT-L-14'
    checkpoint_path = hf_hub_download(
        "chendelong/RemoteCLIP",
        f"RemoteCLIP-{model_name}.pt",
        cache_dir='checkpoints',
    )
    print(f'{model_name} is downloaded to {checkpoint_path}.')
    model, _, preprocess = open_clip.create_model_and_transforms(model_name)
    tokenizer = open_clip.get_tokenizer(model_name)
    path_to_your_checkpoints = 'checkpoints/models--chendelong--RemoteCLIP/snapshots/bf1d8a3ccf2ddbf7c875705e46373bfe542bce38'
    ckpt = torch.load(f"{path_to_your_checkpoints}/RemoteCLIP-{model_name}.pt", map_location="cpu")
    message = model.load_state_dict(ckpt)
    print(message)
    model = model.cuda().eval()

    dataset = Dataset(
        base_path='rsicd/RSICD_images',
        transforms=preprocess,
        filename='rsicd/RSICD_optimal/dataset_rsicd.json',
    )
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=128,
        num_workers=4,
        drop_last=False,
        shuffle=False,
    )
    n_batches = len(dataloader)

    # list of batch of images embedding
    batch_images_emb_list = []
    # list of batch of text embedding
    batch_texts_emb_list = []
    # for each text, we collect the corresponding image index, as each image can have multiple corresponding texts
    texts_image_index = []

    dataloader = dataloader_with_indices(dataloader)
    for batch_images, batch_texts, inds in tqdm(dataloader, total=n_batches):
        batch_images = batch_images.to(device)
        batch_texts_tok = [tokenizer(text) for i, texts in enumerate(batch_texts) for text in texts]
        batch_texts_image_index = [ind for ind, texts in zip(inds, batch_texts) for text in texts]

        with torch.no_grad():
            batch_image_features = model.encode_image(batch_images)
            batch_text_features = [model.encode_text(t.to(device)) for t in batch_texts_tok]
            batch_text_features = torch.cat(batch_text_features)

        batch_images_emb = F.normalize(batch_image_features, dim=-1)
        batch_texts_emb = F.normalize(batch_text_features, dim=-1)

        batch_images_emb_list.append(batch_images_emb.cpu())
        batch_texts_emb_list.append(batch_texts_emb.cpu())
        texts_image_index.extend(batch_texts_image_index)

    batch_size = len(batch_images_emb_list[0])

    images_emb = torch.cat(batch_images_emb_list)
    texts_emb = torch.cat(batch_texts_emb_list)

    # get the score for each text and image pair
    scores  = texts_emb @ images_emb.t()

    positive_pairs = torch.zeros_like(scores, dtype=bool)
    positive_pairs[torch.arange(len(scores)), texts_image_index] = True
    metrics = {}
    recall_k_list = [1, 5, 10]
    for recall_k in recall_k_list:
        metrics[f"retrieval-image2text-R@{recall_k}"] = (batchify(recall_at_k, scores.T, positive_pairs.T, batch_size, device, k=recall_k)>0).float().mean().item() * 100

    for recall_k in recall_k_list:
        metrics[f"retrieval-text2image-R@{recall_k}"] = (batchify(recall_at_k, scores, positive_pairs, batch_size, device, k=recall_k)>0).float().mean().item() * 100

    metrics[f"retrieval-mean-recall"] = np.mean(list(metrics.values()))

    for key, item in metrics.items():
        metrics[key] = round(float(item), 2)

    for key in metrics.keys():
        print(key, metrics[key])

The evaluation number is printed as below.

retrieval-image2text-R@1 0.37
retrieval-image2text-R@5 1.56
retrieval-image2text-R@10 2.29
retrieval-text2image-R@1 0.6
retrieval-text2image-R@5 2.76
retrieval-text2image-R@10 4.68
retrieval-mean-recall 2.04

Huggingface cannot be opened

Hi,
I have tried many times recently but the huggingface website cannot be accessed. Is it possible to upload the pre-trained model to other websites? This will effectively help me use this pre-trained model.

Mean/Standard deviation over training dataset to remove rogue embedding dimensions outside normal magnitudes?

I've previously been experimenting with the RSICD CLIP model [1], which was trained just over the RSICD dataset. I'm very impressed at how many data sources you've used to train RemoteCLIP, which you document in your paper with this table:

image

In my experiments with RSICD I've found that having the mean and standard deviation of the training data in order to do normalization was an important part of getting quality results, so that input imagery follows the remote sensing distribution set by the training data rather than the mean/std that the parent OpenCLIP model has, as remote sensing imagery obviously has a very different distribution that standard consumer photography.

Since you have so many datasets for your model, which is a strength, that unfortunately makes it quite hard for us to compute our own mean/std. Might it be possible for you to compute a per-band mean and std over the datasets you trained with that you might have locally? It's probably fine to compute this using a sampling strategy as long as the sample size is large enough.

Thanks again for such a great paper and open sourcing your project :)

[1] https://github.com/arampacha/CLIP-rsicd

Box to caption

Hi, I am very interested in your work. We have some remote sensing related data that need to be captioned. Could you please provide us with your specific box to caption code? I would be grateful if you could get a reply.

dataset

Could you provide Remote CLIP available dataset ?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.