Comments (2)
I was only able to reproduce the bug on the ndarray backend, it seems to work on the tch backend. You can see the test on the branch: fix/max_dim_gather
from burn.
Hmm, odd that it didn't reproduce with tch! I had whittled down the example to be minimal, and indeed, that one doesn't cause a panic on tch for me either. Perhaps this is two bugs in a trenchcoat pretending to be one!
Here's a specific snippet which does crash for me:
let a: Vec<f32> = vec![-0.35060948, -0.6759874, -1.2398422, -0.55234957];
let b = [2, 2, 2, 3, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 2, 3, 2, 3, 2, 2];
let b: Tensor<Autodiff<LibTorch>, 2, Int> =
Tensor::from_data(Data::from(b.as_slice()), &LibTorchDevice::default()).reshape([5, 4]);
let a = Tensor::from_data(Data::from(a.as_slice()), &LibTorchDevice::default())
.reshape([1, 4])
.require_grad();
let grammar: Tensor<_, 2> = a.clone().repeat(0, 5);
let loss = grammar.gather(1, b);
let loss = loss.clone().max_dim(0) + loss;
let loss = loss.sum();
let g = loss.backward();
from burn.
Related Issues (20)
- Implement multi-dimensional repeat operation and rename existing repeat method HOT 2
- [Tensor] Add `cumsum` operation HOT 1
- .select_assign does not work with Autodiff<NdArray> backend
- Add indentation in contributing book
- Text classification example gives "Shader validation error" when run on multiple GPUs HOT 5
- Upgrade all dependencies
- Better memory management in Burn Compute
- Config Derive: Generic Types? HOT 2
- Optimizer / Visitor / Mapper confusion, no documentation HOT 4
- clamp_min does not handle -inf correctly on Autodiff<NdArray> backend
- Update tch to 0.16+
- Add multi-stream support to all the different backends.
- Add application logger strategy to learner builder
- Improve pickle (`CandleTensor`) conversions to `NestedValue`
- Add `squeeze_dims` function
- Building failed. Err: Is gcc.exe installed HOT 1
- Bug with element types in JIT when using all(), related to PRNG
- Crate libc 0.2.154 is yanked
- Feature: Burn equivalent to torch.retain_grad
- Burn-WGPU tests fail on Windows with Radeon 6950
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from burn.