GithubHelp home page GithubHelp logo

lee-plus-plus / info-nce-pytorch Goto Github PK

View Code? Open in Web Editor NEW

This project forked from relbers/info-nce-pytorch

0.0 0.0 0.0 129 KB

PyTorch implementation of the InfoNCE loss for self-supervised learning.

License: MIT License

Python 100.00%

info-nce-pytorch's Introduction

InfoNCE

PyTorch implementation of the InfoNCE loss from "Representation Learning with Contrastive Predictive Coding". In contrastive learning, we want to learn how to map high dimensional data to a lower dimensional embedding space. This mapping should place semantically similar samples close together in the embedding space, whilst placing semantically distinct samples further apart. The InfoNCE loss function can be used for the purpose of contrastive learning.

This package is available on PyPI and can be installed via:

pip install info-nce-pytorch

Example usage

Import this package.

from info_nce import InfoNCE, info_nce

Can be used without explicit negative keys, whereby each sample is compared with the other samples in the batch.

loss = InfoNCE()
batch_size, embedding_size = 32, 128
query = torch.randn(batch_size, embedding_size)
positive_key = torch.randn(batch_size, embedding_size)
output = loss(query, positive_key)

Can be used with negative keys, whereby every combination between query and negative key is compared.

loss = InfoNCE(negative_mode='unpaired') # negative_mode='unpaired' is the default value
batch_size, num_negative, embedding_size = 32, 48, 128
query = torch.randn(batch_size, embedding_size)
positive_key = torch.randn(batch_size, embedding_size)
negative_keys = torch.randn(num_negative, embedding_size)
output = loss(query, positive_key, negative_keys)

Can be used with negative keys, whereby each query sample is compared with only the negative keys it is paired with.

loss = InfoNCE(negative_mode='paired')
batch_size, num_negative, embedding_size = 32, 6, 128
query = torch.randn(batch_size, embedding_size)
positive_key = torch.randn(batch_size, embedding_size)
negative_keys = torch.randn(batch_size, num_negative, embedding_size)
output = loss(query, positive_key, negative_keys)

Loss graph

Suppose we have some initial mean vectors µ_q, µ_p, µ_n and a covariance matrix Σ = I/10, then we can plot the value of the InfoNCE loss by sampling from distributions with interpolated mean vectors. Given interpolation weights α and β, we define the distribution Q ~ N(µ_q, Σ) for the query samples, the distribution P_α ~ N(αµ_q + (1-α)µ_p, Σ) for the positive samples and the distribution N_β ~ N(βµ_q + (1-β)µ_n, Σ) for the negative samples. Shown below is the value of the loss with inputs sampled from the distributions defined above for different values of α and β.

image

info-nce-pytorch's People

Contributors

lee-plus-plus avatar relbers avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.