GithubHelp home page GithubHelp logo

Comments (2)

clarkkev avatar clarkkev commented on June 22, 2024

scatter_nd sums up values with the same index, but we want to just pick a single value per index. If the values for an index are the same (e.g., we are just scattering a mask tensor of 1s), then dividing the summed values by the number of occurrences at that index fixes the issue. That's what is implemented in the code you linked to.

However, this does NOT fully fix having duplicate mask positions. It does stop there from being errors due to overflowing above the vocab size, but if (1) the same position is sampled twice and (2) the generator samples different tokens for the position then the replaced token will be a "random" token obtained by averaging the two sampled token ids. I didn't bother fixing this issue when developing ELECTRA because this occurs for a pretty small fraction of masked positions. But there actually is an easy fix: replace the sampling step (masked_lm_positions = tf.random.categorical... in pretrain_helpers.mask) with

masked_lm_positions = tf.math.top_k(
      sample_logits + sample_gumbel(
          modeling.get_shape_list(sample_logits)), N)[1]

This will do sampling without replacement when picking mask positions.

from electra.

zheyuye avatar zheyuye commented on June 22, 2024

That is clear. Thank you for the explanation.

from electra.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.