GithubHelp home page GithubHelp logo

Question about training about brevitas HOT 3 CLOSED

xilinx avatar xilinx commented on May 23, 2024
Question about training

from brevitas.

Comments (3)

volcacius avatar volcacius commented on May 23, 2024

Hello,
You have a typo there, it should be out = self.relu1(self.fc1(out)) not out = self.relu1(self.fc1(x)).

from brevitas.

samsterckval avatar samsterckval commented on May 23, 2024

Ok, that was a stupid mistake, sorry for that.

But after that, training still isn't working.
I changed a few things, and now it seems to be training, but loss is always 2.303 .

I guess I am still making a stupid mistake, but can't seem to find it.

An example notebook for training a simple quantised network would be great!

Here's the notebook I have at this point

Versions are PyTorch : 1.4.0, torchvision : 0.5.0

#!/usr/bin/env python
# coding: utf-8

# In[29]:


import torch
import torchvision
import brevitas


# In[30]:


print(f"pytorch version : {torch.__version__}")
print(f"torchvision version : {torchvision.__version__}")


# In[17]:


from torch import nn
from torch.nn import Module
import torch.nn.functional as F
import brevitas.nn as qnn
from brevitas.core.quant import QuantType
import torchvision
import torchvision.transforms as transforms


# In[18]:


class QuantNet(Module):
    def __init__(self):
        super(QuantNet, self).__init__()
        self.fc1   = qnn.QuantLinear(3072, 2048, bias=False, 
                                     weight_quant_type=QuantType.INT, 
                                     weight_bit_width=8)
        self.relu1 = qnn.QuantReLU(quant_type=QuantType.INT, bit_width=4, max_val=6)
        
        self.fc2   = qnn.QuantLinear(2048, 1024, bias=False, 
                                     weight_quant_type=QuantType.INT, 
                                     weight_bit_width=2)
        self.relu2 = qnn.QuantReLU(quant_type=QuantType.INT, bit_width=4, max_val=6)
        
        self.fc3   = qnn.QuantLinear(1024, 512, bias=False, 
                                     weight_quant_type=QuantType.INT, 
                                     weight_bit_width=2)
        self.relu3 = qnn.QuantReLU(quant_type=QuantType.INT, bit_width=4, max_val=6)
        
        self.fc4   = qnn.QuantLinear(512, 256, bias=False, 
                                     weight_quant_type=QuantType.INT, 
                                     weight_bit_width=2)
        self.relu4 = qnn.QuantReLU(quant_type=QuantType.INT, bit_width=4, max_val=6)
        
        self.fc5   = qnn.QuantLinear(256, 10, bias=False, 
                                     weight_quant_type=QuantType.INT, 
                                     weight_bit_width=8)

    def forward(self, x):
        out = x.view(-1, 32*32*3)
        out = self.relu1(self.fc1(out))
        out = self.relu2(self.fc2(out))
        out = self.relu3(self.fc3(out))
        out = self.relu4(self.fc4(out))
        out = self.fc5(out)
        return out


# In[19]:


class NotQuantNet(nn.Module):
    def __init__(self):
        super(NotQuantNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


# In[20]:


import numpy as np
import PIL

def transding(pic):
    return np.array(pic, dtype=np.uint8)


# In[ ]:





# In[21]:


transform = transforms.Compose([transforms.ToTensor()])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transding)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1,
                                          shuffle=True, num_workers=8)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transding)
testloader = torch.utils.data.DataLoader(testset, batch_size=1,
                                         shuffle=False, num_workers=8)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


# In[22]:


import matplotlib.pyplot as plt
import numpy as np


# In[23]:


def imshow(img):
#     img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
#     plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.imshow(npimg)
    plt.show()


# In[24]:


# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()


# In[25]:


# show images
imshow(torchvision.utils.make_grid(images, nrow=len(images)))
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(1)))


# In[26]:


net = QuantNet()
# net = NotQuantNet()


# In[27]:


import torch.optim as optim
from torch import nn

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9)


# In[28]:


for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        
        inputs = inputs.float()
        
        # zero the parameter gradients
        optimizer.zero_grad()
        
        # forward + backward + optimize
        outputs = net(inputs)
        
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 10 == 9:    # print every 10 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 10))
            running_loss = 0.0

print('Finished Training')

from brevitas.

volcacius avatar volcacius commented on May 23, 2024

I appreciate the suggestion about examples, it's something that I'm working on for some scenarios. In general though it boils down to choice of hypeparameters, and there isn't a single recipe for that. What Brevitas does is just to provide you with a set of building blocks.
I would first try to get a floating point version of the topology you are quantizing to train, and then start from there. For 2 bit weights then, you might try lowering the learning rate a bit.

Alessandro

from brevitas.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.