LEARNED STEP SIZE QUANTIZATION論文復現
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
class Round(Function):
@staticmethod
def forward(self, input):
output = torch.round(input)
return output
@staticmethod
def backward(self, grad_output):
grad_input = grad_output.clone()
return grad_input
def quant(x, scale):
return Round.apply(torch.clamp(x / scale, -128, 127))
def dequant(x, scale):
return x * scale
# ********************* 量化卷積(同時量化A/W,並做卷積) *********************
class Conv2d_Q(nn.Conv2d):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
first_layer=0,
):
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias
)
self.weight_scale = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True)
self.activation_scale = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True)
self.fist_batch = 0
self.first_layer = first_layer
def forward(self, input):
if self.fist_batch == 0:
self.activation_scale = torch.nn.parameter.Parameter(2 * torch.mean(torch.abs(input)) / torch.sqrt(torch.tensor(127.0)))
self.weight_scale = torch.nn.parameter.Parameter(2 * torch.mean(torch.abs(input)) / torch.sqrt(torch.tensor(127.0)))
self.fist_batch = 1
# 量化A和W
if not self.first_layer:
input = dequant(quant(input, self.activation_scale), self.activation_scale)
q_input = input
q_weight = dequant(quant(self.weight, self.weight_scale), self.weight_scale)
# 量化卷積
output = F.conv2d(
input=q_input,
weight=q_weight,
bias=self.bias,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups
)
return output
class QuanConv2d(nn.Module):
def __init__(self, input_channels, output_channels,
kernel_size=-1, stride=-1, padding=-1, groups=1, last_relu=0, first_layer=0):
super(QuanConv2d, self).__init__()
self.last_relu = last_relu
self.first_layer = first_layer
self.q_conv = Conv2d_Q(input_channels, output_channels,
kernel_size=kernel_size, stride=stride, padding=padding, groups=groups,
first_layer=first_layer)
self.bn = nn.BatchNorm2d(output_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
if not self.first_layer:
x = self.relu(x)
x = self.q_conv(x)
x = self.bn(x)
if self.last_relu:
x = self.relu(x)
return x
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.quan_model = nn.Sequential(
QuanConv2d(1, 8, kernel_size=3, stride=1, padding=1, first_layer=1),
nn.MaxPool2d(kernel_size=2, stride=2),
QuanConv2d(8, 16, kernel_size=3, stride=1, padding=1),
QuanConv2d(16, 32, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(kernel_size=2, stride=2),
QuanConv2d(32, 10, kernel_size=3, stride=1, padding=1, last_relu=1),
nn.AvgPool2d(kernel_size=7, stride=1, padding=0),
)
def forward(self, x):
x = self.quan_model(x)
x = x.view(x.size(0), -1)
return x
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torchvision
import torchvision.transforms as transforms
device = torch.device('cuda:0')
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True
def train(epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
#data, target = data.to(device), target.to(device)
data, target = Variable(data), Variable(target)
output = model(data)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tLR: {}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.data.item(),
optimizer.param_groups[0]['lr']))
return
def test():
model.eval()
test_loss = 0
correct = 0
for data, target in test_loader:
#data, target = data.to(device), target.to(device)
data, target = Variable(data), Variable(target)
output = model(data)
test_loss += criterion(output, target).data.item()
pred = output.data.max(1, keepdim=True)[1]
correct += pred.eq(target.data.view_as(pred)).cpu().sum()
acc = 100. * float(correct) / len(test_loader.dataset)
print('acc is {}'.format(acc))
if __name__ == '__main__':
setup_seed(int(time.time()))
print('==> Preparing data..')
train_dataset = torchvision.datasets.MNIST(root='../../data', train=True, transform=transforms.ToTensor(),
download=True)
test_dataset = torchvision.datasets.MNIST(root='../../data', train=False, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=128, shuffle=False)
print('******Initializing model******')
model = Net()
#model.to(device)
for m in model.modules():
if isinstance(m, nn.Conv2d):
nn.init.xavier_uniform_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()
criterion = nn.CrossEntropyLoss()
base_lr = float(0.01)
param_dict = dict(model.named_parameters())
params = []
for key, value in param_dict.items():
if key=='quan_model.0.q_conv.weight_scale':
g=1/torch.sqrt(torch.tensor(127.0*8*1*3*3))
params += [{'params': [value], 'lr': base_lr*g, 'weight_decay': 0.00005}]
elif key=='quan_model.0.q_conv.activation_scale':
g=1/torch.sqrt(torch.tensor(127.0*128*1*28*28))
params += [{'params': [value], 'lr': base_lr*g, 'weight_decay': 0.00005}]
elif key=='quan_model.2.q_conv.weight_scale':
g = 1 / torch.sqrt(torch.tensor(127.0*16*8*3*3))
params += [{'params': [value], 'lr': base_lr*g, 'weight_decay': 0.00005}]
elif key=='quan_model.2.q_conv.activation_scale':
g = 1 / torch.sqrt(torch.tensor(127.0*8*14*14))
params += [{'params': [value], 'lr': base_lr, 'weight_decay': 0.00005}]
elif key=='quan_model.3.q_conv.weight_scale':
g = 1 / torch.sqrt(torch.tensor(127.0*32*16*3*3))
params += [{'params': [value], 'lr': base_lr*g, 'weight_decay': 0.00005}]
elif key=='quan_model.3.q_conv.activation_scale':
g = 1 / torch.sqrt(torch.tensor(127.0*16*14*14))
params += [{'params': [value], 'lr': base_lr*g, 'weight_decay': 0.00005}]
elif key=='quan_model.5.q_conv.weight_scale':
g = 1 / torch.sqrt(torch.tensor(127.0*10*32*3*3))
params += [{'params': [value], 'lr': base_lr*g, 'weight_decay': 0.00005}]
elif key=='quan_model.5.q_conv.activation_scale':
g = 1 / torch.sqrt(torch.tensor(127.0*32*7*7))
params += [{'params': [value], 'lr': base_lr*g, 'weight_decay': 0.00005}]
else:
params += [{'params': [value], 'lr': base_lr, 'weight_decay': 0.00005}]
optimizer = optim.Adam(params, lr=base_lr, weight_decay=0.00005)
for epoch in range(1, 10):
train(epoch)
test()
for name,p in model.named_parameters():
print(name)
其中scale的初始化按照下圖進行
而scale的梯度分別乘以下圖的g以使得模型收斂更快,精度更高
相關文章
- Split to Be Slim: 論文復現
- R-Drop論文復現與理論講解
- ICML 2017大熱論文:Wasserstein GAN | 經典論文復現
- 【論文考古】量化SGD QSGD: Communication-Efficient SGD via Gradient Quantization and EncodingEncoding
- Squarified Treemaps 論文演算法復現演算法
- 論文Anonymous Zether實驗復現(持續更)
- 論文復現|Panoptic Deeplab(全景分割PyTorch)PyTorch
- Perceptual Losses 風格遷移論文復現小記
- 實踐案例丨CenterNet-Hourglass論文復現
- Product Quantization
- 一文詳解ATK Loss論文復現與程式碼實戰
- 論文復現丨基於ModelArts實現Text2SQLSQL
- FCOS論文復現:通用物體檢測演算法演算法
- Promise的實現(step by step)Promise
- 論文閱讀筆記:A Two-Step Approach for Event Factuality Identification筆記APPIDE
- Color quantization with PythonPython
- 自監督影像論文復現 | BYOL(pytorch)| 2020PyTorch
- InfoGAN:一種無監督生成方法 | 經典論文復現
- 小白經典CNN論文復現系列(一):LeNet1989CNN
- 透過RMAN進行資料庫恢復(step by step)資料庫
- 論文結果難復現?教你實現深度強化學習演算法DQN強化學習演算法
- java實現論文查重Java
- 硬幣系列三 | 硬幣自動分類的一個論文復現
- 手把手帶你復現ICCV 2017經典論文—PyraNet
- 論文復現丨基於ModelArts進行影像風格化繪畫
- 微信團隊開源圍棋AI技術PhoenixGo,復現AlphaGo Zero論文AIGo
- 【論文】軍事理論課程論文
- 實踐心得:從讀論文到復現到為開源貢獻程式碼
- 論文
- 論文結果難復現?本文教你完美實現深度強化學習演算法DQN強化學習演算法
- 想輕鬆復現深度強化學習論文?看這篇經驗之談強化學習
- 神經網路常見引數解釋:epoch、batch、batch size、step、iteration神經網路BAT
- Elasticsearch Mantanence Lessons Learned TodayElasticsearch
- React Step by StepReact
- 利用Label Security實現行級安全性 Step By Step
- 基於標註策略的實體和關係聯合抽取 | 經典論文復現
- 經典論文復現 | 基於標註策略的實體和關係聯合抽取
- 趕論文