当前位置:   article > 正文

python版本, lightgbm使用示例_lightgbm 示例

lightgbm 示例

1、安装lightgbm包,

pip install lightgbm -i https://pypi.tuna.tsinghua.edu.cn/simple --default-timeout=100

2、lightgbm原理:

https://www.cnblogs.com/jiangxinyang/p/9337094.html

 

3、lightgbm使用示例:

  1. def train(x_train, y_train, q_train, model_save_path):
  2. '''
  3. 模型的训练和保存
  4. :param x_train:
  5. :param y_train:
  6. :param q_train:
  7. :param model_save_path:
  8. :return:
  9. '''
  10. train_data = lgb.Dataset(x_train, label=y_train, group=q_train)
  11. params = {
  12. 'task': 'train', # 执行的任务类型
  13. 'boosting_type': 'gbrt', # 基学习器
  14. 'objective': 'lambdarank', # 排序任务(目标函数)
  15. 'metric': 'ndcg', # 度量的指标(评估函数)
  16. 'max_position': 10, # @NDCG 位置优化
  17. 'metric_freq': 1, # 每隔多少次输出一次度量结果
  18. 'train_metric': True, # 训练时就输出度量结果
  19. 'ndcg_at': [10],
  20. 'max_bin': 255, # 一个整数,表示最大的桶的数量。默认值为 255。lightgbm 会根据它来自动压缩内存。如max_bin=255 时,则lightgbm 将使用uint8 来表示特征的每一个值。
  21. 'num_iterations': 200, # 迭代次数,即生成的树的棵数
  22. 'learning_rate': 0.01, # 学习率
  23. 'num_leaves': 31, # 叶子数
  24. # 'max_depth':6,
  25. 'tree_learner': 'serial', # 用于并行学习,‘serial’: 单台机器的tree learner
  26. 'min_data_in_leaf': 30, # 一个叶子节点上包含的最少样本数量
  27. 'verbose': 2 # 显示训练时的信息
  28. }
  29. gbm = lgb.train(params, train_data, valid_sets=[train_data])
  30. gbm.save_model(model_save_path)
  31. def predict(x_test, comments, model_input_path):
  32. '''
  33. 预测得分并排序
  34. :param x_test:
  35. :param comments:
  36. :param model_input_path:
  37. :return:
  38. '''
  39. gbm = lgb.Booster(model_file=model_input_path) # 加载model
  40. ypred = gbm.predict(x_test)
  41. predicted_sorted_indexes = np.argsort(ypred)[::-1] # 返回从大到小的索引
  42. t_results = comments[predicted_sorted_indexes] # 返回对应的comments,从大到小的排序
  43. return t_results
  44. def test_data_ndcg(model_path, test_path):
  45. '''
  46. 评估测试数据的ndcg
  47. :param model_path:
  48. :param test_path:
  49. :return:
  50. '''
  51. with open(test_path, 'r', encoding='utf-8') as testfile:
  52. test_X, test_y, test_qids, comments = read_dataset(testfile)
  53. gbm = lgb.Booster(model_file=model_path)
  54. test_predict = gbm.predict(test_X)
  55. average_ndcg, _ = validate(test_qids, test_y, test_predict, 60)
  56. # 所有qid的平均ndcg
  57. print("all qid average ndcg: ", average_ndcg)
  58. print("job done!")

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/258062
推荐阅读
相关标签
  

闽ICP备14008679号