在西瓜数据集 3.0α 上分别用线性核和高斯核训练一个 SVM,并比较其支持向量的差别。 数据集下载地址: https://amazecourses.obs.cn-north-4.myhuaweicloud.com/datasets/watermelon_3a.csv 任选数据集中的一种分布类型的数据,分别用软、硬间隔SVM和各类核函数训练,并分析他们分类的效果。 数据集下载地址:https://amazecourses.obs.cn-north-4.myhuaweicloud.com/datasets/SVM.zip
此博客为第二问,各类SVM的实现。上一篇博客为https://blog.csdn.net/qq_44459787/article/details/111409314
- 数据载入与处理
#!/usr/bin/env python # -*- coding: utf-8 -*- # @Time : 2020/12/19 20:48 # @Author : Ryu # @Site : # @File : processing_data.py # @Software: PyCharm import numpy as np from sklearn.model_selection import train_test_split def npz_read(file_dir): npz = np.load(file_dir) data = npz['data'] label_list = npz['label'] npz.close() return data, label_list def split_train_test(data, label_list): xtrain, xtest, ytrain, ytest = train_test_split(data, label_list, test_size=0.3) return xtrain, xtest, ytrain, ytest if __name__ == '__main__': file_name = r'D:\Pythonwork\FisherLDA\SVM\2\分布1.npz' data, label_list = npz_read(file_name) xtrain, xtest, ytrain, ytest = split_train_test(data, label_list)
npz为数据集自带的解析算法。分类数据集依然使用的是sklearn的split方法。
2. 核函数实现
#!/usr/bin/env python # -*- coding: utf-8 -*- # @Time : 2020/12/19 21:07 # @Author : Ryu # @Site : # @File : SVM.py # @Software: PyCharm from sklearn import svm from sklearn.metrics import accuracy_score from SVM2.processing_data import * if __name__ == '__main__': file_name = 'D:\Pythonwork\FisherLDA\SVM\SVM2\分布1.npz' data, label_list = npz_read(file_name) xtrain, xtest, ytrain, ytest = split_train_test(data, label_list) # 线性核处理 linear_svm = svm.LinearSVC(C=0.5, class_weight='balanced') linear_svm.fit(xtrain, ytrain) y_pred = linear_svm.predict(xtest) print('线性核的准确率为:{}'.format(accuracy_score(y_pred=y_pred, y_true=ytest))) # 高斯核处理 gauss_svm = svm.SVC(C=0.5, kernel='rbf', class_weight='balanced') gauss_svm.fit(xtrain, ytrain) y_pred2 = gauss_svm.predict(xtest) print('高斯核的准确率: %s' % (accuracy_score(y_pred=y_pred2, y_true=ytest))) # 多项式核 poly_svm = svm.SVC(C=0.5, kernel='poly', degree=3, gamma='auto', coef0=0, class_weight='balanced') poly_svm.fit(xtrain, ytrain) y_pred3 = poly_svm.predict(xtest) print('多项式核的准确率: %s' % (accuracy_score(y_pred=y_pred3, y_true=ytest))) # sigmoid核 sigmoid_svm = svm.SVC(C=0.5, kernel='sigmoid', degree=3, gamma='auto', coef0=0, class_weight='balanced') sigmoid_svm.fit(xtrain, ytrain) y_pred4 = sigmoid_svm.predict(xtest) print('sigmoid核的准确率: %s' % (accuracy_score(y_pred=y_pred4, y_true=ytest))) #sigmoid核硬间隔 sigmoid_hard_svm = svm.SVC(C=1000000, kernel='sigmoid', degree=3, gamma='auto', coef0=0, class_weight='balanced') sigmoid_hard_svm.fit(xtrain, ytrain) y_pred5 = sigmoid_hard_svm.predict(xtest) print('sigmoid核硬间隔的准确率: %s' % (accuracy_score(y_pred=y_pred5, y_true=ytest)))
实现部分主要是实现四类核函数和硬间隔支持向量机。
3. 结果分析
经过多次试验之后,这张结果记录是出现频率最多的。线性核的分类效果很不错,但是并没有高斯核和多项式核的稳定,sigmoid核的分类效果相当的差,但是软间隔的分类效果还是相较于硬间隔更好。
此图的样本数据分类效果本身就很明显,核函数的效果好也是在期望之内的。
