Comments (3)
Ooops! There seems to be a bug in torch_bleu
function. I only tested the code with ngram = 2
. Thank you very much for your suggestions, and I will update the code asap.
from cont.
FYI, this is a tested sample of n_gram_precision, together calculated with bretivy penalty I just wrote.
Take a look!
To obtain the full bleu score, you may take the average of n_gram in range(1, 5)
def n_gram_precision(ref_tensor, sys_tensor, pad_id, n_gram=4):
"""
Calculates n-gram precision with brevity penalty.
ref_tensor: batch x seq_len1
sys_tensor: batch x sample_num x seq_len2
"""
# Determine batch size, sample count(=beam size), n-gram
bsz, sample_num = sys_tensor.size(0), sys_tensor.size(1)
n = min(min(n_gram, ref_tensor.size(-1)), sys_tensor.size(-1))
# Generate masks
ref_padding = (~(ref_tensor == pad_id)).float()
ref_ngram_mask = torch.arange(0, ref_padding.size(1), device=ref_padding.device) * torch.ones_like(ref_padding)
ref_ngram_mask = torch.where(
ref_ngram_mask < (torch.sum(ref_padding, dim=-1, keepdims=True) - n + 1),
ref_padding, torch.zeros_like(ref_padding)
)[:, :ref_ngram_mask.size(-1) - n + 1]
sys_padding = (~(sys_tensor == pad_id)).float()
sys_ngram_mask = torch.arange(0, sys_padding.size(-1), device=sys_padding.device) * torch.ones_like(sys_padding)
sys_ngram_mask = torch.where(
sys_ngram_mask < (torch.sum(sys_padding, dim=-1, keepdims=True) - n + 1),
sys_padding, torch.zeros_like(sys_padding)
)[:, :, :sys_ngram_mask.size(-1) - n + 1]
# Get n-grams
ref_tensor = ref_tensor * ref_padding # mask out paddings
sys_tensor = sys_tensor * sys_padding
ref_tensor = ref_tensor[:, None, :].repeat(1, sample_num, 1) # readjust ref size to match sys
input_tensor1_ngram = form_ngram(ref_tensor, n).float()
input_tensor2_ngram = form_ngram(sys_tensor, n).float() # batch x sample_num x seq_len-(n-1) x n
# Calculate similarity matrix
sim_matrix = (torch.norm( # Calculate L2 norm to find if N-gram in `sys`` is present in `ref``
input_tensor2_ngram.unsqueeze(3) - input_tensor1_ngram.unsqueeze(2),
p=2, dim=-1
) == 0.0).to(torch.float)
# print(sim_matrix.size(), sys_ngram_mask.size(), ref_ngram_mask.size())
sim_matrix *= sys_ngram_mask.unsqueeze(3) * ref_ngram_mask.unsqueeze(1).unsqueeze(2)
sim_matrix = torch.sum(torch.max(sim_matrix, dim=-1).values, dim=-1)
# Brevity penalty
ref_len = torch.sum(ref_padding, dim=-1, keepdims=True)
sys_len = torch.sum(sys_padding, dim=-1)
bp = torch.exp(1 -(ref_len / sys_len))
bp = torch.where(ref_len >= sys_len, bp, torch.ones_like(bp))
return sim_matrix / torch.sum(sys_ngram_mask, dim=-1) * bp # batch x sample_num
from cont.
I have updated the code! Thank you again for your effort !!
from cont.
Related Issues (9)
- About target feature representations from the decoder. HOT 1
- 一点疑问 HOT 2
- When will the code be made avaliable? HOT 3
- 代码的问题 HOT 1
- 数据读取部分的问题 HOT 3
- 一个读取模型部分的小bug HOT 1
- Eval时的报错 HOT 1
- 关于运行配置和训练时间 HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from cont.