GithubHelp home page GithubHelp logo

abhijitanand / negative-cache Goto Github PK

View Code? Open in Web Editor NEW

This project forked from marceljahnke/negative-cache

0.0 0.0 0.0 59 KB

PyTorch Implementation of the Paper "Efficient Training of Retrieval Models using Negative Cache"

License: Apache License 2.0

Python 100.00%

negative-cache's Introduction

Efficient Training of Retrieval Models using Negative Cache

Tests

This repository contains a PyTorch implementation of the paper Efficient Training of Retrieval Models using Negative Cache. It's a training approach for a dual encoder, that uses a memory efficient negative streaming cache.

The general idea, according to the authors, is to sample negatives from the cache and use them in combination with GumbelMax-sampling to approximate the cross-entropy loss function, at each iteration. By design the cache can store a large amount of negatives in a memory efficient way.

The original implementation can be found here.



Installation

To install, run the following commands:

git clone [email protected]:marceljahnke/negative-cache.git
cd negative-cache
python -m pip install .

You can also install the package in editable mode:

python -m pip install -e .

Usage

The generel usage is very similair to the tensorflow version. This chapter explains the usage on a single GPU, see DistributedDataParallel usage for multi-gpu usage.

Set up the specs that describe the document feature dictionary. These describe the feature keys and shapes for the items we need to cache. The document features represent the features that are used to compute the embedding by using the document_network.

from negative_cache.negative_cache import CacheManager, FixedLenFeature
from negative_cache.handlers import CacheLossHandler
from negative_cache.losses import CacheClassificationLoss, DistributedCacheClassificationLoss
data_keys = ('document_feature_1', 'document_feature_2')
embedding_key = 'embedding'
specs = {
    'document_feature_1': FixedLenFeature(shape=[document_feature_1_size], dtype=torch.int32),
    'document_feature_2': FixedLenFeature(shape=[document_feature_2_size], dtype=torch.int32),
    'embedding': FixedLenFeature(shape=[embedding_size], dtype=torch.float32)
}

Set up the cache loss.

cache_manager = CacheManager(specs, cache_size=131072)
cache_loss = CacheClassificationLoss(
    embedding_key=embedding_key,
    data_keys=data_keys,
    score_transform=lambda score: 20.0 * score,  # Optional, applied to scores before loss.
    top_k=64  # Optional, restricts returned elements to the top_k highest scores.
)
handler = CacheLossHandler(
    cache_manager, cache_loss, embedding_key=embedding_key, data_keys=data_keys)

Calculate the cache loss using your query and document networks and data.

query_embeddings = query_network(query_data)
document_embeddings = document_network(document_data)
loss = handler.update_cache_and_compute_loss(
    document_network, 
    query_embeddings,
    document_embeddings, 
    document_data,
    writer # Optional, used to log additional information to tensorboard.
    )

You can call the handler with an optional Tuple writer = (torch.utils.tensorboard.SummaryWriter, global_step) to log additional information, i.e. interpretable loss and staleness of the cache.

Special cases

If your document features consists of only one feature, pass it as a tuple containing only one item:

data_keys = ('document_feature_1',)

DistributedDataParallel usage

When using DistributedDataParallel do not use a lambda function for the score_transform, instead write a regular function. This way the function is pickleable.

def fn(scores):
    return 20 * scores

...
score_transform=fn

instead of

score_transform=lambda scores: 20.0 * scores,

You also want to use the DistributedCacheClassificationLoss instead of the CacheClassificationLoss:

cache_loss = DistributedCacheClassificationLoss(
            embedding_key=embedding_key,
            data_keys=data_keys,
            score_transform=score_transformation,
            top_k=top_k,
        )

negative-cache's People

Contributors

marceljahnke avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.