My evaluation code with your model (ViT-L-14) is below.
from huggingface_hub import hf_hub_download
import open_clip
import numpy as np
import torchvision
import os
import json
import torch
from PIL import Image
from tqdm import tqdm
import torch.nn.functional as F
from clip_benchmark.metrics.zeroshot_retrieval import recall_at_k, batchify, dataloader_with_indices
class Dataset(torch.utils.data.Dataset):
def __init__(self, base_path, transforms, filename):
self.base_path = base_path
self.transforms = transforms
self.filename = filename
self.data = self.load_annotation()
def __getitem__(self, idx):
image = self.data[idx]
filename = os.path.join(self.base_path, image['filename'])
raws = [i['raw'].replace(' .', '') for i in image['sentences']]
return self.transforms(Image.open(filename)), raws
def __len__(self):
return len(self.data)
def load_annotation(self):
with open(self.filename, 'r') as f:
data = json.load(f)
images = data['images']
test_images = list()
for image in images:
if image['split'] == 'test':
test_images.append(image)
return test_images
def main():
device = torch.device('cuda')
model_name = 'ViT-L-14'
checkpoint_path = hf_hub_download(
"chendelong/RemoteCLIP",
f"RemoteCLIP-{model_name}.pt",
cache_dir='checkpoints',
)
print(f'{model_name} is downloaded to {checkpoint_path}.')
model, _, preprocess = open_clip.create_model_and_transforms(model_name)
tokenizer = open_clip.get_tokenizer(model_name)
path_to_your_checkpoints = 'checkpoints/models--chendelong--RemoteCLIP/snapshots/bf1d8a3ccf2ddbf7c875705e46373bfe542bce38'
ckpt = torch.load(f"{path_to_your_checkpoints}/RemoteCLIP-{model_name}.pt", map_location="cpu")
message = model.load_state_dict(ckpt)
print(message)
model = model.cuda().eval()
dataset = Dataset(
base_path='rsicd/RSICD_images',
transforms=preprocess,
filename='rsicd/RSICD_optimal/dataset_rsicd.json',
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=128,
num_workers=4,
drop_last=False,
shuffle=False,
)
n_batches = len(dataloader)
# list of batch of images embedding
batch_images_emb_list = []
# list of batch of text embedding
batch_texts_emb_list = []
# for each text, we collect the corresponding image index, as each image can have multiple corresponding texts
texts_image_index = []
dataloader = dataloader_with_indices(dataloader)
for batch_images, batch_texts, inds in tqdm(dataloader, total=n_batches):
batch_images = batch_images.to(device)
batch_texts_tok = [tokenizer(text) for i, texts in enumerate(batch_texts) for text in texts]
batch_texts_image_index = [ind for ind, texts in zip(inds, batch_texts) for text in texts]
with torch.no_grad():
batch_image_features = model.encode_image(batch_images)
batch_text_features = [model.encode_text(t.to(device)) for t in batch_texts_tok]
batch_text_features = torch.cat(batch_text_features)
batch_images_emb = F.normalize(batch_image_features, dim=-1)
batch_texts_emb = F.normalize(batch_text_features, dim=-1)
batch_images_emb_list.append(batch_images_emb.cpu())
batch_texts_emb_list.append(batch_texts_emb.cpu())
texts_image_index.extend(batch_texts_image_index)
batch_size = len(batch_images_emb_list[0])
images_emb = torch.cat(batch_images_emb_list)
texts_emb = torch.cat(batch_texts_emb_list)
# get the score for each text and image pair
scores = texts_emb @ images_emb.t()
positive_pairs = torch.zeros_like(scores, dtype=bool)
positive_pairs[torch.arange(len(scores)), texts_image_index] = True
metrics = {}
recall_k_list = [1, 5, 10]
for recall_k in recall_k_list:
metrics[f"retrieval-image2text-R@{recall_k}"] = (batchify(recall_at_k, scores.T, positive_pairs.T, batch_size, device, k=recall_k)>0).float().mean().item() * 100
for recall_k in recall_k_list:
metrics[f"retrieval-text2image-R@{recall_k}"] = (batchify(recall_at_k, scores, positive_pairs, batch_size, device, k=recall_k)>0).float().mean().item() * 100
metrics[f"retrieval-mean-recall"] = np.mean(list(metrics.values()))
for key, item in metrics.items():
metrics[key] = round(float(item), 2)
for key in metrics.keys():
print(key, metrics[key])
The evaluation number is printed as below.
retrieval-image2text-R@1 0.37
retrieval-image2text-R@5 1.56
retrieval-image2text-R@10 2.29
retrieval-text2image-R@1 0.6
retrieval-text2image-R@5 2.76
retrieval-text2image-R@10 4.68
retrieval-mean-recall 2.04