赞
踩
wget http://www-stat.stanford.edu/~tibs/ElemStatLearn/datasets/spam.data
val inFile = sc.textFile("./spam.data")
//将数据在每个机器上都有备份
//import org.apache.spark.SparkFiles
//val file = sc.addFile("spam.data")
inFile.first()
es0: String = 0 0.64 0.64 0 0.32 0 0 0 0 0 0 0.64 0 0 0 0.32 0 1.29 1.93 0 0.96 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0.778 0 0 3.756 61 278 1
对于逻辑回归我们要读取特定格式的数据问不是每行的数据记录
将数据转换成double型
val nums = inFile.map(x => x.split(' ').map(_.toDouble))
nums.first()
res2: Array[Double] = Array(0.0, 0.64, 0.64, 0.0, 0.32, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.64, 0.0, 0.0, 0.0, 0.32, 0.0, 1.29, 1.93, 0.0, 0.96, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.778, 0.0, 0.0, 3.756, 61.0, 278.0, 1.0)
构造DataPoints类
import spark.util.Vector
case class DataPoint(x: Vector, y: Double)
def parsePoint(x: Array[Double]): DataPoint = {
DataPoint(new Vector(x.slice(0,x.size-2)) , x(x.size-1))
}
val points = nums.map(parsePoint(_))
随机产生权重矩阵
import java.util.Random
val rand = new Random(53)
var w = Vector(nums.first.size-2, _=> rand.nextDouble)
res13: org.apache.spark.util.Vector = (0.7290865701603526, 0.8009687428076777, 0.6136632797111822, 0.9783178194773176, 0.3719683631485643, 0.46409291255379836, 0.5340172959927323, 0.04034252433669905, 0.3074428389716637, 0.8537414030626244, 0.8415816118493813, 0.719935849109521, 0.2431646830671812, 0.17139348575456848, 0.5005137792223062, 0.8915164469396641, 0.7679331873447098, 0.7887571495335223, 0.7263187438977023, 0.40877063468941244, 0.7794519914671199, 0.1651264689613885, 0.1807006937030201, 0.3227972103818231, 0.2777324549716147, 0.20466985600105037, 0.5823059390134582, 0.4489508737465665, 0.44030858771499415, 0.6419366305419459, 0.5191533842209496, 0.43170678028084863, 0.9237523536173182, 0.5175019655845213, 0.47999523211827544, 0.25862648071479444, 0.020548000101787922, 0.185553...
val iterations = 100
import scala.math._
for (i <- 1 to iterations) {
val gradient = points.map(p =>
(1 / (1 + exp(-p.y*(w dot p.x))) - 1) * p.y * p.x
).reduce(_ + _)
w -= gradient
}
w
res17: org.apache.spark.util.Vector = (0.7291462605526978, 0.8011493694345105, 0.6632462451894483, 0.9783179057774432, 0.5894806547559924, 0.46413037169154797, 0.5352673058138914, 0.04151002242309652, 0.3074579788453562, 0.8554814465008911, 0.8421319858358445, 0.723306806645535, 0.24382860800094663, 0.17140711871915207, 0.5006326041454038, 0.9408116975991101, 0.7739239734124745, 0.790122616980566, 0.9701103050755487, 0.4106048776506287, 0.8098841935066842, 0.16512808143309984, 0.18074648984915714, 0.3268703791115973, 0.28167747744431826, 0.20995838053594057, 0.5823059390134736, 0.4489520120935588, 0.44030859962613983, 0.6419368289264852, 0.5191533895589641, 0.43170678028084863, 0.9237602493794835, 0.5175019655845293, 0.4800004611303587, 0.2587440164596575, 0.020567743652946585, 0.185554...
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。