GithubHelp home page GithubHelp logo

candle-flash-attn-v1's People

Contributors

olivierdehaene avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

Forkers

yinqiwen

candle-flash-attn-v1's Issues

non-exhaustive patterns: `&Storage::Metal(_)` not covered

when i run cargo build, error is:

error[E0004]: non-exhaustive patterns: `&Storage::Metal(_)` not covered
  --> /root/.cargo/git/checkouts/candle-flash-attn-v1-b43982c1dfc19b4b/62b75f1/src/lib.rs:53:31
   |
53 |         let seqlens_k = match &*seqlens_k {
   |                               ^^^^^^^^^^^ pattern `&Storage::Metal(_)` not covered
   |
note: `Storage` defined here
  --> /root/.cargo/registry/src/index.crates.io-6f17d22bba15001f/candle-core-0.3.1/src/storage.rs:11:5
   |
8  | pub enum Storage {
   | ----------------
...
11 |     Metal(MetalStorage),
   |     ^^^^^ not covered
   = note: the matched value is of type `&Storage`
help: ensure that all possible cases are being handled by adding a match arm with a wildcard pattern or an explicit pattern as shown
   |
55 ~             candle::Storage::Cuda(c) => c.as_cuda_slice::<u32>()?,
56 ~             &Storage::Metal(_) => todo!(), // Should be i32!
   |

For more information about this error, try `rustc --explain E0004`.
error: could not compile `candle-flash-attn-v1` (lib) due to 2 previous error

my code is ๏ผš

use candle_ext::{candle::{D, DType, Device, Result, Tensor}, TensorExt, F};

use candle_flash_attn_v1::flash_attn_varlen;

fn main() -> Result<()> {
    let device = Device::new_cuda(0)?;
    let q = Tensor::randn(0., 1., (3, 3, 2, 4), &device)?;
    let k = Tensor::randn(0., 1., (3, 3, 2, 4), &device)?;
    let v = Tensor::randn(0., 1., (3, 3, 2, 4), &device)?;
    let m = Tensor::ones((q.dim(D::Minus2)?, k.dim(D::Minus2)?), DType::U8, &device)?.tril(0)?;

    let o = F::scaled_dot_product_attention(&q, &k, &v, Some(&m), None, None, Some(1.0))?;
    println!("{:?}", o);

    let dims = q.dims().to_vec();
    let (batch_size, seq_len, _n_heads, _d) = (dims[0], dims[1], dims[2], dims[3]);
    let seqlens_q = Tensor::arange_step(0f32, (batch_size as f32 + 1.0) * seq_len as f32, seq_len as f32, &device)?;
    let seqlens_k = Tensor::arange_step(0f32, (batch_size as f32 + 1.0) * seq_len as f32, seq_len as f32, &device)?;
    let o1 = flash_attn_varlen(&q, &k, &v, &seqlens_q, &seqlens_k, seq_len , seq_len, 1.0, true);
    println!("{:?}", o1);
    Ok(())
}

How can I fix it? I just want to compare the results of these two attention calculation methods. Are they the same?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.