赞
踩
逻辑斯谛回归模型其实是一种分类模型,这里实现的是参考李航的《统计机器学习》以及周志华的《机器学习》两本教材来整理实现的。
假定我们的输入为 x x x, x x x 可以是多个维度的,我们想要根据 x x x 去预测 y y y, y ∈ { 0 , 1 } y\in \{0,1\} y∈{0,1}。逻辑斯谛的模型如下:
p ( Y = 1 ∣ x ) = e x p ( w ⋅ x ) 1 + e x p ( w ⋅ x ) (1) p(Y=1|x)=\frac{exp(w\cdot x)}{1+exp(w\cdot x)}\tag{1} p(Y=1∣x)=1+exp(w⋅x)exp(w⋅x)(1)
其中的参数 w w w就是我们要进行学习的,注意:它是包含了权重系数和偏置(bias)b的。在书写程序时,这样表示更加简洁。
参数 w w w是我们需要学习的,我们采用极大似然法估计模型参数。
设:
P ( Y = 1 ∣ x ) = π ( x ) , P ( Y = 0 ∣ x ) = 1 − π ( x ) (2) P(Y=1|x)=\pi(x),\quad P(Y=0|x)=1-\pi(x)\tag{2} P(Y=1∣x)=π(x),P(Y=0∣x)=1−π(x)(2)
似然函数为:
∏ i = 1 N [ π ( x i ) ] y i [ 1 − π ( x i ) ] 1 − y i (3) \prod_{i=1}^N[\pi(x_i)]^{y_i}[1-\pi(x_i)]^{1-y_i} \tag{3} i=1∏N[π(xi)]yi[1−π(xi)]1−yi(3)
因为这种指数的形式不利于求导我们需要将它们转化为对数的形式,如下:
L
(
w
)
=
∑
i
=
1
N
[
y
i
l
o
g
π
(
x
i
)
+
(
1
−
y
i
)
l
o
g
(
1
−
π
(
x
i
)
)
]
=
∑
i
=
1
N
[
y
i
l
o
g
(
π
(
x
i
)
1
−
π
(
x
i
)
)
+
l
o
g
(
1
−
π
(
x
i
)
)
]
=
∑
i
=
1
N
[
y
i
(
w
⋅
x
i
)
−
l
o
g
(
1
+
e
x
p
(
w
⋅
x
i
)
)
]
(4)
对 L ( w ) L(w) L(w)求极大值,得到 w w w的估计值。
梯度下降法是求极小值的,而我们想要得到的是 L ( w ) L(w) L(w)的最大值,因此,我们取 L ( w ) L(w) L(w)的相反数,即:
arg min w − L ( w ) (5) \argmin_{w}-L(w) \tag{5} wargmin−L(w)(5)
对 L ( w ) L(w) L(w)关于 w w w求导,如下:
(
−
L
(
w
)
)
′
=
−
∑
i
=
1
N
[
(
y
i
⋅
x
i
)
−
e
x
p
(
w
⋅
x
i
)
1
+
e
x
p
(
w
⋅
x
)
⋅
x
i
]
=
−
∑
i
=
1
N
[
(
y
i
−
e
x
p
(
w
⋅
x
i
)
1
+
e
x
p
(
w
⋅
x
)
)
⋅
x
i
]
=
∑
i
=
1
N
[
(
e
x
p
(
w
⋅
x
i
)
1
+
e
x
p
(
w
⋅
x
)
−
y
i
)
⋅
x
i
]
(6)
然后我们就得到了参数 w w w的更新公式,如下:
w
′
=
w
−
l
r
⋅
(
−
L
(
w
)
′
)
=
w
−
l
r
⋅
(
∑
i
=
1
N
[
(
e
x
p
(
w
⋅
x
i
)
1
+
e
x
p
(
w
⋅
x
)
−
y
i
)
⋅
x
i
]
)
(7)
关于优化方法的选择,最开始是选择西瓜书上提供的牛顿法来实现的,牛顿法的好处是,可以获得较快的收敛速度,但是坏处是,当海森矩阵为奇异矩阵时,会出现无法求解的情况。
因此,可以采用拟牛顿法进行优化,在解决这个问题的同时,也可以很快的收敛。
但是,自己对拟牛顿法并不熟悉,而梯度下降法虽然收敛可能较慢,但是实现起来较为简单,因此这里采用了梯度下降法来优化似然函数。
package weka.classifiers.myf; import weka.classifiers.Classifier; import weka.core.Instance; import weka.core.Instances; import weka.core.matrix.Matrix; import weka.filters.Filter; import weka.filters.unsupervised.attribute.NominalToBinary; import weka.filters.unsupervised.attribute.Standardize; import java.util.Arrays; /** * @author YFMan * @Description 自定义的 Logistic 回归分类器 * @Date 2023/6/13 11:02 */ public class myLogistic extends Classifier { // 用于存储 线性回归 系数 的数组 private double[] m_Coefficients; // 类别索引 private int m_ClassIndex; // 牛顿法的迭代次数 private int m_MaxIterations = 1000; // 属性数量 private int m_numAttributes; // 系数数量 private int m_numCoefficients; // 梯度下降步长 private double m_lr = 1e-4; // 标准化数据的过滤器 public static final int FILTER_STANDARDIZE = 1; // 用于标准化数据的过滤器 protected Filter m_StandardizeFilter = null; // 用于将 normal 转为 binary 的过滤器 protected Filter m_NormalToBinaryFilter = null; /* * @Author YFMan * @Description 采用牛顿法来训练 logistic 回归模型 * @Date 2023/5/9 22:08 * @Param [data] 训练数据 * @return void **/ public void buildClassifier(Instances data) throws Exception { // 设置类别索引 m_ClassIndex = data.classIndex(); // 设置属性数量 m_numAttributes = data.numAttributes(); // 系数数量 = 输入属性数量 + 1(截距参数b) m_numCoefficients = m_numAttributes; // 初始化 系数数组 m_Coefficients = new double[m_numCoefficients]; Arrays.fill(m_Coefficients, 0); // 将输入数据进行标准化 m_StandardizeFilter = new Standardize(); m_StandardizeFilter.setInputFormat(data); data = Filter.useFilter(data, m_StandardizeFilter); // 将类别属性转为二值属性 m_NormalToBinaryFilter = new NominalToBinary(); m_NormalToBinaryFilter.setInputFormat(data); data = Filter.useFilter(data, m_NormalToBinaryFilter); // 梯度下降法 for(int curPerformIteration = 0; curPerformIteration < m_MaxIterations;curPerformIteration++){ double[] deltaM_Coefficients = new double[m_numCoefficients]; // 计算 l(w) 的一阶导数 for(int i = 0;i<data.numInstances();i++){ double yi = data.instance(i).value(m_ClassIndex); double wxi = 0; int column = 0; for(int j=0;j<m_numAttributes;j++){ if(j!=m_ClassIndex){ wxi += m_Coefficients[column] * data.instance(i).value(j); column++; } } // 加上截距参数 b wxi += m_Coefficients[column]; double pi1 = Math.exp(wxi) / (1 + Math.exp(wxi)); for(int k=0;k<m_numCoefficients - 1;k++){ deltaM_Coefficients[k] += m_lr * (pi1 - yi) * data.instance(i).value(k); } // 这里计算 bias b 对应的更新量 deltaM_Coefficients[m_numCoefficients - 1] += m_lr * (pi1 - yi); } // 进行参数更新 for(int k=0;k<m_numCoefficients;k++){ m_Coefficients[k] -= deltaM_Coefficients[k]; } // 如果参数更新量小于阈值,则停止迭代 double delta = 0; for(int k=0;k<m_numCoefficients;k++){ delta += deltaM_Coefficients[k] * deltaM_Coefficients[k]; } if(delta < 1e-6){ break; } } } /* * @Author YFMan * @Description // 分类实例 * @Date 2023/6/16 11:17 * @Param [instance] * @return double[] **/ public double[] distributionForInstance(Instance instance) throws Exception { // 将输入数据进行标准化 m_StandardizeFilter.input(instance); instance = m_StandardizeFilter.output(); // 将输入属性二值化 m_NormalToBinaryFilter.input(instance); instance = m_NormalToBinaryFilter.output(); double[] result = new double[2]; result[0] = 0; result[1] = 0; int column = 0; for(int i=0;i<m_numAttributes;i++){ if(m_ClassIndex != i){ result[0] += instance.value(i) * m_Coefficients[column]; column++; } } result[0] += m_Coefficients[column]; result[0] = 1 / (1 + Math.exp(result[0])); result[1] = 1 - result[0]; return result; } /* * @Author YFMan * @Description 主函数 生成一个线性回归函数预测器 * @Date 2023/5/9 22:35 * @Param [argv] * @return void **/ public static void main(String[] argv) { runClassifier(new myLogistic(), argv); } }
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。