目录
介绍
训练VGG16
在新图像上评估VGG16
下一步
- 下载源 - 120.7 MB
DeepFashion等数据集的可用性为时尚行业开辟了新的可能性。在本系列文章中,我们将展示一个人工智能驱动的深度学习系统,它可以帮助我们更好地了解客户的需求,从而彻底改变时装设计行业。
在这个项目中,我们将使用:
- Jupyter Notebook作为IDE
- 库:
- TensorFlow 2.0
- NumPy
- MatplotLib
- DeepFashion数据集的自定义子集——相对较小以减少计算和内存开销
我们假设您熟悉深度学习的概念,以及Jupyter Notebooks和TensorFlow。如果您是Jupyter Notebooks的新手,请从本教程开始。欢迎您下载项目代码。
在上一篇文章中,我们向您展示了如何加载DeepFashion数据集,以及如何重构VGG16模型以适应我们的服装分类任务。在本文中,我们将训练VGG16对15种不同的服装类别进行分类并评估模型性能。
训练VGG16VGG16的迁移学习首先冻结通过在ImageNet等大型数据集上训练模型而获得的模型权重。这些学习到的权重和过滤器为网络提供了强大的特征提取能力,这将有助于我们在训练对服装类别进行分类时提高其性能。因此,只训练全连接 (FC) 层,同时保持模型的特征提取部分几乎冻结(通过设置非常低的学习率,如0.001)。让我们通过将特征提取层设置为False来冻结特征提取层:
for layer in conv_model.layers:
layer.trainable = False
现在,我们可以在选择学习率 (0.001) 和优化器 (Adamax) 的同时编译我们的模型:
full_model.compile(loss='categorical_crossentropy',
optimizer=keras.optimizers.Adamax(lr=0.001),
metrics=['acc'])
编译后,我们可以使用fit_generator函数开始模型训练,因为我们曾经使用ImageDataGenerator加载我们的数据。我们将分别使用train_dataset和val_dataset表示的数据训练和验证我们的网络。我们将训练三个时期,但这个数字可以根据网络性能增加。
history = full_model.fit_generator(
train_dataset,
validation_data = val_dataset,
workers=0,
epochs=3,
)
运行上面的代码将产生以下输出:
现在,为了绘制网络的学习和损失曲线,让我们添加plot_history函数:
def plot_history(history, yrange):
'''Plot loss and accuracy as a function of the epoch,
for the training and validation datasets.
'''
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']
# Get number of epochs
epochs = range(len(acc))
# Plot training and validation accuracy per epoch
plt.plot(epochs, acc)
plt.plot(epochs, val_acc)
plt.title('Training and validation accuracy')
plt.ylim(yrange)
# Plot training and validation loss per epoch
plt.figure()
plt.plot(epochs, loss)
plt.plot(epochs, val_loss)
plt.title('Training and validation loss')
plt.show()
plot_history(history, yrange=(0.9,1))
此函数将生成以下两个图:
我们的网络在训练期间表现良好。因此,在它以前从未见过的衣服图像上进行测试时,它也应该表现良好,对吧?我们将在我们的测试图像集上对其进行测试。
首先,让我们加载测试集,然后使用该model.evaluate函数将测试图像传递给模型以测量网络精度。
from tensorflow.keras.preprocessing.image import ImageDataGenerator
test_dir=r'C:\Users\abdul\Desktop\ContentLab\P2\DeepFashion\Test'
test_datagen = ImageDataGenerator()
test_generator = test_datagen.flow_from_directory(test_dir, target_size=(224, 224), batch_size=3, class_mode='categorical')
# X_test, y_test = next(test_generator)
Testresults = full_model.evaluate(test_generator)
print("test loss, test acc:", Testresults)
好吧,很明显我们的网络训练有素。没有过拟合:它在测试集上达到了92%的准确率。
下一步在接下来的文章中,我们将使用VGG19由手机相机拍摄的实际图像评估。敬请关注!
https://www.codeproject.com/Articles/5297327/Fine-tuning-VGG16-to-Classify-Clothing