GithubHelp home page GithubHelp logo

wangxing-xx / polyloss-pytorch Goto Github PK

View Code? Open in Web Editor NEW

This project forked from abhuse/polyloss-pytorch

0.0 0.0 0.0 7 KB

Polyloss Pytorch Implementation

License: MIT License

Python 100.00%

polyloss-pytorch's Introduction

PolyLoss in Pytorch

PolyLoss implementation in Pytorch as described in:
[Leng et al. 2022] PolyLoss: A Polynomial Expansion Perspective of Classification Loss Functions

Both Poly-Cross-Entropy and Poly-Focal losses are provided.

Examples

import torch

# Poly1 Cross-Entropy Loss
# classification task
batch_size = 10
num_classes = 5
logits = torch.rand([batch_size, num_classes])
labels = torch.randint(0, num_classes, [batch_size])
loss = Poly1CrossEntropyLoss(num_classes=num_classes, 
                             reduction='mean')
out = loss(logits, labels)
out.backward()
# optimizer.step()


# Poly1 Focal Loss
## Case 1. labels hold class ids
# batch_size, num_classes, height, width
B, num_classes, H, W = 2, 3, 4, 7
logits = torch.rand([B, num_classes, H, W])
labels = torch.randint(0, num_classes, [B, H, W])

# optional, class-wise weights, shape must be broadcastable to [B, num_classes, H, W]
# put 5 times more weight to class id 2
pos_weight = torch.tensor([1., 1., 5.]).reshape([1, num_classes, 1, 1])

loss = Poly1FocalLoss(num_classes=num_classes,
                      reduction='mean',
                      label_is_onehot=False,
                      pos_weight=pos_weight)

out = loss(logits, labels)
# out.backward()
# optimizer.step()


## Case 2. labels are one-hot or multi-hot (in case of multi-label task) encoded
# batch_size, num_classes, height, width
B, num_classes, H, W = 2, 3, 4, 7
logits = torch.rand([B, num_classes, H, W])
labels = torch.rand([B, num_classes, H, W]) # labels are of same shape as logits

# optionally provide class-wise weights, shape must be broadcastable to [B, num_classes, H, W]
# put 5 times more weight to class id 2
pos_weight = torch.tensor([1., 1., 5.]).reshape([1, num_classes, 1, 1])
# weight tensor shape [1, num_classes, 1, 1] is broadcastable to [B, num_classes, H, W]

loss = Poly1FocalLoss(num_classes=num_classes,
                      reduction='mean',
                      label_is_onehot=True,
                      pos_weight=pos_weight)

out = loss(logits, labels)
# out.backward()
# optimizer.step()

Parameters

Poly1CrossEntropyLoss

  • num_classes, (int) - Number of classes
  • epsilon, (float), (Default=1.0) - PolyLoss epsilon
  • reduction, (str), (Default='none') - apply reduction to the output, one of: none | sum | mean
  • weight, (torch.Tensor), (Default=None) - manual rescaling weight for each class, passed to Cross-Entropy loss

Poly1FocalLoss

  • num_classes, (int) - Number of classes
  • epsilon, (float), (Default=1.0) - PolyLoss epsilon
  • alpha, (float), (Default=0.25) - Focal loss alpha
  • gamma, (float), (Default=2.0) - Focal loss gamma
  • reduction, (str), (Default='none') - apply reduction to the output, one of: none | sum | mean
  • weight, (torch.Tensor), (Default=None) - manual rescaling weight given to the loss of each batch element, passed to underlying binary_cross_entropy loss (*)
  • pos_weight, (torch.Tensor), (Default=None) - weight of positive examples, passed to underlying binary_cross_entropy loss (*)
  • label_is_onehot, (bool), (Default=False) - set to True if labels are one-hot (or multi-hot) encoded

* Check formulas in the documentation page for BCEWithLogitsLoss to understand how weight (w_n) and pos_weight (p_c) parameters are plugged into the loss function and how they affect the loss. Detailed explanation coming soon. Further discussions can be found in this and this threads.

Requirements

  • Python 3.6+
  • Pytorch 1.1+

polyloss-pytorch's People

Contributors

abhuse avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.