Comments (9)
- When I print
loss
after line 175, I get below information
Variable containing:
92.9371
[torch.cuda.FloatTensor of size 1 (GPU 0)]
I guess you are using torch 0.4.0, but this code is written in 0.2.0. You can install the correct pytorch version and try again.
to_scalar
function is just for recording the current loss and print out in https://github.com/ZhixiuYe/HSCRF-pytorch/blob/master/train.py#L199. And it's ok to delete line 177.
from hscrf-pytorch.
Unfortunately, the rest of the code I am using is pytorch 0.4.0, so I can't mix and match and need to port the SCRF code. For some guidance, could you print some more sizes so I can see what I need to fix?
My tag size is 4 since I am not doing NER, and am trying to only get binary labels, so my tags are (no tag, tag, start, end). Using test data with batch size 10, I am getting the following sizes:
def forward(self, feats, mask_word, tags, mask_tag):
self.batch_size = feats.size(0)
self.sent_len = feats.size(1)
# feats: (10 x 48 x 256)
# mask_words (10)
# tags: (10, 40, 4)
# mask_tag: (10, 40)
feats = self.dense(feats)
self.SCRF_scores = self.HSCRF_scores(feats)
# self.SCRF_scores: (10, 48, 48, 4, 4)
forward_score = self.get_logloss_denominator(self.SCRF_scores, mask_word)
# forward_score: (1)
numerator = self.get_logloss_numerator(tags, self.SCRF_scores, mask_tag)
# numerator: (209)
loss = (forward_score - numerator.sum()) / self.batch_size
# loss: (209)
return loss
Here are the two functions annotated:
def get_logloss_numerator(self, goldfactors, scores, mask):
# mask: (10, 40)
batch_size = scores.size(0) # 10
sent_len = scores.size(1) # 48
tagset_size = scores.size(3) # 4
goldfactors = goldfactors[:, :, 0]*sent_len*tagset_size*tagset_size + goldfactors[:,:,1]*tagset_size*tagset_size+goldfactors[:,:,2]*tagset_size+goldfactors[:,:,3]
# goldfactors: (10, 40)
factorexprs = scores.view(batch_size, -1)
# factorexprs: (10, 36864)
val = torch.gather(factorexprs, 1, goldfactors)
# val: (10, 40)
numerator = val.masked_select(mask)
# numerator: (209)
return numerator
def get_logloss_denominator(self, scores, mask):
logalpha = Variable(torch.FloatTensor(self.batch_size, self.sent_len+1, self.tagset_size).fill_(-10000.)).cuda()
# logalpha: (10, 49, 4)
logalpha[:, 0, self.start_id] = 0.
istarts = [0] * self.ALLOWED_SPANLEN + range(self.sent_len - self.ALLOWED_SPANLEN+1)
# len(istarts): 49
for i in range(1, self.sent_len+1):
tmp = scores[:, istarts[i]:i, i-1] + \
logalpha[:, istarts[i]:i].unsqueeze(3).expand(self.batch_size, i - istarts[i], self.tagset_size, self.tagset_size)
tmp = tmp.transpose(1, 3).contiguous().view(self.batch_size, self.tagset_size, (i-istarts[i])*self.tagset_size)
max_tmp, _ = torch.max(tmp, dim=2)
tmp = tmp - max_tmp.view(self.batch_size, self.tagset_size, 1)
logalpha[:, i] = max_tmp + torch.log(torch.sum(torch.exp(tmp), dim=2))
mask = mask.unsqueeze(1).unsqueeze(1).expand(self.batch_size, 1, self.tagset_size)
# mask: (10,1,4)
alpha = torch.gather(logalpha, 1, mask).squeeze(1)
# alpha: (10,4)
return alpha[:,self.stop_id].sum() # return: (1)
=======================><=========================
Edit: As it turns out, I summed the wrong tensor - sizes are all correct. I am now getting a ton of leaf variable has been moved into the graph interior
errors, due to the indexing and overwriting in values in these functions. Did you encounter these errors when you built the model? How did you address this?
from hscrf-pytorch.
- In following code, obviously, loss should be size of 1.
forward_score = self.get_logloss_denominator(self.SCRF_scores, mask_word)
# forward_score: (1)
numerator = self.get_logloss_numerator(tags, self.SCRF_scores, mask_tag)
# numerator: (209)
loss = (forward_score - numerator.sum()) / self.batch_size
# loss: (209)
leaf variable has been moved into the graph interior
. I guess it'is because that in pytorch 0.4.0, the classVariable
has been removed and replaced bytensor
. But I'm not very familiar with pytorch 0.4.0 that I don;t know the details.
from hscrf-pytorch.
I managed to refactor this to torch.cat operations so the error is resolved. I now run into a problem that I can't quite understand from your code - your HSCRF_scores
functions only computes the likelihoods for positive labels, but keeps O/start/end at -1e5 (by setting it to values in the m30000
tensor). Where in your SCRF code do you actually compute the probability that a tag is O?
from hscrf-pytorch.
First of all, you can refer to this paper Semi-Markov Conditional Random Fields for Information Extraction
for some details about semi-Markov CRFs.
Actually, HSCRF_scores
is to computes scores
and the shape of scores
is (self.batch_size, self.sent_len, self.sent_len, self.tagset_size, self.tagset_size)
, which is corresponding to gk(j, x, s) in that paper instead of likelihoods.
from hscrf-pytorch.
Thanks for the link to the paper. It might be helpful to annotate your code with the corresponding equations to help code understanding. I still don't get why O is never scored. Eq(2) in your linked paper defines g^k
in terms of y_j
and y_{j-1}
, but the code is only scoring the different tags.
from hscrf-pytorch.
I get!
This line if span == 0:
, I calculate the score of O
, and I assume that the socre of O can be calculated only when its length is one, and when its length is more than one, we don't calculate its score.
from hscrf-pytorch.
I see - but even when I only print the result of the code for span length 0,
tmp = torch.cat((self.transition[:, :validtag_size].unsqueeze(0).unsqueeze(0) + emb_x[:, 0, :, :validtag_size].unsqueeze(2),
m10000,
self.transition[:, -2:].unsqueeze(0).unsqueeze(0) + emb_x[:, 0, :, -2:].unsqueeze(2)), 3)
scores[:, diag0, diag0] = tmp
every entry looks like this:
[[ 6.1834e-01, -1.0000e+04, -4.2706e-01, 2.8736e-01],
[-4.2289e-01, -1.0000e+04, -4.3145e-02, -1.0890e+00],
[-5.2040e-01, -1.0000e+04, -3.2427e-01, -1.1558e+00],
[-6.0971e-01, -1.0000e+04, 2.9183e-01, 5.1828e-01]],
I only have one tag, so the first entry is for that tag, the one for O is all not calculated at all, and the last two are START and STOP.
from hscrf-pytorch.
I add some annotations: c7142f2#diff-e90865298a808f704cff7317a658876e
These four entries are a tag(like PER), STOP, START and O respectively.
from hscrf-pytorch.
Related Issues (20)
- asking for the cuda OOM questions HOT 1
- tried to construct a tensor HOT 3
- What is the dataset? HOT 1
- version of PyTorch HOT 2
- Have you tried this method on a Chinese dataset? HOT 1
- sequence lenght mismatch HOT 1
- no
- Can you offer me the check point files? HOT 2
- out of memory HOT 4
- Question about scrf_to_crf in utils.py HOT 1
- logalpha initialize HOT 1
- Why the CoNLL 2003 NER dataset isn't annotated by BIOES tags as the paper described? HOT 2
- 有关 'HSCRF_scores' 函数的一个问题 HOT 5
- 请问有tensorflow版本的代码吗?
- 如何将HSCRF模型放入自己的模型做分类器使用?谢谢!
- 关于模型9
- 解释一下score,谢谢
- What is the format of the dataset? HOT 2
- Goldfactors format HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from hscrf-pytorch.