import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import distributed
from torch.nn import init
from torch.nn.parameter import Parameter
from torch.autograd import Function
# ********************* range_trackers(范围统计器,统计量化前范围) *********************
class RangeTracker(nn.Module):
def __init__(self, q_level):
super().__init__()
self.q_level = q_level
def update_range(self, min_val, max_val):
raise NotImplementedError
@torch.no_grad()
def forward(self, input):
if self.q_level == 'L': # A,min_max_shape=(1, 1, 1, 1),layer级
min_val = torch.min(input)
max_val = torch.max(input)
elif self.q_level == 'C': # W,min_max_shape=(N, 1, 1, 1),channel级
min_val = torch.min(torch.min(torch.min(input, 3, keepdim=True)[0], 2, keepdim=True)[0], 1, keepdim=True)[0]
max_val = torch.max(torch.max(torch.max(input, 3, keepdim=True)[0], 2, keepdim=True)[0], 1, keepdim=True)[0]
self.update_range(min_val, max_val)
class GlobalRangeTracker(RangeTracker): # W,min_max_shape=(N, 1, 1, 1),channel级,取本次和之前相比的min_max —— (N, C, W, H)
def __init__(self, q_level, out_channels):
super().__init__(q_level)
self.register_buffer('min_val', torch.zeros(out_channels, 1, 1, 1))
self.register_buffer('max_val', torch.zeros(out_channels, 1, 1, 1))
self.register_buffer('first_w', torch.zeros(1))
def update_range(self, min_val, max_val):
temp_minval = self.min_val
temp_maxval = self.max_val
if self.first_w == 0:
self.first_w.add_(1)
self.min_val.add_(min_val)
self.max_val.add_(max_val)
else:
self.min_val.add_(-temp_minval).add_(torch.min(temp_minval, min_val))
self.max_val.add_(-temp_maxval).add_(torch.max(temp_maxval, max_val))
class AveragedRangeTracker(RangeTracker): # A,min_max_shape=(1, 1, 1, 1),layer级,取running_min_max —— (N, C, W, H)
def __init__(self, q_level, momentum=0.1):
super().__init__(q_level)
self.momentum = momentum
self.register_buffer('min_val', torch.zeros(1))
self.register_buffer('max_val', torch.zeros(1))
self.register_buffer('first_a', torch.zeros(1))
def update_range(self, min_val, max_val):
if self.first_a == 0:
self.first_a.add_(1)
self.min_val.add_(min_val)
self.max_val.add_(max_val)
else:
self.min_val.mul_(1 - self.momentum).add_(min_val * self.momentum)
self.max_val.mul_(1 - self.momentum).add_(max_val * self.momentum)
# ********************* quantizers(量化器,量化) *********************
class Round(Function):
@staticmethod
def forward(self, input):
output = torch.round(input)
return output
@staticmethod
def backward(self, grad_output):
grad_input = grad_output.clone()
return grad_input
class Quantizer(nn.Module):
def __init__(self, bits, range_tracker):
super().__init__()
self.bits = bits
self.range_tracker = range_tracker
self.register_buffer('scale', None) # 量化比例因子
self.register_buffer('zero_point', None) # 量化零点
def update_params(self):
raise NotImplementedError
# 量化
def quantize(self, input):
output = input * self.scale - self.zero_point
return output
def round(self, input):
output = Round.apply(input)
return output
# 截断
def clamp(self, input):
output = torch.clamp(input, self.min_val, self.max_val)
return output
# 反量化
def dequantize(self, input):
output = (input + self.zero_point) / self.scale
return output
def forward(self, input):
if self.bits == 32:
output = input
elif self.bits == 1:
print('!Binary quantization is not supported !')
assert self.bits != 1
else:
self.range_tracker(input)
self.update_params()
output = self.quantize(input) # 量化
output = self.round(output)
output = self.clamp(output) # 截断
output = self.dequantize(output)# 反量化
return output
class SignedQuantizer(Quantizer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.register_buffer('min_val', torch.tensor(-(1
1658642721
查看更多评论