GithubHelp home page GithubHelp logo

wecarsoniv / beta-divergence-metrics Goto Github PK

View Code? Open in Web Editor NEW
9.0 1.0 0.0 369 KB

PyTorch implementations of the beta divergence loss.

License: BSD 3-Clause "New" or "Revised" License

Python 100.00%
pytorch beta-divergence loss-functions loss kullback-leibler-divergence kl-divergence mean-square-error mean-squared-error nmf nmf-decomposition non-negative-matrix-factorization itakura-saito-divergence objective-functions distance-measures distance-metric distance-metrics divergence divergences torch numpy

beta-divergence-metrics's Introduction

Beta-Divergence Loss Implementations

This repository contains code for Python implementations of the beta-divergence loss, including implementations compatible NumPy and PyTorch.

Dependencies

This library is written in Python, and requires Python (with recommended version >= 3.9) to run. In addition to a working PyTorch installation, this library relies on the following libraries and recommended version numbers:

Installation

To install the latest stable release, use pip. Use the following command to install:

$ pip install beta-divergence-metrics

Usage

The numpybd.loss module contains two beta-divergence function implementations compatible with NumPy and NumPy arrays: one general beta-divergence between two arrays, and a beta-divergence implementation specific to non-negative matrix factorization (NMF). Similarly torchbd.loss module contains two beta-divergence class implementations compatible with PyTorch and PyTorch tensors. Beta-divergence implementations can be imported as follows:

# Import beta-divergence loss implementations
from numpybd.loss import *
from torchbd.loss import *

Beta-divergence between two NumPy arrays

To calculate the beta-divergence between a NumPy array a and a target or reference array b, use the beta_div loss function. The beta_div loss function can be used as follows:

# Calculate beta-divergence loss between array a and target array b
loss_val = beta_div(beta=0, reduction='mean')

Beta-divergence between two PyTorch tensors

To calculate the beta-divergence between tensor a and a target or reference tensor b, use the BetaDivLoss loss function. The BetaDivLoss loss function can be instantiated and used as follows:

# Instantiate beta-divergence loss object
loss_func = BetaDivLoss(beta=0, reduction='mean')

# Calculate beta-divergence loss between tensor a and target tensor b
loss_val = loss_func(input=a, target=b)

NMF beta-divergence between NumPy array of data and data reconstruction

To calculate the NMF-specific beta-divergence between a NumPy array of data matrix X and the product of a scores matrix H and a components matrix W, use the nmf_beta_div loss function. The nmf_beta_div loss function can beused as follows:

# Calculate beta-divergence loss between data matrix X (target or
# reference matrix) and matrix product of H and W
loss_val = nmf_beta_div(X=X, H=H, W=W, beta=0, reduction='mean')

NMF beta-divergence between PyTorch tensor of data and data reconstruction

To calculate the NMF-specific beta-divergence between a PyTorch tensor of data matrix X and the matrix product of a scores matrix H and a components matrix W, use the NMFBetaDivLoss loss class function. The NMFBetaDivLoss loss function can be instantiated and used as follows:

# Instantiate NMF beta-divergence loss object
loss_func = NMFBetaDivLoss(beta=0, reduction='mean')

# Calculate beta-divergence loss between data matrix X (target or
# reference matrix) and matrix product of H and W
loss_val = loss_func(X=X, H=H, W=W)

Choosing beta value

When instantiating beta-divergence loss objects, the value of beta should be chosen depending on data type and application. For NMF applications, a beta value of 0 (Itakura-Saito divergence) is recommemded. Integer values of beta correspond to the following divergences and loss functions:

Issue Tracking and Reports

Please use the GitHub issue tracker associated with this repository for issue tracking, filing bug reports, and asking general questions about the package or project.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.