GithubHelp home page GithubHelp logo

hanhwi / sage Goto Github PK

View Code? Open in Web Editor NEW

This project forked from eis-lab/sage

0.0 0.0 0.0 169 KB

Experimental deep learning framework written in Rust

License: Apache License 2.0

C 22.04% Rust 77.96%

sage's Introduction

Sage

Sage is an experimental deep learning framework written in Rust. Sage is designed for building high-performance differentiable programs with complex runtime logic. Ideally, it aims to bring PyTorch-level flexibility and TVM-level performance together by leveraging lazy evaluation and JIT compilation.

Core features:

  • Lazy and incremental tensor evaluation
  • Optimized JIT compilation (OpenCL)
  • Efficient runtime memory management

Disclaimer: Sage is still in a very early stage of development. Numerical correctness of operation is not guaranteed. There will be breaking API changes without prior notice.

Installation

The core framework of Sage is written in pure Rust, but it depends on OpenCL for GPU support. Please check whether the system has an OpenCL driver installed. For Android builds, it is necessary to link the OpenCL library (i.e., libOpenCL.so) extracted from the target platform.

Documentation

Visit sage.rs for examples and documentation (work in progress)

Example

Basic usage

Tensors and Variables

// Context specifies the processor (e.g., GPU) that executes the program.
let mut ctx = Context::with_device(2);

// Tensors are n-dimension array
let x_data = Tensor::new([
    [0.5173, -0.9896, -0.7773],
    [0.1546, -0.7499, 0.2420],
    [-1.6632, 1.0712, -0.2654],
]).to_device(&mut ctx);

// Variables hold (un)evaluated tensors.
let x = Var::new(x_data);

let y = Var::new(Tensor::new([
    [0.5173, -0.9896, -0.7773],
    [0.1546, -0.7499, 0.2420],
    [-1.6632, 1.0712, -0.2654],
]).to_device(&mut ctx));

Lazy evaluation

// New variable is created as a result of operation
// There are no actual computations at this moment
let z = &x * &y + (&x * 3.14);

// Tensor is evaluated when eval() is called
let z_data = z.eval(&mut ctx);
println!("{:?}", z_data);

// Because c already contains evaluated tensor,
// this only computes addition of the two tensors
let u_data = (&z + &x).eval(&mut ctx);
println!("{:?}", u_data);

Basic operators

// Arithmetic operators
let y = (&x * &x - &x) / &x;

// Math functions
x.abs(); x.log(); x.exp(); x.sqrt(); x.erf(); ...

// Trigonometric functions
x.sin(); x.sinh(); x.asin(); x.asinh(); ...

// Rounding functions
x.round(); x.ceil(); x.floor(); ...

// Logical operators
and(&x, &y); or(&x, &y); gt(&x, &y); le(&x, &y); ...

// Conditional operator (ternary operator)
cond(gt(&x, 0.0), &x, &y);

// Datatype casting
x.int(); x.float(); ...

Tensor shaping

// Tensor extent (i.e., shape() in NumPy)
assert_eq!(x.extents(), &[3, 3]);

// Tensor rank (i.e., ndim() in NumPy)
assert_eq!(x.rank(), 2);

// For binary operations, tensor shapes are broadcasted
// (c.f., https://numpy.org/doc/stable/user/basics.broadcasting.html)
let y = &x + Tensor::new([[1.0], [2.0], [3.0]]);

// Shape manipulations
x.transpose(0, 1);
x.permute([1, 0]);
x.unsqueeze(0).squeeze(0);
x.expand([1, 3, 3]);
x.reshape([1, 9]);

Indexing operators

// Slicing
x.slice(0, 0, 2);

// Concatenation
concat([&x, &y, &z]);

// Gather and scatter
let t = Tensor::new([
    [[1, 1, 1], [1, 1, 1], [1, 1, 1]],
]);
let y = x.gather(t, 0);
let x = y.scatter(t, [3, 3]);

Reduction operators

// Summation
x.sum([0, 1], true);

// Product
x.prod(0, true);

// Minimum and maximum
x.min(0, true);
x.max(0, true);

// Example: softmax cross entropy
fn log_sum_exp(x: Var, axes: Vec<usize>) -> Var
{
    let c = x.max(&axes, true);
    (x - &c).exp().sum(&axes, true).log() + c
}

fn softmax_cross_entropy(x1: Var, x2: Var) -> Var
{
    let log_z = &x1 - log_sum_exp(&x1, 1);
    let log_p = log_z.gather(x2, 1); //log_z * x2;

    -log_p.sum(1, false)
}

Contraction operators

// Matrix multiplication
x.matmul(&y);

// Batched matrix multiplication
x.batch_matmul(&y);

Automatic differentiation

All operations defined for Variable is differentiable. The gradient of a variable can be obtained by grad() function.

let x_data = Tensor::new([
    [0.5173, -0.9896, -0.7773],
    [0.1546, -0.7499, 0.2420],
    [-1.6632, 1.0712, -0.2654],
]).to_device(&mut ctx);

// Variables hold (un)evaluated tensors.
let x = Var::new(x_data);
let y = (&x + 3.0) * (&x + 5.5);

let gy = grad(&y, [&x]);

// Get gradient of x
let gygx = gy.get(&x).unwrap();

// Higher-order differentiation is also possible
let ggygx = grad(gygx, [&x]);
let ggyggx = ggygx.get(&x).unwrap();

println!("{:?}", ggyggx.eval(&mut ctx));

Neural Networks

Sage provide basic set of neural network operators required to implement basic DNN models.

Defining a model

Visit src/model for more advanced examples, such as ResNet , DenseNet, MobileNet v2, and BERT.

let mut model = layers::Sequential::new();

model
.add(layers::Conv2d::new(1, 64, [3, 3]))
.add(layers::Relu)
.add(layers::MaxPool2d::new([2, 2]))
.add(layers::Conv2d::new(64, 128, [3, 3]))
.add(layers::Relu)
.add(layers::MaxPool2d::new([2, 2]))
.add(layers::Conv2d::new(128, 128, [3, 3]))
.add(layers::Relu)
.add(layers::Flatten)
.add(layers::Dense::new(3 * 3 * 128, 64))
.add(layers::Relu)
.add(layers::Dense::new(64, 10));

let logits = model.pass(&x);

Training a model

Several momentum-based optimizers (e.g., Adam) are available.

println!("{:?}", Device::get_list());

let mut ctx = Context::new();

let mut model = ResNet::new(ResNetConfig::d18(1, 10));

let batch_size = 128;
let num_epoch = 30;
let learning_rate = 1e-4;

let dataset = Mnist::from_source(
    "./dataset/mnist/train-images.idx3-ubyte",
    "./dataset/mnist/train-labels.idx1-ubyte",
).unwrap();

let mut optimizer = Adam::new(learning_rate);

model.init(&mut ctx, 0);
optimizer.init(&mut ctx);

let input = Var::empty([batch_size, 28, 28, 1], DataType::Float);
let label = Var::empty([batch_size, 1], DataType::Uint);

let logits = model.pass(&input);

let loss = softmax_cross_entropy(&logits, &label).mean(0, false);
let grads = grad_param(&loss, &model);
let acc = accuracy(&logits, &label);

let p = Program::compile(&[], grads.values().chain([&loss, &acc]));

for i in 0..num_epoch {
    for (j, (images, labels)) in dataset.iter().batch(batch_size, Mnist::collate).enumerate() {
        let (images, labels) = (images.to_device(&mut ctx), labels.to_device(&mut ctx));
        
        input.set(images);
        label.set(labels);
        
        p.exec(&mut ctx);
        
        optimizer.update(&grads, &mut ctx);
        
        println!(
            "epoch {:?} / batch {:?} / acc: {:?} / loss: {:?}",
            i,
            j,
            acc.eval(&mut ctx).to_host().scalar::<f32>(),
            loss.eval(&mut ctx).to_host().scalar::<f32>(),
        );
        
        ctx.data.clear();
    }
}

Runtime memory management

Sage has several built-in tensor memory management strategies to support large-scale model training and memory-constrained computing environments. Please read our paper on memory-efficient on-device training for more details.

License

Sage is licensed under either of Apache License, Version 2.0 or MIT License at your option.

sage's People

Contributors

ingim avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.