GithubHelp home page GithubHelp logo

asteroid-team / asteroid-filterbanks Goto Github PK

View Code? Open in Web Editor NEW
80.0 6.0 19.0 261 KB

Asteroid's filterbanks :rocket:

Home Page: https://asteroid-team.github.io/

License: MIT License

Python 100.00%
python3 pytorch deep-learning audio audio-processing filterbanks asteroid-filterbanks

asteroid-filterbanks's People

Contributors

cameronmaske avatar faroit avatar iver56 avatar jonashaag avatar mpariente avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

asteroid-filterbanks's Issues

script_if_tracing breaks onnxruntime

It appears that the onnxruntime does not like the script_if_tracing decorator introduced for backwards compatibility. If you remove the decorator from here, everything works fine. However, with the decorator included, we get a reshape error.

If the decorator is only needed for torch 1.6.0 support, maybe it should only be used for that version of torch?

Here's a reproducible example that you can run on colab:

%pip install asteroid_filterbanks
%pip install onnx
%pip install onnxruntime
%pip install torch

from asteroid_filterbanks.enc_dec import Encoder
from asteroid_filterbanks import torch_stft_fb
import numpy as np
import torch
import torch.onnx
import onnxruntime as ort
import numpy as np

window = np.hanning(512 + 1)[:-1] ** 0.5
fb = torch_stft_fb.TorchSTFTFB(
    n_filters=512,
    kernel_size=512,
    center=True,
    stride=256,
    window=window
)
encoder = Encoder(fb)

nb_samples = 1
nb_channels = 2
nb_timesteps = 11111
example = torch.rand((nb_samples, nb_channels, nb_timesteps))

out = encoder(example)

torch.onnx.export(
    encoder,
    example,
    "test.onnx",
    export_params=True,
    opset_version=16,
    do_constant_folding=True,
    input_names=['input'],
    output_names=['output'],
    verbose=True,
)

ort_sess = ort.InferenceSession("test.onnx")
outputs = ort_sess.run(None, {'input': example.numpy()})

Produces:

---------------------------------------------------------------------------

RuntimeException                          Traceback (most recent call last)

[<ipython-input-29-acc756eada00>](https://localhost:8080/#) in <module>()
     37 
     38 ort_sess = ort.InferenceSession("test.onnx")
---> 39 outputs = ort_sess.run(None, {'input': example.numpy()})

[/usr/local/lib/python3.7/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py](https://localhost:8080/#) in run(self, output_names, input_feed, run_options)
    198             output_names = [output.name for output in self._outputs_meta]
    199         try:
--> 200             return self._sess.run(output_names, input_feed, run_options)
    201         except C.EPFail as err:
    202             if self._enable_fallback:

RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Reshape node. Name:'Reshape_115' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/tensor/reshape_helper.h:41 onnxruntime::ReshapeHelper::ReshapeHelper(const onnxruntime::TensorShape&, onnxruntime::TensorShapeVector&, bool) gsl::narrow_cast<int64_t>(input_shape.Size()) == size was false. The input tensor cannot be reshaped to the requested shape. Input shape:{1,2,44}, requested shape:{1,1,514,44}

Decoder length argument doesn't apply to 2D and 3D tensors

The forward path of a Decoder takes an additional length argument. While the shortening of the signal works fine for 1D arrays it fails for 2D and 3D tensor.
Since they are supported, this results in a bug which unfortunately is untested.

if length is not None:
length = min(length, wav.shape[-1])
return wav[:length]
return wav

The potential fix is a one-liner:

return wav[:length] -> return wav[..., :length]

Question on hilbert transform

Thanks for this project, very impressive contribution! I have a question on the hibert transform from the code asteroid-filterbanks/asteroid_filterbanks/analytic_free_fb.py :

    ft_f = rfft(self._filters, 1, normalized=True)
    hft_f = conj(ft_f)
    hft_f = irfft(hft_f, 1, normalized=True, signal_sizes=(self.kernel_size,))
    return torch.cat([self._filters, hft_f], dim=0)`

As far as I know, the hilbert transform is performed like this:

  1. Take the real part of the signal;
  2. Rotating the phase of the signal by 90°
  3. Analytical signal = real + i*(rotated signal).

From the code, it looks like using conj to perform rotation operation, is this correct?

Support ONNX export of selected filterbank modules

one of the benefits of 1d conv based filterbanks is that they can be more easily exported for deployment.

testing TorchSTFTFB reveals that onnx export doesn't currently work and its not clear where the error stems from due to this.

example of traced module of the encoder exported with onnx:

    import torch.onnx
    from asteroid_filterbanks.enc_dec import Encoder
    from asteroid_filterbanks import torch_stft_fb

    nb_samples = 1
    nb_channels = 2
    nb_timesteps = 11111

    example = torch.rand((nb_samples, nb_channels, nb_timesteps))

    fb = torch_stft_fb.TorchSTFTFB(n_filters=512, kernel_size=512)
    enc = Encoder(fb)
    torch_out = enc(example)
    # Export the model
    torch.onnx.export(
        enc,
        example,
        "umx.onnx",
        export_params=True,
        opset_version=10,
        do_constant_folding=True,
        input_names=['input'],
        output_names=['output'],
        verbose=True
    )

results in

Traceback (most recent call last):
  File "onnx.py", line 28, in <module>
    verbose=False
  File "/Users/faro/repositories/open-unmix-pytorch/env-cpu/lib/python3.7/site-packages/torch/onnx/__init__.py", line 230, in export
    custom_opsets, enable_onnx_checker, use_external_data_format)
  File "/Users/faro/repositories/open-unmix-pytorch/env-cpu/lib/python3.7/site-packages/torch/onnx/utils.py", line 91, in export
    use_external_data_format=use_external_data_format)
  File "/Users/faro/repositories/open-unmix-pytorch/env-cpu/lib/python3.7/site-packages/torch/onnx/utils.py", line 639, in _export
    dynamic_axes=dynamic_axes)
  File "/Users/faro/repositories/open-unmix-pytorch/env-cpu/lib/python3.7/site-packages/torch/onnx/utils.py", line 421, in _model_to_graph
    dynamic_axes=dynamic_axes, input_names=input_names)
  File "/Users/faro/repositories/open-unmix-pytorch/env-cpu/lib/python3.7/site-packages/torch/onnx/utils.py", line 203, in _optimize_graph
    graph = torch._C._jit_pass_onnx(graph, operator_export_type)
  File "/Users/faro/repositories/open-unmix-pytorch/env-cpu/lib/python3.7/site-packages/torch/onnx/__init__.py", line 263, in _run_symbolic_function
    return utils._run_symbolic_function(*args, **kwargs)
  File "/Users/faro/repositories/open-unmix-pytorch/env-cpu/lib/python3.7/site-packages/torch/onnx/utils.py", line 968, in _run_symbolic_function
    torch._C._jit_pass_onnx_block(b, new_block, operator_export_type, env)
  File "/Users/faro/repositories/open-unmix-pytorch/env-cpu/lib/python3.7/site-packages/torch/onnx/__init__.py", line 263, in _run_symbolic_function
    return utils._run_symbolic_function(*args, **kwargs)
  File "/Users/faro/repositories/open-unmix-pytorch/env-cpu/lib/python3.7/site-packages/torch/onnx/utils.py", line 979, in _run_symbolic_function
    operator_export_type)
  File "/Users/faro/repositories/open-unmix-pytorch/env-cpu/lib/python3.7/site-packages/torch/onnx/utils.py", line 888, in _find_symbolic_in_registry
    return sym_registry.get_registered_op(op_name, domain, opset_version)
  File "/Users/faro/repositories/open-unmix-pytorch/env-cpu/lib/python3.7/site-packages/torch/onnx/symbolic_registry.py", line 111, in get_registered_op
    raise RuntimeError(msg)
RuntimeError: Exporting the operator prim_Uninitialized to ONNX opset version 10 is not supported. Please open a bug to request ONNX export support for the missing operator.

Mixed precision support

Looks like this doesn't work with mixed precision. Any idea what it might take to add this support?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.