GithubHelp home page GithubHelp logo

Comments (3)

ChenxinAn-fdu avatar ChenxinAn-fdu commented on August 21, 2024

Ooops! There seems to be a bug in torch_bleu function. I only tested the code with ngram = 2 . Thank you very much for your suggestions, and I will update the code asap.

from cont.

jinulee-v avatar jinulee-v commented on August 21, 2024

FYI, this is a tested sample of n_gram_precision, together calculated with bretivy penalty I just wrote.
Take a look!

To obtain the full bleu score, you may take the average of n_gram in range(1, 5)

def n_gram_precision(ref_tensor, sys_tensor, pad_id, n_gram=4):
    """
    Calculates n-gram precision with brevity penalty.

    ref_tensor: batch x seq_len1
    sys_tensor: batch x sample_num x seq_len2
    """
    # Determine batch size, sample count(=beam size), n-gram
    bsz, sample_num = sys_tensor.size(0), sys_tensor.size(1)
    n = min(min(n_gram, ref_tensor.size(-1)), sys_tensor.size(-1))

    # Generate masks
    ref_padding = (~(ref_tensor == pad_id)).float()
    ref_ngram_mask = torch.arange(0, ref_padding.size(1), device=ref_padding.device) * torch.ones_like(ref_padding)
    ref_ngram_mask = torch.where(
        ref_ngram_mask < (torch.sum(ref_padding, dim=-1, keepdims=True) - n + 1),
        ref_padding, torch.zeros_like(ref_padding)
    )[:, :ref_ngram_mask.size(-1) - n + 1]
    sys_padding = (~(sys_tensor == pad_id)).float()
    sys_ngram_mask = torch.arange(0, sys_padding.size(-1), device=sys_padding.device) * torch.ones_like(sys_padding)
    sys_ngram_mask = torch.where(
        sys_ngram_mask < (torch.sum(sys_padding, dim=-1, keepdims=True) - n + 1),
        sys_padding, torch.zeros_like(sys_padding)
    )[:, :, :sys_ngram_mask.size(-1) - n + 1]

    # Get n-grams
    ref_tensor = ref_tensor * ref_padding # mask out paddings
    sys_tensor = sys_tensor * sys_padding
    ref_tensor = ref_tensor[:, None, :].repeat(1, sample_num, 1) # readjust ref size to match sys
    input_tensor1_ngram = form_ngram(ref_tensor, n).float()
    input_tensor2_ngram = form_ngram(sys_tensor, n).float()  # batch x sample_num x seq_len-(n-1) x n

    # Calculate similarity matrix
    sim_matrix = (torch.norm( # Calculate L2 norm to find if N-gram in `sys`` is present in `ref``
        input_tensor2_ngram.unsqueeze(3) - input_tensor1_ngram.unsqueeze(2),
        p=2, dim=-1
    ) == 0.0).to(torch.float)
    # print(sim_matrix.size(), sys_ngram_mask.size(), ref_ngram_mask.size())
    sim_matrix *= sys_ngram_mask.unsqueeze(3) * ref_ngram_mask.unsqueeze(1).unsqueeze(2)
    sim_matrix = torch.sum(torch.max(sim_matrix, dim=-1).values, dim=-1)

    # Brevity penalty
    ref_len = torch.sum(ref_padding, dim=-1, keepdims=True)
    sys_len = torch.sum(sys_padding, dim=-1)
    bp = torch.exp(1 -(ref_len / sys_len))
    bp = torch.where(ref_len >= sys_len, bp, torch.ones_like(bp))

    return sim_matrix / torch.sum(sys_ngram_mask, dim=-1) * bp  # batch x sample_num

from cont.

ChenxinAn-fdu avatar ChenxinAn-fdu commented on August 21, 2024

I have updated the code! Thank you again for your effort !!

from cont.

Related Issues (9)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.