bilstm_attn's Introduction
bilstm_attn's People
Forkers
liweiweng jonecherry jecktion jiangchenglin521 whatyouknow123 bxy123456 li1tian fangego tiaotiaosong swirlingcloud li-study jeniyat bangguowei zhaoyuanyuan2011 yunpengfu yuanjianggit whs1111 yanghaihuo jasdasdf linhongxiang zqf0722 aduoge lxy1993 chihuanbin hdu-jimlau qianrenjian lichao88 andrealee kdongyi zoe-zhanghan xiaolinpeter amygo123 mmountains 16xhli22 1895-art qzqzzw lxm0708 canye520 yycsu aiedward jiajunqiu hhucs emperorlu wonqiao ningshiqi baixue1 github-sci andrew2019github ztz695183179 hlyu-hit travel678 caojiabao xxd0626 lxy613bilstm_attn's Issues
pytorch的版本是多少啊
pytorch的版本是多少啊
我认为的一些bug
在model.py中
1、不能直接reshape( [-1, self.sequence_length]),应该reshape([self.sequence_length,-1])再用permute交换维度,或者一开始就交换维度,之所以这个bug没有表现出来,请参考2
def attention_net(self, lstm_output):
#print(lstm_output.size()) = (squence_length, batch_size, hidden_size*layer_size)
output_reshape = torch.Tensor.reshape(lstm_output, [-1, self.hidden_size*self.layer_size])
#print(output_reshape.size()) = (squence_length * batch_size, hidden_size*layer_size)
attn_tanh = torch.tanh(torch.mm(output_reshape, self.w_omega))
#print(attn_tanh.size()) = (squence_length * batch_size, attention_size)
attn_hidden_layer = torch.mm(attn_tanh, torch.Tensor.reshape(self.u_omega, [-1, 1]))
#print(attn_hidden_layer.size()) = (squence_length * batch_size, 1)
exps = torch.Tensor.reshape(torch.exp(attn_hidden_layer), [-1, self.sequence_length]) #这里不能直接reshape( [-1, self.sequence_length]),应该reshape([self.sequence_length,-1])再用permute交换维度,或者一开始就交换维度,下面要做相应修改,有类似错误
#print(exps.size()) = (batch_size, squence_length)
alphas = exps / torch.Tensor.reshape(torch.sum(exps, 1), [-1, 1])
#print(alphas.size()) = (batch_size, squence_length)
alphas_reshape = torch.Tensor.reshape(alphas, [-1, self.sequence_length, 1])
#print(alphas_reshape.size()) = (batch_size, squence_length, 1)
state = lstm_output.permute(1, 0, 2)
#print(state.size()) = (batch_size, squence_length, hidden_size*layer_size)
attn_output = torch.sum(state * alphas_reshape, 1) #alphas_reshape的值始终为0.0625,
#print(attn_output.size()) = (batch_size, hidden_size*layer_size)
return attn_output
2、你的u_omega,w_omega类型有问题,应该是Parameter,并且不应该初始化为0。你现在这样做的结果就是u_omega,w_omega始终为0,优化器不会更新它,alphas_reshape的值始终为0.0625,
The bug of the code in main.py
when I tried to run your code, I have found some bugs.
The bugs are in main.py
def train function, return should be total_loss/training_data.sents_size not total_loss[0] /training_data.sents_size
similaritly, the def evaluate function ,return should be eval_loss/_size, not eval_loss[0]
my environments are
python3.6
pytorch 1.1.0
torchvision 0.3.0
RuntimeError: Expected tensor for argument #1 'indices' to have scalar type Long; but got torch.IntTensor instead (while checking arguments for embedding)
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.