GithubHelp home page GithubHelp logo

Comments (10)

xenova avatar xenova commented on May 31, 2024 1

Sure, here's the code for it:

/**
* Perform mean pooling of the last hidden state followed by a normalization step.
* @param {Tensor} last_hidden_state Tensor of shape [batchSize, seqLength, embedDim]
* @param {Tensor} attention_mask Tensor of shape [batchSize, seqLength]
* @returns {Tensor} Returns a new Tensor of shape [batchSize, embedDim].
*/
export function mean_pooling(last_hidden_state, attention_mask) {
// last_hidden_state: [batchSize, seqLength, embedDim]
// attention_mask: [batchSize, seqLength]
let shape = [last_hidden_state.dims[0], last_hidden_state.dims[2]];
let returnedData = new last_hidden_state.data.constructor(shape[0] * shape[1]);
let [batchSize, seqLength, embedDim] = last_hidden_state.dims;
let outIndex = 0;
for (let i = 0; i < batchSize; ++i) {
let offset = i * embedDim * seqLength;
for (let k = 0; k < embedDim; ++k) {
let sum = 0;
let count = 0;
let attnMaskOffset = i * seqLength;
let offset2 = offset + k;
// Pool over all words in sequence
for (let j = 0; j < seqLength; ++j) {
// index into attention mask
let attn = Number(attention_mask.data[attnMaskOffset + j]);
count += attn;
sum += last_hidden_state.data[offset2 + j * embedDim] * attn;
}
let avg = sum / count;
returnedData[outIndex++] = avg;
}
}
return new Tensor(
last_hidden_state.type,
returnedData,
shape
)
}

Nothing too fancy... and it assumes certain dimensions for the input.

from transformers.js.

mahi83 avatar mahi83 commented on May 31, 2024 1

You saved me hours! :P

from transformers.js.

xenova avatar xenova commented on May 31, 2024

The most likely reason is due to quantisation of the models. The model weights are reduced in precision from 32-bit to 8-bit to reduce model size by a factor of ~4 (very important for usage on a website).

However, if you are okay with loading the full model, you can export the model yourself without quantising it, and this should produce the exact same outputs. The conversion script provided uses huggingface's optimum library under the hood to do the conversion, and they generally match the accuracy quite well.

Here's an example command (without --quantize)

python ./scripts/convert.py --model_id sentence-transformers/all-MiniLM-L6-v2  --from_hub --task default

and then just point to the location of your model (see readme)

from transformers.js.

xenova avatar xenova commented on May 31, 2024

That said, while trying to do this on my end, I did run into an issue where the pooled value wasn't being returned (most likely due to the newest version of optimum, which removed some of those nodes). So, I will implement the mean pooling myself (see here)

That should fix everything :) (since you will be able to use the original model, which is only 80MB, so, nothing too problematic)

from transformers.js.

ekolve avatar ekolve commented on May 31, 2024

Thanks for the quick reply!

I ran convert.py and generated an unquantized model and modified my code to look like this:

global.self = global;

const { pipeline, env } = require("@xenova/transformers");
env.onnx.wasm.numThreads = 1;
env.remoteModels = false;
env.localURL = "transformers.js/models/onnx/unquantized";

(async()=> {
        let embedder = await pipeline('embeddings', 'sentence-transformers/all-MiniLM-L6-v2')
        let sentences = [
            'The quick brown fox jumps over the lazy dog.'
        ]
        let output = await embedder(sentences)
        console.log(output[0].length);
})();

and I now get this error:

TypeError: Cannot read properties of undefined (reading 'data')
    at Function._call (node-transformers/node_modules/@xenova/transformers/src/pipelines.js:286:51)

which corresponds to this line in pipeline.js:

        let embeddings = reshape(embeddingsTensor.data, embeddingsTensor.dims);

I confirmed that the path for the model was correct by altering the localUrl to an invalid path and transformers.js responded with File not found. So it appears that transformers.js is finding the exported model.

from transformers.js.

xenova avatar xenova commented on May 31, 2024

Yep, you did everything correct.

That's the exact error message I got now (which is because of the version of optimum used to export). Busy fixing now! :)

from transformers.js.

xenova avatar xenova commented on May 31, 2024

Okay - these changes should fix it: 851815b

This also fixes the original issue of the outputs being different (since I wasn't correctly performing mean pooling + normalization, as the library does). The outputs should be much closer to what they should be (when quantized), and nearly identical (when unquantized)

Before I make a full release, do you mind testing on your side to see if it functions correctly? I believe you can install an npm package directly from GitHub


Also: I updated it so it returns a tensor (instead of nested javascript lists), for efficiency reasons. To get back the list, just call .tolist() on the tensor.

from transformers.js.

ekolve avatar ekolve commented on May 31, 2024

Just tested this out and everything looks great! Thank you! I generated embeddings for two sentences in both JS and Python and calculated the cos similarity (in their respective libraries) and they were identical.

from transformers.js.

xenova avatar xenova commented on May 31, 2024

Awesome! I'll push a new release.

v1.3.1 is now live with the changes :) https://www.npmjs.com/package/@xenova/transformers

Thanks again for reporting!

from transformers.js.

mahi83 avatar mahi83 commented on May 31, 2024

That said, while trying to do this on my end, I did run into an issue where the pooled value wasn't being returned (most likely due to the newest version of optimum, which removed some of those nodes). So, I will implement the mean pooling myself (see here)

That should fix everything :) (since you will be able to use the original model, which is only 80MB, so, nothing too problematic)

Do you have the implementation of this mean pooling in js?

from transformers.js.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.