【机器学习】支持向量机 SVM 从原理到实战(Python 全流程实现)
目录
注:其中iris.csv数据集和SVM详细API文档都在我的主页资源中
一、前言
支持向量机(Support Vector Machine,SVM)是机器学习领域经典的有监督分类算法,自诞生以来凭借扎实的数学理论、优秀的小样本学习能力、强大的非线性拟合能力,在分类、回归等任务中得到了广泛应用。本文将从通俗的原理讲解入手,深入拆解 SVM 的核心逻辑,再基于 Python+sklearn 实现完整的 SVM 分类任务,包含可视化、模型训练、评估全流程,帮助读者从入门到实战彻底掌握 SVM。
二、SVM 核心原理(从通俗到深入)
2.1 什么是 SVM?一个通俗的小故事
我们用一个经典的故事理解 SVM 的核心思想:很久以前,公主被魔鬼绑架,王子需要完成魔鬼的挑战:用一根棍子分开桌子上两种颜色的球,并且要求后续加入更多球时,这根棍子依然能有效分类。

1.第一次王子随便放了棍子,结果新增的球直接越界,分类失效;

2.后来王子把棍子放在了两类球的中间,让棍子两边到最近的球的距离尽可能大,此时哪怕新增更多球,棍子依然能稳定分类;

3.魔鬼又把球摆成了非线性的布局,二维平面里根本没法用一根直线分开,王子一拍桌子让球飞到空中,用一张纸完美隔开了两类球。

对应到 SVM 的核心概念里:
两种颜色的球 = 我们的训练数据棍子 / 纸 = 分类决策边界(超平面)让棍子两边间隙最大的操作 = 最大间隔最优化拍桌子让球飞起来 = 核函数(低维映射到高维)离棍子最近、决定棍子位置的球 = 支持向量
2.2 核心目标:最优超平面与最大间隔
SVM 的核心目标,就是找到一个最优超平面,让不同类别的样本被完美分开,且两类样本到超平面的最小距离(间隔)最大化。
2.2.1 超平面方程
超平面是分类的决策边界,在不同维度空间有不同的表达形式:
二维平面:一条直线,方程为
三维空间:一个平面,方程为
更高维空间:超平面,通用方程为 。
其中 ω为超平面的法向量(决定超平面方向),b为偏置项(决定超平面的位置)。
最终的分类决策函数为:。
其中sign为符号函数,输入大于 0 输出 1(正例),小于 0 输出 - 1(负例)。
2.2.2 点到超平面的距离
样本点到超平面的距离,是衡量分类置信度的核心指标,公式为:
。
结合分类的正确性,我们可以得到几何间隔:当样本分类正确时,
,因此样本到超平面的几何间隔可写为:
2.2.3 最大间隔的优化目标
我们的目标是:让离超平面最近的样本点(支持向量)到超平面的距离最大化。通过数学放缩,我们可以约束支持向量满足
,此时最大间隔的优化目标可转化为:
。
约束条件:
。这个优化问题可以通过拉格朗日乘子法转化为对偶问题求解,最终得到最优的ω和b,也就是超平面的参数。
2.2.4 什么是支持向量?
在求解过程中,只有满足
的样本点,对应的拉格朗日乘子
,这些样本就是支持向量。
SVM 的核心特性之一就是:最终的决策超平面只由少数支持向量决定,哪怕移除其他所有样本,超平面的位置也不会改变。这也是 SVM 在小样本场景下表现优异的核心原因。
2.3 软间隔:解决噪声与线性不可分
现实场景中,很多数据存在噪声点,无法实现完美的线性可分,如果强行追求 100% 分类正确,会导致模型泛化能力极差。因此 SVM 引入了软间隔的概念:允许少数样本点违反约束、出现在间隔带内,甚至被误分类,以此提升模型的泛化能力。
我们引入松弛因子
,将约束条件放宽为:
。 同时优化目标更新为:
其中C为惩罚因子,是 SVM 的核心超参数:
C越大:对误分类的惩罚越重,模型越不允许出现误分类,容易过拟合,泛化能力弱;
C越小:对误分类的惩罚越轻,允许更多样本违反约束,模型泛化能力强,容易欠拟合。
2.4 核函数:低维解决高维非线性问题
对于完全线性不可分的数据,SVM 通过核函数解决问题:将低维空间的线性不可分数据,映射到高维特征空间,使其在高维空间中线性可分,再在高维空间中学习最优超平面。
直接在高维空间计算会带来巨大的计算量,而核函数的核心优势是:在低维空间完成高维空间的内积运算,结果完全一致,大幅降低计算复杂度。
常用的核函数有以下几种:
1.线性核:,适用于线性可分的数据,计算速度快,可解释性强;
2.多项式核:,适用于中等规模的非线性数据,可通过 degree 调整多项式维度;
3.高斯核(RBF,径向基函数):,默认核函数,适用于绝大多数非线性场景,通过调整映射范围:
越小:正态分布越 “胖”,辐射范围越大,过拟合风险越低;
越大:正态分布越 “瘦”,辐射范围越小,过拟合风险越高。
2.5 SVM 的优缺点
优点
有严格的数学理论支撑,可解释性强,不同于黑盒模型;小样本场景下表现优异,最终决策仅由少数支持向量决定;软间隔机制可有效提升模型泛化能力,适配带噪声的现实数据;核函数可完美解决非线性分类问题,避免 “维数灾难”;泛化能力强,在分类任务中不易过拟合。
缺点
对大规模训练样本适配性差,样本量超过 10 万时,核矩阵的存储和计算会耗费大量内存和时间;对核函数和超参数的选择非常敏感,不同参数对模型效果影响极大;预测速度与支持向量的数量成正比,支持向量过多时,预测效率较低。
三、SVM 实战:基于 Python+sklearn 实现
本次实战使用经典的鸢尾花数据集,分为两个部分:
- 二维特征线性 SVM 可视化,直观展示超平面、间隔、支持向量;
- 全特征 RBF 核 SVM 多分类,完成模型训练、混淆矩阵可视化、分类报告输出全流程。
3.1 环境准备
需要提前安装相关依赖库:
pip install pandas numpy matplotlib scikit-learn3.2 实战一:二维特征线性 SVM 可视化
本部分选取鸢尾花数据集的 2 个特征,训练线性核 SVM,并可视化超平面、间隔边界和支持向量,直观理解 SVM 的核心逻辑。
3.2.1 完整代码实现
import pandas as pd data=pd.read_csv("iris.csv",header=None) import matplotlib.pyplot as plt data1=data.iloc[:50,:] data2=data.iloc[50:100,:] data3=data.iloc[100:,:] plt.scatter(data1[1],data1[3],marker="^") plt.scatter(data2[1],data2[3],marker="o") #使用svm训练 from sklearn.svm import SVC x = data.iloc[:,[1,3]] y = data.iloc[:,-1] svm = SVC(kernel="linear",C=100,random_state=0) #c无穷大float('inf'),则软间隔为0,不容有其他点进入软间隔内 svm.fit(x,y) #可视化svm结果 #参数w【原始数据为二维数组】 w = svm.coef_[0] b = svm.intercept_[0] #超平面方程w1*1+w2*2+b=0 import numpy as np x1 = np.linspace(0,5,700) #超平面方程 x2 = -(w[0]*x1+b)/w[1] #上超平面方程 x3 = (1-(w[0]*x1+b))/w[1] #下超平面方程 x4 = (-1-(w[0]*x1+b))/w[1] #可视化超平面 plt.plot(x1,x2,linewidth=2,color='r') plt.plot(x1,x3,linewidth=1,color='r',linestyle='--') plt.plot(x1,x4,linewidth=1,color='r',linestyle='--') # #对坐标限制 # plt.xlim([2,6]) # plt.ylim([0,3]) #找到支持向量【二维数组】可视化向量 vet = svm.support_vectors_ plt.scatter(vet[:,0],vet[:,1],c="b",marker="+") plt.show() 3.2.2 结果可视化与解读

从可视化结果中我们可以清晰看到:
红色实线为 SVM 学习到的最优分类超平面,完美分隔了两类样本;两条红色虚线为间隔边界,两类样本的间隔被最大化;蓝色 + 标记的点就是支持向量,这些点落在间隔边界上,是决定超平面位置的核心样本,其他样本的移除不会改变超平面的位置。
3.3 实战二:鸢尾花数据集全特征 SVM 多分类
本部分使用鸢尾花数据集的全部 4 个特征,基于 RBF 核 SVM 完成三分类任务,包含数据划分、模型训练、混淆矩阵可视化、分类报告输出全流程。
3.3.1 数据集与预处理
鸢尾花数据集包含 3 类鸢尾花,每类 50 个样本,共 150 条数据,4 个特征分别为花萼长度、花萼宽度、花瓣长度、花瓣宽度,我们按照 8:2 划分训练集和测试集。
3.3.2 完整代码实现
'''四个特征全训练''' import pandas as pd datas = pd.read_csv("iris.csv",header=None) data = datas.iloc[:,:-1].values target = datas.iloc[:,-1].values #数据切分 from sklearn.model_selection import train_test_split x_train,x_test,y_train,y_test = \ train_test_split(data,target,test_size=0.2,random_state=0) #可视化混淆矩阵 def cm_plot(ah,yp): from sklearn.metrics import confusion_matrix import matplotlib.pyplot as plt cm = confusion_matrix(ah,yp) plt.matshow(cm, cmap=plt.cm.Blues) plt.colorbar() for x in range(len(cm)): for y in range(len(cm[x])): plt.annotate(cm[x][y], xy=(y,x),horizontalalignment="center" ,color="white",verticalalignment="center") plt.ylabel('True label') plt.xlabel('Predicted label') plt.show() return plt #模型训练 from sklearn.svm import SVC svm = SVC(kernel='rbf',C=10) svm.fit(x_train,y_train) #模型自测试 y_pred = svm.predict(x_train) cm_plot(y_train,y_pred) #测试集测试 y_test_pred = svm.predict(x_test) cm_plot(y_test,y_test_pred) """"#测试集测试获得分类结果报告""" from sklearn import metrics test_predicted_big = svm.predict(x_test) print(metrics.classification_report(y_test,test_predicted_big))3.3.3 模型评估结果解读
训练集混淆矩阵
可视化视图

训练集共 120 个样本,仅出现 2 个误分类样本,整体分类准确率超过 98%,模型在训练数据上拟合效果良好,没有出现欠拟合。
测试集混淆矩阵
可视化视图

测试集共 30 个样本,所有样本均被正确分类,无任何误分类情况,模型在未见过的测试数据上表现完美,泛化能力优异。
测试集分类报告

从分类报告可以看到,3 个类别的精确率(precision)、召回率(recall)、F1-score 均为 1.00,整体准确率(accuracy)达到 100%,进一步验证了 SVM 在该分类任务上的优秀表现。
四、SVM 核心 API 参数详解
本文使用 sklearn 的SVC类实现 SVM 分类,核心参数如下:
| 参数名 | 作用 | 核心说明 |
|---|---|---|
| C | 惩罚因子 | 浮点数,默认 1.0。C 越大,对误分类惩罚越重,易过拟合;C 越小,容错率越高,易欠拟合 |
| kernel | 核函数 | 默认rbf,可选linear(线性核)、poly(多项式核)、sigmoid |
| degree | 多项式维度 | 整数,默认 3,仅对poly核生效,其他核函数会忽略该参数 |
| gamma | 核函数系数 | 仅对rbf、poly、sigmoid生效。gamma 越大,过拟合风险越高;gamma 越小,泛化能力越强 |
| random_state | 随机种子 | 固定随机种子,保证实验结果可复现 |
其中,C、kernel、gamma是对模型效果影响最大的三个超参数,实际使用中建议通过网格搜索 + 交叉验证的方式选择最优参数组合。
五、总结
本文从通俗的原理入手,深入讲解了 SVM 的最优超平面、最大间隔、支持向量、软间隔、核函数等核心概念,同时基于 Python+sklearn 实现了完整的 SVM 分类任务,包含可视化、模型训练、评估全流程。SVM 作为经典的机器学习算法,在小样本、中等样本量的分类任务中有着不可替代的优势,掌握其核心原理和实战技巧,是机器学习入门的必备技能。读者可以基于本文的代码,更换自己的数据集,调整超参数,进一步深入理解 SVM 的特性。