当前位置:   article > 正文

机器学习之逻辑回归模型

逻辑回归模型

1 逻辑回归模型介绍

        逻辑回归(Logistic Regression, LR)又称为逻辑回归分析,是一种机器学习算法,属于分类和预测算法中的一种,主要用于解决二分类问题。逻辑回归通过历史数据的表现对未来结果发生的概率进行预测。例如,我们可以将购买的概率设置为因变量,将用户的特征属性,例如性别,年龄,注册时间等设置为自变量。根据特征属性预测购买的概率。

        逻辑回归它通过建立一个逻辑回归模型来预测输入样本属于某个类别的概率。逻辑回归模型的核心思想是使用一个称为sigmoid函数(或者称为逻辑函数)的函数来建模概率。sigmoid函数的公式为:

3f7cad7875574855a07fd6d7384381ad.png

2 逻辑回归的应用场景

        逻辑回归是一种简单而高效的机器学习算法,它具有多个优势。首先,逻辑回归模型易于理解和实现,计算效率高,特别适用于大规模数据集。其次,逻辑回归提供了对结果的解释和推断能力,模型的系数可以揭示哪些特征对分类结果的影响较大或较小。此外,逻辑回归适用于高维数据,能够处理具有大量特征的问题,并捕捉到不同特征之间的关系。另外,逻辑回归能够输出概率预测,而不仅仅是分类结果,对于需要概率估计或不确定性分析的任务非常有用。最后,逻辑回归对于数据中的噪声和缺失值具有一定的鲁棒性,能够适应现实世界中的不完美数据。综上所述,逻辑回归是一种强大而实用的分类算法,在许多实际应用中被广泛采用。以下是逻辑回归常见的应用场景。

  • 金融领域:逻辑回归可用于信用评分、欺诈检测、客户流失预测等金融风险管理任务。
  • 医学领域:逻辑回归可以用于疾病诊断、患者预后评估、药物反应预测等医学决策支持任务。

  • 市场营销:逻辑回归可用于客户分类、用户行为分析、广告点击率预测等市场营销领域的任务。

  • 自然语言处理:逻辑回归可用于文本分类、情感分析、垃圾邮件过滤等自然语言处理任务。

  • 图像识别:逻辑回归可以应用于图像分类、目标检测中的二分类问题。

        逻辑回归的简单性和可解释性使其在许多实际应用中得到广泛应用。然而,对于复杂的非线性问题,逻辑回归可能不适用,此时可以考虑使用其他更复杂的模型或者结合特征工程技术来改进性能。

 3 基于pytorch实现银行欺诈人员的二分类判别

(1)数据集

在一个银行欺诈数据集上,通过15个特征,得到二分类的判别结果:是否为欺诈失信人员。建的模型依旧是线性模型。输出的值通过sigmoid进行转换,变成0~1的概率。一般认为大于0.5就是1,小于0.5就是0。

0

56.75

12.25

0

0

6

0

1.25

0

0

4

0

0

200

0

-1

0

31.67

16.165

0

0

1

0

3

0

0

9

1

0

250

730

-1

1

23.42

0.79

1

1

8

0

1.5

0

0

2

0

0

80

400

-1

1

20.42

0.835

0

0

8

0

1.585

0

0

1

1

0

0

0

-1

0

26.67

4.25

0

0

2

0

4.29

0

0

1

0

0

120

0

-1

0

34.17

1.54

0

0

2

0

1.54

0

0

1

0

0

520

50000

-1

1

36

1

0

0

0

0

2

0

0

11

1

0

0

456

-1

0

25.5

0.375

0

0

6

0

0.25

0

0

3

1

0

260

15108

-1

0

19.42

6.5

0

0

9

1

1.46

0

0

7

1

0

80

2954

-1

0

35.17

25.125

0

0

10

1

1.625

0

0

1

0

0

515

500

-1

0

32.33

7.5

0

0

11

2

1.585

0

1

0

0

2

420

0

1

1

38.58

5

0

0

2

0

13.5

0

1

0

0

0

980

0

1

最后一列为-1是失信欺诈人员,为1不是失信欺诈人员

(2)pytorch完整代码

  1. import torch
  2. import pandas as pd
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. from torch import nn
  6. from torch.utils.data import TensorDataset, DataLoader
  7. from sklearn.model_selection import train_test_split
  8. def accuracy(y_pred,y_true):
  9. y_pred = (y_pred>0.5).type(torch.int32)
  10. acc = (y_pred == y_true).float().mean()
  11. return acc
  12. loss_fn = nn.BCELoss()
  13. epochs = 1000
  14. batch = 16
  15. lr = 0.0001
  16. data = pd.read_csv("credit.csv",header=None)
  17. X = data.iloc[:,:-1]
  18. Y = data.iloc[:,-1].replace(-1,0)
  19. X = torch.from_numpy(X.values).type(torch.float32)
  20. Y = torch.from_numpy(Y.values).type(torch.float32)
  21. train_x,test_x,train_y,test_y = train_test_split(X,Y)
  22. train_ds = TensorDataset(train_x,train_y)
  23. train_dl = DataLoader(train_ds,batch_size=batch,shuffle=True)
  24. test_ds = TensorDataset(test_x,test_y)
  25. test_dl = DataLoader(test_ds,batch_size=batch)
  26. model = nn.Sequential(
  27. nn.Linear(15,1),
  28. nn.Sigmoid()
  29. )
  30. optim = torch.optim.Adam(model.parameters(),lr=lr)
  31. accuracy_rate = []
  32. for epoch in range(epochs):
  33. for x,y in train_dl:
  34. y_pred = model(x)
  35. y_pred = y_pred.squeeze()
  36. loss = loss_fn(y_pred,y)
  37. optim.zero_grad()
  38. loss.backward()
  39. optim.step()
  40. with torch.no_grad():
  41. # 训练集的准确率和loss
  42. y_pred = model(train_x)
  43. y_pred = y_pred.squeeze()
  44. epoch_accuracy = accuracy(y_pred,train_y)
  45. epoch_loss = loss_fn(y_pred,train_y).data
  46. accuracy_rate.append(epoch_accuracy*100)
  47. # 测试集的准确率和loss
  48. y_pred = model(test_x)
  49. y_pred = y_pred.squeeze()
  50. epoch_test_accuracy = accuracy(y_pred,test_y)
  51. epoch_test_loss = loss_fn(y_pred,test_y).data
  52. print('epoch:',epoch,
  53. 'train_loss:',round(epoch_loss.item(),3),
  54. "train_accuracy",round(epoch_accuracy.item(),3),
  55. 'test_loss:',round(epoch_test_loss.item(),3),
  56. "test_accuracy",round(epoch_test_accuracy.item(),3)
  57. )
  58. accuracy_rate = np.array(accuracy_rate)
  59. times = np.linspace(1, epochs, epochs)
  60. plt.xlabel('epochs')
  61. plt.ylabel('accuracy rate')
  62. plt.plot(times, accuracy_rate)
  63. plt.show()

(3)输出结果

  1. epoch: 951 train_loss: 0.334 train_accuracy 0.869 test_loss: 0.346 test_accuracy 0.866
  2. epoch: 952 train_loss: 0.334 train_accuracy 0.863 test_loss: 0.348 test_accuracy 0.866
  3. epoch: 953 train_loss: 0.337 train_accuracy 0.867 test_loss: 0.358 test_accuracy 0.86
  4. epoch: 954 train_loss: 0.334 train_accuracy 0.867 test_loss: 0.35 test_accuracy 0.866
  5. epoch: 955 train_loss: 0.334 train_accuracy 0.871 test_loss: 0.346 test_accuracy 0.866
  6. epoch: 956 train_loss: 0.333 train_accuracy 0.865 test_loss: 0.348 test_accuracy 0.872
  7. epoch: 957 train_loss: 0.333 train_accuracy 0.871 test_loss: 0.349 test_accuracy 0.866
  8. epoch: 958 train_loss: 0.333 train_accuracy 0.867 test_loss: 0.347 test_accuracy 0.866
  9. epoch: 959 train_loss: 0.334 train_accuracy 0.863 test_loss: 0.352 test_accuracy 0.866
  10. epoch: 960 train_loss: 0.333 train_accuracy 0.867 test_loss: 0.35 test_accuracy 0.878
  11. epoch: 961 train_loss: 0.334 train_accuracy 0.873 test_loss: 0.346 test_accuracy 0.866
  12. epoch: 962 train_loss: 0.334 train_accuracy 0.865 test_loss: 0.353 test_accuracy 0.866
  13. epoch: 963 train_loss: 0.333 train_accuracy 0.873 test_loss: 0.35 test_accuracy 0.866
  14. epoch: 964 train_loss: 0.334 train_accuracy 0.863 test_loss: 0.345 test_accuracy 0.872
  15. epoch: 965 train_loss: 0.333 train_accuracy 0.861 test_loss: 0.351 test_accuracy 0.866
  16. epoch: 966 train_loss: 0.333 train_accuracy 0.873 test_loss: 0.348 test_accuracy 0.866
  17. epoch: 967 train_loss: 0.333 train_accuracy 0.863 test_loss: 0.348 test_accuracy 0.866
  18. epoch: 968 train_loss: 0.333 train_accuracy 0.867 test_loss: 0.351 test_accuracy 0.866
  19. epoch: 969 train_loss: 0.334 train_accuracy 0.869 test_loss: 0.345 test_accuracy 0.878
  20. epoch: 970 train_loss: 0.333 train_accuracy 0.869 test_loss: 0.348 test_accuracy 0.872
  21. epoch: 971 train_loss: 0.335 train_accuracy 0.865 test_loss: 0.344 test_accuracy 0.86
  22. epoch: 972 train_loss: 0.333 train_accuracy 0.867 test_loss: 0.35 test_accuracy 0.86
  23. epoch: 973 train_loss: 0.334 train_accuracy 0.871 test_loss: 0.345 test_accuracy 0.872
  24. epoch: 974 train_loss: 0.333 train_accuracy 0.865 test_loss: 0.351 test_accuracy 0.866
  25. epoch: 975 train_loss: 0.333 train_accuracy 0.873 test_loss: 0.351 test_accuracy 0.86
  26. epoch: 976 train_loss: 0.333 train_accuracy 0.869 test_loss: 0.346 test_accuracy 0.878
  27. epoch: 977 train_loss: 0.333 train_accuracy 0.863 test_loss: 0.351 test_accuracy 0.866
  28. epoch: 978 train_loss: 0.332 train_accuracy 0.865 test_loss: 0.351 test_accuracy 0.866
  29. epoch: 979 train_loss: 0.332 train_accuracy 0.871 test_loss: 0.349 test_accuracy 0.866
  30. epoch: 980 train_loss: 0.333 train_accuracy 0.865 test_loss: 0.345 test_accuracy 0.872
  31. epoch: 981 train_loss: 0.332 train_accuracy 0.867 test_loss: 0.348 test_accuracy 0.872
  32. epoch: 982 train_loss: 0.332 train_accuracy 0.863 test_loss: 0.349 test_accuracy 0.872
  33. epoch: 983 train_loss: 0.333 train_accuracy 0.865 test_loss: 0.353 test_accuracy 0.866
  34. epoch: 984 train_loss: 0.332 train_accuracy 0.865 test_loss: 0.35 test_accuracy 0.872
  35. epoch: 985 train_loss: 0.333 train_accuracy 0.867 test_loss: 0.353 test_accuracy 0.86
  36. epoch: 986 train_loss: 0.333 train_accuracy 0.871 test_loss: 0.345 test_accuracy 0.866
  37. epoch: 987 train_loss: 0.331 train_accuracy 0.865 test_loss: 0.349 test_accuracy 0.872
  38. epoch: 988 train_loss: 0.332 train_accuracy 0.869 test_loss: 0.345 test_accuracy 0.872
  39. epoch: 989 train_loss: 0.332 train_accuracy 0.865 test_loss: 0.353 test_accuracy 0.866
  40. epoch: 990 train_loss: 0.331 train_accuracy 0.865 test_loss: 0.348 test_accuracy 0.872
  41. epoch: 991 train_loss: 0.333 train_accuracy 0.875 test_loss: 0.344 test_accuracy 0.86
  42. epoch: 992 train_loss: 0.332 train_accuracy 0.865 test_loss: 0.351 test_accuracy 0.866
  43. epoch: 993 train_loss: 0.331 train_accuracy 0.869 test_loss: 0.348 test_accuracy 0.872
  44. epoch: 994 train_loss: 0.331 train_accuracy 0.871 test_loss: 0.348 test_accuracy 0.872
  45. epoch: 995 train_loss: 0.331 train_accuracy 0.865 test_loss: 0.347 test_accuracy 0.872
  46. epoch: 996 train_loss: 0.331 train_accuracy 0.865 test_loss: 0.347 test_accuracy 0.872
  47. epoch: 997 train_loss: 0.331 train_accuracy 0.867 test_loss: 0.35 test_accuracy 0.872
  48. epoch: 998 train_loss: 0.331 train_accuracy 0.867 test_loss: 0.349 test_accuracy 0.872
  49. epoch: 999 train_loss: 0.331 train_accuracy 0.865 test_loss: 0.348 test_accuracy 0.872

bb1e0726dc7943cea90943a70ff470cd.png

 4 完整数据集及代码下载

完整代码及数据集:代码和数据集

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Monodyee/article/detail/709704
推荐阅读
相关标签
  

闽ICP备14008679号