Comments (5)
Thanks for the kind words!
I think we should combine with #5 like you said.
Right now I have progress represented as each segment of the waveform (that Demucs separates in overlapping 7.8-second segments).
However, within each segment, the entire pass of Demucs inference has these layers:
- 4 encoder layers
- 5 crosstransformer layers
- 4 decoder layers
The layers don't all take the same amount of time (encoder is fastest, crosstransformer is slowest), but it might still be fine to treat them as "13 total layers" and call the progress callback more frequently, 13 times per segment (1 callback after each layer).
from demucs.cpp.
We can also pass a std::string log_message
to the progress callback to address the desire for logging
from demucs.cpp.
New unified logging + progress behavior in a branch of mine:
(system) sevagh@pop-os:~/repos/demucs.cpp/build$ ./demucs.cpp.main ../ggml-demucs/ggml-model-htdemucs-4s-f16.bin ../test/data/gspi_stereo.wav ./demucs-cpp-out-progress
demucs.cpp Main driver program
Input samples: 262144
Length in seconds: 5.94431
Number of channels: 2
load_demucs_model: loading model
Loading model_file...
Checking the magic of model_file
Model magic is Demucs 4-source
Loading demucs model...
Loading weights from model_file
Loaded model (533 tensors, 80.08 MB) in 0.209938 s
demucs_model_load returned true
Starting Demucs (4-source) inference
1., apply model w/ shift, offset: 4033
in split inference!
2., apply model w/ split, offset: 0, chunk shape: (2, 280161)
in segment inference!
(0.000%) 3., apply_model mix shape: (2, 343980)
(0.000%) buffers.z: 2, 2049, 336
(0.000%) buffers.x: 4, 2048, 336
(0.000%) Freq branch: normalized
(0.000%) Time branch: normalized
(1.923%) Time encoder 0
(3.846%) Freq encoder 0
(3.846%) Freq branch: applied frequency embedding
(5.769%) Time encoder 1
(7.692%) Freq encoder 1
(9.615%) Time encoder 2
(11.538%) Freq encoder 2
(13.462%) Time encoder 3
(15.385%) Freq encoder 3
(15.385%) Freq channels upsampled
(15.385%) Time channels upsampled
(15.385%) Applying crosstransformer
(15.385%) Freq (crosstransformer): norm + pos_embed
(15.385%) Time (crosstransformer): norm + pos_embed
(17.308%) Freq (crosstransformer): layer 0
(19.231%) Time (crosstransformer): layer 0
(21.154%) Freq (crosstransformer): layer 1
(23.077%) Time (crosstransformer): layer 1
(25.000%) Freq (crosstransformer): layer 2
(26.923%) Time (crosstransformer): layer 2
(28.846%) Freq (crosstransformer): layer 3
(30.769%) Time (crosstransformer): layer 3
(32.692%) Freq (crosstransformer): layer 4
(34.615%) Time (crosstransformer): layer 4
(34.615%) Crosstransformer finished
(34.615%) Freq channels downsampled
(34.615%) Time channels downsampled
(36.538%) Freq: decoder 0
(38.462%) Time: decoder 0
(40.385%) Freq: decoder 1
(42.308%) Time: decoder 1
(44.231%) Freq: decoder 2
(46.154%) Time: decoder 2
(48.077%) Freq: decoder 3
(50.000%) Time: decoder 3
(50.000%) Mask + istft
(50.000%) mix: 2, 343980
(50.000%) mix: 2, 343980
(50.000%) mix: 2, 343980
(50.000%) mix: 2, 343980
2., apply model w/ split, offset: 257985, chunk shape: (2, 22176)
in segment inference!
(50.000%) 3., apply_model mix shape: (2, 343980)
(50.000%) buffers.z: 2, 2049, 336
(50.000%) buffers.x: 4, 2048, 336
(50.000%) Freq branch: normalized
(50.000%) Time branch: normalized
(51.923%) Time encoder 0
(53.846%) Freq encoder 0
(53.846%) Freq branch: applied frequency embedding
(55.769%) Time encoder 1
...
(96.154%) Time: decoder 2
(98.077%) Freq: decoder 3
(100.000%) Time: decoder 3
(100.000%) Mask + istft
(100.000%) mix: 2, 343980
(100.000%) mix: 2, 343980
(100.000%) mix: 2, 343980
(100.000%) mix: 2, 343980
Writing wav file "./demucs-cpp-out-progress/target_0_drums.wav"
Encoder Status: 0
Writing wav file "./demucs-cpp-out-progress/target_1_bass.wav"
Encoder Status: 0
Writing wav file "./demucs-cpp-out-progress/target_2_other.wav"
Encoder Status: 0
Writing wav file "./demucs-cpp-out-progress/target_3_vocals.wav"
Encoder Status: 0
It's a bit noisy but easily customizeable:
demucscpp::ProgressCallback progressCallback =
[](float progress, const std::string &log_message)
{
std::cout << "(" << std::setw(3) << std::setfill(' ')
<< progress * 100.0f << "%) " << log_message << std::endl;
};
What I did is this:
- Set progress for each segment of the waveform (this was previously the only coarse-grained progress tracker)
- Within each segment, split it into 26 sequential layers:
- 8 encoders (4 time 4 freq)
- 10 crosstransformer (5 time 5 freq)
- 8 decoders (4 time 4 freq)
- Some middle steps (channel upsampler, normalizing, etc.) that aren't an important layer don't count towards progress
- We end up with 26 finer-grained progress steps per segment
I also removed most of the cout messages from the main src/
library (only left some in model_apply and model_load). The only cout deep in the actual inference layers now comes through the ProgressCallback.
from demucs.cpp.
Let me know how this works for you.
from demucs.cpp.
Works great thanks! Setting the thread callbacks from outside threaded inference function call Would be useful
from demucs.cpp.
Related Issues (12)
- Amount of time to demux an audio file HOT 4
- Feature request - optional logging HOT 1
- CMakeLists.txt demucs.cpp.test target - missing dependency gtest HOT 2
- Support Demucs v3 (hdemucs_mmi) HOT 1
- unknown target CPI 'apple-m1' HOT 2
- Question about GPU HOT 6
- demucs_mt.cpp.main hard wired for 4-source HOT 2
- Memory access error with MT on mac HOT 2
- Two stem model HOT 3
- How to apply it in WebAssembly? HOT 11
- Demucs weights HOT 5
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from demucs.cpp.