GithubHelp home page GithubHelp logo

Changes that requires to be made while using wav2vec2.0(CLSRIL-23.pt) features for training CTC/Attention based training about espnet HOT 6 CLOSED

mukherjeesougata avatar mukherjeesougata commented on May 27, 2024
Changes that requires to be made while using wav2vec2.0(CLSRIL-23.pt) features for training CTC/Attention based training

from espnet.

Comments (6)

mukherjeesougata avatar mukherjeesougata commented on May 27, 2024 1

It seems it is working, at least the 1st epoch of training has been completed without any error after changing the input size of preencoder to 768 as it is a base model.

from espnet.

sw005320 avatar sw005320 commented on May 27, 2024

@simpleoier, can you answer it for me?
We have several examples of using SSLs without S3PRL.

from espnet.

simpleoier avatar simpleoier commented on May 27, 2024

Hi @mukherjeesougata , ESPnet supports some fairseq based models. From a quick look at this CLSRIL-23 model, it seems to be based on fairseq, using wav2vec2 architecture. You can probably try to use the following config based on wav2vec2 encoder

encoder: wav2vec2
encoder_conf:
    w2v_url: https://storage.googleapis.com/vakyansh-open-models/pretrained_models/clsril-23/CLSRIL-23.pt

Using s3prl is also feasible.

frontend: s3prl
frontend_conf:
    frontend_conf:
        upstream: wav2vec2_url # Note: If the upstream is changed, please change the input_size in the preencoder.
        path_or_url: https://storage.googleapis.com/vakyansh-open-models/pretrained_models/clsril-23/CLSRIL-23.pt
    download_dir: ./hub

from espnet.

mukherjeesougata avatar mukherjeesougata commented on May 27, 2024

I am trying to use CLSRIL-23.pt as frontend and not as an encoder. So, I have tried out using s3prl.

frontend: s3prl
frontend_conf:
    frontend_conf:
        upstream: wav2vec2_url # Note: If the upstream is changed, please change the input_size in the preencoder.
        path_or_url: https://storage.googleapis.com/vakyansh-open-models/pretrained_models/clsril-23/CLSRIL-23.pt
    download_dir: ./hub

But I am getting the following error:-

Traceback (most recent call last):
  File "/home/sougata/anaconda3/envs/espnet/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/sougata/anaconda3/envs/espnet/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/sougata/espnet/espnet2/bin/asr_train.py", line 23, in <module>
    main()
  File "/home/sougata/espnet/espnet2/bin/asr_train.py", line 19, in main
    ASRTask.main(cmd=cmd)
  File "/home/sougata/espnet/espnet2/tasks/abs_task.py", line 1117, in main
    cls.main_worker(args)
  File "/home/sougata/espnet/espnet2/tasks/abs_task.py", line 1227, in main_worker
    model = cls.build_model(args=args)
  File "/home/sougata/espnet/espnet2/tasks/asr.py", line 529, in build_model
    frontend = frontend_class(**args.frontend_conf)
  File "/home/sougata/espnet/espnet2/asr/frontend/s3prl.py", line 47, in __init__
    upstream = S3PRLUpstream(
  File "/home/sougata/anaconda3/envs/espnet/lib/python3.8/site-packages/s3prl/nn/upstream.py", line 117, in __init__
    self.upstream = getattr(hub, name)(**upstream_conf)
  File "/home/sougata/anaconda3/envs/espnet/lib/python3.8/site-packages/s3prl/upstream/wav2vec2/hubconf.py", line 73, in wav2vec2_url
    return wav2vec2_custom(*args, **kwargs)
  File "/home/sougata/anaconda3/envs/espnet/lib/python3.8/site-packages/s3prl/upstream/wav2vec2/hubconf.py", line 65, in wav2vec2_custom
    return _UpstreamExpert(ckpt, **kwargs)
  File "/home/sougata/anaconda3/envs/espnet/lib/python3.8/site-packages/s3prl/upstream/interfaces.py", line 30, in __call__
    instance = super().__call__(*args, **kwargs)
  File "/home/sougata/anaconda3/envs/espnet/lib/python3.8/site-packages/s3prl/upstream/wav2vec2/expert.py", line 23, in __init__
    model, task_cfg = load_converted_model(ckpt)
  File "/home/sougata/anaconda3/envs/espnet/lib/python3.8/site-packages/s3prl/upstream/wav2vec2/convert.py", line 31, in load_converted_model
    raise ValueError(
ValueError: /home/sougata/espnet/egs2/clsril23Mini-Librispeech/asr1/hub/b6d981d1d795e6875d9d3f294c078def6ac231d27519c2cdf5723a1c4df8d90b.CLSRIL-23.pt is not a valid checkpoint since the required key: task_cfg is missing
# Accounting: time=88 threads=1
# Ended (code 1) at Wed Mar 20 13:52:38 IST 2024, elapsed time 88 seconds

from espnet.

simpleoier avatar simpleoier commented on May 27, 2024

Hi @mukherjeesougata , the model trained with fairseq need to be converted. Sorry I didn't recall this step.

So you can refer to s3prl here, converting the ckpt:

python /home/sougata/anaconda3/envs/espnet/lib/python3.8/site-packages/s3prl/upstream/wav2vec2/convert.py \
    /home/sougata/espnet/egs2/clsril23Mini-Librispeech/asr1/hub/b6d981d1d795e6875d9d3f294c078def6ac231d27519c2cdf5723a1c4df8d90b.CLSRIL-23.pt \
   --output_dir SOMEWHEREYOUWANTTOPUT

Then in the config, you can make the path_or_url as the converted ckpt path.
This works in my environment and the training starts.

from espnet.

mukherjeesougata avatar mukherjeesougata commented on May 27, 2024

I have converted the checkpoint using this code and made the required changes which are shown below:-

frontend: s3prl
frontend_conf:
    frontend_conf:
        upstream: wav2vec2_local  # Note: If the upstream is changed, please change the input_size in the preencoder.
        path_or_url: /home/sougata/Pretrained_model/converted/CLSRIL-23/CLSRIL-23.pt

I didn't face any errors during the conversion. But during the training stage i.e. after running the command

./asr.sh --stage 11 --stop_stage 11 --train_set train_nodev --valid_set train_dev --test_sets 'train_dev test' --ngpu 1 --asr_config conf/train_asr_conformer_scctc.yaml --feats_normalize uttmvn

I am facing the following error:-

Traceback (most recent call last):
  File "/home/sougata/anaconda3/envs/espnet/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/sougata/anaconda3/envs/espnet/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/sougata/espnet/espnet2/bin/asr_train.py", line 23, in <module>
    main()
  File "/home/sougata/espnet/espnet2/bin/asr_train.py", line 19, in main
    ASRTask.main(cmd=cmd)
  File "/home/sougata/espnet/espnet2/tasks/abs_task.py", line 1117, in main
    cls.main_worker(args)
  File "/home/sougata/espnet/espnet2/tasks/abs_task.py", line 1430, in main_worker
    cls.trainer.run(
  File "/home/sougata/espnet/espnet2/train/trainer.py", line 304, in run
    all_steps_are_invalid = cls.train_one_epoch(
  File "/home/sougata/espnet/espnet2/train/trainer.py", line 588, in train_one_epoch
    retval = model(**batch)
  File "/home/sougata/anaconda3/envs/espnet/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/sougata/anaconda3/envs/espnet/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sougata/espnet/espnet2/asr/espnet_model.py", line 237, in forward
    encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  File "/home/sougata/espnet/espnet2/asr/espnet_model.py", line 390, in encode
    feats, feats_lengths = self.preencoder(feats, feats_lengths)
  File "/home/sougata/anaconda3/envs/espnet/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/sougata/anaconda3/envs/espnet/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sougata/espnet/espnet2/asr/preencoder/linear.py", line 31, in forward
    output = self.linear_out(self.dropout(input))
  File "/home/sougata/anaconda3/envs/espnet/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/sougata/anaconda3/envs/espnet/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sougata/anaconda3/envs/espnet/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (12152x768 and 1024x80)
# Accounting: time=12 threads=1
# Ended (code 1) at Wed Mar 20 21:41:41 IST 2024, elapsed time 12 seconds

from espnet.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.