神经网络模型

Reads: 2030 Edit

1 数据

以R中自带的鸢尾花数据集为例,根据花瓣、萼片的长宽来预测植物类别!

> data(iris)
> iris
    Sepal.Length Sepal.Width Petal.Length Petal.Width    Species
1            5.1         3.5          1.4         0.2     setosa
2            4.9         3.0          1.4         0.2     setosa
3            4.7         3.2          1.3         0.2     setosa
4            4.6         3.1          1.5         0.2     setosa
5            5.0         3.6          1.4         0.2     setosa
6            5.4         3.9          1.7         0.4     setosa
7            4.6         3.4          1.4         0.3     setosa
8            5.0         3.4          1.5         0.2     setosa

说明:iris是R自带的数据,Species是鸢尾花的种类,Sepal.Length Sepal.Width Petal.Length Petal.Width分别是萼片、花瓣的长和宽。

2 划分训练集和测试集

> dim(iris)
[1] 150   5
> n=dim(iris)[1]
> sp=sample(1:n,size=round(n*0.3),replace=FALSE)     # 随机抽取30%的数据
> iris_train=iris[-sp,]                      # 70%作为训练集
> iris_test=iris[sp,]                       # 30%作为测试集

3 采用神经网络模型进行分类

3.1 一个隐藏单元

训练模型:

> install.packages("neuralnet")
> library(neuralnet)
> fit_network = neuralnet(Species~.,iris_train,hidden = 1)

> plot(fit_network)

r-94

预测:

> fit_network$model.list
$response
[1] "setosa"     "versicolor" "virginica" 

说明:neuralnet的预测值不是一个值,而是预测变量类别的个数;fit_network$model.list主要是为了获得预测变量类别的次序!

> pdt = predict(fit_network,iris_test)
> pdt=c("setosa","versicolor","virginica")[apply(pdt,1,which.max)]
> sum(as.vector(pdt)==iris_test$Species)/dim(iris_test)[1]
[1] 0.6222222
> table(as.vector(pdt),iris_test$Species)

             setosa versicolor virginica
  setosa         15          0         0
  versicolor      0         13        17

说明:一个隐藏单元的预测精度不高

3.2 三个隐藏单元

训练模型:

> fit_network1 = neuralnet(Species~.,iris_train,hidden = 3)
> plot(fit_network1)

r-95

预测:

> pdt1 = predict(fit_network1,iris_test)
> pdt1=fit_network1$model.list$response[apply(pdt1,1,which.max)]
> sum(as.vector(pdt1)==iris_test$Species)/dim(iris_test)[1]
[1] 0.9111111
> table(as.vector(pdt1),iris_test$Species)

             setosa versicolor virginica
  setosa         15          0         0
  versicolor      0         13         4
  virginica       0          0        13

说明:采用3个隐藏单元后,分类精度大福提高。继续设置4个5个隐藏单元后,精度无法再提高。


Comments

Make a comment