GithubHelp home page GithubHelp logo

rrmina / eureka Goto Github PK

View Code? Open in Web Editor NEW
3.0 2.0 1.0 2.17 MB

Deep Learning Framework from scratch --- translating my Aha! moments into codes --- :zap: :bulb: :high_brightness:

Python 100.00%
neural-network deep-learning deep-learning-framework mnist cifar

eureka's Introduction

Eureka : A simple Neural Network Framework written in Numpy โšก ๐Ÿ’ก ๐Ÿ”†

Clean Interface!!!

Loading Datasets in-house

import datasets.mnist

train_x, train_y = datasets.mnist.load_dataset(download=True, train=True)
test_x, test_y = datasets.mnist.load_dataset(download=True, train=False)

Dataloader and Minibatch Maker

from utils import dataloader

trainloader = dataloader(x, y, batch_size=64, shuffle=True)

Defining Model Architecture, Optimizer, and Criterion/Loss Function

import eureka.nn as nn
import eureka.optim as optim
import eureka.losses as losses

# MNIST Dense network with 1-hidden layer of 256 neurons,
# a BatchNorm after activation with learnable parameters,
# and a Dropout layer with 0.5 probability of dropping neurons
model = nn.Sequential([
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.BatchNorm1d(256, affine=True),
    nn.Dropout(0.2),
    nn.Linear(256, 10),
    nn.Softmax()
])

# Adam Optimizer
optimizer = optim.Adam(model, lr=0.0002)

# Define the criterion/loss function
criterion = losses.CrossEntropyLoss()

Forward and Backpropagation

for inputs, labels in trainloader:
    # Forward Propagation and Compute loss
    out = model.forward(inputs)
    m = inputs.shape[0]
    batch_loss += criterion(out, labels)

    # Compute Loss and Model Gradients
    back_var = criterion.backward()
    model.backward(labels)

    # Backward Prop using Optimizer step
    optimizer.step()

Example: MNIST Classification

import numpy as np
from eureka.utils import one_hot_encoder, dataloader
import eureka.losses as losses
import eureka.optim as optim
import eureka.nn as nn
import datasets.mnist

# Load dataset and Preprocess
train_x, train_y = datasets.mnist.load_dataset(download=True, train=True)
x = train_x.reshape(train_x.shape[0], -1)
y = one_hot_encoder(train_y)
num_samples = x.shape[0]

# Prepare the dataloader
trainloader = dataloader(x, y, batch_size=64, shuffle=True)

# Define model architecture, Optimizer, and Criterion/Loss Function
model = nn.Sequential([
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.BatchNorm1d(256, affine=False),
    nn.Dropout(0.2),
    nn.Linear(256, 10),
    nn.Softmax()
])
optimizer = optim.Adam(model, lr=0.0002)
criterion = losses.CrossEntropyLoss()

# Train loop
num_epochs = 20
for epoch in range(1, num_epochs+1):
    print("Epoch: {}/{}\n==========".format(epoch, num_epochs))
    acc = 0
    batch_loss = 0
    for inputs, labels in trainloader:
        # Forward Propagation and Compute loss
        out = model.forward(inputs)
        m = inputs.shape[0]
        batch_loss += criterion(out, labels)

        # Compute Accuracy
        pred = np.argmax(out, axis=1).reshape(-1, 1)
        acc += np.sum(pred == labels.argmax(axis=1).reshape(-1,1))
        
        # Compute Loss and Model Gradients
        dloss_over_dout = criterion.backward()
        model.backward(dloss_over_dout)

        # Backward Prop using Optimizer step
        optimizer.step()
    
    # Print Loss and Accuracy
    print("Loss: {:.6f}".format(batch_loss/num_samples)) 
    print("Accuracy: {:.2f}%\n".format(acc/num_samples*100)) 

eureka's People

Contributors

rrmina avatar

Stargazers

 avatar  avatar  avatar

Watchers

 avatar  avatar

Forkers

mildlyautistic

eureka's Issues

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.