tunz / transformer-pytorch Goto Github PK
View Code? Open in Web Editor NEWTransformer implementation in PyTorch.
Home Page: https://tunz.kr/post/4
License: MIT License
Transformer implementation in PyTorch.
Home Page: https://tunz.kr/post/4
License: MIT License
Hello! This is a very nice work!
I am looking into your codes and did not figure out why only the first training file ([0]) was loaded in line 146 of translation.py
examples_train = torch.load(data_paths[0])
Thanks!
Hello, thaks for your work.
No matter what I input, I always get a sentencei would like to thank you for your comments . <eos>
What should I do?
我是在github上下载的两个模型,不知道放在哪个目录下,会报错:
Traceback (most recent call last):
File "/home/liying/transformer-pytorch-master/train.py", line 213, in
main()
File "/home/liying/transformer-pytorch-master/train.py", line 148, in main
problem.prepare(opt.problem, opt.data_dir, opt.max_length,
File "/home/liying/transformer-pytorch-master/dataset/problem.py", line 10, in prepare
from dataset import translation
File "/home/liying/transformer-pytorch-master/dataset/translation.py", line 19, in
spacy_de = spacy.load('de_core_news_sm')
File "/opt/conda/lib/python3.10/site-packages/spacy/init.py", line 51, in load
return util.load_model(
File "/opt/conda/lib/python3.10/site-packages/spacy/util.py", line 472, in load_model
raise IOError(Errors.E050.format(name=name))
OSError: [E050] Can't find model 'de_core_news_sm'. It doesn't seem to be a Python package or a valid path to a data directory.
请问怎么解决
I use torchtext of version 0.2.3 and get this issue:
File "/home/dm/ldl/transformer-pytorch/dataset/translation.py", line 173, in prepare
pad_token=pad, lower=True, eos_token='')
TypeError: init() got an unexpected keyword argument 'is_target'
maybe the different version cause it.
Hi, thank you for sharing the implementation!
I was just wondering if you could explain the loss computation where you use confidence and label smoothing. I know that this is also done in the tensor2tensor repo, but I have a hard time reading and understanding this concept from that repo as well. I was reading up on Normalized Cross Entropy here, but it seems like both the formula you used here and in tensor2tensor doesn't really fit with the formula that they were talking about in that article. Could you elaborate on the implementation of that formula.
Also, since we're taking into consideration both the correct values and incorrect values, this is different than normal cross entropy? Since with cross entropy where the true values have true probability of 1 and incorrect values have probability of 0, cross entropy is only affected by the predicted probability of the true value, and it doesn't matter how the remaining probability fraction is distributed over the incorrect values. However, with the NCE formula they provided, we have to take into consideration the incorrect values as well, right?
def get_loss(pred, ans, vocab_size, label_smoothing, pad):
# took this "normalizing" from tensor2tensor. We subtract it for
# readability. This makes no difference on learning.
confidence = 1.0 - label_smoothing
low_confidence = (1.0 - confidence) / float(vocab_size - 1)
normalizing = -(
confidence * math.log(confidence) + float(vocab_size - 1) *
low_confidence * math.log(low_confidence + 1e-20))
one_hot = torch.zeros_like(pred).scatter_(1, ans.unsqueeze(1), 1)
one_hot = one_hot * confidence + (1 - one_hot) * low_confidence
log_prob = F.log_softmax(pred, dim=1)
xent = -(one_hot * log_prob).sum(dim=1)
xent = xent.masked_select(ans != pad)
loss = (xent - normalizing).mean()
return loss
I'm training a chatbot with around 150k words in my vocabulary, and so the starting iterations, each log_softmax entry is about ~-11 and so the sum over each sentence position is around 3000 when I do xent = -(one_hot * log_prob).sum(dim=1)
, and so the average loss is around 3000 when I take the mean over all predictions. Does this sound reasonable ... it seems like 3000 for loss is kind of off the roof?
Thanks in advance.
python3 train.py --problem wmt32k --output_dir ./output --data_dir ./wmt32k_data
Traceback (most recent call last):
File "train.py", line 211, in
main()
File "train.py", line 156, in main
opt.batch_size, device, opt)
File "/home/vivien/Bureau/transformer-pytorch/dataset/problem.py", line 12, in prepare
translation.prepare(max_length, batch_size, device, opt, data_dir)
File "/home/vivien/Bureau/transformer-pytorch/dataset/translation.py", line 157, in prepare
pad_token=pad, lower=True, eos_token='')
TypeError: init() got an unexpected keyword argument 'is_target'
transformer-pytorch/model/transformer.py
Line 87 in e726667
I noticed that according to the paper the query and key values are then divided by d_k before passing to softmax. I dont see it in the code, did i miss anything? Thank you!
transformer-pytorch/model/transformer.py
Line 282 in 5cf29c0
I noticed that you do slicing on the second dimension but do padding on the first dimension (top padding).
For example, if the size of target_embedded is 100x100, after doing this, its size is 101x99.
I am a little confused.
File "decoder.py", line 99, in main
t_self_mask = utils.create_trg_self_mask(targets.size()[1], device=targets.device)
UnboundLocalError: local variable 'targets' referenced before assignment
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.