当前位置:   article > 正文

PySpark MLlib:逻辑回归模型训练过程(训练、评估、编解码、保存、读取)_spark mllib 训练好的模型 保存

spark mllib 训练好的模型 保存


没有做训练测试集划分,直接全量训练,全量测试

一、引入 Spark 环境

from pyspark.sql import SparkSession
spark = SparkSession.builder.master("local[*]").getOrCreate()
  • 1
  • 2

二、设置模型评估方法

# 评估
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
def check(train_eval):
    f1_score = MulticlassClassificationEvaluator(predictionCol='prediction', labelCol='Type_idx', metricName='f1').evaluate(train_eval)
    acc_score = MulticlassClassificationEvaluator(predictionCol='prediction', labelCol='Type_idx', metricName='accuracy').evaluate(train_eval)
    loss = MulticlassClassificationEvaluator(predictionCol='prediction', labelCol='Type_idx', metricName='logLoss').evaluate(train_eval)
    precision = MulticlassClassificationEvaluator(predictionCol='prediction', labelCol='Type_idx', metricName='weightedPrecision').evaluate(train_eval)
    recall = MulticlassClassificationEvaluator(predictionCol='prediction', labelCol='Type_idx', metricName='weightedRecall').evaluate(train_eval)
    return pd.DataFrame({
        'F1': [f1_score],
        'Recall': [recall],
        'Precision': [precision],
        'Accuracy': [acc_score],
        'Loss': [loss],
    })
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

三、读取/修改 数据

# 用 pandas 读取数据并修改异常列名,挑选训练列
import pandas as pd
df = spark.createDataFrame(pd.read_excel('data.xlsx', sheet_name='training dataset'))
# '.' 在后面会报错,这里直接换掉
df = df.withColumnRenamed('DBE.C', 'DBEC').withColumnRenamed('DBE.O', 'DBEO')
# 选择使用到的列
train_df = df.select(['C', 'H', 'O', 'N', 'S', 'group', 'AImod', 'DBE', 'MZ', 'OC', 'HC', 'SC', 'NC', 'NOSC', 'DBEC', 'DBEO', 'location', 'sample', 'Type'])
train_df
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

在这里插入图片描述

四、编码、合并列

# 编码、合并列
from pyspark.ml.feature import StringIndexer
from pyspark.ml.feature import IndexToString
from pyspark.ml import PipelineModel
from pyspark.ml.feature import VectorAssembler
# 将 string、负数 列换成数字
indexer = StringIndexer(inputCols = ['group', 'NOSC', 'location', 'sample', 'Type'], outputCols = ['group_idx', 'NOSC_idx', 'location_idx', 'sample_idx', 'Type_idx'])
encoder = indexer.fit(train_df)
decoder = IndexToString(inputCol = 'prediction', outputCol = 'result', labels = encoder.labelsArray[4])
# 将这些列合并成一列
assembler = VectorAssembler(inputCols = ['C', 'H', 'O', 'N', 'S', 'group_idx', 'AImod', 'DBE', 'MZ',
                                         'OC', 'HC', 'SC', 'NC', 'NOSC_idx', 'DBEC', 'DBEO', 'location_idx', 'sample_idx']
                            , outputCol = 'features')
train_data = assembler.transform(encoder.transform(train_df)).select('features', 'Type_idx')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

五、模型训练

逻辑回归

  • 啥也没调,指标难看
from pyspark.ml.classification import LogisticRegression
lr = LogisticRegression(featuresCol = 'features', labelCol = 'Type_idx')
model = lr.fit(train_data)
# 指标检测
check(model.transform(assembler.transform(encoder.transform(train_df))))
  • 1
  • 2
  • 3
  • 4
  • 5

在这里插入图片描述

朴素贝叶斯

  • 这个指标更看不了,懒得调了
from pyspark.ml.classification import NaiveBayes
nb = NaiveBayes(featuresCol = 'features', labelCol = 'Type_idx')
model = nb.fit(train_data)
# 指标检测
check(model.transform(assembler.transform(encoder.transform(train_df))))
  • 1
  • 2
  • 3
  • 4
  • 5

在这里插入图片描述

六、模型保存

  • Pipeline 会按照列表的顺序一个一个执行 transform,上一个结果传给下一个
# 流水线保存
pipeline = PipelineModel(stages = [encoder, assembler, model, decoder])
pipeline.write().overwrite().save('./output/model')
  • 1
  • 2
  • 3

七、读取模型测试数据

# 读取模型测试数据
import pandas as pd
df = spark.createDataFrame(pd.read_excel('data.xlsx', sheet_name='validation dataset'))
df = df.withColumnRenamed('DBE.C', 'DBEC').withColumnRenamed('DBE.O', 'DBEO')
test_df = df.select(['C', 'H', 'O', 'N', 'S', 'group', 'AImod', 'DBE', 'MZ', 'OC', 'HC', 'SC', 'NC', 'NOSC', 'DBEC', 'DBEO', 'location', 'sample'])

from pyspark.ml import PipelineModel
model = PipelineModel.load('./output/model')
test_res = model.transform(test_df)
test_res
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/92730
推荐阅读
相关标签
  

闽ICP备14008679号