1 数据
以R中自带的鸢尾花数据集为例,根据花瓣、萼片的长宽来预测植物类别!
> data(iris)
> iris
Sepal.Length Sepal.Width Petal.Length Petal.Width Species
1 5.1 3.5 1.4 0.2 setosa
2 4.9 3.0 1.4 0.2 setosa
3 4.7 3.2 1.3 0.2 setosa
4 4.6 3.1 1.5 0.2 setosa
5 5.0 3.6 1.4 0.2 setosa
6 5.4 3.9 1.7 0.4 setosa
7 4.6 3.4 1.4 0.3 setosa
8 5.0 3.4 1.5 0.2 setosa
说明:iris是R自带的数据,Species是鸢尾花的种类,Sepal.Length Sepal.Width Petal.Length Petal.Width分别是萼片、花瓣的长和宽。
2 划分训练集和测试集
> dim(iris)
[1] 150 5
> n=dim(iris)[1]
> sp=sample(1:n,size=round(n*0.3),replace=FALSE) # 随机抽取30%的数据
> iris_train=iris[-sp,] # 70%作为训练集
> iris_test=iris[sp,] # 30%作为测试集
3 采用SVM模型进行分类
> install.packages("e1071")
> library(e1071)
3.1 使用默认参数
训练模型:
> fit_svm = svm(Species~.,data=iris_train)
预测:
> pdt = predict(fit_svm,iris_test)
> sum(as.vector(pdt)==iris_test$Species)/dim(iris_test)[1]
[1] 0.9111111
> table(as.vector(pdt),iris_test$Species)
setosa versicolor virginica
setosa 15 0 0
versicolor 0 13 4
virginica 0 0 13
> plot(fit_svm,data = iris_test,Petal.Width~Petal.Length,slice = list(Sepal.Width = 3, Sepal.Length = 5))
说明:由于是四维空间中截取了(Sepal.Width = 3, Sepal.Length = 5)一个平面,所以图像不能用于判断样本点的划分。
3.2 设置核函数
训练模型:
> fit_svm1 = svm(Species~.,data=iris_train,kernel = "linear",cost=2)
预测:
> pdt1 = predict(fit_svm1,iris_test)
> sum(as.vector(pdt1)==iris_test$Species)/dim(iris_test)[1]
[1] 0.9333333
> table(as.vector(pdt1),iris_test$Species)
setosa versicolor virginica
setosa 15 0 0
versicolor 0 13 3
virginica 0 0 14
说明:virginica的分类正确数由13增加到14
3.3 自动选择最优参数
训练模型:
> fit_svmAuto = tune(svm,Species~.,data = iris_train,ranges = list(epsilon = seq(0,1,0.1,),cost = c(2:100)))
> plot(fit_svmAuto)
说明:颜色越深,说明cost的取值越好
> fit_svmAuto$best.model
Call:
best.tune(method = svm, train.x = Species ~ ., data = iris_train, ranges = list(epsilon = seq(0,
1, 0.1, ), cost = c(2:100)))
Parameters:
SVM-Type: C-classification
SVM-Kernel: radial
cost: 2
Number of Support Vectors: 32
说明:选择的最优模型为cost=2,核函数为radial,共有32个支持向量
预测:
> bestmodel=fit_svmAuto$best.model
> pdt2 = predict(bestmodel,iris_test)
> sum(as.vector(pdt2)==iris_test$Species)/dim(iris_test)[1]
[1] 0.9333333
> table(as.vector(pdt2),iris_test$Species)
setosa versicolor virginica
setosa 15 0 0
versicolor 0 13 3
virginica 0 0 14
说明:对于测试样本,精度依然为93.33%