赞
踩
城市拥有房产 | 婚姻历史(离过婚、单身) | 年收入(单位:万元) | 见面(是、否) |
是
|
单身
|
12
|
是
|
否
|
单身
|
15
|
是
|
是
|
离过婚
|
10
|
是
|
否
|
单身
|
18
|
是
|
是
|
离过婚
|
25
|
是
|
是
|
单身
|
50
|
是
|
否
|
离过婚
|
35
|
是
|
是
|
离过婚
|
40
|
是
|
否
|
单身
|
60
|
是
|
否
|
离过婚
|
17
|
否
|
01
02
03
04
05
06
07
08
09
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
|
import
org.apache.spark.mllib.tree.RandomForest[
/size
][
/b
]
import
org.apache.spark.mllib.tree.model.RandomForestModel
import
org.apache.spark.mllib.util.MLUtils
//
加载数据
val data = MLUtils.loadLibSVMFile(sc,
"data/mllib/sample_libsvm_data.txt"
)
//
将数据随机分配为两份,一份用于训练,一份用于测试
val splits = data.randomSplit(Array(0.7, 0.3))
val (trainingData, testData) = (splits(0), splits(1))
//
随机森林训练参数设置
//
分类数
val numClasses = 2
//
categoricalFeaturesInfo 为空,意味着所有的特征为连续型变量
val categoricalFeaturesInfo = Map[Int, Int]()
//
树的个数
val numTrees = 3
//
特征子集采样策略,auto 表示算法自主选取
val featureSubsetStrategy =
"auto"
//
纯度计算
val impurity =
"gini"
//
树的最大层次
val maxDepth = 4
//
特征最大装箱数
val maxBins = 32
//
训练随机森林分类器,trainClassifier 返回的是 RandomForestModel 对象
val model = RandomForest.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)
//
测试数据评价训练好的分类器并计算错误率
val labelAndPreds = testData.map { point =>
val prediction = model.predict(point.features)
(point.label, prediction)
}
val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count()
println(
"Test Error = "
+ testErr)
println(
"Learned classification forest model:\n"
+ model.toDebugString)
//
将训练后的随机森林模型持久化
model.save(sc,
"myModelPath"
)
//
加载随机森林模型到内存
val sameModel = RandomForestModel.load(sc,
"myModelPath"
)
|
01
02
03
04
05
06
07
08
09
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
|
/**
* Method to train a decision tree model
for
binary or multiclass classification.
*
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* Labels should take values {0, 1, ..., numClasses-1}.
* @param numClasses number of classes
for
classification.
* @param categoricalFeaturesInfo Map storing arity of categorical features.
* E.g., an entry (n -> k) indicates that feature n is categorical
* with k categories indexed from 0: {0, 1, ..., k-1}.
* @param numTrees Number of trees
in
the random forest.
* @param featureSubsetStrategy Number of features to consider
for
splits at each node.
* Supported:
"auto"
,
"all"
,
"sqrt"
,
"log2"
,
"onethird"
.
* If
"auto"
is
set
, this parameter is
set
based on numTrees:
*
if
numTrees == 1,
set
to
"all"
;
*
if
numTrees > 1 (forest)
set
to
"sqrt"
.
* @param impurity Criterion used
for
information gain calculation.
* Supported values:
"gini"
(recommended) or
"entropy"
.
* @param maxDepth Maximum depth of the tree.
* E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
* (suggested value: 4)
* @param maxBins maximum number of bins used
for
splitting features
* (suggested value: 100)
* @param seed Random seed
for
bootstrapping and choosing feature subsets.
* @
return
a random forest model that can be used
for
prediction
*/
def trainClassifier(
input: RDD[LabeledPoint],
numClasses: Int,
categoricalFeaturesInfo: Map[Int, Int],
numTrees: Int,
featureSubsetStrategy: String,
impurity: String,
maxDepth: Int,
maxBins: Int,
seed: Int = Utils.random.nextInt()): RandomForestModel = {
val impurityType = Impurities.fromString(impurity)
val strategy = new Strategy(Classification, impurityType, maxDepth,
numClasses, maxBins, Sort, categoricalFeaturesInfo)
//
调用的是重载的另外一个 trainClassifier
trainClassifier(input, strategy, numTrees, featureSubsetStrategy, seed)
}
|
01
02
03
04
05
06
07
08
09
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
|
/**
* Method to train a decision tree model
for
binary or multiclass classification.
*
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* Labels should take values {0, 1, ..., numClasses-1}.
* @param strategy Parameters
for
training each tree
in
the forest.
* @param numTrees Number of trees
in
the random forest.
* @param featureSubsetStrategy Number of features to consider
for
splits at each node.
* Supported:
"auto"
,
"all"
,
"sqrt"
,
"log2"
,
"onethird"
.
* If
"auto"
is
set
, this parameter is
set
based on numTrees:
*
if
numTrees == 1,
set
to
"all"
;
*
if
numTrees > 1 (forest)
set
to
"sqrt"
.
* @param seed Random seed
for
bootstrapping and choosing feature subsets.
* @
return
a random forest model that can be used
for
prediction
*/
def trainClassifier(
input: RDD[LabeledPoint],
strategy: Strategy,
numTrees: Int,
featureSubsetStrategy: String,
seed: Int): RandomForestModel = {
require(strategy.algo == Classification,
s
"RandomForest.trainClassifier given Strategy with invalid algo: ${strategy.algo}"
)
//
在该方法中创建 RandomForest 对象
val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed)
//
再调用其 run 方法,传入的参数是类型 RDD[LabeledPoint],方法返回的是 RandomForestModel 实例
rf.run(input)
}
|
001
002
003
004
005
006
007
008
009
010
011
012
013
014
015
016
017
018
019
020
021
022
023
024
025
026
027
028
029
030
031
032
033
034
035
036
037
038
039
040
041
042
043
044
045
046
047
048
049
050
051
052
053
054
055
056
057
058
059
060
061
062
063
064
065
066
067
068
069
070
071
072
073
074
075
076
077
078
079
080
081
082
083
084
085
086
087
088
089
090
091
092
093
094
095
096
097
098
099
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
|
/**
* Method to train a decision tree model over an RDD
* @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
* @
return
a random forest model that can be used
for
prediction
*/
def run(input: RDD[LabeledPoint]): RandomForestModel = {
val timer = new TimeTracker()
timer.start(
"total"
)
timer.start(
"init"
)
val retaggedInput = input.retag(classOf[LabeledPoint])
//
建立决策树的元数据信息(分裂点位置、箱子数及各箱子包含特征属性的值等等)
val metadata =
DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy)
logDebug(
"algo = "
+ strategy.algo)
logDebug(
"numTrees = "
+ numTrees)
logDebug(
"seed = "
+ seed)
logDebug(
"maxBins = "
+ metadata.maxBins)
logDebug(
"featureSubsetStrategy = "
+ featureSubsetStrategy)
logDebug(
"numFeaturesPerNode = "
+ metadata.numFeaturesPerNode)
logDebug(
"subsamplingRate = "
+ strategy.subsamplingRate)
//
Find the splits and the corresponding bins (interval between the splits) using a sample
//
of the input data.
timer.start(
"findSplitsBins"
)
//
找到切分点(splits)及箱子信息(Bins)
//
对于连续型特征,利用切分点抽样统计简化计算
//
对于名称型特征,如果是无序的,则最多有个 splits=2^(numBins-1)-1 划分
//
如果是有序的,则最多有 splits=numBins-1 个划分
val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, metadata)
timer.stop(
"findSplitsBins"
)
logDebug(
"numBins: feature: number of bins"
)
logDebug(Range(0, metadata.numFeatures).map { featureIndex =>
s
"\t$featureIndex\t${metadata.numBins(featureIndex)}"
}.mkString(
"\n"
))
//
Bin feature values (TreePoint representation).
//
Cache input RDD
for
speedup during multiple passes.
//
转换成树形的 RDD 类型,转换后,所有样本点已经按分裂点条件分到了各自的箱子中
val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata)
val withReplacement =
if
(numTrees > 1)
true
else
false
//
convertToBaggedRDD 方法使得每棵树就是样本的一个子集
val baggedInput
= BaggedPoint.convertToBaggedRDD(treeInput,
strategy.subsamplingRate, numTrees,
withReplacement, seed).persist(StorageLevel.MEMORY_AND_DISK)
//
depth of the decision tree
val maxDepth = strategy.maxDepth
require(maxDepth <= 30,
s
"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth."
)
//
Max memory usage
for
aggregates
//
TODO: Calculate memory usage
more
precisely.
val maxMemoryUsage: Long = strategy.maxMemoryInMB * 1024L * 1024L
logDebug(
"max memory usage for aggregates = "
+ maxMemoryUsage +
" bytes."
)
val maxMemoryPerNode = {
val featureSubset: Option[Array[Int]] =
if
(metadata.subsamplingFeatures) {
//
Find numFeaturesPerNode largest bins to get an upper bound on memory usage.
Some(metadata.numBins.zipWithIndex.sortBy(- _._1)
.take(metadata.numFeaturesPerNode).map(_._2))
}
else
{
None
}
//
计算聚合操作时节点的内存
RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L
}
require(maxMemoryPerNode <= maxMemoryUsage,
s
"RandomForest/DecisionTree given maxMemoryInMB = ${strategy.maxMemoryInMB},"
+
" which is too small for the given features."
+
s
" Minimum value = ${maxMemoryPerNode / (1024L * 1024L)}"
)
timer.stop(
"init"
)
/*
* The main idea here is to perform group-wise training of the decision tree nodes thus
* reducing the passes over the data from (
# nodes) to (# nodes / maxNumberOfNodesPerGroup).
* Each data sample is handled by a particular node (or it reaches a leaf and is not used
*
in
lower levels).
*/
//
Create an RDD of node Id cache.
//
At first, all the rows belong to the root nodes (node Id == 1).
//
节点是否使用缓存,节点 ID 从 1 开始,1 即为这颗树的根节点,左节点为 2,右节点为 3,依次递增下去
val nodeIdCache =
if
(strategy.useNodeIdCache) {
Some(NodeIdCache.init(
data = baggedInput,
numTrees = numTrees,
checkpointInterval = strategy.checkpointInterval,
initVal = 1))
}
else
{
None
}
//
FIFO queue of nodes to train: (treeIndex, node)
val nodeQueue = new mutable.Queue[(Int, Node)]()
val rng = new scala.util.Random()
rng.setSeed(seed)
//
Allocate and queue root nodes.
//
创建树的根节点
val topNodes: Array[Node] = Array.fill[Node](numTrees)(Node.emptyNode(nodeIndex = 1))
//
将(树的索引,数的根节点)入队,树索引从 0 开始,根节点从 1 开始
Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex))))
while
(nodeQueue.nonEmpty) {
//
Collect some nodes to
split
, and choose features
for
each node (
if
subsampling).
//
Each group of nodes may come from one or multiple trees, and at multiple levels.
//
取得每个树所有需要切分的节点
val (nodesForGroup, treeToNodeToIndexInfo) =
RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng)
//
Sanity check (should never occur):
assert(nodesForGroup.size > 0,
s
"RandomForest selected empty nodesForGroup. Error for unknown reason."
)
//
Choose node splits, and enqueue new nodes as needed.
timer.start(
"findBestSplits"
)
//
找出最优切点
DecisionTree.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup,
treeToNodeToIndexInfo, splits, bins, nodeQueue, timer, nodeIdCache = nodeIdCache)
timer.stop(
"findBestSplits"
)
}
baggedInput.unpersist()
timer.stop(
"total"
)
logInfo(
"Internal timing for DecisionTree:"
)
logInfo(s
"$timer"
)
//
Delete any remaining checkpoints used
for
node Id cache.
if
(nodeIdCache.nonEmpty) {
try {
nodeIdCache.get.deleteAllCheckpoints()
} catch {
case
e: IOException =>
logWarning(s
"delete all checkpoints failed. Error reason: ${e.getMessage}"
)
}
}
val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo))
new RandomForestModel(strategy.algo, trees)
}
}
|
001
002
003
004
005
006
007
008
009
010
011
012
013
014
015
016
017
018
019
020
021
022
023
024
025
026
027
028
029
030
031
032
033
034
035
036
037
038
039
040
041
042
043
044
045
046
047
048
049
050
051
052
053
054
055
056
057
058
059
060
061
062
063
064
065
066
067
068
069
070
071
072
073
074
075
076
077
078
079
080
081
082
083
084
085
086
087
088
089
090
091
092
093
094
095
096
097
098
099
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
|
/**
* Returns splits and bins
for
decision tree calculation.
* Continuous and categorical features are handled differently.
*
* Continuous features:
* For each feature, there are numBins - 1 possible splits representing the possible binary
* decisions at each node
in
the tree.
* This finds locations (feature values)
for
splits using a subsample of the data.
*
* Categorical features:
* For each feature, there is 1 bin per
split
.
* Splits and bins are handled
in
2 ways:
* (a)
"unordered features"
* For multiclass classification with a low-arity feature
* (i.e.,
if
isMulticlass && isSpaceSufficientForAllCategoricalSplits),
* the feature is
split
based on subsets of categories.
* (b)
"ordered features"
* For regression and binary classification,
* and
for
multiclass classification with a high-arity feature,
* there is one bin per category.
*
* @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
* @param metadata Learning and dataset metadata
* @
return
A tuple of (splits, bins).
* Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]]
* of size (numFeatures, numSplits).
* Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]]
* of size (numFeatures, numBins).
*/
protected[tree] def findSplitsBins(
input: RDD[LabeledPoint],
metadata: DecisionTreeMetadata): (Array[Array[Split]], Array[Array[Bin]]) = {
logDebug(
"isMulticlass = "
+ metadata.isMulticlass)
val numFeatures = metadata.numFeatures
//
Sample the input only
if
there are continuous features.
//
判断特征中是否存在连续特征
val hasContinuousFeatures = Range(0, numFeatures).exists(metadata.isContinuous)
val sampledInput =
if
(hasContinuousFeatures) {
//
Calculate the number of samples
for
approximate quantile calculation.
//
采样样本数量,最少应该为 10000 个
val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000)
//
计算采样比例
val fraction =
if
(requiredSamples < metadata.numExamples) {
requiredSamples.toDouble / metadata.numExamples
}
else
{
1.0
}
logDebug(
"fraction of data used for calculating quantiles = "
+ fraction)
input.sample(withReplacement =
false
, fraction, new XORShiftRandom().nextInt()).collect()
}
else
{
//
如果为离散特征,则构建一个空数组(即无需采样)
new Array[LabeledPoint](0)
}
//
//
分裂点策略,目前 Spark 中只实现了一种策略:排序 Sort
metadata.quantileStrategy match {
case
Sort =>
//
每个特征分别对应一组切分点位置
val splits = new Array[Array[Split]](numFeatures)
//
存放切分点位置对应的箱子信息
val bins = new Array[Array[Bin]](numFeatures)
//
Find all splits.
//
Iterate over all features.
var featureIndex = 0
//
遍历所有的特征
while
(featureIndex < numFeatures) {
//
特征为连续的情况
if
(metadata.isContinuous(featureIndex)) {
val featureSamples = sampledInput.map(lp => lp.features(featureIndex))
//
findSplitsForContinuousFeature 返回连续特征的所有切分位置
val featureSplits = findSplitsForContinuousFeature(featureSamples,
metadata, featureIndex)
val numSplits = featureSplits.length
//
连续特征的箱子数为切分点个数+1
val numBins = numSplits + 1
logDebug(s
"featureIndex = $featureIndex, numSplits = $numSplits"
)
//
切分点数组及特征箱子数组
splits(featureIndex) = new Array[Split](numSplits)
bins(featureIndex) = new Array[Bin](numBins)
var splitIndex = 0
//
遍历切分点
while
(splitIndex < numSplits) {
//
获取切分点对应的值,由于是排过序的,因此它具有阈值属性
val threshold = featureSplits(splitIndex)
//
保存对应特征所有的切分点位置信息
splits(featureIndex)(splitIndex) =
new Split(featureIndex, threshold, Continuous, List())
splitIndex += 1
}
//
采用最小阈值 Double.MinValue 作为最左边的分裂位置并进行装箱
bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous),
splits(featureIndex)(0), Continuous, Double.MinValue)
splitIndex = 1
//
除最后一个箱子外剩余箱子的计算,各箱子里将存放的是两个切分点位置阈值区间的属性值
while
(splitIndex < numSplits) {
bins(featureIndex)(splitIndex) =
new Bin(splits(featureIndex)(splitIndex - 1), splits(featureIndex)(splitIndex),
Continuous, Double.MinValue)
splitIndex += 1
}
//
最后一个箱子的计算采用最大阈值 Double.MaxValue 作为最右边的切分位置
bins(featureIndex)(numSplits) = new Bin(splits(featureIndex)(numSplits - 1),
new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue)
}
else
{
//
特征为离散情况时的计算
val numSplits = metadata.numSplits(featureIndex)
val numBins = metadata.numBins(featureIndex)
//
Categorical feature
//
离线属性的个数
val featureArity = metadata.featureArity(featureIndex)
//
特征无序时的处理方式
if
(metadata.isUnordered(featureIndex)) {
//
Unordered features
//
2^(maxFeatureValue - 1) - 1 combinations
splits(featureIndex) = new Array[Split](numSplits)
var splitIndex = 0
while
(splitIndex < numSplits) {
//
提取特征的属性值,返回集合包含其中一个或多个的离散属性值
val categories: List[Double] =
extractMultiClassCategories(splitIndex + 1, featureArity)
splits(featureIndex)(splitIndex) =
new Split(featureIndex, Double.MinValue, Categorical, categories)
splitIndex += 1
}
}
else
{
//
有序特征无需处理,箱子与特征值对应
//
Ordered features
//
Bins correspond to feature values, so we
do
not need to compute splits or bins
//
beforehand. Splits are constructed as needed during training.
splits(featureIndex) = new Array[Split](0)
}
//
For ordered features, bins correspond to feature values.
//
For unordered categorical features, there is no need to construct the bins.
//
since there is a one-to-one correspondence between the splits and the bins.
bins(featureIndex) = new Array[Bin](0)
}
featureIndex += 1
}
(splits, bins)
case
MinMax =>
throw new UnsupportedOperationException(
"minmax not supported yet."
)
case
ApproxHist =>
throw new UnsupportedOperationException(
"approximate histogram not supported yet."
)
}
}
|
001
002
003
004
005
006
007
008
009
010
011
012
013
014
015
016
017
018
019
020
021
022
023
024
025
026
027
028
029
030
031
032
033
034
035
036
037
038
039
040
041
042
043
044
045
046
047
048
049
050
051
052
053
054
055
056
057
058
059
060
061
062
063
064
065
066
067
068
069
070
071
072
073
074
075
076
077
078
079
080
081
082
083
084
085
086
087
088
089
090
091
092
093
094
095
096
097
098
099
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
|
/**
* Find the best
split
for
a node.
* @param binAggregates Bin statistics.
* @
return
tuple
for
best
split
: (Split, information gain, prediction at node)
*/
private def binsToBestSplit(
binAggregates: DTStatsAggregator,
//
DTStatsAggregator,其中引用了 ImpurityAggregator,给出计算不纯度 impurity 的逻辑
splits: Array[Array[Split]],
featuresForNode: Option[Array[Int]],
node: Node): (Split, InformationGainStats, Predict) = {
//
calculate predict and impurity
if
current node is
top
node
val level = Node.indexToLevel(node.
id
)
var predictWithImpurity: Option[(Predict, Double)] =
if
(level == 0) {
None
}
else
{
Some((node.predict, node.impurity))
}
//
For each (feature,
split
), calculate the gain, and
select
the best (feature,
split
).
//
对各特征及切分点,计算其信息增益并从中选择最优 (feature,
split
)
val (bestSplit, bestSplitStats) =
Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx =>
val featureIndex =
if
(featuresForNode.nonEmpty) {
featuresForNode.get.apply(featureIndexIdx)
}
else
{
featureIndexIdx
}
val numSplits = binAggregates.metadata.numSplits(featureIndex)
//
特征为连续值的情况
if
(binAggregates.metadata.isContinuous(featureIndex)) {
//
Cumulative
sum
(scanLeft) of bin statistics.
//
Afterwards, binAggregates
for
a bin is the
sum
of aggregates
for
//
that bin + all preceding bins.
val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
var splitIndex = 0
while
(splitIndex < numSplits) {
binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
splitIndex += 1
}
//
Find best
split
.
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map {
case
splitIdx =>
//
计算 leftChild 及 rightChild 子节点的 impurity
val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
rightChildStats.subtract(leftChildStats)
//
求 impurity 的预测值,采用的是平均值计算
predictWithImpurity = Some(predictWithImpurity.getOrElse(
calculatePredictImpurity(leftChildStats, rightChildStats)))
//
求信息增益 information gain 值,用于评估切分点是否最优
val gainStats = calculateGainForSplit(leftChildStats,
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
(splitIdx, gainStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
}
else
if
(binAggregates.metadata.isUnordered(featureIndex)) {
//
无序离散特征时的情况
//
Unordered categorical feature
val (leftChildOffset, rightChildOffset) =
binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { splitIndex =>
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
predictWithImpurity = Some(predictWithImpurity.getOrElse(
calculatePredictImpurity(leftChildStats, rightChildStats)))
val gainStats = calculateGainForSplit(leftChildStats,
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
(splitIndex, gainStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
}
else
{
//
有序离散特征时的情况
//
Ordered categorical feature
val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
val numBins = binAggregates.metadata.numBins(featureIndex)
/* Each bin is one category (feature value).
* The bins are ordered based on centroidForCategories, and this ordering determines
which
* splits are considered. (With K categories, we consider K - 1 possible splits.)
*
* centroidForCategories is a list: (category, centroid)
*/
//
多元分类时的情况
val centroidForCategories =
if
(binAggregates.metadata.isMulticlass) {
//
For categorical variables
in
multiclass classification,
//
the bins are ordered by the impurity of their corresponding labels.
Range(0, numBins).map {
case
featureValue =>
val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
val centroid =
if
(categoryStats.count != 0) {
//
impurity 求的就是均方差
categoryStats.calculate()
}
else
{
Double.MaxValue
}
(featureValue, centroid)
}
}
else
{
//
回归或二元分类时的情况 regression or binary classification
//
For categorical variables
in
regression and binary classification,
//
the bins are ordered by the centroid of their corresponding labels.
Range(0, numBins).map {
case
featureValue =>
val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
val centroid =
if
(categoryStats.count != 0) {
//
求的就是平均值作为 impurity
categoryStats.predict
}
else
{
Double.MaxValue
}
(featureValue, centroid)
}
}
logDebug(
"Centroids for categorical variable: "
+ centroidForCategories.mkString(
","
))
//
bins sorted by centroids
val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2)
logDebug(
"Sorted centroids for categorical variable = "
+
categoriesSortedByCentroid.mkString(
","
))
//
Cumulative
sum
(scanLeft) of bin statistics.
//
Afterwards, binAggregates
for
a bin is the
sum
of aggregates
for
//
that bin + all preceding bins.
var splitIndex = 0
while
(splitIndex < numSplits) {
val currentCategory = categoriesSortedByCentroid(splitIndex)._1
val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
//
将两个箱子的状态信息进行合并
binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory)
splitIndex += 1
}
//
lastCategory = index of bin with total aggregates
for
this (node, feature)
val lastCategory = categoriesSortedByCentroid.last._1
//
Find best
split
.
//
通过信息增益值选择最优切分点
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { splitIndex =>
val featureValue = categoriesSortedByCentroid(splitIndex)._1
val leftChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
val rightChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
rightChildStats.subtract(leftChildStats)
predictWithImpurity = Some(predictWithImpurity.getOrElse(
calculatePredictImpurity(leftChildStats, rightChildStats)))
val gainStats = calculateGainForSplit(leftChildStats,
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
(splitIndex, gainStats)
}.maxBy(_._2.gain)
val categoriesForSplit =
categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
val bestFeatureSplit =
new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit)
(bestFeatureSplit, bestFeatureGainStats)
}
}.maxBy(_._2.gain)
(bestSplit, bestSplitStats, predictWithImpurity.get._1)
}
|
01
02
03
04
05
06
07
08
09
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
|
/**
* :: Experimental ::
* Represents a random forest model.
*
* @param algo algorithm
for
the ensemble model, either Classification or Regression
* @param trees tree ensembles
*/
//
RandomForestModel 扩展自 TreeEnsembleModel
@Experimental
class RandomForestModel(override val algo: Algo, override val trees: Array[DecisionTreeModel])
extends TreeEnsembleModel(algo, trees, Array.fill(trees.length)(1.0),
combiningStrategy =
if
(algo == Classification) Vote
else
Average)
with Saveable {
require(trees.forall(_.algo == algo))
//
将训练好的模型持久化
override def save(sc: SparkContext, path: String): Unit = {
TreeEnsembleModel.SaveLoadV1_0.save(sc, path, this,
RandomForestModel.SaveLoadV1_0.thisClassName)
}
override protected def formatVersion: String = RandomForestModel.formatVersion
}
object RandomForestModel extends Loader[RandomForestModel] {
private[mllib] def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion
//
将训练好的模型加载到内存
override def load(sc: SparkContext, path: String): RandomForestModel = {
val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path)
val classNameV1_0 = SaveLoadV1_0.thisClassName
(loadedClassName, version) match {
case
(className,
"1.0"
)
if
className == classNameV1_0 =>
val metadata = TreeEnsembleModel.SaveLoadV1_0.readMetadata(jsonMetadata)
assert(metadata.treeWeights.forall(_ == 1.0))
val trees =
TreeEnsembleModel.SaveLoadV1_0.loadTrees(sc, path, metadata.treeAlgo)
new RandomForestModel(Algo.fromString(metadata.algo), trees)
case
_ => throw new Exception(s
"RandomForestModel.load did not recognize model"
+
s
" with (className, format version): ($loadedClassName, $version). Supported:\n"
+
s
" ($classNameV1_0, 1.0)"
)
}
}
private object SaveLoadV1_0 {
//
Hard-code class name string
in
case
it changes
in
the future
def thisClassName: String =
"org.apache.spark.mllib.tree.model.RandomForestModel"
}
}
|
01
02
03
04
05
06
07
08
09
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
|
/**
* Represents a tree ensemble model.
*
* @param algo algorithm
for
the ensemble model, either Classification or Regression
* @param trees tree ensembles
* @param treeWeights tree ensemble weights
* @param combiningStrategy strategy
for
combining the predictions, not used
for
regression.
*/
private[tree] sealed class TreeEnsembleModel(
protected val algo: Algo,
protected val trees: Array[DecisionTreeModel],
protected val treeWeights: Array[Double],
protected val combiningStrategy: EnsembleCombiningStrategy) extends Serializable {
require(numTrees > 0,
"TreeEnsembleModel cannot be created without trees."
)
//
其它代码省略
//
通过投票实现最终的分类
/**
* Classifies a single data point based on (weighted) majority votes.
*/
private def predictByVoting(features: Vector): Double = {
val votes = mutable.Map.empty[Int, Double]
trees.view.zip(treeWeights).foreach {
case
(tree, weight) =>
val prediction = tree.predict(features).toInt
votes(prediction) = votes.getOrElse(prediction, 0.0) + weight
}
votes.maxBy(_._2)._1
}
/**
* Predict values
for
a single data point using the model trained.
*
* @param features array representing a single data point
* @
return
predicted category from the trained model
*/
//
不同的策略采用不同的预测方法
def findSplitsBins(features: Vector): Double = {
(algo, combiningStrategy) match {
case
(Regression, Sum) =>
predictBySumming(features)
case
(Regression, Average) =>
predictBySumming(features) / sumWeights
case
(Classification, Sum) =>
//
binary classification
val prediction = predictBySumming(features)
//
TODO: predicted labels are +1 or -1
for
GBT. Need a better way to store this info.
if
(prediction > 0.0) 1.0
else
0.0
//
随机森林对应 predictByVoting 方法
case
(Classification, Vote) =>
predictByVoting(features)
case
_ =>
throw new IllegalArgumentException(
"TreeEnsembleModel given unsupported (algo, combiningStrategy) combination: "
+
s
"($algo, $combiningStrategy)."
)
}
}
//
predict 方法的具体实现
/**
* Predict values
for
the given data
set
.
*
* @param features RDD representing data points to be predicted
* @
return
RDD[Double] where each entry contains the corresponding prediction
*/
def predict(features: RDD[Vector]): RDD[Double] = features.map(x => findSplitsBins (x))
//
其它代码省略
}
|
记录号 | 是否拥有房产(是/否) | 婚姻情况(单身、已婚、离婚) | 年收入(单位:万元) | 是否具备还款能力(是、否) |
10001
|
否
|
已婚
|
10
|
是
|
10002
|
否
|
单身
|
8
|
是
|
10003
|
是
|
单身
|
13
|
是
|
……
|
….
|
…..
|
….
|
……
|
11000
|
是
|
单身
|
8
|
否
|
是否拥有房产(是/否) | 婚姻情况(单身、已婚、离婚) | 年收入(单位:万元) |
否
|
已婚
|
12
|
01
02
03
04
05
06
07
08
09
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
|
package cn.ml
import
org.apache.spark.SparkConf
import
org.apache.spark.SparkContext
import
org.apache.spark.mllib.util.MLUtils
import
org.apache.spark.mllib.regression.LabeledPoint
import
org.apache.spark.rdd.RDD
import
org.apache.spark.mllib.tree.RandomForest
import
org.apache.spark.mllib.tree.model.RandomForestModel
import
org.apache.spark.mllib.linalg.Vectors
object RandomForstExample {
def main(args: Array[String]) {
val sparkConf = new SparkConf().setAppName(
"RandomForestExample"
).
val sc = new SparkContext(sparkConf)
val data: RDD[LabeledPoint] = MLUtils.loadLibSVMFile(sc,
"/data/sample_data.txt"
)
val numClasses = 2
val featureSubsetStrategy =
"auto"
val numTrees = 3
val model: RandomForestModel =RandomForest.trainClassifier(
data, Strategy.defaultStrategy(
"classification"
),numTrees,
featureSubsetStrategy,new java.util.Random().nextInt())
val input: RDD[LabeledPoint] = MLUtils.loadLibSVMFile(sc,
"/data/input.txt"
)
val predictResult = input.map { point =>
val prediction = model.predict(point.features)
(point.label, prediction)
}
//
打印输出结果,在 spark-shell 上执行时使用
predictResult.collect()
//
将结果保存到 hdfs
//predictResult
.saveAsTextFile(
"/data/predictResult"
)
sc.stop()
}
}
|
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。