迁移学习,简单的理解就是使用一些已经训练好的模型迁移到类似的新的问题进行使用,而不必对新问题重新建模,从头训练和优化参数。这些训练好的模型同时包含了优化好的参数,在使用的时候只需要做一些简单的调整就可以应用到新问题中了,可以说,迁移学习在某种程度上是站在了巨人的肩膀上。
本文使用已经训练完成的VGG16模型,固定特征提取层的参数,对分类层稍作修改,然后进行训练,训练过程仅更新分类层参数,采用的数据集则是有5个类别的花卉数据集,下图是花卉数据集的目录:
数据集的目录格式如下,每一个文件夹,代表一种不同的花卉,这个文件夹内的所有图片都属于该花卉类型,针对这种形式的数据集,pytorch提供了一个API,能很方便的对它进行读取,那就是ImageFolder,具体的读取方式如下
BATCH_SIZE=32
path='F:\\data\\flower_photos\\flower_photos'
flower_class=['daisy','dandelion','roses','sunflowers','tulips']
transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
"val": transforms.Compose([transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
}
image_path = path
trainset = datasets.ImageFolder(root=image_path,
transform=transform["train"])
trainloader = data.DataLoader(trainset, BATCH_SIZE, shuffle=True)
print(trainset.classes) #根据分的文件夹的名字来确定的类别
print(trainset.class_to_idx) #按顺序为这些类别定义索引为0,1...
# print(trainset.imgs) #返回从所有文件夹中得到的图片的路径以及其类别
这里,为了适应VGG16的输入尺寸,我们把所有图片都resize到224x224的大小,通过ImageFolder读取图片数据之后,我们可以通过.class属性,来查看有哪些类别,通过class_to_idx属性,来查看这些不同类别所对应的数字编码,还可以通过.img,查看每一张图片及其它所属的类别。
模型的迁移pytorch中有已经训练完成的VGG16模型,因此,我们可以很方便的通过代码加载,模型加载完毕后,我们让特征提取层的所有参数都不进行梯度的计算和权重的更新,然后,我们修改分类层的最后一层,将imagenet中的1000个类别修改为这里的5,最后,我们选取Adam优化器,并只优化分类层的参数,代码如下所示:
model = models.vgg16(pretrained=True)
# 查看迁移模型细节
print("迁移VGG16:\n", model)
# 对迁移模型进行调整
for parma in model.parameters():
parma.requires_grad = False
model.classifier = torch.nn.Sequential(torch.nn.Linear(25088, 4096),
torch.nn.ReLU(),
torch.nn.Dropout(p=0.5),
torch.nn.Linear(4096, 4096),
torch.nn.ReLU(),
torch.nn.Dropout(p=0.5),
torch.nn.Linear(4096, 5))
# 定义代价函数和优化函数
loss_f = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.classifier.parameters(), lr=0.00001)
完整代码
import os
import torchvision.transforms as transforms
from torchvision import datasets
import torch.utils.data as data
import torch
import numpy as np
import matplotlib.pyplot as plt
import torchvision.models as models
BATCH_SIZE=32
path='F:\\data\\flower_photos\\flower_photos'
flower_class=['daisy','dandelion','roses','sunflowers','tulips']
transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
"val": transforms.Compose([transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
}
image_path = path
trainset = datasets.ImageFolder(root=image_path,
transform=transform["train"])
trainloader = data.DataLoader(trainset, BATCH_SIZE, shuffle=True)
print(trainset.classes) #根据分的文件夹的名字来确定的类别
print(trainset.class_to_idx) #按顺序为这些类别定义索引为0,1...
# print(trainset.imgs) #返回从所有文件夹中得到的图片的路径以及其类别
def imshow(image):
for i in range(image.size(0)):
img = image[i] # plt.imshow()只能接受3-D Tensor,所以也要用image[0]消去batch那一维
img = img.numpy() # FloatTensor转为ndarray
img = np.transpose(img, (1, 2, 0)) # 把channel那一维放到最后
# 显示图片
plt.imshow(img)
plt.show()
model = models.vgg16(pretrained=True)
# 查看迁移模型细节
print("迁移VGG16:\n", model)
# 对迁移模型进行调整
for parma in model.parameters():
parma.requires_grad = False
model.classifier = torch.nn.Sequential(torch.nn.Linear(25088, 4096),
torch.nn.ReLU(),
torch.nn.Dropout(p=0.5),
torch.nn.Linear(4096, 4096),
torch.nn.ReLU(),
torch.nn.Dropout(p=0.5),
torch.nn.Linear(4096, 5))
# 定义代价函数和优化函数
loss_f = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.classifier.parameters(), lr=0.00001)
# 模型训练和参数优化
epoch_n = 5
torch.cuda.empty_cache()
for epoch in range(epoch_n):
print("Epoch {}/{}".format(epoch + 1, epoch_n))
print("-" * 10)
# 设置为True,会进行Dropout并使用batch mean和batch var
model.train(True)
running_loss = 0.0
running_corrects = 0
# enuerate(),返回的是索引和元素
for batch, data in enumerate(trainloader):
X, y = data
y_pred = model(X)
# pred,概率较大值对应的索引值,可看做预测结果
_, pred = torch.max(y_pred.data, 1)
# 梯度归零
optimizer.zero_grad()
# 计算损失
loss = loss_f(y_pred, y)
loss.backward()
optimizer.step()
# 计算损失和
running_loss += float(loss)
# 统计预测正确的图片数
running_corrects += torch.sum(pred == y.data)
print("loss=",running_loss/BATCH_SIZE)
print("acc is {}%".format(running_corrects.item()/BATCH_SIZE*100.0))
running_loss=0
running_corrects=0
torch.save(model.state_dict(),'model.pkl')
训练结果
训练5个EPOCH之后