当前位置:   article > 正文

3.1、随机森林之随机森林实例_随机森林error in yhat[keep] <- ans$ypred : nas are not

随机森林error in yhat[keep] <- ans$ypred : nas are not allowed in subscripte

随机森林实例

Markdown脚本及数据集:http://pan.baidu.com/s/1bnY6ar9

实例一、用随机森林对鸢尾花数据进行分类

  1. #1、加载数据并查看
  2. data("iris")
  3. summary(iris)
  1. ## Sepal.Length Sepal.Width Petal.Length Petal.Width
  2. ## Min. :4.300 Min. :2.000 Min. :1.000 Min. :0.100
  3. ## 1st Qu.:5.100 1st Qu.:2.800 1st Qu.:1.600 1st Qu.:0.300
  4. ## Median :5.800 Median :3.000 Median :4.350 Median :1.300
  5. ## Mean :5.843 Mean :3.057 Mean :3.758 Mean :1.199
  6. ## 3rd Qu.:6.400 3rd Qu.:3.300 3rd Qu.:5.100 3rd Qu.:1.800
  7. ## Max. :7.900 Max. :4.400 Max. :6.900 Max. :2.500
  8. ## Species
  9. ## setosa :50
  10. ## versicolor:50
  11. ## virginica :50
  12. ##
  13. ##
  14. ##
str(iris)
  1. ## 'data.frame': 150 obs. of 5 variables:
  2. ## $ Sepal.Length: num 5.1 4.9 4.7 4.6 5 5.4 4.6 5 4.4 4.9 ...
  3. ## $ Sepal.Width : num 3.5 3 3.2 3.1 3.6 3.9 3.4 3.4 2.9 3.1 ...
  4. ## $ Petal.Length: num 1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ...
  5. ## $ Petal.Width : num 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...
  6. ## $ Species : Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 1 1 1 1 1 1 ...
  1. #2、创建训练集和测试集数据
  2. set.seed(2001)
  3. library(caret)
## Loading required package: lattice
## Loading required package: ggplot2
## Warning: package 'ggplot2' was built under R version 3.2.3
  1. index <- createDataPartition(iris$Species, p=0.7, list=F)
  2. train_iris <- iris[index, ]
  3. test_iris <- iris[-index, ]
  4. #3、建模
  5. library(randomForest)
## randomForest 4.6-12
## Type rfNews() to see new features/changes/bug fixes.
  1. ##
  2. ## Attaching package: 'randomForest'
  1. ## The following object is masked from 'package:ggplot2':
  2. ##
  3. ## margin
  1. model_iris <- randomForest(Species~., data=train_iris, ntree=50, nPerm=10, mtry=3, proximity=T, importance=T)
  2. #4、模型评估
  3. model_iris
  1. ##
  2. ## Call:
  3. ## randomForest(formula = Species ~ ., data = train_iris, ntree = 50, nPerm = 10, mtry = 3, proximity = T, importance = T)
  4. ## Type of random forest: classification
  5. ## Number of trees: 50
  6. ## No. of variables tried at each split: 3
  7. ##
  8. ## OOB estimate of error rate: 4.76%
  9. ## Confusion matrix:
  10. ## setosa versicolor virginica class.error
  11. ## setosa 35 0 0 0.00000000
  12. ## versicolor 0 32 3 0.08571429
  13. ## virginica 0 2 33 0.05714286
str(model_iris)
  1. ## List of 19
  2. ## $ call : language randomForest(formula = Species ~ ., data = train_iris, ntree = 50, nPerm = 10, mtry = 3, proximity = T, importance = T)
  3. ## $ type : chr "classification"
  4. ## $ predicted : Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 1 1 1 1 1 1 ...
  5. ## ..- attr(*, "names")= chr [1:105] "5" "7" "8" "11" ...
  6. ## $ err.rate : num [1:50, 1:4] 0.0513 0.0758 0.0741 0.0435 0.0505 ...
  7. ## ..- attr(*, "dimnames")=List of 2
  8. ## .. ..$ : NULL
  9. ## .. ..$ : chr [1:4] "OOB" "setosa" "versicolor" "virginica"
  10. ## $ confusion : num [1:3, 1:4] 35 0 0 0 32 2 0 3 33 0 ...
  11. ## ..- attr(*, "dimnames")=List of 2
  12. ## .. ..$ : chr [1:3] "setosa" "versicolor" "virginica"
  13. ## .. ..$ : chr [1:4] "setosa" "versicolor" "virginica" "class.error"
  14. ## $ votes : matrix [1:105, 1:3] 1 1 1 1 1 1 1 1 1 1 ...
  15. ## ..- attr(*, "dimnames")=List of 2
  16. ## .. ..$ : chr [1:105] "5" "7" "8" "11" ...
  17. ## .. ..$ : chr [1:3] "setosa" "versicolor" "virginica"
  18. ## ..- attr(*, "class")= chr [1:2] "matrix" "votes"
  19. ## $ oob.times : num [1:105] 15 23 22 16 17 11 20 20 17 19 ...
  20. ## $ classes : chr [1:3] "setosa" "versicolor" "virginica"
  21. ## $ importance : num [1:4, 1:5] 0 0 0.3417 0.34918 -0.00518 ...
  22. ## ..- attr(*, "dimnames")=List of 2
  23. ## .. ..$ : chr [1:4] "Sepal.Length" "Sepal.Width" "Petal.Length" "Petal.Width"
  24. ## .. ..$ : chr [1:5] "setosa" "versicolor" "virginica" "MeanDecreaseAccuracy" ...
  25. ## $ importanceSD : num [1:4, 1:4] 0 0 0.04564 0.04711 0.00395 ...
  26. ## ..- attr(*, "dimnames")=List of 2
  27. ## .. ..$ : chr [1:4] "Sepal.Length" "Sepal.Width" "Petal.Length" "Petal.Width"
  28. ## .. ..$ : chr [1:4] "setosa" "versicolor" "virginica" "MeanDecreaseAccuracy"
  29. ## $ localImportance: NULL
  30. ## $ proximity : num [1:105, 1:105] 1 1 1 1 1 1 1 1 1 1 ...
  31. ## ..- attr(*, "dimnames")=List of 2
  32. ## .. ..$ : chr [1:105] "5" "7" "8" "11" ...
  33. ## .. ..$ : chr [1:105] "5" "7" "8" "11" ...
  34. ## $ ntree : num 50
  35. ## $ mtry : num 3
  36. ## $ forest :List of 14
  37. ## ..$ ndbigtree : int [1:50] 11 5 9 9 9 9 9 11 11 9 ...
  38. ## ..$ nodestatus: int [1:17, 1:50] 1 -1 1 1 1 -1 -1 1 -1 -1 ...
  39. ## ..$ bestvar : int [1:17, 1:50] 4 0 4 3 3 0 0 1 0 0 ...
  40. ## ..$ treemap : int [1:17, 1:2, 1:50] 2 0 4 6 8 0 0 10 0 0 ...
  41. ## ..$ nodepred : int [1:17, 1:50] 0 1 0 0 0 2 3 0 3 2 ...
  42. ## ..$ xbestsplit: num [1:17, 1:50] 0.8 0 1.65 5.25 4.85 0 0 6.05 0 0 ...
  43. ## ..$ pid : num [1:3] 1 1 1
  44. ## ..$ cutoff : num [1:3] 0.333 0.333 0.333
  45. ## ..$ ncat : Named int [1:4] 1 1 1 1
  46. ## .. ..- attr(*, "names")= chr [1:4] "Sepal.Length" "Sepal.Width" "Petal.Length" "Petal.Width"
  47. ## ..$ maxcat : int 1
  48. ## ..$ nrnodes : int 17
  49. ## ..$ ntree : num 50
  50. ## ..$ nclass : int 3
  51. ## ..$ xlevels :List of 4
  52. ## .. ..$ Sepal.Length: num 0
  53. ## .. ..$ Sepal.Width : num 0
  54. ## .. ..$ Petal.Length: num 0
  55. ## .. ..$ Petal.Width : num 0
  56. ## $ y : Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 1 1 1 1 1 1 ...
  57. ## ..- attr(*, "names")= chr [1:105] "5" "7" "8" "11" ...
  58. ## $ test : NULL
  59. ## $ inbag : NULL
  60. ## $ terms :Classes 'terms', 'formula' length 3 Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width
  61. ## .. ..- attr(*, "variables")= language list(Species, Sepal.Length, Sepal.Width, Petal.Length, Petal.Width)
  62. ## .. ..- attr(*, "factors")= int [1:5, 1:4] 0 1 0 0 0 0 0 1 0 0 ...
  63. ## .. .. ..- attr(*, "dimnames")=List of 2
  64. ## .. .. .. ..$ : chr [1:5] "Species" "Sepal.Length" "Sepal.Width" "Petal.Length" ...
  65. ## .. .. .. ..$ : chr [1:4] "Sepal.Length" "Sepal.Width" "Petal.Length" "Petal.Width"
  66. ## .. ..- attr(*, "term.labels")= chr [1:4] "Sepal.Length" "Sepal.Width" "Petal.Length" "Petal.Width"
  67. ## .. ..- attr(*, "order")= int [1:4] 1 1 1 1
  68. ## .. ..- attr(*, "intercept")= num 0
  69. ## .. ..- attr(*, "response")= int 1
  70. ## .. ..- attr(*, ".Environment")=<environment: R_GlobalEnv>
  71. ## .. ..- attr(*, "predvars")= language list(Species, Sepal.Length, Sepal.Width, Petal.Length, Petal.Width)
  72. ## .. ..- attr(*, "dataClasses")= Named chr [1:5] "factor" "numeric" "numeric" "numeric" ...
  73. ## .. .. ..- attr(*, "names")= chr [1:5] "Species" "Sepal.Length" "Sepal.Width" "Petal.Length" ...
  74. ## - attr(*, "class")= chr [1:2] "randomForest.formula" "randomForest"
  1. pred <- predict(model_iris, train_iris)
  2. mean(pred==train_iris[, 5])
## [1] 1
  1. #5、预测
  2. pred_iris <- predict(model_iris, test_iris)
  3. table(pred_iris, test_iris[, 5])
  1. ##
  2. ## pred_iris setosa versicolor virginica
  3. ## setosa 15 0 0
  4. ## versicolor 0 13 2
  5. ## virginica 0 2 13
mean(pred_iris==test_iris[, 5])
## [1] 0.9111111
  1. library(gmodels)
  2. CrossTable(pred_iris, test_iris[, 5])
  1. ##
  2. ##
  3. ## Cell Contents
  4. ## |-------------------------|
  5. ## | N |
  6. ## | Chi-square contribution |
  7. ## | N / Row Total |
  8. ## | N / Col Total |
  9. ## | N / Table Total |
  10. ## |-------------------------|
  11. ##
  12. ##
  13. ## Total Observations in Table: 45
  14. ##
  15. ##
  16. ## | test_iris[, 5]
  17. ## pred_iris | setosa | versicolor | virginica | Row Total |
  18. ## -------------|------------|------------|------------|------------|
  19. ## setosa | 15 | 0 | 0 | 15 |
  20. ## | 20.000 | 5.000 | 5.000 | |
  21. ## | 1.000 | 0.000 | 0.000 | 0.333 |
  22. ## | 1.000 | 0.000 | 0.000 | |
  23. ## | 0.333 | 0.000 | 0.000 | |
  24. ## -------------|------------|------------|------------|------------|
  25. ## versicolor | 0 | 13 | 2 | 15 |
  26. ## | 5.000 | 12.800 | 1.800 | |
  27. ## | 0.000 | 0.867 | 0.133 | 0.333 |
  28. ## | 0.000 | 0.867 | 0.133 | |
  29. ## | 0.000 | 0.289 | 0.044 | |
  30. ## -------------|------------|------------|------------|------------|
  31. ## virginica | 0 | 2 | 13 | 15 |
  32. ## | 5.000 | 1.800 | 12.800 | |
  33. ## | 0.000 | 0.133 | 0.867 | 0.333 |
  34. ## | 0.000 | 0.133 | 0.867 | |
  35. ## | 0.000 | 0.044 | 0.289 | |
  36. ## -------------|------------|------------|------------|------------|
  37. ## Column Total | 15 | 15 | 15 | 45 |
  38. ## | 0.333 | 0.333 | 0.333 | |
  39. ## -------------|------------|------------|------------|------------|
  40. ##
  41. ##

实例二、用坦泰尼克号乘客是否存活数据应用到随机森林算法中

在随机森林算法的函数randomForest()中有两个非常重要的参数,而这两个参数又将影响模型的准确性,它们分别是mtry和ntree。一般对mtry的选择是逐一尝试,直到找到比较理想的值,ntree的选择可通过图形大致判断模型内误差稳定时的值。 randomForest包中的randomForest(formula, data, ntree, nPerm, mtry, proximity, importace)函数:随机森林分类与回归。ntree表示生成决策树的数目(不应设置太小,默认为 500);nPerm表示计算importance时的重复次数,数量大于1给出了比较稳定的估计,但不是很有效(目前只实现了回归);mtry表示选择的分裂属性的个数;proximity表示是否生成邻近矩阵,为T表示生成邻近矩阵;importance表示输出分裂属性的重要性。

下面使用坦泰尼克号乘客是否存活数据应用到随机森林算法中,看看模型的准确性如何。

  1. #1、加载数据并查看:同时读取训练样本和测试样本集
  2. train <- read.table("F:\\R\\Rworkspace\\RandomForest/train.csv", header=T, sep=",")
  3. test <- read.table("F:\\R\\Rworkspace\\RandomForest/test.csv", header=T, sep=",")
  4. #注意:训练集和测试集数据来自不同的数据集,一定要注意测试集和训练集的factor的levels相同,否则,在利用训练集训练的模型对测试集进行预测时,会报错!!!
  5. str(train)
  1. ## 'data.frame': 891 obs. of 8 variables:
  2. ## $ Survived: int 0 1 1 1 0 0 0 0 1 1 ...
  3. ## $ Pclass : int 3 1 3 1 3 3 1 3 3 2 ...
  4. ## $ Sex : Factor w/ 2 levels "female","male": 2 1 1 1 2 2 2 2 1 1 ...
  5. ## $ Age : num 22 38 26 35 35 NA 54 2 27 14 ...
  6. ## $ SibSp : int 1 1 0 1 0 0 0 3 0 1 ...
  7. ## $ Parch : int 0 0 0 0 0 0 0 1 2 0 ...
  8. ## $ Fare : num 7.25 71.28 7.92 53.1 8.05 ...
  9. ## $ Embarked: Factor w/ 4 levels "","C","Q","S": 4 2 4 4 4 3 4 4 4 2 ...
str(test)
  1. ## 'data.frame': 418 obs. of 7 variables:
  2. ## $ Pclass : int 3 3 2 3 3 3 3 2 3 3 ...
  3. ## $ Sex : Factor w/ 2 levels "female","male": 2 1 2 2 1 2 1 2 1 2 ...
  4. ## $ Age : num 34.5 47 62 27 22 14 30 26 18 21 ...
  5. ## $ SibSp : int 0 1 0 0 1 0 0 1 0 2 ...
  6. ## $ Parch : int 0 0 0 0 1 0 0 1 0 0 ...
  7. ## $ Fare : num 7.83 7 9.69 8.66 12.29 ...
  8. ## $ Embarked: Factor w/ 3 levels "C","Q","S": 2 3 2 3 3 3 2 3 1 3 ...
  1. #从上可知:训练集数据共891条记录,8个变量,Embarked因子水平为4;测试集数据共418条记录,7个变量,Embarked因子水平为3;训练集中存在缺失数据;Survived因变量为数字类型,测试集数据无因变量
  2. #2、数据清洗
  3. #1)调整测试集与训练基地因子水平
  4. levels(train$Embarked)
## [1] ""  "C" "Q" "S"
levels(test$Embarked)
## [1] "C" "Q" "S"
  1. levels(test$Embarked) <- levels(train$Embarked)
  2. #2)把因变量转化为因子类型
  3. train$Survived <- as.factor(train$Survived)
  4. #3)使用rfImpute()函数补齐训练集的缺失值NA
  5. library(randomForest)
  6. train_impute <- rfImpute(Survived~., data=train)
  1. ## ntree OOB 1 2
  2. ## 300: 16.39% 7.83% 30.12%
  3. ## ntree OOB 1 2
  4. ## 300: 16.50% 8.93% 28.65%
  5. ## ntree OOB 1 2
  6. ## 300: 16.72% 8.74% 29.53%
  7. ## ntree OOB 1 2
  8. ## 300: 16.50% 8.56% 29.24%
  9. ## ntree OOB 1 2
  10. ## 300: 17.28% 9.47% 29.82%
  1. #4)补齐测试集的缺失值:对待测样本进行预测,发现待测样本中存在缺失值,这里使用多重插补法将缺失值补齐
  2. summary(test)
  1. ## Pclass Sex Age SibSp
  2. ## Min. :1.000 female:152 Min. : 0.17 Min. :0.0000
  3. ## 1st Qu.:1.000 male :266 1st Qu.:21.00 1st Qu.:0.0000
  4. ## Median :3.000 Median :27.00 Median :0.0000
  5. ## Mean :2.266 Mean :30.27 Mean :0.4474
  6. ## 3rd Qu.:3.000 3rd Qu.:39.00 3rd Qu.:1.0000
  7. ## Max. :3.000 Max. :76.00 Max. :8.0000
  8. ## NA's :86
  9. ## Parch Fare Embarked
  10. ## Min. :0.0000 Min. : 0.000 :102
  11. ## 1st Qu.:0.0000 1st Qu.: 7.896 C: 46
  12. ## Median :0.0000 Median : 14.454 Q:270
  13. ## Mean :0.3923 Mean : 35.627 S: 0
  14. ## 3rd Qu.:0.0000 3rd Qu.: 31.500
  15. ## Max. :9.0000 Max. :512.329
  16. ## NA's :1
  1. #可是看出测试集数据存在缺失值NA,Age和Fare的数据有NA
  2. #多重插补法填充缺失值:
  3. library(mice)
## Loading required package: Rcpp
## mice 2.25 2015-11-09
imput <- mice(data=test, m=10)
  1. ##
  2. ## iter imp variable
  3. ## 1 1 Age Fare
  4. ## 1 2 Age Fare
  5. ## 1 3 Age Fare
  6. ## 1 4 Age Fare
  7. ## 1 5 Age Fare
  8. ## 1 6 Age Fare
  9. ## 1 7 Age Fare
  10. ## 1 8 Age Fare
  11. ## 1 9 Age Fare
  12. ## 1 10 Age Fare
  13. ## 2 1 Age Fare
  14. ## 2 2 Age Fare
  15. ## 2 3 Age Fare
  16. ## 2 4 Age Fare
  17. ## 2 5 Age Fare
  18. ## 2 6 Age Fare
  19. ## 2 7 Age Fare
  20. ## 2 8 Age Fare
  21. ## 2 9 Age Fare
  22. ## 2 10 Age Fare
  23. ## 3 1 Age Fare
  24. ## 3 2 Age Fare
  25. ## 3 3 Age Fare
  26. ## 3 4 Age Fare
  27. ## 3 5 Age Fare
  28. ## 3 6 Age Fare
  29. ## 3 7 Age Fare
  30. ## 3 8 Age Fare
  31. ## 3 9 Age Fare
  32. ## 3 10 Age Fare
  33. ## 4 1 Age Fare
  34. ## 4 2 Age Fare
  35. ## 4 3 Age Fare
  36. ## 4 4 Age Fare
  37. ## 4 5 Age Fare
  38. ## 4 6 Age Fare
  39. ## 4 7 Age Fare
  40. ## 4 8 Age Fare
  41. ## 4 9 Age Fare
  42. ## 4 10 Age Fare
  43. ## 5 1 Age Fare
  44. ## 5 2 Age Fare
  45. ## 5 3 Age Fare
  46. ## 5 4 Age Fare
  47. ## 5 5 Age Fare
  48. ## 5 6 Age Fare
  49. ## 5 7 Age Fare
  50. ## 5 8 Age Fare
  51. ## 5 9 Age Fare
  52. ## 5 10 Age Fare
  1. Age <- data.frame(Age=apply(imput$imp$Age, 1, mean))
  2. Fare <- data.frame(Fare=apply(imput$imp$Fare, 1, mean))
  3. #添加行标号:
  4. test$Id <- row.names(test)
  5. Age$Id <- row.names(Age)
  6. Fare$Id <- row.names(Fare)
  7. #替换缺失值:
  8. test[test$Id %in% Age$Id, 'Age'] <- Age$Age
  9. test[test$Id %in% Fare$Id, 'Fare'] <- Fare$Fare
  10. summary(test)
  1. ## Pclass Sex Age SibSp
  2. ## Min. :1.000 female:152 Min. : 0.17 Min. :0.0000
  3. ## 1st Qu.:1.000 male :266 1st Qu.:22.00 1st Qu.:0.0000
  4. ## Median :3.000 Median :26.19 Median :0.0000
  5. ## Mean :2.266 Mean :29.41 Mean :0.4474
  6. ## 3rd Qu.:3.000 3rd Qu.:36.65 3rd Qu.:1.0000
  7. ## Max. :3.000 Max. :76.00 Max. :8.0000
  8. ## Parch Fare Embarked Id
  9. ## Min. :0.0000 Min. : 0.000 :102 Length:418
  10. ## 1st Qu.:0.0000 1st Qu.: 7.896 C: 46 Class :character
  11. ## Median :0.0000 Median : 14.454 Q:270 Mode :character
  12. ## Mean :0.3923 Mean : 35.583 S: 0
  13. ## 3rd Qu.:0.0000 3rd Qu.: 31.472
  14. ## Max. :9.0000 Max. :512.329
  1. #从上可知:测试数据集中已经没有了NA值。
  2. #3、选着随机森林的mtry和ntree值
  3. #1)选着mtry
  4. (n <- length(names(train)))
## [1] 8
  1. library(randomForest)
  2. for(i in 1:n) {
  3. model <- randomForest(Survived~., data=train_impute, mtry=i)
  4. err <- mean(model$err.rate)
  5. print(err)
  6. }
  1. ## [1] 0.2100028
  2. ## [1] 0.1889116
  3. ## [1] 0.1776607
  4. ## [1] 0.1902606
  5. ## [1] 0.1960938
  6. ## [1] 0.1953451
  7. ## [1] 0.1951303
  8. ## [1] 0.2018745
  1. #从上可知:mtry=2或者mtry=3时,模型内评价误差最小,故确定参数mtry=2或者mtry=3
  2. #2)选着ntree
  3. set.seed(2002)
  4. model <- randomForest(Survived~., data=train_impute, mtry=2, ntree=1000)
  5. plot(model)

  1. #从上图可知:ntree在400左右时,模型内误差基本稳定,故取ntree=400
  2. #4、建模
  3. model_fit <- randomForest(Survived~., data=train_impute, mtry=2, ntree=400, importance=T)
  4. #5、模型评估
  5. model_fit
  1. ##
  2. ## Call:
  3. ## randomForest(formula = Survived ~ ., data = train_impute, mtry = 2, ntree = 400, importance = T)
  4. ## Type of random forest: classification
  5. ## Number of trees: 400
  6. ## No. of variables tried at each split: 2
  7. ##
  8. ## OOB estimate of error rate: 16.61%
  9. ## Confusion matrix:
  10. ## 0 1 class.error
  11. ## 0 500 49 0.08925319
  12. ## 1 99 243 0.28947368
  1. #查看变量的重要性
  2. (importance <- importance(x=model_fit))
  1. ## 0 1 MeanDecreaseAccuracy MeanDecreaseGini
  2. ## Pclass 16.766454 28.241508 32.16125 33.15984
  3. ## Sex 46.578191 76.145306 72.42624 100.74843
  4. ## Age 19.882605 24.586274 30.52032 60.85186
  5. ## SibSp 19.070707 2.834303 18.95690 16.11720
  6. ## Parch 10.366140 8.380559 13.18282 12.28725
  7. ## Fare 18.649672 20.967558 29.43262 66.31489
  8. ## Embarked 7.904436 11.479919 14.18780 12.68924
  1. #绘制变量的重要性图
  2. varImpPlot(model_fit)

  1. #从上图可知:模型中乘客的性别最为重要,接下来的是Pclass,age,Fare和Fare,age,Pclass。
  2. #6、预测
  3. #1)对训练集数据预测:
  4. train_pred <- predict(model_fit, train_impute)
  5. mean(train_pred==train_impute$Survived)
## [1] 0.9135802
table(train_pred, train_impute$Survived)
  1. ##
  2. ## train_pred 0 1
  3. ## 0 535 63
  4. ## 1 14 279
  1. #模型的预测精度在90%以上
  2. #2)对测试集数据预测:
  3. test_pred <- predict(model_fit, test[, 1:7])
  4. head(test_pred)
  1. ## 1 2 3 4 5 6
  2. ## 0 0 0 0 1 0
  3. ## Levels: 0 1
声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop】
推荐阅读
相关标签
  

闽ICP备14008679号