GithubHelp home page GithubHelp logo

fangs22 / onlineminingtripletloss Goto Github PK

View Code? Open in Web Editor NEW

This project forked from negation/onlineminingtripletloss

0.0 0.0 0.0 1.06 MB

PyTorch conversion of https://omoindrot.github.io/triplet-loss

Home Page: https://www.jrishaug.com/OnlineMiningTripletLoss

License: MIT License

Python 27.99% Makefile 1.14% Jupyter Notebook 70.87%

onlineminingtripletloss's Introduction

online_triplet_loss

PyTorch conversion of the excellent post on the same topic in Tensorflow. Simply an implementation of a triple loss with online mining of candidate triplets used in semi-supervised learning.

Install

pip install online_triplet_loss

Then import with: from online_triplet_loss.losses import *

PS: Requires Pytorch version 1.1.0 or above to use.

How to use

In these examples I use a really large margin, since the embedding space is so small. A more realistic margins seems to be between 0.1 and 2.0

from torch import nn
import torch

model = nn.Embedding(10, 10)
#from online_triplet_loss.losses import *
labels = torch.randint(high=10, size=(5,)) # our five labels

embeddings = model(labels)
print('Labels:', labels)
print('Embeddings:', embeddings)
loss = batch_hard_triplet_loss(labels, embeddings, margin=100)
print('Loss:', loss)
loss.backward()
Labels: tensor([6, 1, 3, 6, 6])
Embeddings: tensor([[-1.1335,  0.3364, -3.0174, -0.8732, -0.9301,  1.3619,  0.3746,  0.0457,
          0.0180, -0.4500],
        [ 1.0757, -0.8420, -0.7630, -0.0746,  1.1545,  0.4017,  0.5587,  1.7947,
          0.1992, -2.2288],
        [ 0.2646,  1.2383,  0.1949,  0.5743, -0.8460, -0.9929, -2.0350,  0.2095,
          0.2129, -0.4855],
        [-1.1335,  0.3364, -3.0174, -0.8732, -0.9301,  1.3619,  0.3746,  0.0457,
          0.0180, -0.4500],
        [-1.1335,  0.3364, -3.0174, -0.8732, -0.9301,  1.3619,  0.3746,  0.0457,
          0.0180, -0.4500]], grad_fn=<EmbeddingBackward>)
Loss: tensor(95.1271, grad_fn=<MeanBackward0>)
#from online_triplet_loss.losses import *
embeddings = model(labels)
print('Labels:', labels)
print('Embeddings:', embeddings)
loss, fraction_pos = batch_all_triplet_loss(labels, embeddings, squared=False, margin=100)
print('Loss:', loss)
loss.backward()
Labels: tensor([6, 1, 3, 6, 6])
Embeddings: tensor([[-1.1335,  0.3364, -3.0174, -0.8732, -0.9301,  1.3619,  0.3746,  0.0457,
          0.0180, -0.4500],
        [ 1.0757, -0.8420, -0.7630, -0.0746,  1.1545,  0.4017,  0.5587,  1.7947,
          0.1992, -2.2288],
        [ 0.2646,  1.2383,  0.1949,  0.5743, -0.8460, -0.9929, -2.0350,  0.2095,
          0.2129, -0.4855],
        [-1.1335,  0.3364, -3.0174, -0.8732, -0.9301,  1.3619,  0.3746,  0.0457,
          0.0180, -0.4500],
        [-1.1335,  0.3364, -3.0174, -0.8732, -0.9301,  1.3619,  0.3746,  0.0457,
          0.0180, -0.4500]], grad_fn=<EmbeddingBackward>)
tensor(94.9947, grad_fn=<DivBackward0>) tensor(1.)
Loss: tensor(94.9947, grad_fn=<DivBackward0>)

References

onlineminingtripletloss's People

Contributors

negation avatar dependabot[bot] avatar f-fl0 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.