您当前的位置: 首页 > 

FPGA硅农

暂无认证

  • 1浏览

    0关注

    282博文

    0收益

  • 0浏览

    0点赞

    0打赏

    0留言

私信
关注
热门博文

BNN训练MNIST数据集

FPGA硅农 发布时间:2020-10-13 16:19:31 ,浏览量:1

使用BNN对mnist数据进行训练,训练结束后,提取模型参数,并模拟推断过程,这里W没有乘以缩放因子。从四个print语句可以看到,BWc1、BWc2、BWc3和BWc4是二值化后的权重矩阵,激活经过sign函数后,便和二值化的W进行卷积计算,然后加上浮点型的偏置bias,得到二值化卷积的输出。代码如下:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import torch.nn as nn
import torch
import torch.nn.functional as F
from torch.autograd import Function

# ********************* 二值(+-1) ***********************
# A
class Binary_a(Function):

    @staticmethod
    def forward(self, input):
        self.save_for_backward(input)
        output = torch.sign(input)
        return output

    @staticmethod
    def backward(self, grad_output):
        input, = self.saved_tensors
        #*******************ste*********************
        grad_input = grad_output.clone()
        #****************saturate_ste***************
        grad_input[input.ge(1)] = 0
        grad_input[input.le(-1)] = 0
        return grad_input
# W
class Binary_w(Function):

    @staticmethod
    def forward(self, input):
        output = torch.sign(input)
        return output

    @staticmethod
    def backward(self, grad_output):
        #*******************ste*********************
        grad_input = grad_output.clone()
        return grad_input
# ********************* 三值(+-1、0) ***********************
class Ternary(Function):

    @staticmethod
    def forward(self, input):
        # **************** channel级 - E(|W|) ****************
        E = torch.mean(torch.abs(input), (3, 2, 1), keepdim=True)
        # **************** 阈值 ****************
        threshold = E * 0.7
        # ************** W —— +-1、0 **************
        output = torch.sign(torch.add(torch.sign(torch.add(input, threshold)),torch.sign(torch.add(input, -threshold))))
        return output, threshold

    @staticmethod
    def backward(self, grad_output, grad_threshold):
        #*******************ste*********************
        grad_input = grad_output.clone()
        return grad_input

# ********************* A(特征)量化(二值) ***********************
class activation_bin(nn.Module):
  def __init__(self, A):
    super().__init__()
    self.A = A
    self.relu = nn.ReLU(inplace=True)

  def binary(self, input):
    output = Binary_a.apply(input)
    return output

  def forward(self, input):
    if self.A == 2:
      output = self.binary(input)
      # ******************** A —— 1、0 *********************
      #a = torch.clamp(a, min=0)
    else:
      output = self.relu(input)
    return output
# ********************* W(模型参数)量化(三/二值) ***********************
def meancenter_clampConvParams(w):
    mean = w.data.mean(1, keepdim=True)
    w.data.sub(mean) # W中心化(C方向)
    w.data.clamp(-1.0, 1.0) # W截断
    return w
class weight_tnn_bin(nn.Module):
  def __init__(self, W):
    super().__init__()
    self.W = W

  def binary(self, input):
    output = Binary_w.apply(input)
    return output

  def ternary(self, input):
    output = Ternary.apply(input)
    return output

  def forward(self, input):
    if self.W == 2 or self.W == 3:
        # **************************************** W二值 *****************************************
        if self.W == 2:
            output = meancenter_clampConvParams(input) # W中心化+截断
            # **************** channel级 - E(|W|) ****************
            E = torch.mean(torch.abs(output), (3, 2, 1), keepdim=True)
            # **************** α(缩放因子) ****************
            alpha = E
            # ************** W —— +-1 **************
            output = self.binary(output)
            # ************** W * α **************
            #output = output * alpha # 若不需要α(缩放因子),注释掉即可
            # **************************************** W三值 *****************************************
        elif self.W == 3:
            output_fp = input.clone()
            # ************** W —— +-1、0 **************
            output, threshold = self.ternary(input)
            # **************** α(缩放因子) ****************
            output_abs = torch.abs(output_fp)
            mask_le = output_abs.le(threshold)
            mask_gt = output_abs.gt(threshold)
            output_abs[mask_le] = 0
            output_abs_th = output_abs.clone()
            output_abs_th_sum = torch.sum(output_abs_th, (3, 2, 1), keepdim=True)
            mask_gt_sum = torch.sum(mask_gt, (3, 2, 1), keepdim=True).float()
            alpha = output_abs_th_sum / mask_gt_sum # α(缩放因子)
            # *************** W * α ****************
            output = output * alpha # 若不需要α(缩放因子),注释掉即可
    else:
      output = input
    return output

# ********************* 量化卷积(同时量化A/W,并做卷积) ***********************
class Conv2d_Q(nn.Conv2d):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride=1,
        padding=0,
        dilation=1,
        groups=1,
        bias=True,
        A=2,
        W=2
      ):
        super().__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=bias
        )
        # 实例化调用A和W量化器
        self.activation_quantizer = activation_bin(A=A)
        self.weight_quantizer = weight_tnn_bin(W=W)
          
    def forward(self, input):
        # 量化A和W
        bin_input = self.activation_quantizer(input)
        tnn_bin_weight = self.weight_quantizer(self.weight)    
        #print(bin_input)
        #print(tnn_bin_weight)
        # 用量化后的A和W做卷积
        output = F.conv2d(
            input=bin_input, 
            weight=tnn_bin_weight, 
            bias=self.bias, 
            stride=self.stride, 
            padding=self.padding, 
            dilation=self.dilation, 
            groups=self.groups)
        return output

# *********************量化(三值、二值)卷积*********************
class Tnn_Bin_Conv2d(nn.Module):
    # 参数:last_relu-尾层卷积输入激活
    def __init__(self, input_channels, output_channels,
            kernel_size=-1, stride=-1, padding=-1, groups=1, last_relu=0, A=2, W=2):
        super(Tnn_Bin_Conv2d, self).__init__()
        self.A = A
        self.W = W
        self.last_relu = last_relu

        # ********************* 量化(三/二值)卷积 *********************
        self.tnn_bin_conv = Conv2d_Q(input_channels, output_channels,
                kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, A=A, W=W)
        self.bn = nn.BatchNorm2d(output_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.tnn_bin_conv(x)
        x = self.bn(x)
        if self.last_relu:
            x = self.relu(x)
        return x

class Net(nn.Module):
    def __init__(self, cfg = None, A=2, W=2):
        super(Net, self).__init__()
        # 模型结构与搭建
        if cfg is None:
            cfg = [16,32,64,10]
        self.tnn_bin = nn.Sequential(
                nn.Conv2d(1, cfg[0], kernel_size=5, stride=1,padding=2),
                nn.BatchNorm2d(cfg[0]),
                nn.MaxPool2d(kernel_size=2, stride=2),

                Tnn_Bin_Conv2d(cfg[0], cfg[1], kernel_size=5, stride=1,padding=2, A=A, W=W),
                Tnn_Bin_Conv2d(cfg[1], cfg[1], kernel_size=5, stride=1,padding=2, A=A, W=W),
                nn.MaxPool2d(kernel_size=2, stride=2),

                Tnn_Bin_Conv2d(cfg[1], cfg[2], kernel_size=5, stride=1,padding=2, A=A, W=W),
                Tnn_Bin_Conv2d(cfg[2], cfg[3], kernel_size=5, stride=1,padding=2, last_relu=1, A=A, W=W),
                nn.AvgPool2d(kernel_size=7, stride=1, padding=0),
                )

    def forward(self, x):
        x = self.tnn_bin(x)
        x = x.view(x.size(0), -1)
        return x

import numpy as np
import torch.optim as optim
from torch.autograd import Variable
import torchvision
import torchvision.transforms as transforms


device = torch.device('cuda:0')

# 随机种子——训练结果可复现
def setup_seed(seed):
    torch.manual_seed(seed)                                 
    torch.cuda.manual_seed_all(seed)           
    np.random.seed(seed)                       
    torch.backends.cudnn.deterministic = True

# 训练lr调整
def adjust_learning_rate(optimizer, epoch):
    update_list = [10,20,30,40,50]
    if epoch in update_list:
        for param_group in optimizer.param_groups:
            param_group['lr'] = param_group['lr'] * 0.2
    return

# 模型训练
def train(epoch):
    model.train()

    for batch_idx, (data, target) in enumerate(train_loader):
        # 前向传播
        data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        output = model(data)
        loss = criterion(output, target)

        # 反向传播
        optimizer.zero_grad()
        loss.backward() # 求梯度
        optimizer.step() # 参数更新

        # 显示训练集loss(/100个batch)
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tLR: {}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.data.item(),
                optimizer.param_groups[0]['lr']))
    return

# 模型测试
def test():
    global best_acc
    model.eval()
    test_loss = 0
    average_test_loss = 0
    correct = 0

    for data, target in test_loader:
        data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        # 前向传播
        output = model(data)
        test_loss += criterion(output, target).data.item()
        pred = output.data.max(1, keepdim=True)[1]
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()
    # 测试准确率
    acc = 100. * float(correct) / len(test_loader.dataset)

    print(acc)


if __name__=='__main__':
    setup_seed(1)#随机种子——训练结果可复现

    train_dataset = torchvision.datasets.MNIST(root='../../data', 
                                           train=True, 
                                           transform=transforms.ToTensor(),  
                                           download=True)

    test_dataset = torchvision.datasets.MNIST(root='../../data', 
                                          train=False, 
                                          transform=transforms.ToTensor())

    # Data loader
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=128, 
                                           shuffle=True)

    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                          batch_size=128, 
                                          shuffle=False)

    

    print('******Initializing model******')
    # ******************** 在model的量化卷积中同时量化A(特征)和W(模型参数) ************************
    model = Net(A=2, W=2)
    best_acc = 0
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            nn.init.xavier_uniform_(m.weight.data)
            m.bias.data.zero_()
        elif isinstance(m, nn.Linear):
            m.weight.data.normal_(0, 0.01)
            m.bias.data.zero_()
    

    # cpu、gpu
    model.to(device)
    # 打印模型结构
    print(model)

    # 损失函数
    criterion = nn.CrossEntropyLoss()
    # 优化器
    optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0)

    # 训练模型
    for epoch in range(1, 20):
        adjust_learning_rate(optimizer, epoch)
        train(epoch)
        test()

    param=model.state_dict()

    WeightBin=weight_tnn_bin(2)
    #print(WeightBin.forward(torch.from_numpy(BWc1)))
    
    #浮点卷积层
    Wc1=param['tnn_bin.0.weight']
    bc1=param['tnn_bin.0.bias']
    #BN层
    bn1_mean=param['tnn_bin.1.running_mean']
    bn1_var=param['tnn_bin.1.running_var']
    bn1_gamma=param['tnn_bin.1.weight']
    bn1_beta=param['tnn_bin.1.bias']
    #二值卷积层1,2
    BWc1=WeightBin.forward(param['tnn_bin.3.tnn_bin_conv.weight'])
    Bbc1=param['tnn_bin.3.tnn_bin_conv.bias']
    bn2_mean=param['tnn_bin.3.bn.running_mean']
    bn2_var=param['tnn_bin.3.bn.running_var']
    bn2_gamma=param['tnn_bin.3.bn.weight']
    bn2_beta=param['tnn_bin.3.bn.bias']

    BWc2=WeightBin.forward(param['tnn_bin.4.tnn_bin_conv.weight'])
    Bbc2=param['tnn_bin.4.tnn_bin_conv.bias']
    bn3_mean=param['tnn_bin.4.bn.running_mean']
    bn3_var=param['tnn_bin.4.bn.running_var']
    bn3_gamma=param['tnn_bin.4.bn.weight']
    bn3_beta=param['tnn_bin.4.bn.bias']
    #二值卷积层3,4
    BWc3=WeightBin.forward(param['tnn_bin.6.tnn_bin_conv.weight'])
    Bbc3=param['tnn_bin.6.tnn_bin_conv.bias']
    bn4_mean=param['tnn_bin.6.bn.running_mean']
    bn4_var=param['tnn_bin.6.bn.running_var']
    bn4_gamma=param['tnn_bin.6.bn.weight']
    bn4_beta=param['tnn_bin.6.bn.bias']

    BWc4=WeightBin.forward(param['tnn_bin.7.tnn_bin_conv.weight'])
    Bbc4=param['tnn_bin.7.tnn_bin_conv.bias']
    bn5_mean=param['tnn_bin.7.bn.running_mean']
    bn5_var=param['tnn_bin.7.bn.running_var']
    bn5_gamma=param['tnn_bin.7.bn.weight']
    bn5_beta=param['tnn_bin.7.bn.bias']

    print("BWc1")
    print(BWc1)
    print("BWc2")
    print(BWc2)
    print("BWc3")
    print(BWc3)
    print("BWc4")
    print(BWc4)
    
    correct=0
    for batch_idx, (data, target) in enumerate(train_loader):
      data,target=data.to(device),target.to(device)
      x=torch.nn.functional.conv2d(data, Wc1, bias=bc1, stride=1, padding=2)
      x=torch.nn.functional.batch_norm(x, running_mean=bn1_mean,running_var=bn1_var,weight=bn1_gamma,bias=bn1_beta)
      x=torch.nn.functional.max_pool2d(x,kernel_size=2,stride=2)
      
      x=torch.sign(x)
      x=torch.nn.functional.conv2d(x,BWc1,bias=Bbc1,stride=1,padding=2)
      x=torch.nn.functional.batch_norm(x, running_mean=bn2_mean,running_var=bn2_var,weight=bn2_gamma,bias=bn2_beta)
      x=torch.sign(x)
      x=torch.nn.functional.conv2d(x,BWc2,bias=Bbc2,stride=1,padding=2)
      x=torch.nn.functional.batch_norm(x, running_mean=bn3_mean,running_var=bn3_var,weight=bn3_gamma,bias=bn3_beta)
      x=torch.nn.functional.max_pool2d(x,kernel_size=2,stride=2)

      x=torch.sign(x)
      x=torch.nn.functional.conv2d(x,BWc3,bias=Bbc3,stride=1,padding=2)
      x=torch.nn.functional.batch_norm(x, running_mean=bn4_mean,running_var=bn4_var,weight=bn4_gamma,bias=bn4_beta)
      x=torch.sign(x)
      x=torch.nn.functional.conv2d(x,BWc4,bias=Bbc4,stride=1,padding=2)
      x=torch.nn.functional.batch_norm(x, running_mean=bn5_mean,running_var=bn5_var,weight=bn5_gamma,bias=bn5_beta)
      x=torch.nn.functional.avg_pool2d(x,kernel_size=7)

      output=torch.argmax(x,axis=1)
      for i in range(data.size(0)):
        if target[i]==output[i]:
          correct+=1
    print("Test accuracy is {}".format(correct/60000))

    



运行结果 在这里插入图片描述 下面是考虑缩放因子的情况,这时,二值化后的权重还需要乘以 α \alpha α,这就使得卷积时激活是二值的,但权重是浮点的,为了避免浮点运算,可以先和二值化后的权重卷积,然后结果再乘以 α \alpha α,即 W = W b ∗ α W=W_b*\alpha W=Wb​∗α y = X ∗ W + b = X ∗ ( α W b ) + b = α ( X ∗ W b ) + b y=X*W+b=X*(\alpha W_b)+b=\alpha (X*W_b)+b y=X∗W+b=X∗(αWb​)+b=α(X∗Wb​)+b 代码如下:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import torch.nn as nn
import torch
import torch.nn.functional as F
from torch.autograd import Function

# ********************* 二值(+-1) ***********************
# A
class Binary_a(Function):

    @staticmethod
    def forward(self, input):
        self.save_for_backward(input)
        output = torch.sign(input)
        return output

    @staticmethod
    def backward(self, grad_output):
        input, = self.saved_tensors
        #*******************ste*********************
        grad_input = grad_output.clone()
        #****************saturate_ste***************
        grad_input[input.ge(1)] = 0
        grad_input[input.le(-1)] = 0
        return grad_input
# W
class Binary_w(Function):

    @staticmethod
    def forward(self, input):
        output = torch.sign(input)
        return output

    @staticmethod
    def backward(self, grad_output):
        #*******************ste*********************
        grad_input = grad_output.clone()
        return grad_input
# ********************* 三值(+-1、0) ***********************
class Ternary(Function):

    @staticmethod
    def forward(self, input):
        # **************** channel级 - E(|W|) ****************
        E = torch.mean(torch.abs(input), (3, 2, 1), keepdim=True)
        # **************** 阈值 ****************
        threshold = E * 0.7
        # ************** W —— +-1、0 **************
        output = torch.sign(torch.add(torch.sign(torch.add(input, threshold)),torch.sign(torch.add(input, -threshold))))
        return output, threshold

    @staticmethod
    def backward(self, grad_output, grad_threshold):
        #*******************ste*********************
        grad_input = grad_output.clone()
        return grad_input

# ********************* A(特征)量化(二值) ***********************
class activation_bin(nn.Module):
  def __init__(self, A):
    super().__init__()
    self.A = A
    self.relu = nn.ReLU(inplace=True)

  def binary(self, input):
    output = Binary_a.apply(input)
    return output

  def forward(self, input):
    if self.A == 2:
      output = self.binary(input)
      # ******************** A —— 1、0 *********************
      #a = torch.clamp(a, min=0)
    else:
      output = self.relu(input)
    return output
# ********************* W(模型参数)量化(三/二值) ***********************
def meancenter_clampConvParams(w):
    mean = w.data.mean(1, keepdim=True)
    w.data.sub(mean) # W中心化(C方向)
    w.data.clamp(-1.0, 1.0) # W截断
    return w
class weight_tnn_bin(nn.Module):
  def __init__(self, W):
    super().__init__()
    self.W = W

  def binary(self, input):
    output = Binary_w.apply(input)
    return output

  def ternary(self, input):
    output = Ternary.apply(input)
    return output

  def forward(self, input):
    if self.W == 2 or self.W == 3:
        # **************************************** W二值 *****************************************
        if self.W == 2:
            output = meancenter_clampConvParams(input) # W中心化+截断
            # **************** channel级 - E(|W|) ****************
            E = torch.mean(torch.abs(output), (3, 2, 1), keepdim=True)
            # **************** α(缩放因子) ****************
            alpha = E
            # ************** W —— +-1 **************
            output = self.binary(output)
            # ************** W * α **************
            output = output * alpha # 若不需要α(缩放因子),注释掉即可
            # **************************************** W三值 *****************************************
        elif self.W == 3:
            output_fp = input.clone()
            # ************** W —— +-1、0 **************
            output, threshold = self.ternary(input)
            # **************** α(缩放因子) ****************
            output_abs = torch.abs(output_fp)
            mask_le = output_abs.le(threshold)
            mask_gt = output_abs.gt(threshold)
            output_abs[mask_le] = 0
            output_abs_th = output_abs.clone()
            output_abs_th_sum = torch.sum(output_abs_th, (3, 2, 1), keepdim=True)
            mask_gt_sum = torch.sum(mask_gt, (3, 2, 1), keepdim=True).float()
            alpha = output_abs_th_sum / mask_gt_sum # α(缩放因子)
            # *************** W * α ****************
            output = output * alpha # 若不需要α(缩放因子),注释掉即可
    else:
      output = input
    return output

# ********************* 量化卷积(同时量化A/W,并做卷积) ***********************
class Conv2d_Q(nn.Conv2d):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride=1,
        padding=0,
        dilation=1,
        groups=1,
        bias=True,
        A=2,
        W=2
      ):
        super().__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=bias
        )
        # 实例化调用A和W量化器
        self.activation_quantizer = activation_bin(A=A)
        self.weight_quantizer = weight_tnn_bin(W=W)
          
    def forward(self, input):
        # 量化A和W
        bin_input = self.activation_quantizer(input)
        tnn_bin_weight = self.weight_quantizer(self.weight)    
        #print(bin_input)
        #print(tnn_bin_weight)
        # 用量化后的A和W做卷积
        output = F.conv2d(
            input=bin_input, 
            weight=tnn_bin_weight, 
            bias=self.bias, 
            stride=self.stride, 
            padding=self.padding, 
            dilation=self.dilation, 
            groups=self.groups)
        return output

# *********************量化(三值、二值)卷积*********************
class Tnn_Bin_Conv2d(nn.Module):
    # 参数:last_relu-尾层卷积输入激活
    def __init__(self, input_channels, output_channels,
            kernel_size=-1, stride=-1, padding=-1, groups=1, last_relu=0, A=2, W=2):
        super(Tnn_Bin_Conv2d, self).__init__()
        self.A = A
        self.W = W
        self.last_relu = last_relu

        # ********************* 量化(三/二值)卷积 *********************
        self.tnn_bin_conv = Conv2d_Q(input_channels, output_channels,
                kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, A=A, W=W)
        self.bn = nn.BatchNorm2d(output_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.tnn_bin_conv(x)
        x = self.bn(x)
        if self.last_relu:
            x = self.relu(x)
        return x

class Net(nn.Module):
    def __init__(self, cfg = None, A=2, W=2):
        super(Net, self).__init__()
        # 模型结构与搭建
        if cfg is None:
            cfg = [16,32,64,10]
        self.tnn_bin = nn.Sequential(
                nn.Conv2d(1, cfg[0], kernel_size=5, stride=1,padding=2),
                nn.BatchNorm2d(cfg[0]),
                nn.MaxPool2d(kernel_size=2, stride=2),

                Tnn_Bin_Conv2d(cfg[0], cfg[1], kernel_size=5, stride=1,padding=2, A=A, W=W),
                Tnn_Bin_Conv2d(cfg[1], cfg[1], kernel_size=5, stride=1,padding=2, A=A, W=W),
                nn.MaxPool2d(kernel_size=2, stride=2),

                Tnn_Bin_Conv2d(cfg[1], cfg[2], kernel_size=5, stride=1,padding=2, A=A, W=W),
                Tnn_Bin_Conv2d(cfg[2], cfg[3], kernel_size=5, stride=1,padding=2, last_relu=1, A=A, W=W),
                nn.AvgPool2d(kernel_size=7, stride=1, padding=0),
                )

    def forward(self, x):
        x = self.tnn_bin(x)
        x = x.view(x.size(0), -1)
        return x

import numpy as np
import torch.optim as optim
from torch.autograd import Variable
import torchvision
import torchvision.transforms as transforms


device = torch.device('cuda:0')

# 随机种子——训练结果可复现
def setup_seed(seed):
    torch.manual_seed(seed)                                 
    torch.cuda.manual_seed_all(seed)           
    np.random.seed(seed)                       
    torch.backends.cudnn.deterministic = True

# 训练lr调整
def adjust_learning_rate(optimizer, epoch):
    update_list = [10,20,30,40,50]
    if epoch in update_list:
        for param_group in optimizer.param_groups:
            param_group['lr'] = param_group['lr'] * 0.2
    return

# 模型训练
def train(epoch):
    model.train()

    for batch_idx, (data, target) in enumerate(train_loader):
        # 前向传播
        data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        output = model(data)
        loss = criterion(output, target)

        # 反向传播
        optimizer.zero_grad()
        loss.backward() # 求梯度
        optimizer.step() # 参数更新

        # 显示训练集loss(/100个batch)
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tLR: {}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.data.item(),
                optimizer.param_groups[0]['lr']))
    return

# 模型测试
def test():
    global best_acc
    model.eval()
    test_loss = 0
    average_test_loss = 0
    correct = 0

    for data, target in test_loader:
        data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        # 前向传播
        output = model(data)
        test_loss += criterion(output, target).data.item()
        pred = output.data.max(1, keepdim=True)[1]
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()
    # 测试准确率
    acc = 100. * float(correct) / len(test_loader.dataset)

    print(acc)


if __name__=='__main__':
    setup_seed(1)#随机种子——训练结果可复现

    train_dataset = torchvision.datasets.MNIST(root='../../data', 
                                           train=True, 
                                           transform=transforms.ToTensor(),  
                                           download=True)

    test_dataset = torchvision.datasets.MNIST(root='../../data', 
                                          train=False, 
                                          transform=transforms.ToTensor())

    # Data loader
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=128, 
                                           shuffle=True)

    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                          batch_size=128, 
                                          shuffle=False)

    

    print('******Initializing model******')
    # ******************** 在model的量化卷积中同时量化A(特征)和W(模型参数) ************************
    model = Net(A=2, W=2)
    best_acc = 0
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            nn.init.xavier_uniform_(m.weight.data)
            m.bias.data.zero_()
        elif isinstance(m, nn.Linear):
            m.weight.data.normal_(0, 0.01)
            m.bias.data.zero_()
    

    # cpu、gpu
    model.to(device)
    # 打印模型结构
    print(model)

    # 损失函数
    criterion = nn.CrossEntropyLoss()
    # 优化器
    optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0)

    # 训练模型
    for epoch in range(1, 20):
        adjust_learning_rate(optimizer, epoch)
        train(epoch)
        test()

    param=model.state_dict()

    WeightBin=weight_tnn_bin(2)
    #print(WeightBin.forward(torch.from_numpy(BWc1)))
    
    #浮点卷积层
    Wc1=param['tnn_bin.0.weight']
    bc1=param['tnn_bin.0.bias']
    #BN层
    bn1_mean=param['tnn_bin.1.running_mean']
    bn1_var=param['tnn_bin.1.running_var']
    bn1_gamma=param['tnn_bin.1.weight']
    bn1_beta=param['tnn_bin.1.bias']
    #二值卷积层1,2
    BWc1=WeightBin.forward(param['tnn_bin.3.tnn_bin_conv.weight'])
    Bbc1=param['tnn_bin.3.tnn_bin_conv.bias']
    bn2_mean=param['tnn_bin.3.bn.running_mean']
    bn2_var=param['tnn_bin.3.bn.running_var']
    bn2_gamma=param['tnn_bin.3.bn.weight']
    bn2_beta=param['tnn_bin.3.bn.bias']

    BWc2=WeightBin.forward(param['tnn_bin.4.tnn_bin_conv.weight'])
    Bbc2=param['tnn_bin.4.tnn_bin_conv.bias']
    bn3_mean=param['tnn_bin.4.bn.running_mean']
    bn3_var=param['tnn_bin.4.bn.running_var']
    bn3_gamma=param['tnn_bin.4.bn.weight']
    bn3_beta=param['tnn_bin.4.bn.bias']
    #二值卷积层3,4
    BWc3=WeightBin.forward(param['tnn_bin.6.tnn_bin_conv.weight'])
    Bbc3=param['tnn_bin.6.tnn_bin_conv.bias']
    bn4_mean=param['tnn_bin.6.bn.running_mean']
    bn4_var=param['tnn_bin.6.bn.running_var']
    bn4_gamma=param['tnn_bin.6.bn.weight']
    bn4_beta=param['tnn_bin.6.bn.bias']

    BWc4=WeightBin.forward(param['tnn_bin.7.tnn_bin_conv.weight'])
    Bbc4=param['tnn_bin.7.tnn_bin_conv.bias']
    bn5_mean=param['tnn_bin.7.bn.running_mean']
    bn5_var=param['tnn_bin.7.bn.running_var']
    bn5_gamma=param['tnn_bin.7.bn.weight']
    bn5_beta=param['tnn_bin.7.bn.bias']

    print("BWc1")
    print(BWc1)
    print("BWc2")
    print(BWc2)
    print("BWc3")
    print(BWc3)
    print("BWc4")
    print(BWc4)

    alpha1=torch.mean(torch.abs(BWc1),dim=(1,2,3),keepdim=True)
    alpha2=torch.mean(torch.abs(BWc2),dim=(1,2,3),keepdim=True)
    alpha3=torch.mean(torch.abs(BWc3),dim=(1,2,3),keepdim=True)
    alpha4=torch.mean(torch.abs(BWc4),dim=(1,2,3),keepdim=True)

    
    correct=0
    for batch_idx, (data, target) in enumerate(train_loader):
      data,target=data.to(device),target.to(device)
      x=torch.nn.functional.conv2d(data, Wc1, bias=bc1, stride=1, padding=2)
      x=torch.nn.functional.batch_norm(x, running_mean=bn1_mean,running_var=bn1_var,weight=bn1_gamma,bias=bn1_beta)
      x=torch.nn.functional.max_pool2d(x,kernel_size=2,stride=2)
      
      x=torch.sign(x)
      x=torch.nn.functional.conv2d(x,torch.sign(BWc1),stride=1,padding=2)
      x=x*alpha1.view(1,-1,1,1)+Bbc1.view(1,-1,1,1)
      x=torch.nn.functional.batch_norm(x, running_mean=bn2_mean,running_var=bn2_var,weight=bn2_gamma,bias=bn2_beta)
      x=torch.sign(x)
      x=torch.nn.functional.conv2d(x,torch.sign(BWc2),stride=1,padding=2)
      x=x*alpha2.view(1,-1,1,1)+Bbc2.view(1,-1,1,1)
      x=torch.nn.functional.batch_norm(x, running_mean=bn3_mean,running_var=bn3_var,weight=bn3_gamma,bias=bn3_beta)
      x=torch.nn.functional.max_pool2d(x,kernel_size=2,stride=2)

      x=torch.sign(x)
      x=torch.nn.functional.conv2d(x,torch.sign(BWc3),stride=1,padding=2)
      x=x*alpha3.view(1,-1,1,1)+Bbc3.view(1,-1,1,1)
      x=torch.nn.functional.batch_norm(x, running_mean=bn4_mean,running_var=bn4_var,weight=bn4_gamma,bias=bn4_beta)
      x=torch.sign(x)
      x=torch.nn.functional.conv2d(x,torch.sign(BWc4),stride=1,padding=2)
      x=x*alpha4.view(1,-1,1,1)+Bbc4.view(1,-1,1,1)
      x=torch.nn.functional.batch_norm(x, running_mean=bn5_mean,running_var=bn5_var,weight=bn5_gamma,bias=bn5_beta)
      x=torch.nn.functional.avg_pool2d(x,kernel_size=7)

      output=torch.argmax(x,axis=1)
      for i in range(data.size(0)):
        if target[i]==output[i]:
          correct+=1
    print("Test accuracy is {}".format(correct/60000))

    



运算结果: 在这里插入图片描述

关注
打赏
1658642721
查看更多评论
立即登录/注册

微信扫码登录

0.0419s