GithubHelp home page GithubHelp logo

baoy-nlp / cnat Goto Github PK

View Code? Open in Web Editor NEW
10.0 1.0 2.0 144 KB

Non-autoregressive Translation by Learning Target Categorical Codes

License: MIT License

Python 100.00%
pytorch non-autoregressive-translation conditional-random-fields vector-quantization

cnat's Issues

fairseqTypeError: can't convert cuda:1 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

hi, when training model, failed with erros -->
-- Process 1 terminated with the following error:
Traceback (most recent call last):
File "/root/anaconda3/envs/torch_fairseq/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
fn(i, *args)
File "/search/CNAT/fairseq/fairseq/distributed_utils.py", line 270, in distributed_main
main(args, **kwargs)
File "/search/CNAT/train.py", line 112, in main
valid_losses, should_stop = train(args, trainer, task, epoch_itr)
File "/root/anaconda3/envs/torch_fairseq/lib/python3.7/contextlib.py", line 74, in inner
return func(*args, **kwds)
File "/search/CNAT/train.py", line 206, in train
args, trainer, task, epoch_itr, valid_subsets, end_of_epoch
File "/search/CNAT/train.py", line 238, in validate_and_save
valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
File "/search/CNAT/train.py", line 295, in validate
trainer.valid_step(sample)
File "/root/anaconda3/envs/torch_fairseq/lib/python3.7/contextlib.py", line 74, in inner
return func(*args, **kwds)
File "/search/CNAT/fairseq/fairseq/trainer.py", line 764, in valid_step
logging_output = self._reduce_and_log_stats(logging_outputs, sample_size)
File "/search/CNAT/fairseq/fairseq/trainer.py", line 1068, in _reduce_and_log_stats
self.task.reduce_metrics(logging_outputs, self.get_criterion())
File "/search/CNAT/fairseq/fairseq/tasks/translation.py", line 379, in reduce_metrics
metrics.log_scalar("_bleu_counts", np.array(counts))
File "/root/anaconda3/envs/torch_fairseq/lib/python3.7/site-packages/torch/tensor.py", line 621, in array
return self.numpy()
TypeError: can't convert cuda:1 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

https://github.com/pytorch/fairseq/blob/83e615d66905b8ca7483122a37da1a85f13f4b8e/fairseq/tasks/translation.py#L366

fairseq new version fixed this bug

https://github.com/pytorch/fairseq/blob/1a1380e5a8b0cce49090676d95044626f208c48c/fairseq/tasks/translation.py#L400

invalid choice: 'nat'

Hi, thanks for your great work!
I tried run train.py, but, I get the following error:
train.py: error: argument --task: invalid choice: 'nat' (choose from 'audio_pretraining', 'cross_lingual_lm', 'denoising', 'language_modeling', 'legacy_masked_lm', 'masked_lm', 'multilingual_denoising', 'multilingual_masked_lm', 'translation', 'multilingual_translation', 'semisupervised_translation', 'sentence_prediction', 'sentence_ranking', 'speech_to_text', 'translation_from_pretrained_bart', 'translation_from_pretrained_xlm', 'translation_lev', 'translation_multi_simple_epoch', 'dummy_lm', 'dummy_masked_lm', 'dummy_mt')

PyTorch == 1.8
fairseq==0.10.2

KeyError: 'prior_ret'

hi, when training cnat, I get following error:

if "VQ" in inner_states[GlobalNames.PRI_RET]:
KeyError: 'prior_ret'

detail info:
-- Process 1 terminated with the following error:
Traceback (most recent call last):
File "/root/anaconda3/envs/torch_fairseq/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
fn(i, *args)
File "/search/CNAT/fairseq/fairseq/distributed_utils.py", line 270, in distributed_main
main(args, **kwargs)
File "/search/CNAT/train.py", line 112, in main
valid_losses, should_stop = train(args, trainer, task, epoch_itr)
File "/root/anaconda3/envs/torch_fairseq/lib/python3.7/contextlib.py", line 74, in inner
return func(*args, **kwds)
File "/search/CNAT/train.py", line 190, in train
log_output = trainer.train_step(samples)
File "/root/anaconda3/envs/torch_fairseq/lib/python3.7/contextlib.py", line 74, in inner
return func(*args, **kwds)
File "/search/CNAT/fairseq/fairseq/trainer.py", line 486, in train_step
ignore_grad=is_dummy_batch,
File "/search/CNAT/latent_nat/nat_task.py", line 154, in train_step
return super().train_step(sample, model, criterion, optimizer, update_num, ignore_grad)
File "/search/CNAT/fairseq/fairseq/tasks/translation_lev.py", line 178, in train_step
loss, sample_size, logging_output = criterion(model, sample)
File "/root/anaconda3/envs/torch_fairseq/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/search/CNAT/latent_nat/awesome_nat_loss.py", line 55, in forward
outputs = model(src_tokens, src_lengths, prev_output_tokens, tgt_tokens)
File "/root/anaconda3/envs/torch_fairseq/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/search/CNAT/fairseq/fairseq/legacy_distributed_data_parallel.py", line 85, in forward
return self.module(*inputs, **kwargs)
File "/root/anaconda3/envs/torch_fairseq/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/search/CNAT/latent_nat/cnat.py", line 46, in forward
losses.update(self._compute_vq_loss(inner_states))
File "/search/CNAT/latent_nat/cnat.py", line 57, in _compute_vq_loss
if "VQ" in inner_states[GlobalNames.PRI_RET]:
KeyError: 'prior_ret'

My fully training command is:
python3 train.py ${DATA_BIN}
--user-dir ${USER_DIR}
--save-dir $CHECKPOINT
--ddp-backend=no_c10d
--task nat
--criterion awesome_nat_loss
--arch cnat_wmt14
--self-attn-cls shaw
--block-cls highway
--max-rel-positions 4
--enc-self-attn-cls shaw
--enc-block-cls highway
--share-rel-embeddings
--share-decoder-input-output-embed
--mapping-func interpolate
--mapping-use output
--noise full_mask
--apply-bert-init
--optimizer adam
--lr 0.0007
--lr-scheduler inverse_sqrt
--warmup-updates 10000
--warmup-init-lr 1e-07
--min-lr 1e-09
--weight-decay 0.0
--dropout 0.1
--encoder-learned-pos
--decoder-learned-pos
--pred-length-offset
--length-loss-factor 0.1
--label-smoothing 0.0
--log-interval 100
--fixed-validation-seed 7
--max-tokens 4096
--update-freq 1
--save-interval-updates 500
--keep-best-checkpoints 5
--no-epoch-checkpoints
--keep-interval-updates 5
--max-update 300000
--num-workers 0
--eval-bleu
--eval-bleu-detok moses
--eval-bleu-remove-bpe
--best-checkpoint-metric bleu
--maximize-best-checkpoint-metric
--iter-decode-max-iter 0
--iter-decode-eos-penalty 0
--left-pad-source False
--batch-size-valid 128
--latent-factor 0.5
--num-codes 64
--vq-ema
--crf-cls BCRF
--crf-num-head 4
--latent-layers 5
--vq-schedule-ratio 0.5
--find-unused-parameters

Question about data preprocessing

Hi.

You have mentioned that

IWSLT14 German-English & WMT14 English-German: we mostly follow the instruction of the Fairseq.

Then I assume you used this script, and I noticed there is a --icml17 flag. Did you use this flag when you pre-process the WMT14 English-German dataset?
Sorry for this trivial question, but I just want to make sure my data preprocessing is the same as yours.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.