Comments (8)
This one is different, you can only index in a tensor with an integer tensors so u32 f16 makes sense but f16 f16 wouldn't as the index cannot be a float.
from candle.
Yeah there is no real limitation for this, I've made #2036 for mamba. It works with bf16
but not with f16
(which is somewhat expected, models trained in f32 or bf16 are likely to break with f16). On my RTX 4080, speed slightly increases from 320 token/s to 360 token/s so wouldn't consider it as a big improvement.
from candle.
Is it possible that it's only when retrieving the results back with to_vec<f32>
or equivalent. With a var store, the conversion should be handled for you. When retrieving results, we error out when converting values back, hence you have to call to_dtype
beforehand.
A good way to know where the issue is coming from is to enable RUST_BACKTRACE=1
, also you probably want the profile release-with-debug
or the equivalent from your project so that line numbers are properly related.
(just trying to guess here, happy to give a more in depth look if you can provide a simple repro)
from candle.
Thank you @LaurentMazare ! The issue was, I believe, I was not converting the input tensor to a dtype other than f32
. I refactored the code from
for &token in tokens.iter() {
let input = Tensor::new(&[token], &self.device)?;
let logits = self.model.forward(&input, &mut state)?;
next_logits = Some(logits);
if let Some(t) = self.tokenizer.next_token(token)? {
output.push_str(t.as_str());
}
}
to
for &token in tokens.iter() {
let input = Tensor::new(&[token], &self.device)?.to_dtype(self.device)?;
let logits = self.model.forward(&input, &mut state)?;
next_logits = Some(logits);
if let Some(t) = self.tokenizer.next_token(token)? {
output.push_str(t.as_str());
}
}
However, running the later code on my Macbook (with Metal features) I get the following error:
Candle error:
Metal contiguous index_select BF16 BF16 not implemented`
Is it the case that current metal kernels do not support types other than f32
?
from candle.
Most metal ops should support f32
, f16
, and bf16
, this one was missing somehow so I added it in #2035 That said, my macbook doesn't support bf16
so I wasn't able to really test but hopefully that will work for you.
from candle.
Thanks a lot for the PR ! Unfortunately, I also have the same issue with other dtypes, including f16
:
Candle error: Metal contiguous index_select F16 F16 not implemented
from candle.
I see, right. It seems though that many of these models do not have support for f16
or bf16
. Without erroneously converting the indices to f16
, I am getting this error:
dtype mismatch in mul, lhs: F16, rhs: F32
.
I am running these experiments on mamba and falcon, and from the implementation it seems these models do not support other dtypes other than f32
(mamba state is hardcoded to be in f32 precision, whereas falcon the mask is also hardcoded on f32 precision.
I wonder, if it is possible to allow other precision types for these models (including f16
and bf16
) ?
from candle.
This is interesting, on my Macbook pro machine it works with f16
, but not with bf16
. Thanks for the PR @LaurentMazare, it would be great to have this for both Llama and Falcon models, too.
from candle.
Related Issues (20)
- How to run inference of a (very) large model across mulitple GPUs ? HOT 2
- does candle support nvidia 2080ti on windows 11? HOT 2
- Metavoice with quantized model - "Non contiguous rmsnorm is not implemented" - on M1 Max using metal HOT 2
- How to specify which graphics card to run a task on in a server with multiple graphics cards? HOT 2
- Flash attention 2 support ? HOT 2
- Could someone please explain why this is happening? (batcher.rs seq_len:4294967040) HOT 1
- The output diverges in comparison to the Python implementation. HOT 5
- How to use CUDA as the backend in `candle-wasm-examples/llama2-c` ? HOT 2
- Support for tensors with 0-length dimensions
- Unsound usages of unsafe function `slice::from_raw_parts` HOT 2
- Error: DriverError(CUDA_ERROR_NOT_FOUND, "named symbol not found") when loading cast_f32_bf16 HOT 3
- Falcon example seems broken (on metal) HOT 3
- Error: DriverError(CUDA_ERROR_OUT_OF_MEMORY, "out of memory") with multiple GPU HOT 2
- wait_until_completed is not working for metal device HOT 3
- Falcon implementation issues HOT 2
- llama example seems broken HOT 4
- flash attention does not yield speed gains on llama example HOT 1
- [Possible bug?] Apparent memory leak in repeat_kv, KV cache HOT 3
- [Bug] Multiple batches fails in `quantized` example HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from candle.