您当前的位置: 首页 > 

基于KL散度的INT8训练后量化

发布时间:2020-11-08 12:58:32 ,浏览量:7

我们知道,P,Q两列数据的相对熵越小,那么P,Q分布越接近,用Q近似P损失的信息就少,英伟达的INT8量化就是基于这个原理,如图是英伟达int8量化的算法伪代码 在这里插入图片描述

下面是根据相对熵来选取最佳阈值的代码。

import numpy as np import copy def compute_kl_divergence(P,Q): length=len(P) sum=0.0 for i in range(length): if P[i]!=0: if Q[i]==0: sum+=1 else: sum+=P[i]*np.log(P[i]/Q[i]) return sum def threshold_distribution(distribution,target_bin): target_threshold = target_bin
    min_kl_divergence = 10000000000000 length = len(distribution) for threshold in range(target_bin,length): #t_distribution=np.empty((threshold,)) t_distribution=copy.deepcopy(distribution[0:threshold]) t_distribution[threshold - 1] += np.sum(distribution[threshold:]) #get P num_per_bin = threshold / target_bin

        quantize_distribution = np.zeros((target_bin,)) for i in range(target_bin): start = i * num_per_bin
            end = start + num_per_bin

            left_upper = int(np.ceil(start)) if left_upper > start: left_scale = left_upper - start
                quantize_distribution[i] += left_scale * distribution[left_upper - 1] right_lower = int(np.floor(end)) if right_lower < end: right_scale = end - right_lower
                quantize_distribution[i] += right_scale * distribution[right_lower] for j in range(left_upper,right_lower): quantize_distribution[i] += distribution[j] # get Q expand_distribution=np.zeros_like(t_distribution) for i in range(target_bin): start = i * num_per_bin
            end = start + num_per_bin

            count = 0 left_upper = int(np.ceil(start)) left_scale = 0 if left_upper > start: left_scale = left_upper - start if t_distribution[left_upper - 1] != 0: count += left_scale

            right_lower = int(np.floor(end)) right_scale = 0 if right_lower < end: right_scale = end - right_lower if t_distribution[right_lower] != 0: count += right_scale for j in range(left_upper,right_lower): if t_distribution[j] != 0: count+=1 expand_value = quantize_distribution[i] / count if left_upper > start: if t_distribution[left_upper - 1] != 0: expand_distribution[left_upper - 1] += expand_value * left_scale if right_lower < end: if t_distribution[right_lower] != 0: expand_distribution[right_lower] += expand_value * right_scale for j in range(left_upper,right_lower): if t_distribution[j] != 0: expand_distribution[j] += expand_value

        kl_divergence = compute_kl_divergence(t_distribution, expand_distribution) #print(threshold,kl_divergence) if kl_divergence < min_kl_divergence: min_kl_divergence = kl_divergence
            target_threshold = threshold return target_threshold if __name__=='__main__': distribution=np.empty((2048,)) for i in range(len(distribution)): distribution[i]=i
    distribution/=np.sum(distribution) target_threshold=threshold_distribution(distribution,128) print(target_threshold) 
关注
打赏
1688896170
查看更多评论

暂无认证

  • 7浏览

    0关注

    115984博文

    0收益

  • 0浏览

    0点赞

    0打赏

    0留言

私信
关注
热门博文
立即登录/注册

微信扫码登录

0.0768s