Comments (3)
Hello,
You have a typo there, it should be out = self.relu1(self.fc1(out))
not out = self.relu1(self.fc1(x))
.
from brevitas.
Ok, that was a stupid mistake, sorry for that.
But after that, training still isn't working.
I changed a few things, and now it seems to be training, but loss is always 2.303 .
I guess I am still making a stupid mistake, but can't seem to find it.
An example notebook for training a simple quantised network would be great!
Here's the notebook I have at this point
Versions are PyTorch : 1.4.0, torchvision : 0.5.0
#!/usr/bin/env python
# coding: utf-8
# In[29]:
import torch
import torchvision
import brevitas
# In[30]:
print(f"pytorch version : {torch.__version__}")
print(f"torchvision version : {torchvision.__version__}")
# In[17]:
from torch import nn
from torch.nn import Module
import torch.nn.functional as F
import brevitas.nn as qnn
from brevitas.core.quant import QuantType
import torchvision
import torchvision.transforms as transforms
# In[18]:
class QuantNet(Module):
def __init__(self):
super(QuantNet, self).__init__()
self.fc1 = qnn.QuantLinear(3072, 2048, bias=False,
weight_quant_type=QuantType.INT,
weight_bit_width=8)
self.relu1 = qnn.QuantReLU(quant_type=QuantType.INT, bit_width=4, max_val=6)
self.fc2 = qnn.QuantLinear(2048, 1024, bias=False,
weight_quant_type=QuantType.INT,
weight_bit_width=2)
self.relu2 = qnn.QuantReLU(quant_type=QuantType.INT, bit_width=4, max_val=6)
self.fc3 = qnn.QuantLinear(1024, 512, bias=False,
weight_quant_type=QuantType.INT,
weight_bit_width=2)
self.relu3 = qnn.QuantReLU(quant_type=QuantType.INT, bit_width=4, max_val=6)
self.fc4 = qnn.QuantLinear(512, 256, bias=False,
weight_quant_type=QuantType.INT,
weight_bit_width=2)
self.relu4 = qnn.QuantReLU(quant_type=QuantType.INT, bit_width=4, max_val=6)
self.fc5 = qnn.QuantLinear(256, 10, bias=False,
weight_quant_type=QuantType.INT,
weight_bit_width=8)
def forward(self, x):
out = x.view(-1, 32*32*3)
out = self.relu1(self.fc1(out))
out = self.relu2(self.fc2(out))
out = self.relu3(self.fc3(out))
out = self.relu4(self.fc4(out))
out = self.fc5(out)
return out
# In[19]:
class NotQuantNet(nn.Module):
def __init__(self):
super(NotQuantNet, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# In[20]:
import numpy as np
import PIL
def transding(pic):
return np.array(pic, dtype=np.uint8)
# In[ ]:
# In[21]:
transform = transforms.Compose([transforms.ToTensor()])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transding)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1,
shuffle=True, num_workers=8)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transding)
testloader = torch.utils.data.DataLoader(testset, batch_size=1,
shuffle=False, num_workers=8)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# In[22]:
import matplotlib.pyplot as plt
import numpy as np
# In[23]:
def imshow(img):
# img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
# plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.imshow(npimg)
plt.show()
# In[24]:
# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()
# In[25]:
# show images
imshow(torchvision.utils.make_grid(images, nrow=len(images)))
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(1)))
# In[26]:
net = QuantNet()
# net = NotQuantNet()
# In[27]:
import torch.optim as optim
from torch import nn
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9)
# In[28]:
for epoch in range(2): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
inputs = inputs.float()
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i % 10 == 9: # print every 10 mini-batches
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 10))
running_loss = 0.0
print('Finished Training')
from brevitas.
I appreciate the suggestion about examples, it's something that I'm working on for some scenarios. In general though it boils down to choice of hypeparameters, and there isn't a single recipe for that. What Brevitas does is just to provide you with a set of building blocks.
I would first try to get a floating point version of the topology you are quantizing to train, and then start from there. For 2 bit weights then, you might try lowering the learning rate a bit.
Alessandro
from brevitas.
Related Issues (20)
- Remove QuantDropout module HOT 2
- Remove QuantMaxPool
- Evaluate deprecatation of quant_accumulator.py
- How QuantHardTanh works
- Fix QCDQDecoupledWeightQuantProxyHandlerMixin return args
- Guidance for QAT
- cannot import name 'activation_equalization_mode' from 'brevitas.graph.equalize' HOT 6
- QuantMultiheadAttention: Use signed quantizer for attention weights? HOT 1
- QuantMultiheadAttention: Transpose keys after quantizer?
- Bias Correction with DDP
- Value Tracer __setslice__
- Move create quant maps functions from ptq to quantize_impl
- Brevitas `make_fx` generating different graph HOT 4
- Learned Round + FX quantization
- Control Overflow mode and Quantization mode
- `QuantTensor`'s `__truediv__` always results in a `NaN` zero-point when both inputs have a 0 zero point HOT 1
- Fix gptq activation quantization error propagation
- SymbolicValueError During 4-bit Quantized CNN ONNX Export with Brevitas HOT 3
- `AssertionError` when combining `BREVITAS_JIT=1` and `torch.compile` under PyTorch `v2.0.1`
- Hello, how can I use frames for non-uniform quantization? This is because I found zero_point to be 0 in my code HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from brevitas.