GithubHelp home page GithubHelp logo

a-mhamdi / neural-network-from-scratch-in-julia Goto Github PK

View Code? Open in Web Editor NEW
1.0 1.0 0.0 5.26 MB

Without using any deep learning frameworks, we construct and train a neural network architecture in Julia from the ground up.

License: MIT License

Julia 100.00%
julialang neural-networks backpropagation

neural-network-from-scratch-in-julia's Introduction

Neural Network From Scratch In Julia

Without using any deep learning frameworks, we construct and train a neural network architecture in Julia from the ground up.

Architecture's Design

The neural network (NN) is built from scratch and trained on some data. Here is a possible representation of the NN architecture:

model = [ # MLP
    Layer(num_features, num_neurons_1, relu; distribution='n'),
    Layer(num_neurons_1, num_neurons_2, relu; distribution='n'), 
    Layer(num_neurons_2, num_targets, softmax; distribution='n')
    ]

where num_features, num_targets and num_neurons_\d+ denote, respectively, the numbers of input features, output targets, and neurons in the hidden layers. The choice of distribution for the weights initialization is either a normal distribution (n) or uniform distribution (u). Both Xavier and He methods were implemented.

Some of the hyperparameters are configured as follows:

Settings(epochs, batch_size)

Model's Training

We can define a struct for the regularization as follows:

reg = Regularization(method, λ, r, dropout)

method can be symbol or string of one of the following: l1, l2, elasticnet, or none. The λ parameter is the regularization parameter. The r parameter determines the mix of penalties in case of elasticnet method. The dropout parameter is the dropout rate. loss and optimizer are accessed through:

Solver(loss, optimizer, learning_rate, reg)

loss can be :mae, :mse, :rmse, :binarycrossentropy or :crossentropy. :sgd is the default optimizer. The model is trained using the following method:

TrainNN(model, data_in, data_out, x_val, y_val; solver)

Under the hood, the TrainNN method calls the FeedForward and BackProp functions. The FeedForward method returns the pre-activations z and the post-activations a, bundled into data_cache. The method signature is:

data_cache = FeedForward(model, data_in; solver::Solver)

The BackProp method allows to return the loss and the gradients of the weights and biases: ∇W and ∇b, as follows:

loss, ∇W, ∇b = BackProp(model, data_cache, data_out; solver::Solver)

Detailed steps of the backpropagation algorithm are shown in the images below (source).

BackProp

The code is written in Julia. The main.jl file contains the primary code to setup the simulation. The required modules are located in the src folder.

Simulation Outcomes

Julia version and status of used packages are shown below:

Version info and pkgs status

The figure hereafter displays the model's loss for both the training and test sets at the end of each epoch.

Loss

The following provides specifics about confusion matrix, accuracy, precision, recall and f1-score metrics.

Metrics

Note: The code is not optimized for performance. It is written for educational purposes. There is always room for improvement.

TODO: Implement the following features:

  • parallelization of backprop on the batch of data instead of using a for loop;
  • optimizers: SGD+Momentum, SGD+Nesterov, Adam, RMSprop, Adagrad, Adadelta.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.