上一篇博客介绍了如何进行量化感知训练,并进行uint8前向推理,但是没有将BN层进行融合,这使得模型推理时的计算复杂度仍然有改进的空间,本篇博客讲述了如何进行BN融合的量化感知训练,并在训练完成后进行了uint8推理的模拟,以方便今后在FPGA上的部署。 代码:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
from torch.nn.parameter import Parameter
from torch.autograd import Function
# ********************* range_trackers(范围统计器,统计量化前范围) *********************
class RangeTracker(nn.Module):
def __init__(self, q_level):
super().__init__()
self.q_level = q_level
def update_range(self, min_val, max_val):
raise NotImplementedError
@torch.no_grad()
def forward(self, input):
if self.q_level == 'L': # A,min_max_shape=(1, 1, 1, 1),layer级
min_val = torch.min(input)
max_val = torch.max(input)
elif self.q_level == 'C': # W,min_max_shape=(N, 1, 1, 1),channel级
min_val = torch.min(torch.min(torch.min(input, 3, keepdim=True)[0], 2, keepdim=True)[0], 1, keepdim=True)[0]
max_val = torch.max(torch.max(torch.max(input, 3, keepdim=True)[0], 2, keepdim=True)[0], 1, keepdim=True)[0]
self.update_range(min_val, max_val)
class GlobalRangeTracker(RangeTracker): # W,min_max_shape=(N, 1, 1, 1),channel级,取本次和之前相比的min_max —— (N, C, W, H)
def __init__(self, q_level, out_channels):
super().__init__(q_level)
self.register_buffer('min_val', torch.zeros(out_channels, 1, 1, 1))
self.register_buffer('max_val', torch.zeros(out_channels, 1, 1, 1))
self.register_buffer('first_w', torch.zeros(1))
def update_range(self, min_val, max_val):
temp_minval = self.min_val
temp_maxval = self.max_val
if self.first_w == 0:
self.first_w.add_(1)
self.min_val.add_(min_val)
self.max_val.add_(max_val)
else:
self.min_val.add_(-temp_minval).add_(torch.min(temp_minval, min_val))
self.max_val.add_(-temp_maxval).add_(torch.max(temp_maxval, max_val))
class AveragedRangeTracker(RangeTracker): # A,min_max_shape=(1, 1, 1, 1),layer级,取running_min_max —— (N, C, W, H)
def __init__(self, q_level, momentum=0.1):
super().__init__(q_level)
self.momentum = momentum
self.register_buffer('min_val', torch.zeros(1))
self.register_buffer('max_val', torch.zeros(1))
self.register_buffer('first_a', torch.zeros(1))
def update_range(self, min_val, max_val):
if self.first_a == 0:
self.first_a.add_(1)
self.min_val.add_(min_val)
self.max_val.add_(max_val)
else:
self.min_val.mul_(1 - self.momentum).add_(min_val * self.momentum)
self.max_val.mul_(1 - self.momentum).add_(max_val * self.momentum)
# ********************* quantizers(量化器,量化) *********************
class Round(Function):
@staticmethod
def forward(self, input):
output = torch.round(input)
return output
@staticmethod
def backward(self, grad_output):
grad_input = grad_output.clone()
return grad_input
class Quantizer(nn.Module):
def __init__(self, bits, range_tracker):
super().__init__()
self.bits = bits
self.range_tracker = range_tracker
self.register_buffer('scale', None) # 量化比例因子
self.register_buffer('zero_point', None) # 量化零点
def update_params(self):
raise NotImplementedError
# 量化
def quantize(self, input):
output = input * self.scale - self.zero_point
return output
def round(self, input):
output = Round.apply(input)
return output
# 截断
def clamp(self, input):
output = torch.clamp(input, self.min_val, self.max_val)
return output
# 反量化
def dequantize(self, input):
output = (input + self.zero_point) / self.scale
return output
def forward(self, input):
if self.bits == 32:
output = input
elif self.bits == 1:
print('!Binary quantization is not supported !')
assert self.bits != 1
else:
self.range_tracker(input)
self.update_params()
output = self.quantize(input) # 量化
output = self.round(output)
output = self.clamp(output) # 截断
output = self.dequantize(output) # 反量化
return output
class SignedQuantizer(Quantizer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.register_buffer('min_val', torch.tensor(-(1
关注
打赏
最近更新
- 深拷贝和浅拷贝的区别(重点)
- 【Vue】走进Vue框架世界
- 【云服务器】项目部署—搭建网站—vue电商后台管理系统
- 【React介绍】 一文带你深入React
- 【React】React组件实例的三大属性之state,props,refs(你学废了吗)
- 【脚手架VueCLI】从零开始,创建一个VUE项目
- 【React】深入理解React组件生命周期----图文详解(含代码)
- 【React】DOM的Diffing算法是什么?以及DOM中key的作用----经典面试题
- 【React】1_使用React脚手架创建项目步骤--------详解(含项目结构说明)
- 【React】2_如何使用react脚手架写一个简单的页面?