GithubHelp home page GithubHelp logo

graph_pit's Introduction

Graph-PIT: Generalized permutation invariant training for continuous separation of arbitrary numbers of speakers

GitHub Actions

This repository contains a PyTorch implementation of the Graph-PIT objective proposed in the paper "Graph-PIT: Generalized permutation invariant training for continuous separation of arbitrary numbers of speakers", submitted to INTERSPEECH 2021 and the optimized variant from the paper "Speeding up permutation invariant training for source separation", submitted to the 14th ITG conference on Speech Communication 2021.

The optimized uPIT code used in [2] can be found in padertorch and the example noteook runtimes.ipynb.

Installation

You can install this package from GitHub:

pip install git+https://github.com/fgnt/graph_pit.git

Or in editable mode if you want to make modifications:

git clone https://github.com/fgnt/graph_pit.git
cd graph_pit
pip install -e .

This will install the basic dependencies of the package. If you want to run the example or the tests, install their requirements with

git clone https://github.com/fgnt/graph_pit.git
cd graph_pit
pip install -e '.[example]' # Installs example requirements
pip install -e '.[test]'    # Installs test requirements
pip install -e '.[all]'     # Installs all requirements

Usage

The Graph-PIT losses in this repository require a list of utterance signals and segment boundaries (tuples of start and end times). There are two different implementations:

  • graph_pit.loss.unoptimized contains the original Graph-PIT loss as proposed in [1], and
  • graph_pit.loss.optimized contains the optimized Graph-PIT loss variants from [2].

The default (unoptimized) Graph-PIT loss from [1] can be used as follows:

import torch
from graph_pit import graph_pit_loss

# Create three target utterances and two estimated signals
targets = [torch.rand(100), torch.rand(200), torch.rand(150)]   # List of target utterance signals
segment_boundaries = [(0, 100), (150, 350), (300, 450)]     # One start and end time for each utterance
estimate = torch.rand(2, 500)   # The estimated separated streams

# Compute loss with the unoptimized loss function, here mse for example
loss = graph_pit_loss(
    estimate, targets, segment_boundaries,
    torch.nn.functional.mse_loss
)

# Example for using the optimized sa-SDR loss from [2]
from graph_pit.loss.optimized import optimized_graph_pit_source_aggregated_sdr_loss
loss = optimized_graph_pit_source_aggregated_sdr_loss(
   estimate, targets, segment_boundaries,
   # assignent_solver can be one of:
   #  - 'optimal_brute_force'
   #  - 'optimal_branch_and_bound'
   #  - 'optimal_dynamic_programming' <- fastest
   #  - 'dfs'
   #  - 'greedy_cop' 
   assignment_solver='optimal_dynamic_programming'
)

This unoptimized loss variant works with any loss function loss_fn, but it is in may cases quite slow (see [2]). The optimized version from [2] can be found in graph_pit.loss.optimized for the source-aggregated SDR. You can define your own optimized Graph-PIT losses by subclassing graph_pit.loss.optimized.OptimizedGraphPITLoss and defining the property similarity_matrix and the method compute_f.

Advanced Usage

Each loss variant has three interfaces:

  • function: A simple functional interface as used above
  • class: A (data)class that computes the loss for one pair of estimates and targets and exposes all intermediate states (e.g., the intermediate signals, the best coloring, ...). This makes testing (you can test for intermediate signals, mock things, ...) and extension (you can easily sub-class and overwrite parts of the computation) easier.
  • torch.nn.Module: A module wrapper around the class interface that allows usage as a Module so that loss_fn can be a trainable module and the loss shows up in the print representation.

This is an example of the class interface GraphPITLoss to get access to the best coloring and target sum signals:

import torch
from graph_pit import GraphPITLoss

# Create three target utterances and two estimated signals
targets = [torch.rand(100), torch.rand(200), torch.rand(150)]
segment_boundaries = [(0, 100), (150, 350), (300, 450)]
estimate = torch.rand(2, 500)

# Compute loss
loss = GraphPITLoss(
    estimate, targets, segment_boundaries,
    torch.nn.functional.mse_loss
)
print(loss.loss)
print(loss.best_coloring)   # This is the coloring that minimizes the loss
print(loss.best_target_sum) # This is the target sum signal (\tilde{s})

This is an example of the torch.nn.Module variant:

import torch
from graph_pit.loss import GraphPITLossModule, ThresholdedSDRLoss


class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.loss = GraphPITLossModule(
            loss_fn=ThresholdedSDRLoss(max_sdr=20, epsilon=1e-6)
        )

Examples

There are two examples in graph_pit.examples:

  • tasnet: An example training script for a DPRNN-based TasNet model trained with Graph-PIT using padertorch
  • runtimes.ipynb: A Jupyter notebook comparing the runtimes of different uPIT and Graph-PIT variants. This notebook creates plots similar to [2].

Cite this work / References

If you use this code, please cite the papers:

  • [1] The first paper: "Graph-PIT: Generalized permutation invariant training for continuous separation of arbitrary numbers of speakers": https://arxiv.org/abs/2107.14446
@inproceedings{vonneumann21_GraphPIT,
  author={Thilo von Neumann and Keisuke Kinoshita and Christoph Boeddeker and Marc Delcroix and Reinhold Haeb-Umbach},
  title={{Graph-PIT: Generalized Permutation Invariant Training for Continuous Separation of Arbitrary Numbers of Speakers}},
  year=2021,
  booktitle={Proc. Interspeech 2021},
  pages={3490--3494},
  doi={10.21437/Interspeech.2021-1177}
}
@inproceedings{vonneumann21_SpeedingUp,
  author={Thilo von Neumann and Christoph Boeddeker and Keisuke Kinoshita and Marc Delcroix and Reinhold Haeb-Umbach},
  booktitle={Speech Communication; 14th ITG Conference}, 
  title={Speeding Up Permutation Invariant Training for Source Separation}, 
  year={2021},
  volume={},
  number={},
  pages={1-5},
  doi={}
}
@inproceedings{kinoshita22_interspeech,
  author={Keisuke Kinoshita and Thilo von Neumann and Marc Delcroix and Christoph Boeddeker and Reinhold Haeb-Umbach},
  title={{Utterance-by-utterance overlap-aware neural diarization with Graph-PIT}},
  year=2022,
  booktitle={Proc. Interspeech 2022},
  pages={1486--1490},
  doi={10.21437/Interspeech.2022-11408}
}

graph_pit's People

Contributors

emrys365 avatar thequilo avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

graph_pit's Issues

Empty sequence when getting optimal assignment

When using the OptimalDynamicProgramming method for assignment, on some inputs I get an error on this line:

return min(candidates.items(), key=lambda x: x[1][1])[1][0]

ValueError: min() arg is an empty sequence

This is probably because the adjacency list must be sorted at the beginning of this function when the neighbors are computed (as mentioned in the comment), which I suppose means that the segment_boundaries must be provided in sorted order. Perhaps this can be mentioned in the usage example?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.