GithubHelp home page GithubHelp logo

Comments (5)

sevagh avatar sevagh commented on June 9, 2024

Thanks for the kind words!

I think we should combine with #5 like you said.

Right now I have progress represented as each segment of the waveform (that Demucs separates in overlapping 7.8-second segments).

However, within each segment, the entire pass of Demucs inference has these layers:

  • 4 encoder layers
  • 5 crosstransformer layers
  • 4 decoder layers

The layers don't all take the same amount of time (encoder is fastest, crosstransformer is slowest), but it might still be fine to treat them as "13 total layers" and call the progress callback more frequently, 13 times per segment (1 callback after each layer).

from demucs.cpp.

sevagh avatar sevagh commented on June 9, 2024

We can also pass a std::string log_message to the progress callback to address the desire for logging

from demucs.cpp.

sevagh avatar sevagh commented on June 9, 2024

New unified logging + progress behavior in a branch of mine:

(system) sevagh@pop-os:~/repos/demucs.cpp/build$ ./demucs.cpp.main ../ggml-demucs/ggml-model-htdemucs-4s-f16.bin ../test/data/gspi_stereo.wav ./demucs-cpp-out-progress
demucs.cpp Main driver program
Input samples: 262144
Length in seconds: 5.94431
Number of channels: 2
load_demucs_model: loading model
Loading model_file...
Checking the magic of model_file
Model magic is Demucs 4-source
Loading demucs model...
Loading weights from model_file
Loaded model (533 tensors,  80.08 MB) in 0.209938 s
demucs_model_load returned true
Starting Demucs (4-source) inference
1., apply model w/ shift, offset: 4033
in split inference!
2., apply model w/ split, offset: 0, chunk shape: (2, 280161)
in segment inference!
(0.000%) 3., apply_model mix shape: (2, 343980)
(0.000%) buffers.z: 2, 2049, 336
(0.000%) buffers.x: 4, 2048, 336
(0.000%) Freq branch: normalized
(0.000%) Time branch: normalized
(1.923%) Time encoder 0
(3.846%) Freq encoder 0
(3.846%) Freq branch: applied frequency embedding
(5.769%) Time encoder 1
(7.692%) Freq encoder 1
(9.615%) Time encoder 2
(11.538%) Freq encoder 2
(13.462%) Time encoder 3
(15.385%) Freq encoder 3
(15.385%) Freq channels upsampled
(15.385%) Time channels upsampled
(15.385%) Applying crosstransformer
(15.385%) Freq (crosstransformer): norm + pos_embed
(15.385%) Time (crosstransformer): norm + pos_embed
(17.308%) Freq (crosstransformer): layer 0
(19.231%) Time (crosstransformer): layer 0
(21.154%) Freq (crosstransformer): layer 1
(23.077%) Time (crosstransformer): layer 1
(25.000%) Freq (crosstransformer): layer 2
(26.923%) Time (crosstransformer): layer 2
(28.846%) Freq (crosstransformer): layer 3
(30.769%) Time (crosstransformer): layer 3
(32.692%) Freq (crosstransformer): layer 4
(34.615%) Time (crosstransformer): layer 4
(34.615%) Crosstransformer finished
(34.615%) Freq channels downsampled
(34.615%) Time channels downsampled
(36.538%) Freq: decoder 0
(38.462%) Time: decoder 0
(40.385%) Freq: decoder 1
(42.308%) Time: decoder 1
(44.231%) Freq: decoder 2
(46.154%) Time: decoder 2
(48.077%) Freq: decoder 3
(50.000%) Time: decoder 3
(50.000%) Mask + istft
(50.000%) mix: 2, 343980
(50.000%) mix: 2, 343980
(50.000%) mix: 2, 343980
(50.000%) mix: 2, 343980
2., apply model w/ split, offset: 257985, chunk shape: (2, 22176)
in segment inference!
(50.000%) 3., apply_model mix shape: (2, 343980)
(50.000%) buffers.z: 2, 2049, 336
(50.000%) buffers.x: 4, 2048, 336
(50.000%) Freq branch: normalized
(50.000%) Time branch: normalized
(51.923%) Time encoder 0
(53.846%) Freq encoder 0
(53.846%) Freq branch: applied frequency embedding
(55.769%) Time encoder 1
...
(96.154%) Time: decoder 2
(98.077%) Freq: decoder 3
(100.000%) Time: decoder 3
(100.000%) Mask + istft
(100.000%) mix: 2, 343980
(100.000%) mix: 2, 343980
(100.000%) mix: 2, 343980
(100.000%) mix: 2, 343980
Writing wav file "./demucs-cpp-out-progress/target_0_drums.wav"
Encoder Status: 0
Writing wav file "./demucs-cpp-out-progress/target_1_bass.wav"
Encoder Status: 0
Writing wav file "./demucs-cpp-out-progress/target_2_other.wav"
Encoder Status: 0
Writing wav file "./demucs-cpp-out-progress/target_3_vocals.wav"
Encoder Status: 0

It's a bit noisy but easily customizeable:

    demucscpp::ProgressCallback progressCallback =
        [](float progress, const std::string &log_message)
    {
        std::cout << "(" << std::setw(3) << std::setfill(' ')
                  << progress * 100.0f << "%) " << log_message << std::endl;
    };

What I did is this:

  1. Set progress for each segment of the waveform (this was previously the only coarse-grained progress tracker)
  2. Within each segment, split it into 26 sequential layers:
    • 8 encoders (4 time 4 freq)
    • 10 crosstransformer (5 time 5 freq)
    • 8 decoders (4 time 4 freq)
      • Some middle steps (channel upsampler, normalizing, etc.) that aren't an important layer don't count towards progress
  3. We end up with 26 finer-grained progress steps per segment

I also removed most of the cout messages from the main src/ library (only left some in model_apply and model_load). The only cout deep in the actual inference layers now comes through the ProgressCallback.

from demucs.cpp.

sevagh avatar sevagh commented on June 9, 2024

Let me know how this works for you.

from demucs.cpp.

olilarkin avatar olilarkin commented on June 9, 2024

Works great thanks! Setting the thread callbacks from outside threaded inference function call Would be useful

from demucs.cpp.

Related Issues (12)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.