GithubHelp home page GithubHelp logo

mbrukman / grokking-deep-learning-rs Goto Github PK

View Code? Open in Web Editor NEW

This project forked from suyash/grokking-deep-learning-rs

0.0 1.0 0.0 154 KB

Grokking Deep Learning Examples implemented in Rust

License: Apache License 2.0

Rust 100.00%

grokking-deep-learning-rs's Introduction

Grokking Deep Learning Rust

Build Status

The exercises from the @iamtrask book Grokking Deep Learning implemented in rust.

This crate isn't published, because ideally you'd do this on your own, but if you insist

cargo add grokking_deep_learning_rs --git https://github.com/suyash/grokking-deep-learning-rs

This crate is structured as a library, with the core library describing some common primitives used throughout and the individual chapters implemented in the exercises. To run the exercises from a particular chapter, for example chapter 12

cargo run --example chapter12

Currently this uses rulinalg for matrix operations, which uses a Rust implementation of dgemm and provides a 3x performance over normal ijk multiplication (see included benchmark). However, it still isn't as fast as numpy because it isn't multi-threaded. Currently working on something of my own.

The datasets are extracted into a separate library crate, which currently provides functions for loading 4 datasets, and an iterator for batching and shuffling. Planning to add more. Can be added using

cargo add datasets --git https://github.com/suyash/datasets

As a result of slower matmul, chapter 8 onwards, certain examples are smaller in size compared to the python examples.

The Chapter 13 core components were extracted into the core library, so they could be used in later chapters.

So, something like

use rulinalg::matrix::Matrix;

use grokking_deep_learning_rs::activations::{Sigmoid, Tanh};
use grokking_deep_learning_rs::layers::{Layer, Linear, Sequential};
use grokking_deep_learning_rs::losses::{Loss, MSELoss};
use grokking_deep_learning_rs::optimizers::{Optimizer, SGDOptimizer};
use grokking_deep_learning_rs::tensor::Tensor;

let data = Tensor::new_const(Matrix::new(
    4,
    2,
    vec![0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0],
));

let target = Tensor::new_const(Matrix::new(4, 1, vec![0.0, 1.0, 0.0, 1.0]));

let model = Sequential::new(vec![
    Box::new(Linear::new(2, 3)),
    Box::new(Tanh),
    Box::new(Linear::new(3, 1)),
    Box::new(Sigmoid),
]);

let criterion = MSELoss;
let optim = SGDOptimizer::new(model.parameters(), 0.5);

for _ in 0..10 {
    let pred = model.forward(&[&data]);

    // compare
    let loss = criterion.forward(&pred[0], &target);

    println!("Loss: {:?}", loss.0.borrow().data.data());

    // calculate difference
    loss.backward(Tensor::grad(Matrix::ones(1, 1)));

    // learn
    optim.step(true);
}

In Chapter 14, the RNN and LSTM examples have vanishing gradients and loss keeps going to NaN. There seems to be some kind of logic bomb in the code, where something is not doing what I think it does, still investigating. I tried reproducing the problem in chapter 13 final exercise and also implemented min-char-rnn.py in Rust, but no luck so far.

For Chapter 15, the encrypted federated learning exercise is not implemented. There does exist a crate for paillier homomorphic crypto, but the current implementation only works with integers and BigInts, not floating point numbers. Will try to see how to get it to work.

License

This project is licensed under either of

at your option.

Contribution

Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion in this work by you, as defined in the Apache-2.0 license, shall be dual licensed as above, without any additional terms or conditions.

grokking-deep-learning-rs's People

Contributors

suyash avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.