赞
踩
最近使用到了skearn做数据集的处理,记录一下这两个个比较重要的函数:KFold与StratifiedKFold,作用是在机器学习中进行交叉验证来使用
这两个函数都是sklearn模块中的,在应用之前应该导入:
from sklearn.model_selection import StratifiedKFold,KFold
两者的区别:StratifiedKFold函数采用分层划分的方法(分层随机抽样思想),验证集中不同类别占比与原始样本的比例保持一致,故StratifiedKFold在做划分的时候需要传入标签特征。
下面分别对这两个函数进行举例子说明:
1、KFold函数
参数说明:
n_splits: 默认为3,表示将数据划分为多少份,即k折交叉验证中的k;
shuffle: 默认为False,表示是否需要打乱顺序,这个参数在很多的函数中都会涉及,如果设置为True,则会先打乱顺序再做划分,如果为False,会直接按照顺序做划分;
random_state: 默认为None,表示随机数的种子,只有当shuffle设置为True的时候才会生效。
代码:
import numpy as np from sklearn.model_selection import KFold,StratifiedKFold X = np.array([[1, 2], [3, 4], [1, 2], [3, 4],[5,9],[1,5],[3,9],[5,8],[1,1],[1,4]]) y = np.array([0, 1, 1, 1, 0, 0, 1, 0, 0, 0]) print('X:',X) print('y:',y) seed = 7 np.random.seed(seed) kf = KFold(n_splits=3, shuffle=False) print(kf) #做split时只需传入数据,不需要传入标签 for train_index, test_index in kf.split(X): print("TRAIN:", train_index, "TEST:", test_index) X_train, X_test = X[train_index], X[test_index] y_train, y_test = y[train_index], y[test_index]
结果输出:
X: [[1 2] [3 4] [1 2] [3 4] [5 9] [1 5] [3 9] [5 8] [1 1] [1 4]] y: [0 1 1 1 0 0 1 0 0 0] KFold(n_splits=3, random_state=None, shuffle=False) TRAIN: [4 5 6 7 8 9] TEST: [0 1 2 3] TRAIN: [0 1 2 3 7 8 9] TEST: [4 5 6] TRAIN: [0 1 2 3 4 5 6] TEST: [7 8 9]
输出说明:
大家注意到,输出中每个Train和Test都对应三个结果,是因为我们在调用函数是,参数n_splits=3,即交叉验证三次。其中的数字只是对应索引,并不是真正的数据,比如第一行TEST: [0 1 2 3]代表着:测试集选取了X[0,1,2,3]即对应:
[1 2]
[3 4]
[1 2]
[3 4]
其他同理。
1、StratifiedKFold函数
StratifiedKFold函数的参数与KFold相同。
import numpy as np from sklearn.model_selection import KFold,StratifiedKFold X = np.array([[1, 2], [3, 4], [1, 2], [3, 4],[5,9],[1,5],[3,9],[5,8],[1,1],[1,4]]) y = np.array([0, 1, 1, 1, 0, 0, 1, 0, 0, 0]) print('X:',X) print('y:',y) skf = StratifiedKFold(n_splits=4) print(skf) #做划分是需要同时传入数据集和标签 for train_index, test_index in skf.split(X, y): print('TRAIN:', train_index, "TEST:", test_index) X_train, X_test = X[train_index], X[test_index] y_train, y_test = y[train_index], y[test_index]
结果输出:
X: [[1 2] [3 4] [1 2] [3 4] [5 9] [1 5] [3 9] [5 8] [1 1] [1 4]] y: [0 1 1 1 0 0 1 0 0 0] StratifiedKFold(n_splits=4, random_state=None, shuffle=False) TRAIN: [2 3 5 6 7 8 9] TEST: [0 1 4] TRAIN: [0 1 3 4 6 8 9] TEST: [2 5 7] TRAIN: [0 1 2 4 5 6 7 9] TEST: [3 8] TRAIN: [0 1 2 3 4 5 7 8] TEST: [6 9]
参考文章:https://zhuanlan.zhihu.com/p/150446294
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。