GithubHelp home page GithubHelp logo

Comments (2)

VicenteVivan avatar VicenteVivan commented on July 17, 2024

Hi Oskar,

Thank you for your interest in our work. Regarding your questions:

i. It seems like you are using the same index $i$ for both the $i^{\text {th }}$ sample of a batch, the sum over the batch, and the sum over the dynamical queue.

Yes, you are correct. As you mentioned, we do contrastive learning by comparing the encodings of each image to those of all other coordinates, while keeping the image being compared constant (i.e. we apply the Cross Entropy Loss independently for each row in the similarity matrix). More concretely, for a given batch of images and GPS coordinates, the code to calculate the loss after obtaining the features of the images and the GPS coordinates of the batch and the queue would be as follows:

import torch
from torch import nn
import torch.nn.functional as F

BATCH_SIZE = 256
QUEUE_SIZE = 4096
FEATURE_DIM = 512
TEMPERATURE = torch.rand([])

image_embeddings = torch.rand(BATCH_SIZE, FEATURE_DIM)
gps_embeddings = torch.rand(BATCH_SIZE, FEATURE_DIM)
gps_queue_embeddings = torch.rand(QUEUE_SIZE, FEATURE_DIM)

# Criterion & Targets
criterion = nn.CrossEntropyLoss()
targets_img_gps = torch.Tensor([i for i in range(BATCH_SIZE)]).long()

# for (img_batch, gps_batch) in epoch:

#  ... forward pass & queue update ...

# Normalize the Embeddings
image_embeddings = F.normalize(image_embeddings, dim=1)
gps_embeddings = F.normalize(gps_embeddings, dim=1)
gps_queue_embeddings = F.normalize(gps_queue_embeddings, dim=1)

# Append GPS Queue
gps_embeddings_all = torch.cat([gps_embeddings, gps_queue_embeddings], dim=0)

# Get the temperature
temp = TEMPERATURE.exp()

# Compute the logits
logits_img_gps = temp * (image_embeddings @ gps_embeddings_all.T)

# Compute the loss
img_gps_loss = criterion(logits_img_gps, targets_img_gps)

print(logits_img_gps.shape) # (BATCH_SIZE, BATCH_SIZE + QUEUE_SIZE)
print(img_gps_loss)

ii. If it is true that you do contrastive learning of each image over all other coordinates, why did you decide not to do contrastive learning of each GPS coordinate over all other images? In fact in the original CLIP paper, the Cross Entropy Loss is utilized both horizontally and vertically, yet you have chosen only to use it horizontally. Is there a specific reason for this decision?

That's a good observation. In fact, we originally considered and implemented this idea during the early stages of the project. From what we observed, this minor modification did not provide any improvements compared to its only horizontal counterpart. On top of that, given that adding a queue of GPS coordinates significantly improved GeoCLIP's performance, applying the vertical loss would have complicated the loss function since there are no positive samples for GPS coordinates in the queue. Thus, we decided not to include it in the final method.

iii. Going back to the $P$ augmented views, you mention in your paper that a benefit of using a frozen CLIP backbone is that one can pre-encode all images, making the training process faster. Yet if you perform $P$ augmentations for each image and for each batch, didn't you have to re-encode the augmented images again, thus not being able to take advantage from this benefit?

For each image in our training set, we didn't only pre-encode a single embedding for each image, but we pre-encoded $n$ augmentations of it (with $n = 10$ in our particular case). Then, during training, we would sample a subset of these augmentations for each corresponding image.

Please, let us know if you have any more questions.

from geo-clip.

Oshkr avatar Oshkr commented on July 17, 2024

Hi Vicente,
Thank you for your thorough response!

from geo-clip.

Related Issues (10)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.