当前位置:   article > 正文

三种特征选择方法及Spark MLlib调用实例(Scala/Java/python)_spark mllib java demo

spark mllib java demo

来源:http://blog.csdn.net/liulingyuan6/article/details/53413728


VectorSlicer

算法介绍:

        VectorSlicer是一个转换器输入特征向量,输出原始特征向量子集。VectorSlicer接收带有特定索引的向量列,通过对这些索引的值进行筛选得到新的向量集。可接受如下两种索引

1.整数索引,setIndices()

2.字符串索引代表向量中特征的名字,此类要求向量列有AttributeGroup,因为该工具根据Attribute来匹配名字字段。

指定整数或者字符串类型都是可以的。另外,同时使用整数索引和字符串名字也是可以的。不允许使用重复的特征,所以所选的索引或者名字必须是没有独一的。注意如果使用名字特征,当遇到空值的时候将会报错。

    输出将会首先按照所选的数字索引排序(按输入顺序),其次按名字排序(按输入顺序)。

示例:
假设我们有一个DataFrame含有userFeatures列:

userFeatures

------------------

 [0.0, 10.0, 0.5]

userFeatures是一个向量列包含3个用户特征。假设userFeatures的第一列全为0,我们希望删除它并且只选择后两项。我们可以通过索引setIndices(1,2)来选择后两项并产生一个新的features列:

userFeatures     | features

------------------|-----------------------------

 [0.0, 10.0, 0.5] | [10.0, 0.5]

假设我们还有如同["f1","f2", "f3"]的属性,那可以通过名字setNames("f2","f3")的形式来选择:

userFeatures     | features

------------------|-----------------------------

 [0.0, 10.0, 0.5] | [10.0, 0.5]

 ["f1", "f2","f3"] | ["f2", "f3"]

调用示例:

Scala:

[plain]  view plain  copy
  1. import java.util.Arrays  
  2.   
  3. import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute}  
  4. import org.apache.spark.ml.feature.VectorSlicer  
  5. import org.apache.spark.ml.linalg.Vectors  
  6. import org.apache.spark.sql.Row  
  7. import org.apache.spark.sql.types.StructType  
  8.   
  9. val data = Arrays.asList(Row(Vectors.dense(-2.0, 2.3, 0.0)))  
  10.   
  11. val defaultAttr = NumericAttribute.defaultAttr  
  12. val attrs = Array("f1", "f2", "f3").map(defaultAttr.withName)  
  13. val attrGroup = new AttributeGroup("userFeatures", attrs.asInstanceOf[Array[Attribute]])  
  14.   
  15. val dataset = spark.createDataFrame(data, StructType(Array(attrGroup.toStructField())))  
  16.   
  17. val slicer = new VectorSlicer().setInputCol("userFeatures").setOutputCol("features")  
  18.   
  19. slicer.setIndices(Array(1)).setNames(Array("f3"))  
  20. // or slicer.setIndices(Array(1, 2)), or slicer.setNames(Array("f2", "f3"))  
  21.   
  22. val output = slicer.transform(dataset)  
  23. println(output.select("userFeatures", "features").first())  

Java:

[java]  view plain  copy
  1. import java.util.List;  
  2.   
  3. import com.google.common.collect.Lists;  
  4.   
  5. import org.apache.spark.ml.attribute.Attribute;  
  6. import org.apache.spark.ml.attribute.AttributeGroup;  
  7. import org.apache.spark.ml.attribute.NumericAttribute;  
  8. import org.apache.spark.ml.feature.VectorSlicer;  
  9. import org.apache.spark.ml.linalg.Vectors;  
  10. import org.apache.spark.sql.Dataset;  
  11. import org.apache.spark.sql.Row;  
  12. import org.apache.spark.sql.RowFactory;  
  13. import org.apache.spark.sql.types.*;  
  14.   
  15. Attribute[] attrs = new Attribute[]{  
  16.   NumericAttribute.defaultAttr().withName("f1"),  
  17.   NumericAttribute.defaultAttr().withName("f2"),  
  18.   NumericAttribute.defaultAttr().withName("f3")  
  19. };  
  20. AttributeGroup group = new AttributeGroup("userFeatures", attrs);  
  21.   
  22. List<Row> data = Lists.newArrayList(  
  23.   RowFactory.create(Vectors.sparse(3new int[]{01}, new double[]{-2.02.3})),  
  24.   RowFactory.create(Vectors.dense(-2.02.30.0))  
  25. );  
  26.   
  27. Dataset<Row> dataset =  
  28.   spark.createDataFrame(data, (new StructType()).add(group.toStructField()));  
  29.   
  30. VectorSlicer vectorSlicer = new VectorSlicer()  
  31.   .setInputCol("userFeatures").setOutputCol("features");  
  32.   
  33. vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"});  
  34. // or slicer.setIndices(new int[]{1, 2}), or slicer.setNames(new String[]{"f2", "f3"})  
  35.   
  36. Dataset<Row> output = vectorSlicer.transform(dataset);  
  37.   
  38. System.out.println(output.select("userFeatures""features").first());  

Python:

[python]  view plain  copy
  1. from pyspark.ml.feature import VectorSlicer  
  2. from pyspark.ml.linalg import Vectors  
  3. from pyspark.sql.types import Row  
  4.   
  5. df = spark.createDataFrame([  
  6.     Row(userFeatures=Vectors.sparse(3, {0: -2.012.3}),),  
  7.     Row(userFeatures=Vectors.dense([-2.02.30.0]),)])  
  8.   
  9. slicer = VectorSlicer(inputCol="userFeatures", outputCol="features", indices=[1])  
  10.   
  11. output = slicer.transform(df)  
  12.   
  13. output.select("userFeatures""features").show()  

RFormula

算法介绍:

       RFormula通过R模型公式来选择列。支持R操作中的部分操作,包括‘~’, ‘.’, ‘:’, ‘+’以及‘-‘,基本操作如下:

1. ~分隔目标和对象

2. +合并对象,“+ 0”意味着删除空格

3. :交互(数值相乘,类别二值化)

4. . 除了目标外的全部列

假设a和b为两列:

1. y ~ a + b表示模型y ~ w0 + w1 * a +w2 * b其中w0为截距,w1w2为相关系数。

2. y ~a + b + a:b – 1表示模型y ~ w1* a + w2 * b + w3 * a * b,其中w1w2w3是相关系数。

RFormula产生一个向量特征列以及一个double或者字符串标签列。如果类别列是字符串类型,它将通过StringIndexer转换为double类型。如果标签列不存在,则输出中将通过规定的响应变量创造一个标签列。

示例:

假设我们有一个DataFrame含有id,countryhourclicked四列:

id | country |hour | clicked

---|---------|------|---------

 7 | "US"    | 18  | 1.0

 8 | "CA"    | 12  | 0.0

 9 | "NZ"    | 15  | 0.0

如果我们使用RFormula公式clicked ~ country+ hour,则表明我们希望基于countryhour预测clicked,通过转换我们可以得到如下DataFrame

id | country |hour | clicked | features         | label

---|---------|------|---------|------------------|-------

 7 | "US"    | 18  | 1.0     | [0.0, 0.0, 18.0] | 1.0

 8 | "CA"    | 12  | 0.0     | [0.0, 1.0, 12.0] | 0.0

 9 | "NZ"    | 15  | 0.0     | [1.0, 0.0, 15.0] | 0.0

调用示例:

Scala:

[plain]  view plain  copy
  1. import org.apache.spark.ml.feature.RFormula  
  2.   
  3. val dataset = spark.createDataFrame(Seq(  
  4.   (7, "US", 18, 1.0),  
  5.   (8, "CA", 12, 0.0),  
  6.   (9, "NZ", 15, 0.0)  
  7. )).toDF("id", "country", "hour", "clicked")  
  8. val formula = new RFormula()  
  9.   .setFormula("clicked ~ country + hour")  
  10.   .setFeaturesCol("features")  
  11.   .setLabelCol("label")  
  12. val output = formula.fit(dataset).transform(dataset)  
  13. output.select("features", "label").show()  

Java:

[java]  view plain  copy
  1. import java.util.Arrays;  
  2. import java.util.List;  
  3.   
  4. import org.apache.spark.ml.feature.RFormula;  
  5. import org.apache.spark.sql.Dataset;  
  6. import org.apache.spark.sql.Row;  
  7. import org.apache.spark.sql.RowFactory;  
  8. import org.apache.spark.sql.types.StructField;  
  9. import org.apache.spark.sql.types.StructType;  
  10.   
  11. import static org.apache.spark.sql.types.DataTypes.*;  
  12.   
  13. StructType schema = createStructType(new StructField[]{  
  14.   createStructField("id", IntegerType, false),  
  15.   createStructField("country", StringType, false),  
  16.   createStructField("hour", IntegerType, false),  
  17.   createStructField("clicked", DoubleType, false)  
  18. });  
  19.   
  20. List<Row> data = Arrays.asList(  
  21.   RowFactory.create(7"US"181.0),  
  22.   RowFactory.create(8"CA"120.0),  
  23.   RowFactory.create(9"NZ"150.0)  
  24. );  
  25.   
  26. Dataset<Row> dataset = spark.createDataFrame(data, schema);  
  27. RFormula formula = new RFormula()  
  28.   .setFormula("clicked ~ country + hour")  
  29.   .setFeaturesCol("features")  
  30.   .setLabelCol("label");  
  31. Dataset<Row> output = formula.fit(dataset).transform(dataset);  
  32. output.select("features""label").show();  

Python:

[python]  view plain  copy
  1. from pyspark.ml.feature import RFormula  
  2.   
  3. dataset = spark.createDataFrame(  
  4.     [(7"US"181.0),  
  5.      (8"CA"120.0),  
  6.      (9"NZ"150.0)],  
  7.     ["id""country""hour""clicked"])  
  8. formula = RFormula(  
  9.     formula="clicked ~ country + hour",  
  10.     featuresCol="features",  
  11.     labelCol="label")  
  12. output = formula.fit(dataset).transform(dataset)  
  13. output.select("features""label").show()  

ChiSqSelector

算法介绍:

       ChiSqSelector代表卡方特征选择。它适用于带有类别特征的标签数据。ChiSqSelector根据类别的独立卡方2检验来对特征排序,然后选取类别标签主要依赖的特征。它类似于选取最有预测能力的特征。

示例:

假设我们有一个DataFrame含有id,featuresclicked三列,其中clicked为需要预测的目标:

id | features              | clicked

---|-----------------------|---------

 7 | [0.0, 0.0, 18.0, 1.0] | 1.0

 8 | [0.0, 1.0, 12.0, 0.0] | 0.0

 9 | [1.0, 0.0, 15.0, 0.1] | 0.0

如果我们使用ChiSqSelector并设置numTopFeatures1,根据标签clickedfeatures中最后一列将会是最有用特征:

id | features              | clicked | selectedFeatures

---|-----------------------|---------|------------------

 7 | [0.0, 0.0, 18.0, 1.0] | 1.0     | [1.0]

 8 | [0.0, 1.0, 12.0, 0.0] | 0.0     | [0.0]

 9 | [1.0, 0.0, 15.0, 0.1] | 0.0     | [0.1]

调用示例:

Scala:

[plain]  view plain  copy
  1. import org.apache.spark.ml.feature.ChiSqSelector  
  2. import org.apache.spark.ml.linalg.Vectors  
  3.   
  4. val data = Seq(  
  5.   (7, Vectors.dense(0.0, 0.0, 18.0, 1.0), 1.0),  
  6.   (8, Vectors.dense(0.0, 1.0, 12.0, 0.0), 0.0),  
  7.   (9, Vectors.dense(1.0, 0.0, 15.0, 0.1), 0.0)  
  8. )  
  9.   
  10. val df = spark.createDataset(data).toDF("id", "features", "clicked")  
  11.   
  12. val selector = new ChiSqSelector()  
  13.   .setNumTopFeatures(1)  
  14.   .setFeaturesCol("features")  
  15.   .setLabelCol("clicked")  
  16.   .setOutputCol("selectedFeatures")  
  17.   
  18. val result = selector.fit(df).transform(df)  
  19. result.show()  
Java:

[java]  view plain  copy
  1. import java.util.Arrays;  
  2. import java.util.List;  
  3.   
  4. import org.apache.spark.ml.feature.ChiSqSelector;  
  5. import org.apache.spark.ml.linalg.VectorUDT;  
  6. import org.apache.spark.ml.linalg.Vectors;  
  7. import org.apache.spark.sql.Row;  
  8. import org.apache.spark.sql.RowFactory;  
  9. import org.apache.spark.sql.types.DataTypes;  
  10. import org.apache.spark.sql.types.Metadata;  
  11. import org.apache.spark.sql.types.StructField;  
  12. import org.apache.spark.sql.types.StructType;  
  13.   
  14. List<Row> data = Arrays.asList(  
  15.   RowFactory.create(7, Vectors.dense(0.00.018.01.0), 1.0),  
  16.   RowFactory.create(8, Vectors.dense(0.01.012.00.0), 0.0),  
  17.   RowFactory.create(9, Vectors.dense(1.00.015.00.1), 0.0)  
  18. );  
  19. StructType schema = new StructType(new StructField[]{  
  20.   new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),  
  21.   new StructField("features"new VectorUDT(), false, Metadata.empty()),  
  22.   new StructField("clicked", DataTypes.DoubleType, false, Metadata.empty())  
  23. });  
  24.   
  25. Dataset<Row> df = spark.createDataFrame(data, schema);  
  26.   
  27. ChiSqSelector selector = new ChiSqSelector()  
  28.   .setNumTopFeatures(1)  
  29.   .setFeaturesCol("features")  
  30.   .setLabelCol("clicked")  
  31.   .setOutputCol("selectedFeatures");  
  32.   
  33. Dataset<Row> result = selector.fit(df).transform(df);  
  34. result.show();  
Python:

[python]  view plain  copy
  1. from pyspark.ml.feature import ChiSqSelector  
  2. from pyspark.ml.linalg import Vectors  
  3.   
  4. df = spark.createDataFrame([  
  5.     (7, Vectors.dense([0.00.018.01.0]), 1.0,),  
  6.     (8, Vectors.dense([0.01.012.00.0]), 0.0,),  
  7.     (9, Vectors.dense([1.00.015.00.1]), 0.0,)], ["id""features""clicked"])  
  8.   
  9. selector = ChiSqSelector(numTopFeatures=1, featuresCol="features",  
  10.                          outputCol="selectedFeatures", labelCol="clicked")  
  11.   
  12. result = selector.fit(df).transform(df)  
  13. result.show()  

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/997183
推荐阅读
相关标签
  

闽ICP备14008679号