GithubHelp home page GithubHelp logo

beijita-yegucheng / nlp-loss-pytorch Goto Github PK

View Code? Open in Web Editor NEW

This project forked from shuxinyin/nlp-loss-pytorch

0.0 0.0 0.0 782 KB

Implementation of some unbalanced loss like focal_loss, dice_loss, DSC Loss, GHM Loss et.al

License: MIT License

Shell 0.31% Python 99.69%

nlp-loss-pytorch's Introduction

Implementation of some unbalanced loss for NLP task like focal_loss, dice_loss, DSC Loss, GHM Loss et.al and adversarial training like FGM, FGSM, PGD, FreeAT.

Loss Summary

Here is a loss implementation repository included unbalanced loss

Loss Name paper Notes
Weighted CE Loss UNet Architectures in Multiplanar Volumetric Segmentation -- Validated on Three Knee MRI Cohorts
Focal Loss Focal Loss for Dense Object Detection
Dice Loss V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation
DSC Loss Dice Loss for Data-imbalanced NLP Tasks
GHM Loss Gradient Harmonized Single-stage Detector

How to use?

You can find all the loss usage information in test_loss.py.

Here is a simple demo of usage:

import torch
from unbalanced_loss.focal_loss import MultiFocalLoss

batch_size, num_class = 64, 10
Loss_Func = MultiFocalLoss(num_class=num_class, gamma=2.0, reduction='mean')

logits = torch.rand(batch_size, num_class, requires_grad=True)  # (batch_size, num_classes)
targets = torch.randint(0, num_class, size=(batch_size,))  # (batch_size, )

loss = Loss_Func(logits, targets)
loss.backward()

Adversarial Training Summary

Here is a Summary of Adversarial Training implementation.
you can find more details in adversarial_training/README.md

Adversarial Training paper Notes
FGM Fast Gradient Method
FGSM Fast Gradient Sign Method
PGD Towards Deep Learning Models Resistant to Adversarial Attacks
FreeAT Free Adversarial Training
FreeLB Free Large Batch Adversarial Training

How to use?

You can find a simple demo for bert classification in test_bert.py.

Here is a simple demo of usage:
You just need to rewrite train function according to input for your model in file PGD.py, then you can use adversarial training like below.

import transformers
from model import bert_classification
from adversarial_training.PGD import PGD

batch_size, num_class = 64, 10
# model = your_model()
model = bert_classification()
AT_Model = PGD(model)
optimizer = transformers.AdamW(model.parameters(), lr=0.001)

# rewrite your train function in pgd.py
outputs, loss = AT_Model.train_bert(token, segment, mask, label, optimizer)

Adversarial Training Results Compare

Adversarial Training Time Cost(s/epoch ) best_acc
Normal(not add attack) 23.77 0.773
FGSM 45.95 0.7936
FGM 47.28 0.8008
PGD(k=3) 87.50 0.7963
FreeAT(k=3) 93.26 0.7896

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.