GithubHelp home page GithubHelp logo

Comments (8)

LaurentMazare avatar LaurentMazare commented on June 22, 2024 1

This one is different, you can only index in a tensor with an integer tensors so u32 f16 makes sense but f16 f16 wouldn't as the index cannot be a float.

from candle.

LaurentMazare avatar LaurentMazare commented on June 22, 2024 1

Yeah there is no real limitation for this, I've made #2036 for mamba. It works with bf16 but not with f16 (which is somewhat expected, models trained in f32 or bf16 are likely to break with f16). On my RTX 4080, speed slightly increases from 320 token/s to 360 token/s so wouldn't consider it as a big improvement.

from candle.

LaurentMazare avatar LaurentMazare commented on June 22, 2024

Is it possible that it's only when retrieving the results back with to_vec<f32> or equivalent. With a var store, the conversion should be handled for you. When retrieving results, we error out when converting values back, hence you have to call to_dtype beforehand.
A good way to know where the issue is coming from is to enable RUST_BACKTRACE=1, also you probably want the profile release-with-debug or the equivalent from your project so that line numbers are properly related.
(just trying to guess here, happy to give a more in depth look if you can provide a simple repro)

from candle.

jorgeantonio21 avatar jorgeantonio21 commented on June 22, 2024

Thank you @LaurentMazare ! The issue was, I believe, I was not converting the input tensor to a dtype other than f32. I refactored the code from

for &token in tokens.iter() {
            let input = Tensor::new(&[token], &self.device)?;
            let logits = self.model.forward(&input, &mut state)?;

            next_logits = Some(logits);
            if let Some(t) = self.tokenizer.next_token(token)? {
                output.push_str(t.as_str());
            }
        }

to

for &token in tokens.iter() {
            let input = Tensor::new(&[token], &self.device)?.to_dtype(self.device)?;
            let logits = self.model.forward(&input, &mut state)?;

            next_logits = Some(logits);
            if let Some(t) = self.tokenizer.next_token(token)? {
                output.push_str(t.as_str());
            }
        }

However, running the later code on my Macbook (with Metal features) I get the following error:

Candle error: Metal contiguous index_select BF16 BF16 not implemented`

Is it the case that current metal kernels do not support types other than f32 ?

from candle.

LaurentMazare avatar LaurentMazare commented on June 22, 2024

Most metal ops should support f32, f16, and bf16, this one was missing somehow so I added it in #2035 That said, my macbook doesn't support bf16 so I wasn't able to really test but hopefully that will work for you.

from candle.

jorgeantonio21 avatar jorgeantonio21 commented on June 22, 2024

Thanks a lot for the PR ! Unfortunately, I also have the same issue with other dtypes, including f16:

Candle error: Metal contiguous index_select F16 F16 not implemented

from candle.

jorgeantonio21 avatar jorgeantonio21 commented on June 22, 2024

I see, right. It seems though that many of these models do not have support for f16 or bf16. Without erroneously converting the indices to f16, I am getting this error:

dtype mismatch in mul, lhs: F16, rhs: F32.

I am running these experiments on mamba and falcon, and from the implementation it seems these models do not support other dtypes other than f32 (mamba state is hardcoded to be in f32 precision, whereas falcon the mask is also hardcoded on f32 precision.

I wonder, if it is possible to allow other precision types for these models (including f16 and bf16) ?

from candle.

jorgeantonio21 avatar jorgeantonio21 commented on June 22, 2024

This is interesting, on my Macbook pro machine it works with f16, but not with bf16. Thanks for the PR @LaurentMazare, it would be great to have this for both Llama and Falcon models, too.

from candle.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.