GithubHelp home page GithubHelp logo

relbers / info-nce-pytorch Goto Github PK

View Code? Open in Web Editor NEW
392.0 3.0 37.0 130 KB

PyTorch implementation of the InfoNCE loss for self-supervised learning.

License: MIT License

Python 100.00%
contrastive-learning contrastive-loss self-supervised-learning

info-nce-pytorch's Introduction

InfoNCE

PyTorch implementation of the InfoNCE loss from "Representation Learning with Contrastive Predictive Coding". In contrastive learning, we want to learn how to map high dimensional data to a lower dimensional embedding space. This mapping should place semantically similar samples close together in the embedding space, whilst placing semantically distinct samples further apart. The InfoNCE loss function can be used for the purpose of contrastive learning.

This package is available on PyPI and can be installed via:

pip install info-nce-pytorch

Example usage

Import this package.

from info_nce import InfoNCE, info_nce

Can be used without explicit negative keys, whereby each sample is compared with the other samples in the batch.

loss = InfoNCE()
batch_size, embedding_size = 32, 128
query = torch.randn(batch_size, embedding_size)
positive_key = torch.randn(batch_size, embedding_size)
output = loss(query, positive_key)

Can be used with negative keys, whereby every combination between query and negative key is compared.

loss = InfoNCE(negative_mode='unpaired') # negative_mode='unpaired' is the default value
batch_size, num_negative, embedding_size = 32, 48, 128
query = torch.randn(batch_size, embedding_size)
positive_key = torch.randn(batch_size, embedding_size)
negative_keys = torch.randn(num_negative, embedding_size)
output = loss(query, positive_key, negative_keys)

Can be used with negative keys, whereby each query sample is compared with only the negative keys it is paired with.

loss = InfoNCE(negative_mode='paired')
batch_size, num_negative, embedding_size = 32, 6, 128
query = torch.randn(batch_size, embedding_size)
positive_key = torch.randn(batch_size, embedding_size)
negative_keys = torch.randn(batch_size, num_negative, embedding_size)
output = loss(query, positive_key, negative_keys)

Loss graph

Suppose we have some initial mean vectors µ_q, µ_p, µ_n and a covariance matrix Σ = I/10, then we can plot the value of the InfoNCE loss by sampling from distributions with interpolated mean vectors. Given interpolation weights α and β, we define the distribution Q ~ N(µ_q, Σ) for the query samples, the distribution P_α ~ N(αµ_q + (1-α)µ_p, Σ) for the positive samples and the distribution N_β ~ N(βµ_q + (1-β)µ_n, Σ) for the negative samples. Shown below is the value of the loss with inputs sampled from the distributions defined above for different values of α and β.

image

info-nce-pytorch's People

Contributors

lee-plus-plus avatar relbers avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

info-nce-pytorch's Issues

should I add parameters of InfoNCE to optimizer for optimizing?

I use pip install the infoGAN-pytorch, I know Mutual information is estimated by network. And when I use the code, should I add the parameters of class InfoNCE to the optimizer for optimizing, or I can directly use the function and it will give me the correct loss? Thank you!

multiple positive sampels per query

Hi,
Thanks very much for your work.
I wonder how should I modify your loss to adapting to the case of multiple positive samples per query.
For example, query.shape = (1,128), positive_keys.shape = (5, 128), negative_keys.shape = (5, 128).

Hope your reply

How to import info-nce-pytorch

Hi, I am a new to contrastive learning, when I installed the package by using pip install info-nce-pytorch on my server, I could not import the package by using from info-nce-pytorch import InfoNCE(), could you tell me the name to import the package?
image

Contribute to PyTorch?

Have you ever considered contributing this loss function to PyTorch? TorchGeo would love to use it for SimCLR/MoCo SSL but we are trying to avoid adding yet another dependency. If it was built into PyTorch, it would be much easier to use.

i have one question for the part of code

if negative_keys is not None:
# Explicit negative keys

    # Cosine between positive pairs
    positive_logit = torch.sum(query * positive_key, dim=1, keepdim=True)

    # Cosine between all query-negative combinations
    negative_logits = query @ transpose(negative_keys)

    # First index in last dimension are the positive samples
    logits = torch.cat([positive_logit, negative_logits], dim=1)
    labels = torch.zeros(len(logits), dtype=torch.long, device=query.device)

1)why the labels all are zero, Shouldn't there be a positive sample pairs labeled 1?
2)Is this cosine similarity? It should be just inner product?

Questions about codes for negative key is None

I have a question about codes for negative key =None:

If negative key is None, then Negative keys are implicitly off-diagonal positive keys.
# Cosine between all combinations
logits = query @ transpose(positive_key)

# Positive keys are the entries on the diagonal
labels = torch.arange(len(query), device=query.device)

why labels are torch.arrange(len(query)? for example: 0,1,2,3,4,5...I think labels for query-postive key should be torch.ones(len(query))

Question about input data

I have a quick question. My data is (data1, data2, label), where the label indicates if the pair of data points represent the same thing. 0 being not the same and 1 being the same. My objective is contrastive learning/loss. I take batches of the data, so I may have multiple positive pairs in each batch. What would be the optimal way to use this loss function with this data? Any help would be appreciated. Thanks.

calculate similarity for 3-dim input

i have inputs are:
anchors: torch.Size([8, 20, 128])
positives: torch.Size([8, 20, 128])
negatives: torch.Size([100, 8, 20, 128])

8-batch size, 20-num pairs, 128-embedding dim, 100-num negative samples (shuffled from positive samples)

show, can i calculate similarity for 3d inputs with your code?
thanks

[Issue] InfoCE rather than InfoNCE

According to the implementation, you used torch.nn.functional.cross_entropy for computation. This ignores off-diagonal components in similarity matrix. I implemented a NCE version of this and want to contribute to this repo.

I wonder if the way of InfoNCE I used was wrong( ´•̥̥̥ω•̥̥̥` )

Hi,

I try to optimize the code by using InfoNCE Loss and AAM loss, and code about InfoNCE as followed which is based your code:

def contrastive(self, embeddings_z: t.Tensor, embeddings: t.Tensor, logits: t.Tensor):
        logits1 = logits
        high = embeddings.shape[0]
        idx = random.randint(0,int(high)-1)
        
        query = embeddings[idx:idx+1]
        positive_key = embeddings_z[idx:idx+1]
        # negative_keys = embeddings_z
        negative_keys = t.cat((embeddings_z[:idx],embeddings_z[idx+1:]))
        query = F.normalize(query, dim=-1)
        positive_key = F.normalize(positive_key, dim=-1)
        negative_keys = F.normalize(negative_keys, dim=-1)
        

        
        # Cosine between positive pairs
        positive_logit = t.sum(query * positive_key, dim=1, keepdim=True)
        
        negative_logits = query @ self.transpose(negative_keys)
        logits = t.cat([positive_logit, negative_logits], dim=1)
        labels = t.zeros(len(logits), dtype=t.long, device=query.device)
        
        loss = F.cross_entropy(logits / self.temperature, labels, reduction=self.reduction)
        
        with t.no_grad():
            # put predictions into [0, 1] range for later calculation of accuracy
            prediction = F.softmax(logits1, dim=1).detach()
        
        return loss,prediction

Joint AAM and InfoNce as followed:

self.c_contrastive = nn.Parameter(torch.rand(1))
loss = self.c_aam * aam_loss + self.c_contrastive * contrastive_loss

The smaller loss result, the better performance. But when I ran the code, the c_contrastive always became negative, which was mean the bigger loss result the better performance. so I wonder if the code of InfoNCE I used was wrong.

I was trapped in this for a long time. Soooo looking forward to your reply: )

reduction mode

Thx for your implementation!
Have you ever considered about setting the reduction mode, it seems that can be easily modified.

Correspondance between query and negative_key

Hi, thanks for your implementation. I would love to confirm the following case with you:
Suppose I have each query that has one positive key and two negative_keys, should I organise my input as:

query[0] -->positive_key[0] --> negative_key[0] & negative_key[1]
query[1] -->positive_key[1] --> negative_key[2] & negative_key[3]
....

Thanks very much.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.