赞
踩
相关文章
spark mllib源码分析之逻辑回归弹性网络ElasticNet(一)
spark源码分析之L-BFGS
spark mllib源码分析之OWLQN
spark中的online均值/方差统计
spark源码分析之二分类逻辑回归evaluation
spark正则化
设置用于控制训练的参数
将封装成DataFrame的输入数据再转成简单结构的instance,包括label,weight,特征,默认每个样本的weight为1
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances: RDD[Instance] =
dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}
统计样本每个特征的方差,均值,label的分布情况,用到了MultivariateOnlineSummarizer和MultiClassSummarizer,前面有介绍
val (summarizer, labelSummarizer) = {
val seqOp = (c: (MultivariateOnlineSummarizer, MultiClassSummarizer),
instance: Instance) =>
(c._1.add(instance.features, instance.weight), c._2.add(instance.label, instance.weight))
val combOp = (c1: (MultivariateOnlineSummarizer, MultiClassSummarizer),
c2: (MultivariateOnlineSummarizer, MultiClassSummarizer)) =>
(c1._1.merge(c2._1), c1._2.merge(c2._2))
instances.treeAggregate(
new MultivariateOnlineSummarizer, new MultiClassSummarizer
)(seqOp, combOp, $(aggregationDepth))
}
//各维特征的weightSum
val histogram = labelSummarizer.histogram
//label非法,主要是label非整数和小于0的情况
val numInvalid = labelSummarizer.countInvalid
val numFeatures = summarizer.mean.size
//如果有截距,相当于增加一维值全为1的特征
val numFeaturesPlusIntercept = if (getFitIntercept) numFeatures + 1 else numFeatures
val numClasses = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
case Some(n: Int) =>
require(n >= histogram.length, s"Specified number of classes $n was " +
s"less than the number of unique labels ${histogram.length}.")
n
//最好是labelSummarizer.numClasses
case None => histogram.length
}
val isMultinomial = $(family) match {
case "binomial" =>
require(numClasses == 1 || numClasses == 2, s"Binomial family only supports 1 or 2 " +
s"outcome classes but found $numClasses.")
false
case "multinomial" => true
case "auto" => numClasses > 2
case other => throw new IllegalArgumentException(s"Unsupported family: $other")
}
val numCoefficientSets = if (isMultinomial) numClasses else 1
if (isDefined(thresholds)) {
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
".train() called with non-matching numClasses and thresholds.length." +
s" numClasses=$numClasses, but thresholds has length ${
$(thresholds).length}")
}
根据二/多分类,是否拟合截距,L1/L2等确定训练使用的优化方法,损失函数
判断条件为
$(fitIntercept) && isConstantLabel
label是否唯一的判断
val isConstantLabel = histogram.count(_ != 0.0) == 1
histogram是Array,里面放着样本中各label的数量,也就是说样本里只有一种label。
这种情况返回的系数矩阵为全0的SparseMatrix,对于截距,如果是多分类,返回稀疏向量,向量长度为numClasses,只有index为label的位置有值Double.PositiveInfinity;如果是二分类,返回dense vector,值为Double.PositiveInfinity
判断条件为
!$(fitIntercept) && isConstantLabel
此种情况下,算法可能不会收敛,给出了警告信息,但是会继续尝试优化。
判断条件
!$(fitIntercept) && (0 unti
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。